From 3bcdc12c26628d6990c5b7052f3a827fef68fa47 Mon Sep 17 00:00:00 2001 From: Andreas Eismann Date: Tue, 16 Jul 2024 22:51:11 +0200 Subject: [PATCH 1/6] feat: PoC for shared states. --- reflex/app.py | 7 +++++-- reflex/state.py | 43 ++++++++++++++++++++++++++++++++----------- reflex/vars.py | 6 ++++++ 3 files changed, 43 insertions(+), 13 deletions(-) diff --git a/reflex/app.py b/reflex/app.py index 658ba1a1f..559861c24 100644 --- a/reflex/app.py +++ b/reflex/app.py @@ -1460,16 +1460,19 @@ class EventNamespace(AsyncNamespace): if disconnect_token: self.token_to_sid.pop(disconnect_token, None) - async def emit_update(self, update: StateUpdate, sid: str) -> None: + async def emit_update( + self, update: StateUpdate, sid: str, room: str | None = None + ) -> None: """Emit an update to the client. Args: update: The state update to send. sid: The Socket.IO session id. + room: The room to send the update to. """ # Creating a task prevents the update from being blocked behind other coroutines. await asyncio.create_task( - self.emit(str(constants.SocketEvent.EVENT), update.json(), to=sid) + self.emit(str(constants.SocketEvent.EVENT), update.json(), to=room or sid) ) async def on_event(self, sid, data): diff --git a/reflex/state.py b/reflex/state.py index 49b5bd4a4..d2c5fa567 100644 --- a/reflex/state.py +++ b/reflex/state.py @@ -442,11 +442,14 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow): ) @classmethod - def __init_subclass__(cls, mixin: bool = False, **kwargs): + def __init_subclass__( + cls, mixin: bool = False, scope: Union[Var, str, None] = None, **kwargs + ): """Do some magic for the subclass initialization. Args: mixin: Whether the subclass is a mixin and should not be initialized. + scope: A var or string to set the scope of the state. The state will be shared across all states with the same scope value. **kwargs: The kwargs to pass to the pydantic init_subclass method. Raises: @@ -460,6 +463,9 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow): if mixin: return + # Set the scope of the state. + cls._scope = scope + # Validate the module name. cls._validate_module_name() @@ -1270,9 +1276,7 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow): # The requested state is missing, fetch from redis. pass parent_state = await state_manager.get_state( - token=_substate_key( - self.router.session.client_token, parent_state_name - ), + token=_substate_key(self._get_token(), parent_state_name), top_level=False, get_substates=False, parent_state=parent_state, @@ -1318,8 +1322,9 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow): f"Requested state {state_cls.get_full_name()} is not cached and cannot be accessed without redis. " "(All states should already be available -- this is likely a bug).", ) + return await state_manager.get_state( - token=_substate_key(self.router.session.client_token, state_cls), + token=_substate_key(self._get_token(), state_cls), top_level=False, get_substates=True, parent_state=parent_state_of_state_cls, @@ -1425,6 +1430,19 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow): f"Your handler {handler.fn.__qualname__} must only return/yield: None, Events or other EventHandlers referenced by their class (not using `self`)" ) + def _get_token(self) -> str: + token = self.router.session.client_token + if self.__class__._scope is not None: + scope = None + if isinstance(self.__class__._scope, str): + scope = self.__class__._scope + else: + scope = getattr(self, self.__class__._scope._var_name) + + token = scope + + return token + def _as_state_update( self, handler: EventHandler, @@ -1448,7 +1466,7 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow): while state.parent_state is not None: state = state.parent_state - token = self.router.session.client_token + token = self._get_token() # Convert valid EventHandler and EventSpec into Event fixed_events = fix_events(self._check_valid(handler, events), token) @@ -1909,7 +1927,7 @@ class OnLoadInternalState(State): return [ *fix_events( load_events, - self.router.session.client_token, + self._get_token(), router_data=self.router_data, ), State.set_is_hydrated(True), # type: ignore @@ -2328,10 +2346,15 @@ class StateManagerMemory(StateManager): Returns: The state for the token. """ - # Memory state manager ignores the substate suffix and always returns the top-level state. - token = _split_substate_key(token)[0] if token not in self.states: self.states[token] = self.state(_reflex_internal_init=True) + + # TODO This is a bit madness maybe. + token = self.states[token]._get_token() + + if token not in self.states: + self.states[token] = self.state(_reflex_internal_init=True) + return self.states[token] async def set_state(self, token: str, state: BaseState): @@ -2353,8 +2376,6 @@ class StateManagerMemory(StateManager): Yields: The state for the token. """ - # Memory state manager ignores the substate suffix and always returns the top-level state. - token = _split_substate_key(token)[0] if token not in self._states_locks: async with self._state_manager_lock: if token not in self._states_locks: diff --git a/reflex/vars.py b/reflex/vars.py index 8d93f99c0..79b992cc4 100644 --- a/reflex/vars.py +++ b/reflex/vars.py @@ -383,6 +383,8 @@ class Var: # Extra metadata associated with the Var _var_data: Optional[VarData] + _var_is_scope: bool = False + @classmethod def create( cls, @@ -390,6 +392,7 @@ class Var: _var_is_local: bool = True, _var_is_string: bool | None = None, _var_data: Optional[VarData] = None, + _var_is_scope: bool = False, ) -> Var | None: """Create a var from a value. @@ -455,6 +458,7 @@ class Var: _var_is_local=_var_is_local, _var_is_string=_var_is_string if _var_is_string is not None else False, _var_data=_var_data, + _var_is_scope=_var_is_scope, ) @classmethod @@ -1866,6 +1870,8 @@ class BaseVar(Var): # Extra metadata associated with the Var _var_data: Optional[VarData] = dataclasses.field(default=None) + _var_is_scope: bool = dataclasses.field(default=False) + def __hash__(self) -> int: """Define a hash function for a var. From ab19508451214f5ffd38a1876829df5617076f16 Mon Sep 17 00:00:00 2001 From: Andreas Eismann Date: Wed, 17 Jul 2024 10:21:53 +0200 Subject: [PATCH 2/6] feat: Added redis token redirection. --- reflex/state.py | 17 ++++++++++++----- 1 file changed, 12 insertions(+), 5 deletions(-) diff --git a/reflex/state.py b/reflex/state.py index d2c5fa567..cc9063792 100644 --- a/reflex/state.py +++ b/reflex/state.py @@ -1430,14 +1430,15 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow): f"Your handler {handler.fn.__qualname__} must only return/yield: None, Events or other EventHandlers referenced by their class (not using `self`)" ) - def _get_token(self) -> str: + def _get_token(self, other=None) -> str: token = self.router.session.client_token - if self.__class__._scope is not None: + cls = other or self.__class__ + if cls._scope is not None: scope = None - if isinstance(self.__class__._scope, str): - scope = self.__class__._scope + if isinstance(cls._scope, str): + scope = cls._scope else: - scope = getattr(self, self.__class__._scope._var_name) + scope = getattr(self, cls._scope._var_name) token = scope @@ -2561,6 +2562,10 @@ class StateManagerRedis(StateManager): "StateManagerRedis requires token to be specified in the form of {token}_{state_full_name}" ) + if parent_state is None: + parent_state = await self._get_parent_state(token) + if parent_state is not None: + token = f"{parent_state._get_token(state_cls)}_{state_path}" # Fetch the serialized substate from redis. redis_state = await self.redis.get(token) @@ -2657,6 +2662,8 @@ class StateManagerRedis(StateManager): "or use `@rx.background` decorator for long-running tasks." ) client_token, substate_name = _split_substate_key(token) + client_token = state._get_token() + # If the substate name on the token doesn't match the instance name, it cannot have a parent. if state.parent_state is not None and state.get_full_name() != substate_name: raise RuntimeError( From db4c73a027e29260365228331a73f91c3da2ea78 Mon Sep 17 00:00:00 2001 From: Andreas Eismann Date: Wed, 17 Jul 2024 11:22:00 +0200 Subject: [PATCH 3/6] feat: Adding subscriptions to rooms. TODO: Remove subscriptions when scopes change. --- reflex/app.py | 13 +++++++-- reflex/middleware/hydrate_middleware.py | 2 +- reflex/state.py | 38 +++++++++++++++++++++++++ 3 files changed, 50 insertions(+), 3 deletions(-) diff --git a/reflex/app.py b/reflex/app.py index 559861c24..7a6615ed1 100644 --- a/reflex/app.py +++ b/reflex/app.py @@ -1078,7 +1078,9 @@ class App(MiddlewareMixin, LifespanMixin, Base): # When the state is modified reset dirty status and emit the delta to the frontend. state._clean() await self.event_namespace.emit_update( - update=StateUpdate(delta=delta), + update=StateUpdate( + delta=delta, scopes=state.scopes_and_subscopes() + ), sid=state.router.session.session_id, ) @@ -1273,7 +1275,7 @@ async def process( else: if app._process_background(state, event) is not None: # `final=True` allows the frontend send more events immediately. - yield StateUpdate(final=True) + yield StateUpdate(final=True, scopes=state.scopes_and_subscopes()) return # Process the event synchronously. @@ -1470,6 +1472,13 @@ class EventNamespace(AsyncNamespace): sid: The Socket.IO session id. room: The room to send the update to. """ + # TODO We don't know when to leave a room yet. + for receiver in update.scopes: + if self.sid_to_token[sid] != receiver: + room = receiver + if room not in self.rooms(sid): + await self.enter_room(sid, room) + # Creating a task prevents the update from being blocked behind other coroutines. await asyncio.create_task( self.emit(str(constants.SocketEvent.EVENT), update.json(), to=room or sid) diff --git a/reflex/middleware/hydrate_middleware.py b/reflex/middleware/hydrate_middleware.py index b5694e22f..99abf97aa 100644 --- a/reflex/middleware/hydrate_middleware.py +++ b/reflex/middleware/hydrate_middleware.py @@ -46,4 +46,4 @@ class HydrateMiddleware(Middleware): state._clean() # Return the state update. - return StateUpdate(delta=delta, events=[]) + return StateUpdate(delta=delta, events=[], scopes=state.scopes_and_subscopes()) diff --git a/reflex/state.py b/reflex/state.py index cc9063792..fab22a8e1 100644 --- a/reflex/state.py +++ b/reflex/state.py @@ -8,6 +8,7 @@ import copy import functools import inspect import os +import traceback import uuid from abc import ABC, abstractmethod from collections import defaultdict @@ -68,6 +69,22 @@ var = computed_var TOO_LARGE_SERIALIZED_STATE = 100 * 1024 # 100kb +def print_stack(depth: int = 3): + """Print the current stacktrace to the console. + + Args: + depth: Depth of the stack-trace to print + """ + stack = traceback.extract_stack() + stack.reverse() + print("stacktrace") + for idx in range(1, depth + 1): + stack_info = stack[idx] + print( + f" {stack_info.name} {os.path.basename(stack_info.filename)}:{stack_info.lineno}" + ) + + class HeaderData(Base): """An object containing headers data.""" @@ -1430,6 +1447,20 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow): f"Your handler {handler.fn.__qualname__} must only return/yield: None, Events or other EventHandlers referenced by their class (not using `self`)" ) + def scopes_and_subscopes(self) -> list[str]: + """Recursively gathers all scopes of self and substates. + + Returns: + A unique list of the scopes/token + """ + result = [self._get_token()] + for substate in self.substates.values(): + subscopes = substate.scopes_and_subscopes() + for subscope in subscopes: + if subscope not in result: + result.append(subscope) + return result + def _get_token(self, other=None) -> str: token = self.router.session.client_token cls = other or self.__class__ @@ -1481,6 +1512,7 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow): delta=delta, events=fixed_events, final=final if not handler.is_background else True, + scopes=state.scopes_and_subscopes(), ) except Exception as ex: state._clean() @@ -2254,6 +2286,8 @@ class StateUpdate(Base): # Whether this is the final state update for the event. final: bool = True + scopes: list[str] = [] + class StateManager(Base, ABC): """A class to manage many client states.""" @@ -2531,6 +2565,10 @@ class StateManagerRedis(StateManager): for substate_name, substate_task in tasks.items(): state.substates[substate_name] = await substate_task + # async def print_all_keys(self): + # for key in await self.redis.keys(): + # print(f"redis_key: {key}") + async def get_state( self, token: str, From 351e5a0250df724f7147d0ef9caff7e4136acdb6 Mon Sep 17 00:00:00 2001 From: Andreas Eismann Date: Wed, 17 Jul 2024 14:43:27 +0200 Subject: [PATCH 4/6] fix: Removing router_data from StateUpdate. --- reflex/app.py | 26 +++++++++++++++----------- reflex/state.py | 7 +++++-- 2 files changed, 20 insertions(+), 13 deletions(-) diff --git a/reflex/app.py b/reflex/app.py index 7a6615ed1..e188051b3 100644 --- a/reflex/app.py +++ b/reflex/app.py @@ -1244,17 +1244,21 @@ async def process( from reflex.utils import telemetry try: - # Add request data to the state. - router_data = event.router_data - router_data.update( - { - constants.RouteVar.QUERY: format.format_query_params(event.router_data), - constants.RouteVar.CLIENT_TOKEN: event.token, - constants.RouteVar.SESSION_ID: sid, - constants.RouteVar.HEADERS: headers, - constants.RouteVar.CLIENT_IP: client_ip, - } - ) + router_data = {} + if event.router_data: + # Add request data to the state. + router_data = event.router_data + router_data.update( + { + constants.RouteVar.QUERY: format.format_query_params( + event.router_data + ), + constants.RouteVar.CLIENT_TOKEN: event.token, + constants.RouteVar.SESSION_ID: sid, + constants.RouteVar.HEADERS: headers, + constants.RouteVar.CLIENT_IP: client_ip, + } + ) # Get the state for the session exclusively. async with app.state_manager.modify_state(event.substate_token) as state: # re-assign only when the value is different diff --git a/reflex/state.py b/reflex/state.py index fab22a8e1..beeeb7069 100644 --- a/reflex/state.py +++ b/reflex/state.py @@ -1467,9 +1467,9 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow): if cls._scope is not None: scope = None if isinstance(cls._scope, str): - scope = cls._scope + scope = f"static{cls._scope}" else: - scope = getattr(self, cls._scope._var_name) + scope = f"shared{getattr(self, cls._scope._var_name)}" token = scope @@ -1690,6 +1690,9 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow): .union(self._always_dirty_computed_vars) ) + if len(self.scopes_and_subscopes()) > 1 and "router" in delta_vars: + delta_vars.remove("router") + subdelta = { prop: getattr(self, prop) for prop in delta_vars From 50b9f7b1da9700c0b81f76ef9b727c706e35a6cd Mon Sep 17 00:00:00 2001 From: Andreas Eismann Date: Wed, 17 Jul 2024 17:33:45 +0200 Subject: [PATCH 5/6] fix: Separating StateUpdates by receivee. --- reflex/app.py | 32 +++++++++++++++++++++++++++++--- reflex/state.py | 9 ++++++--- 2 files changed, 35 insertions(+), 6 deletions(-) diff --git a/reflex/app.py b/reflex/app.py index e188051b3..7e1f9da1c 100644 --- a/reflex/app.py +++ b/reflex/app.py @@ -1481,12 +1481,38 @@ class EventNamespace(AsyncNamespace): if self.sid_to_token[sid] != receiver: room = receiver if room not in self.rooms(sid): + print(f"Entering room `{room}`") await self.enter_room(sid, room) + # for room in self.rooms(sid): + # if room not in update.scopes and room != sid: + # print(f"Leaving room `{room}`") + # await self.leave_room(sid, room) + + # deltas = {delta._scope: {state: delta} for state, delta in update.delta.values.items()} + + delta_by_scope = {} + + for state, delta in update.delta.items(): + key = delta.get("_scope", sid) + d = delta_by_scope.get(key, {}) + d.update({state: delta}) + delta_by_scope[key] = d + + for scope, deltas in delta_by_scope.items(): + single_update = StateUpdate( + delta=deltas, scopes=[scope], events=update.events, final=update.final + ) + + await asyncio.create_task( + self.emit( + str(constants.SocketEvent.EVENT), single_update.json(), to=scope + ) + ) + + update.scopes = [] + # Creating a task prevents the update from being blocked behind other coroutines. - await asyncio.create_task( - self.emit(str(constants.SocketEvent.EVENT), update.json(), to=room or sid) - ) async def on_event(self, sid, data): """Event for receiving front-end websocket events. diff --git a/reflex/state.py b/reflex/state.py index beeeb7069..4f7a85e8d 100644 --- a/reflex/state.py +++ b/reflex/state.py @@ -702,7 +702,7 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow): ) # Any substate containing a ComputedVar with cache=False always needs to be recomputed - if cls._always_dirty_computed_vars: + if cls._always_dirty_computed_vars: # or cls._scope is not None: # Tell parent classes that this substate has always dirty computed vars state_name = cls.get_name() parent_state = cls.get_parent_state() @@ -1461,7 +1461,7 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow): result.append(subscope) return result - def _get_token(self, other=None) -> str: + def _get_token(self, other: type[BaseState] | None = None) -> str: token = self.router.session.client_token cls = other or self.__class__ if cls._scope is not None: @@ -1701,6 +1701,9 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow): if len(subdelta) > 0: delta[self.get_full_name()] = subdelta + if self.__class__._scope is not None: + subdelta["_scope"] = self._get_token() + # Recursively find the substate deltas. substates = self.substates for substate in self.dirty_substates.union(self._always_dirty_substates): @@ -1836,7 +1839,7 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow): } else: computed_vars = {} - variables = {**base_vars, **computed_vars} + variables = {"_scope": self._get_token(), **base_vars, **computed_vars} d = { self.get_full_name(): {k: variables[k] for k in sorted(variables)}, } From 608dcab2265477589cce0b6d149f1bece05dd2b0 Mon Sep 17 00:00:00 2001 From: Andreas Eismann Date: Wed, 17 Jul 2024 22:40:55 +0200 Subject: [PATCH 6/6] fix: Separate stateupdates and events. --- reflex/app.py | 41 ++++++++++++++++++++++++++++++++++------- 1 file changed, 34 insertions(+), 7 deletions(-) diff --git a/reflex/app.py b/reflex/app.py index 7e1f9da1c..001c95885 100644 --- a/reflex/app.py +++ b/reflex/app.py @@ -1259,6 +1259,9 @@ async def process( constants.RouteVar.CLIENT_IP: client_ip, } ) + print( + f"Processing event: {event.name} with payload: {event.payload} {event.substate_token}" + ) # Get the state for the session exclusively. async with app.state_manager.modify_state(event.substate_token) as state: # re-assign only when the value is different @@ -1484,14 +1487,19 @@ class EventNamespace(AsyncNamespace): print(f"Entering room `{room}`") await self.enter_room(sid, room) - # for room in self.rooms(sid): - # if room not in update.scopes and room != sid: - # print(f"Leaving room `{room}`") - # await self.leave_room(sid, room) - - # deltas = {delta._scope: {state: delta} for state, delta in update.delta.values.items()} + for room in self.rooms(sid): + if room not in update.scopes and room != sid: + print(f"Leaving room `{room}`") + await self.leave_room(sid, room) delta_by_scope = {} + events_by_scope = {} + + for event in update.events: + scope = self.token_to_sid.get(event.token, event.token) + events = events_by_scope.get(scope, []) + events.append(event) + events_by_scope[scope] = events for state, delta in update.delta.items(): key = delta.get("_scope", sid) @@ -1500,8 +1508,10 @@ class EventNamespace(AsyncNamespace): delta_by_scope[key] = d for scope, deltas in delta_by_scope.items(): + events = events_by_scope.get(scope, []) + print(f"Sending update to {scope} {events}") single_update = StateUpdate( - delta=deltas, scopes=[scope], events=update.events, final=update.final + delta=deltas, scopes=[scope], events=events, final=update.final ) await asyncio.create_task( @@ -1510,6 +1520,21 @@ class EventNamespace(AsyncNamespace): ) ) + for key in events_by_scope: + if key not in delta_by_scope: + single_update = StateUpdate( + delta={}, + scopes=[key], + events=events_by_scope.get(key, []), + final=update.final, + ) + print(f"Sending event to {key}") + await asyncio.create_task( + self.emit( + str(constants.SocketEvent.EVENT), single_update.json(), to=key + ) + ) + update.scopes = [] # Creating a task prevents the update from being blocked behind other coroutines. @@ -1544,6 +1569,8 @@ class EventNamespace(AsyncNamespace): except (KeyError, IndexError): client_ip = environ.get("REMOTE_ADDR", "0.0.0.0") + print(f"Received event {event.name} {event.token} from {client_ip}") + # Process the events. async for update in process(self.app, event, sid, headers, client_ip): # Emit the update from processing the event.