diff --git a/reflex/app.py b/reflex/app.py index 08cb4314e..10af41e99 100644 --- a/reflex/app.py +++ b/reflex/app.py @@ -1108,7 +1108,9 @@ class App(MiddlewareMixin, LifespanMixin): # When the state is modified reset dirty status and emit the delta to the frontend. state._clean() await self.event_namespace.emit_update( - update=StateUpdate(delta=delta), + update=StateUpdate( + delta=delta, scopes=state.scopes_and_subscopes() + ), sid=state.router.session.session_id, ) @@ -1273,16 +1275,23 @@ async def process( from reflex.utils import telemetry try: - # Add request data to the state. - router_data = event.router_data - router_data.update( - { - constants.RouteVar.QUERY: format.format_query_params(event.router_data), - constants.RouteVar.CLIENT_TOKEN: event.token, - constants.RouteVar.SESSION_ID: sid, - constants.RouteVar.HEADERS: headers, - constants.RouteVar.CLIENT_IP: client_ip, - } + router_data = {} + if event.router_data: + # Add request data to the state. + router_data = event.router_data + router_data.update( + { + constants.RouteVar.QUERY: format.format_query_params( + event.router_data + ), + constants.RouteVar.CLIENT_TOKEN: event.token, + constants.RouteVar.SESSION_ID: sid, + constants.RouteVar.HEADERS: headers, + constants.RouteVar.CLIENT_IP: client_ip, + } + ) + print( + f"Processing event: {event.name} with payload: {event.payload} {event.substate_token}" ) # Get the state for the session exclusively. async with app.state_manager.modify_state(event.substate_token) as state: @@ -1319,7 +1328,7 @@ async def process( else: if app._process_background(state, event) is not None: # `final=True` allows the frontend send more events immediately. - yield StateUpdate(final=True) + yield StateUpdate(final=True, scopes=state.scopes_and_subscopes()) return # Process the event synchronously. @@ -1542,13 +1551,74 @@ class EventNamespace(AsyncNamespace): if disconnect_token: self.token_to_sid.pop(disconnect_token, None) - async def emit_update(self, update: StateUpdate, sid: str) -> None: + async def emit_update( + self, update: StateUpdate, sid: str, room: str | None = None + ) -> None: """Emit an update to the client. Args: update: The state update to send. sid: The Socket.IO session id. + room: The room to send the update to. """ + # TODO We don't know when to leave a room yet. + for receiver in update.scopes: + if self.sid_to_token[sid] != receiver: + room = receiver + if room not in self.rooms(sid): + print(f"Entering room `{room}`") + await self.enter_room(sid, room) + + for room in self.rooms(sid): + if room not in update.scopes and room != sid: + print(f"Leaving room `{room}`") + await self.leave_room(sid, room) + + delta_by_scope = {} + events_by_scope = {} + + for event in update.events: + scope = self.token_to_sid.get(event.token, event.token) + events = events_by_scope.get(scope, []) + events.append(event) + events_by_scope[scope] = events + + for state, delta in update.delta.items(): + key = delta.get("_scope", sid) + d = delta_by_scope.get(key, {}) + d.update({state: delta}) + delta_by_scope[key] = d + + for scope, deltas in delta_by_scope.items(): + events = events_by_scope.get(scope, []) + print(f"Sending update to {scope} {events}") + single_update = StateUpdate( + delta=deltas, scopes=[scope], events=events, final=update.final + ) + + await asyncio.create_task( + self.emit( + str(constants.SocketEvent.EVENT), single_update.json(), to=scope + ) + ) + + for key in events_by_scope: + if key not in delta_by_scope: + single_update = StateUpdate( + delta={}, + scopes=[key], + events=events_by_scope.get(key, []), + final=update.final, + ) + print(f"Sending event to {key}") + await asyncio.create_task( + self.emit( + str(constants.SocketEvent.EVENT), single_update.json(), to=key + ) + ) + + update.scopes = [] + # Creating a task prevents the update from being blocked behind other coroutines. await asyncio.create_task( self.emit(str(constants.SocketEvent.EVENT), update, to=sid) @@ -1590,6 +1660,8 @@ class EventNamespace(AsyncNamespace): except (KeyError, IndexError): client_ip = environ.get("REMOTE_ADDR", "0.0.0.0") + print(f"Received event {event.name} {event.token} from {client_ip}") + # Process the events. async for update in process(self.app, event, sid, headers, client_ip): # Emit the update from processing the event. diff --git a/reflex/middleware/hydrate_middleware.py b/reflex/middleware/hydrate_middleware.py index 2198b82c2..2ad329e9d 100644 --- a/reflex/middleware/hydrate_middleware.py +++ b/reflex/middleware/hydrate_middleware.py @@ -47,4 +47,4 @@ class HydrateMiddleware(Middleware): state._clean() # Return the state update. - return StateUpdate(delta=delta, events=[]) + return StateUpdate(delta=delta, events=[], scopes=state.scopes_and_subscopes()) diff --git a/reflex/state.py b/reflex/state.py index a31aae032..a4df05fce 100644 --- a/reflex/state.py +++ b/reflex/state.py @@ -8,6 +8,8 @@ import copy import dataclasses import functools import inspect +import os +import traceback import json import pickle import sys @@ -135,6 +137,23 @@ if environment.REFLEX_PERF_MODE.get() != PerformanceMode.OFF: # Only warn about each state class size once. _WARNED_ABOUT_STATE_SIZE: Set[str] = set() + +def print_stack(depth: int = 3): + """Print the current stacktrace to the console. + + Args: + depth: Depth of the stack-trace to print + """ + stack = traceback.extract_stack() + stack.reverse() + print("stacktrace") + for idx in range(1, depth + 1): + stack_info = stack[idx] + print( + f" {stack_info.name} {os.path.basename(stack_info.filename)}:{stack_info.lineno}" + ) + + # Errors caught during pickling of state HANDLED_PICKLE_ERRORS = ( pickle.PicklingError, @@ -481,11 +500,14 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow): ) @classmethod - def __init_subclass__(cls, mixin: bool = False, **kwargs): + def __init_subclass__( + cls, mixin: bool = False, scope: Union[Var, str, None] = None, **kwargs + ): """Do some magic for the subclass initialization. Args: mixin: Whether the subclass is a mixin and should not be initialized. + scope: A var or string to set the scope of the state. The state will be shared across all states with the same scope value. **kwargs: The kwargs to pass to the pydantic init_subclass method. Raises: @@ -499,6 +521,8 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow): if mixin: return + # Set the scope of the state. + cls._scope = scope # Handle locally-defined states for pickling. if "" in cls.__qualname__: cls._handle_local_def() @@ -795,7 +819,7 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow): } # Any substate containing a ComputedVar with cache=False always needs to be recomputed - if cls._always_dirty_computed_vars: + if cls._always_dirty_computed_vars: # or cls._scope is not None: # Tell parent classes that this substate has always dirty computed vars state_name = cls.get_name() parent_state = cls.get_parent_state() @@ -1532,9 +1556,7 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow): # The requested state is missing, fetch from redis. pass parent_state = await state_manager.get_state( - token=_substate_key( - self.router.session.client_token, parent_state_name - ), + token=_substate_key(self._get_token(), parent_state_name), top_level=False, get_substates=False, parent_state=parent_state, @@ -1577,8 +1599,9 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow): f"Requested state {state_cls.get_full_name()} is not cached and cannot be accessed without redis. " "(All states should already be available -- this is likely a bug).", ) + return await state_manager.get_state( - token=_substate_key(self.router.session.client_token, state_cls), + token=_substate_key(self._get_token(), state_cls), top_level=False, get_substates=True, parent_state=parent_state_of_state_cls, @@ -1720,6 +1743,34 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow): f"Your handler {handler.fn.__qualname__} must only return/yield: None, Events or other EventHandlers referenced by their class (not using `self`)" ) + def scopes_and_subscopes(self) -> list[str]: + """Recursively gathers all scopes of self and substates. + + Returns: + A unique list of the scopes/token + """ + result = [self._get_token()] + for substate in self.substates.values(): + subscopes = substate.scopes_and_subscopes() + for subscope in subscopes: + if subscope not in result: + result.append(subscope) + return result + + def _get_token(self, other: type[BaseState] | None = None) -> str: + token = self.router.session.client_token + cls = other or self.__class__ + if cls._scope is not None: + scope = None + if isinstance(cls._scope, str): + scope = f"static{cls._scope}" + else: + scope = f"shared{getattr(self, cls._scope._var_name)}" + + token = scope + + return token + def _as_state_update( self, handler: EventHandler, @@ -1741,7 +1792,7 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow): # get the delta from the root of the state tree state = self._get_root_state() - token = self.router.session.client_token + token = self._get_token() # Convert valid EventHandler and EventSpec into Event fixed_events = fix_events(self._check_valid(handler, events), token) @@ -1755,6 +1806,7 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow): delta=delta, events=fixed_events, final=final if not handler.is_background else True, + scopes=state.scopes_and_subscopes(), ) except Exception as ex: state._clean() @@ -1970,6 +2022,9 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow): .union(self._always_dirty_computed_vars) ) + if len(self.scopes_and_subscopes()) > 1 and "router" in delta_vars: + delta_vars.remove("router") + subdelta: Dict[str, Any] = { prop: self.get_value(prop) for prop in delta_vars @@ -1979,6 +2034,9 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow): if len(subdelta) > 0: delta[self.get_full_name()] = subdelta + if self.__class__._scope is not None: + subdelta["_scope"] = self._get_token() + # Recursively find the substate deltas. substates = self.substates for substate in self.dirty_substates.union(self._always_dirty_substates): @@ -2114,7 +2172,7 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow): } else: computed_vars = {} - variables = {**base_vars, **computed_vars} + variables = {"_scope": self._get_token(), **base_vars, **computed_vars} d = { self.get_full_name(): {k: variables[k] for k in sorted(variables)}, } @@ -2431,7 +2489,7 @@ class OnLoadInternalState(State): return [ *fix_events( load_events, - self.router.session.client_token, + self._get_token(), router_data=self.router_data, ), State.set_is_hydrated(True), # type: ignore @@ -2833,6 +2891,8 @@ class StateUpdate: # Whether this is the final state update for the event. final: bool = True + scopes: list[str] = [] + def json(self) -> str: """Convert the state update to a JSON string. @@ -2948,10 +3008,15 @@ class StateManagerMemory(StateManager): Returns: The state for the token. """ - # Memory state manager ignores the substate suffix and always returns the top-level state. - token = _split_substate_key(token)[0] if token not in self.states: self.states[token] = self.state(_reflex_internal_init=True) + + # TODO This is a bit madness maybe. + token = self.states[token]._get_token() + + if token not in self.states: + self.states[token] = self.state(_reflex_internal_init=True) + return self.states[token] @override @@ -2975,8 +3040,6 @@ class StateManagerMemory(StateManager): Yields: The state for the token. """ - # Memory state manager ignores the substate suffix and always returns the top-level state. - token = _split_substate_key(token)[0] if token not in self._states_locks: async with self._state_manager_lock: if token not in self._states_locks: @@ -3394,6 +3457,10 @@ class StateManagerRedis(StateManager): f"StateManagerRedis requires token to be specified in the form of {{token}}_{{state_full_name}}, but got {token}" ) + if parent_state is None: + parent_state = await self._get_parent_state(token) + if parent_state is not None: + token = f"{parent_state._get_token(state_cls)}_{state_path}" # The deserialized or newly created (sub)state instance. state = None @@ -3472,6 +3539,8 @@ class StateManagerRedis(StateManager): ) client_token, substate_name = _split_substate_key(token) + client_token = state._get_token() + # If the substate name on the token doesn't match the instance name, it cannot have a parent. if state.parent_state is not None and state.get_full_name() != substate_name: raise RuntimeError(