diff --git a/reflex/app.py b/reflex/app.py index 658ba1a1f..559861c24 100644 --- a/reflex/app.py +++ b/reflex/app.py @@ -1460,16 +1460,19 @@ class EventNamespace(AsyncNamespace): if disconnect_token: self.token_to_sid.pop(disconnect_token, None) - async def emit_update(self, update: StateUpdate, sid: str) -> None: + async def emit_update( + self, update: StateUpdate, sid: str, room: str | None = None + ) -> None: """Emit an update to the client. Args: update: The state update to send. sid: The Socket.IO session id. + room: The room to send the update to. """ # Creating a task prevents the update from being blocked behind other coroutines. await asyncio.create_task( - self.emit(str(constants.SocketEvent.EVENT), update.json(), to=sid) + self.emit(str(constants.SocketEvent.EVENT), update.json(), to=room or sid) ) async def on_event(self, sid, data): diff --git a/reflex/state.py b/reflex/state.py index 49b5bd4a4..d2c5fa567 100644 --- a/reflex/state.py +++ b/reflex/state.py @@ -442,11 +442,14 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow): ) @classmethod - def __init_subclass__(cls, mixin: bool = False, **kwargs): + def __init_subclass__( + cls, mixin: bool = False, scope: Union[Var, str, None] = None, **kwargs + ): """Do some magic for the subclass initialization. Args: mixin: Whether the subclass is a mixin and should not be initialized. + scope: A var or string to set the scope of the state. The state will be shared across all states with the same scope value. **kwargs: The kwargs to pass to the pydantic init_subclass method. Raises: @@ -460,6 +463,9 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow): if mixin: return + # Set the scope of the state. + cls._scope = scope + # Validate the module name. cls._validate_module_name() @@ -1270,9 +1276,7 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow): # The requested state is missing, fetch from redis. pass parent_state = await state_manager.get_state( - token=_substate_key( - self.router.session.client_token, parent_state_name - ), + token=_substate_key(self._get_token(), parent_state_name), top_level=False, get_substates=False, parent_state=parent_state, @@ -1318,8 +1322,9 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow): f"Requested state {state_cls.get_full_name()} is not cached and cannot be accessed without redis. " "(All states should already be available -- this is likely a bug).", ) + return await state_manager.get_state( - token=_substate_key(self.router.session.client_token, state_cls), + token=_substate_key(self._get_token(), state_cls), top_level=False, get_substates=True, parent_state=parent_state_of_state_cls, @@ -1425,6 +1430,19 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow): f"Your handler {handler.fn.__qualname__} must only return/yield: None, Events or other EventHandlers referenced by their class (not using `self`)" ) + def _get_token(self) -> str: + token = self.router.session.client_token + if self.__class__._scope is not None: + scope = None + if isinstance(self.__class__._scope, str): + scope = self.__class__._scope + else: + scope = getattr(self, self.__class__._scope._var_name) + + token = scope + + return token + def _as_state_update( self, handler: EventHandler, @@ -1448,7 +1466,7 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow): while state.parent_state is not None: state = state.parent_state - token = self.router.session.client_token + token = self._get_token() # Convert valid EventHandler and EventSpec into Event fixed_events = fix_events(self._check_valid(handler, events), token) @@ -1909,7 +1927,7 @@ class OnLoadInternalState(State): return [ *fix_events( load_events, - self.router.session.client_token, + self._get_token(), router_data=self.router_data, ), State.set_is_hydrated(True), # type: ignore @@ -2328,10 +2346,15 @@ class StateManagerMemory(StateManager): Returns: The state for the token. """ - # Memory state manager ignores the substate suffix and always returns the top-level state. - token = _split_substate_key(token)[0] if token not in self.states: self.states[token] = self.state(_reflex_internal_init=True) + + # TODO This is a bit madness maybe. + token = self.states[token]._get_token() + + if token not in self.states: + self.states[token] = self.state(_reflex_internal_init=True) + return self.states[token] async def set_state(self, token: str, state: BaseState): @@ -2353,8 +2376,6 @@ class StateManagerMemory(StateManager): Yields: The state for the token. """ - # Memory state manager ignores the substate suffix and always returns the top-level state. - token = _split_substate_key(token)[0] if token not in self._states_locks: async with self._state_manager_lock: if token not in self._states_locks: diff --git a/reflex/vars.py b/reflex/vars.py index 8d93f99c0..79b992cc4 100644 --- a/reflex/vars.py +++ b/reflex/vars.py @@ -383,6 +383,8 @@ class Var: # Extra metadata associated with the Var _var_data: Optional[VarData] + _var_is_scope: bool = False + @classmethod def create( cls, @@ -390,6 +392,7 @@ class Var: _var_is_local: bool = True, _var_is_string: bool | None = None, _var_data: Optional[VarData] = None, + _var_is_scope: bool = False, ) -> Var | None: """Create a var from a value. @@ -455,6 +458,7 @@ class Var: _var_is_local=_var_is_local, _var_is_string=_var_is_string if _var_is_string is not None else False, _var_data=_var_data, + _var_is_scope=_var_is_scope, ) @classmethod @@ -1866,6 +1870,8 @@ class BaseVar(Var): # Extra metadata associated with the Var _var_data: Optional[VarData] = dataclasses.field(default=None) + _var_is_scope: bool = dataclasses.field(default=False) + def __hash__(self) -> int: """Define a hash function for a var.