feat: PoC for shared states.
This commit is contained in:
parent
93231f8168
commit
3bcdc12c26
@ -1460,16 +1460,19 @@ class EventNamespace(AsyncNamespace):
|
|||||||
if disconnect_token:
|
if disconnect_token:
|
||||||
self.token_to_sid.pop(disconnect_token, None)
|
self.token_to_sid.pop(disconnect_token, None)
|
||||||
|
|
||||||
async def emit_update(self, update: StateUpdate, sid: str) -> None:
|
async def emit_update(
|
||||||
|
self, update: StateUpdate, sid: str, room: str | None = None
|
||||||
|
) -> None:
|
||||||
"""Emit an update to the client.
|
"""Emit an update to the client.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
update: The state update to send.
|
update: The state update to send.
|
||||||
sid: The Socket.IO session id.
|
sid: The Socket.IO session id.
|
||||||
|
room: The room to send the update to.
|
||||||
"""
|
"""
|
||||||
# Creating a task prevents the update from being blocked behind other coroutines.
|
# Creating a task prevents the update from being blocked behind other coroutines.
|
||||||
await asyncio.create_task(
|
await asyncio.create_task(
|
||||||
self.emit(str(constants.SocketEvent.EVENT), update.json(), to=sid)
|
self.emit(str(constants.SocketEvent.EVENT), update.json(), to=room or sid)
|
||||||
)
|
)
|
||||||
|
|
||||||
async def on_event(self, sid, data):
|
async def on_event(self, sid, data):
|
||||||
|
@ -442,11 +442,14 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
|
|||||||
)
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def __init_subclass__(cls, mixin: bool = False, **kwargs):
|
def __init_subclass__(
|
||||||
|
cls, mixin: bool = False, scope: Union[Var, str, None] = None, **kwargs
|
||||||
|
):
|
||||||
"""Do some magic for the subclass initialization.
|
"""Do some magic for the subclass initialization.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
mixin: Whether the subclass is a mixin and should not be initialized.
|
mixin: Whether the subclass is a mixin and should not be initialized.
|
||||||
|
scope: A var or string to set the scope of the state. The state will be shared across all states with the same scope value.
|
||||||
**kwargs: The kwargs to pass to the pydantic init_subclass method.
|
**kwargs: The kwargs to pass to the pydantic init_subclass method.
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
@ -460,6 +463,9 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
|
|||||||
if mixin:
|
if mixin:
|
||||||
return
|
return
|
||||||
|
|
||||||
|
# Set the scope of the state.
|
||||||
|
cls._scope = scope
|
||||||
|
|
||||||
# Validate the module name.
|
# Validate the module name.
|
||||||
cls._validate_module_name()
|
cls._validate_module_name()
|
||||||
|
|
||||||
@ -1270,9 +1276,7 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
|
|||||||
# The requested state is missing, fetch from redis.
|
# The requested state is missing, fetch from redis.
|
||||||
pass
|
pass
|
||||||
parent_state = await state_manager.get_state(
|
parent_state = await state_manager.get_state(
|
||||||
token=_substate_key(
|
token=_substate_key(self._get_token(), parent_state_name),
|
||||||
self.router.session.client_token, parent_state_name
|
|
||||||
),
|
|
||||||
top_level=False,
|
top_level=False,
|
||||||
get_substates=False,
|
get_substates=False,
|
||||||
parent_state=parent_state,
|
parent_state=parent_state,
|
||||||
@ -1318,8 +1322,9 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
|
|||||||
f"Requested state {state_cls.get_full_name()} is not cached and cannot be accessed without redis. "
|
f"Requested state {state_cls.get_full_name()} is not cached and cannot be accessed without redis. "
|
||||||
"(All states should already be available -- this is likely a bug).",
|
"(All states should already be available -- this is likely a bug).",
|
||||||
)
|
)
|
||||||
|
|
||||||
return await state_manager.get_state(
|
return await state_manager.get_state(
|
||||||
token=_substate_key(self.router.session.client_token, state_cls),
|
token=_substate_key(self._get_token(), state_cls),
|
||||||
top_level=False,
|
top_level=False,
|
||||||
get_substates=True,
|
get_substates=True,
|
||||||
parent_state=parent_state_of_state_cls,
|
parent_state=parent_state_of_state_cls,
|
||||||
@ -1425,6 +1430,19 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
|
|||||||
f"Your handler {handler.fn.__qualname__} must only return/yield: None, Events or other EventHandlers referenced by their class (not using `self`)"
|
f"Your handler {handler.fn.__qualname__} must only return/yield: None, Events or other EventHandlers referenced by their class (not using `self`)"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def _get_token(self) -> str:
|
||||||
|
token = self.router.session.client_token
|
||||||
|
if self.__class__._scope is not None:
|
||||||
|
scope = None
|
||||||
|
if isinstance(self.__class__._scope, str):
|
||||||
|
scope = self.__class__._scope
|
||||||
|
else:
|
||||||
|
scope = getattr(self, self.__class__._scope._var_name)
|
||||||
|
|
||||||
|
token = scope
|
||||||
|
|
||||||
|
return token
|
||||||
|
|
||||||
def _as_state_update(
|
def _as_state_update(
|
||||||
self,
|
self,
|
||||||
handler: EventHandler,
|
handler: EventHandler,
|
||||||
@ -1448,7 +1466,7 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
|
|||||||
while state.parent_state is not None:
|
while state.parent_state is not None:
|
||||||
state = state.parent_state
|
state = state.parent_state
|
||||||
|
|
||||||
token = self.router.session.client_token
|
token = self._get_token()
|
||||||
|
|
||||||
# Convert valid EventHandler and EventSpec into Event
|
# Convert valid EventHandler and EventSpec into Event
|
||||||
fixed_events = fix_events(self._check_valid(handler, events), token)
|
fixed_events = fix_events(self._check_valid(handler, events), token)
|
||||||
@ -1909,7 +1927,7 @@ class OnLoadInternalState(State):
|
|||||||
return [
|
return [
|
||||||
*fix_events(
|
*fix_events(
|
||||||
load_events,
|
load_events,
|
||||||
self.router.session.client_token,
|
self._get_token(),
|
||||||
router_data=self.router_data,
|
router_data=self.router_data,
|
||||||
),
|
),
|
||||||
State.set_is_hydrated(True), # type: ignore
|
State.set_is_hydrated(True), # type: ignore
|
||||||
@ -2328,10 +2346,15 @@ class StateManagerMemory(StateManager):
|
|||||||
Returns:
|
Returns:
|
||||||
The state for the token.
|
The state for the token.
|
||||||
"""
|
"""
|
||||||
# Memory state manager ignores the substate suffix and always returns the top-level state.
|
|
||||||
token = _split_substate_key(token)[0]
|
|
||||||
if token not in self.states:
|
if token not in self.states:
|
||||||
self.states[token] = self.state(_reflex_internal_init=True)
|
self.states[token] = self.state(_reflex_internal_init=True)
|
||||||
|
|
||||||
|
# TODO This is a bit madness maybe.
|
||||||
|
token = self.states[token]._get_token()
|
||||||
|
|
||||||
|
if token not in self.states:
|
||||||
|
self.states[token] = self.state(_reflex_internal_init=True)
|
||||||
|
|
||||||
return self.states[token]
|
return self.states[token]
|
||||||
|
|
||||||
async def set_state(self, token: str, state: BaseState):
|
async def set_state(self, token: str, state: BaseState):
|
||||||
@ -2353,8 +2376,6 @@ class StateManagerMemory(StateManager):
|
|||||||
Yields:
|
Yields:
|
||||||
The state for the token.
|
The state for the token.
|
||||||
"""
|
"""
|
||||||
# Memory state manager ignores the substate suffix and always returns the top-level state.
|
|
||||||
token = _split_substate_key(token)[0]
|
|
||||||
if token not in self._states_locks:
|
if token not in self._states_locks:
|
||||||
async with self._state_manager_lock:
|
async with self._state_manager_lock:
|
||||||
if token not in self._states_locks:
|
if token not in self._states_locks:
|
||||||
|
@ -383,6 +383,8 @@ class Var:
|
|||||||
# Extra metadata associated with the Var
|
# Extra metadata associated with the Var
|
||||||
_var_data: Optional[VarData]
|
_var_data: Optional[VarData]
|
||||||
|
|
||||||
|
_var_is_scope: bool = False
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def create(
|
def create(
|
||||||
cls,
|
cls,
|
||||||
@ -390,6 +392,7 @@ class Var:
|
|||||||
_var_is_local: bool = True,
|
_var_is_local: bool = True,
|
||||||
_var_is_string: bool | None = None,
|
_var_is_string: bool | None = None,
|
||||||
_var_data: Optional[VarData] = None,
|
_var_data: Optional[VarData] = None,
|
||||||
|
_var_is_scope: bool = False,
|
||||||
) -> Var | None:
|
) -> Var | None:
|
||||||
"""Create a var from a value.
|
"""Create a var from a value.
|
||||||
|
|
||||||
@ -455,6 +458,7 @@ class Var:
|
|||||||
_var_is_local=_var_is_local,
|
_var_is_local=_var_is_local,
|
||||||
_var_is_string=_var_is_string if _var_is_string is not None else False,
|
_var_is_string=_var_is_string if _var_is_string is not None else False,
|
||||||
_var_data=_var_data,
|
_var_data=_var_data,
|
||||||
|
_var_is_scope=_var_is_scope,
|
||||||
)
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@ -1866,6 +1870,8 @@ class BaseVar(Var):
|
|||||||
# Extra metadata associated with the Var
|
# Extra metadata associated with the Var
|
||||||
_var_data: Optional[VarData] = dataclasses.field(default=None)
|
_var_data: Optional[VarData] = dataclasses.field(default=None)
|
||||||
|
|
||||||
|
_var_is_scope: bool = dataclasses.field(default=False)
|
||||||
|
|
||||||
def __hash__(self) -> int:
|
def __hash__(self) -> int:
|
||||||
"""Define a hash function for a var.
|
"""Define a hash function for a var.
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user