feat: PoC for shared states.

This commit is contained in:
Andreas Eismann 2024-07-16 22:51:11 +02:00
parent 93231f8168
commit 3bcdc12c26
No known key found for this signature in database
3 changed files with 43 additions and 13 deletions

View File

@ -1460,16 +1460,19 @@ class EventNamespace(AsyncNamespace):
if disconnect_token: if disconnect_token:
self.token_to_sid.pop(disconnect_token, None) self.token_to_sid.pop(disconnect_token, None)
async def emit_update(self, update: StateUpdate, sid: str) -> None: async def emit_update(
self, update: StateUpdate, sid: str, room: str | None = None
) -> None:
"""Emit an update to the client. """Emit an update to the client.
Args: Args:
update: The state update to send. update: The state update to send.
sid: The Socket.IO session id. sid: The Socket.IO session id.
room: The room to send the update to.
""" """
# Creating a task prevents the update from being blocked behind other coroutines. # Creating a task prevents the update from being blocked behind other coroutines.
await asyncio.create_task( await asyncio.create_task(
self.emit(str(constants.SocketEvent.EVENT), update.json(), to=sid) self.emit(str(constants.SocketEvent.EVENT), update.json(), to=room or sid)
) )
async def on_event(self, sid, data): async def on_event(self, sid, data):

View File

@ -442,11 +442,14 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
) )
@classmethod @classmethod
def __init_subclass__(cls, mixin: bool = False, **kwargs): def __init_subclass__(
cls, mixin: bool = False, scope: Union[Var, str, None] = None, **kwargs
):
"""Do some magic for the subclass initialization. """Do some magic for the subclass initialization.
Args: Args:
mixin: Whether the subclass is a mixin and should not be initialized. mixin: Whether the subclass is a mixin and should not be initialized.
scope: A var or string to set the scope of the state. The state will be shared across all states with the same scope value.
**kwargs: The kwargs to pass to the pydantic init_subclass method. **kwargs: The kwargs to pass to the pydantic init_subclass method.
Raises: Raises:
@ -460,6 +463,9 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
if mixin: if mixin:
return return
# Set the scope of the state.
cls._scope = scope
# Validate the module name. # Validate the module name.
cls._validate_module_name() cls._validate_module_name()
@ -1270,9 +1276,7 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
# The requested state is missing, fetch from redis. # The requested state is missing, fetch from redis.
pass pass
parent_state = await state_manager.get_state( parent_state = await state_manager.get_state(
token=_substate_key( token=_substate_key(self._get_token(), parent_state_name),
self.router.session.client_token, parent_state_name
),
top_level=False, top_level=False,
get_substates=False, get_substates=False,
parent_state=parent_state, parent_state=parent_state,
@ -1318,8 +1322,9 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
f"Requested state {state_cls.get_full_name()} is not cached and cannot be accessed without redis. " f"Requested state {state_cls.get_full_name()} is not cached and cannot be accessed without redis. "
"(All states should already be available -- this is likely a bug).", "(All states should already be available -- this is likely a bug).",
) )
return await state_manager.get_state( return await state_manager.get_state(
token=_substate_key(self.router.session.client_token, state_cls), token=_substate_key(self._get_token(), state_cls),
top_level=False, top_level=False,
get_substates=True, get_substates=True,
parent_state=parent_state_of_state_cls, parent_state=parent_state_of_state_cls,
@ -1425,6 +1430,19 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
f"Your handler {handler.fn.__qualname__} must only return/yield: None, Events or other EventHandlers referenced by their class (not using `self`)" f"Your handler {handler.fn.__qualname__} must only return/yield: None, Events or other EventHandlers referenced by their class (not using `self`)"
) )
def _get_token(self) -> str:
token = self.router.session.client_token
if self.__class__._scope is not None:
scope = None
if isinstance(self.__class__._scope, str):
scope = self.__class__._scope
else:
scope = getattr(self, self.__class__._scope._var_name)
token = scope
return token
def _as_state_update( def _as_state_update(
self, self,
handler: EventHandler, handler: EventHandler,
@ -1448,7 +1466,7 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
while state.parent_state is not None: while state.parent_state is not None:
state = state.parent_state state = state.parent_state
token = self.router.session.client_token token = self._get_token()
# Convert valid EventHandler and EventSpec into Event # Convert valid EventHandler and EventSpec into Event
fixed_events = fix_events(self._check_valid(handler, events), token) fixed_events = fix_events(self._check_valid(handler, events), token)
@ -1909,7 +1927,7 @@ class OnLoadInternalState(State):
return [ return [
*fix_events( *fix_events(
load_events, load_events,
self.router.session.client_token, self._get_token(),
router_data=self.router_data, router_data=self.router_data,
), ),
State.set_is_hydrated(True), # type: ignore State.set_is_hydrated(True), # type: ignore
@ -2328,10 +2346,15 @@ class StateManagerMemory(StateManager):
Returns: Returns:
The state for the token. The state for the token.
""" """
# Memory state manager ignores the substate suffix and always returns the top-level state.
token = _split_substate_key(token)[0]
if token not in self.states: if token not in self.states:
self.states[token] = self.state(_reflex_internal_init=True) self.states[token] = self.state(_reflex_internal_init=True)
# TODO This is a bit madness maybe.
token = self.states[token]._get_token()
if token not in self.states:
self.states[token] = self.state(_reflex_internal_init=True)
return self.states[token] return self.states[token]
async def set_state(self, token: str, state: BaseState): async def set_state(self, token: str, state: BaseState):
@ -2353,8 +2376,6 @@ class StateManagerMemory(StateManager):
Yields: Yields:
The state for the token. The state for the token.
""" """
# Memory state manager ignores the substate suffix and always returns the top-level state.
token = _split_substate_key(token)[0]
if token not in self._states_locks: if token not in self._states_locks:
async with self._state_manager_lock: async with self._state_manager_lock:
if token not in self._states_locks: if token not in self._states_locks:

View File

@ -383,6 +383,8 @@ class Var:
# Extra metadata associated with the Var # Extra metadata associated with the Var
_var_data: Optional[VarData] _var_data: Optional[VarData]
_var_is_scope: bool = False
@classmethod @classmethod
def create( def create(
cls, cls,
@ -390,6 +392,7 @@ class Var:
_var_is_local: bool = True, _var_is_local: bool = True,
_var_is_string: bool | None = None, _var_is_string: bool | None = None,
_var_data: Optional[VarData] = None, _var_data: Optional[VarData] = None,
_var_is_scope: bool = False,
) -> Var | None: ) -> Var | None:
"""Create a var from a value. """Create a var from a value.
@ -455,6 +458,7 @@ class Var:
_var_is_local=_var_is_local, _var_is_local=_var_is_local,
_var_is_string=_var_is_string if _var_is_string is not None else False, _var_is_string=_var_is_string if _var_is_string is not None else False,
_var_data=_var_data, _var_data=_var_data,
_var_is_scope=_var_is_scope,
) )
@classmethod @classmethod
@ -1866,6 +1870,8 @@ class BaseVar(Var):
# Extra metadata associated with the Var # Extra metadata associated with the Var
_var_data: Optional[VarData] = dataclasses.field(default=None) _var_data: Optional[VarData] = dataclasses.field(default=None)
_var_is_scope: bool = dataclasses.field(default=False)
def __hash__(self) -> int: def __hash__(self) -> int:
"""Define a hash function for a var. """Define a hash function for a var.