diff --git a/reflex/app.py b/reflex/app.py index e188051b3..7e1f9da1c 100644 --- a/reflex/app.py +++ b/reflex/app.py @@ -1481,12 +1481,38 @@ class EventNamespace(AsyncNamespace): if self.sid_to_token[sid] != receiver: room = receiver if room not in self.rooms(sid): + print(f"Entering room `{room}`") await self.enter_room(sid, room) + # for room in self.rooms(sid): + # if room not in update.scopes and room != sid: + # print(f"Leaving room `{room}`") + # await self.leave_room(sid, room) + + # deltas = {delta._scope: {state: delta} for state, delta in update.delta.values.items()} + + delta_by_scope = {} + + for state, delta in update.delta.items(): + key = delta.get("_scope", sid) + d = delta_by_scope.get(key, {}) + d.update({state: delta}) + delta_by_scope[key] = d + + for scope, deltas in delta_by_scope.items(): + single_update = StateUpdate( + delta=deltas, scopes=[scope], events=update.events, final=update.final + ) + + await asyncio.create_task( + self.emit( + str(constants.SocketEvent.EVENT), single_update.json(), to=scope + ) + ) + + update.scopes = [] + # Creating a task prevents the update from being blocked behind other coroutines. - await asyncio.create_task( - self.emit(str(constants.SocketEvent.EVENT), update.json(), to=room or sid) - ) async def on_event(self, sid, data): """Event for receiving front-end websocket events. diff --git a/reflex/state.py b/reflex/state.py index beeeb7069..4f7a85e8d 100644 --- a/reflex/state.py +++ b/reflex/state.py @@ -702,7 +702,7 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow): ) # Any substate containing a ComputedVar with cache=False always needs to be recomputed - if cls._always_dirty_computed_vars: + if cls._always_dirty_computed_vars: # or cls._scope is not None: # Tell parent classes that this substate has always dirty computed vars state_name = cls.get_name() parent_state = cls.get_parent_state() @@ -1461,7 +1461,7 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow): result.append(subscope) return result - def _get_token(self, other=None) -> str: + def _get_token(self, other: type[BaseState] | None = None) -> str: token = self.router.session.client_token cls = other or self.__class__ if cls._scope is not None: @@ -1701,6 +1701,9 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow): if len(subdelta) > 0: delta[self.get_full_name()] = subdelta + if self.__class__._scope is not None: + subdelta["_scope"] = self._get_token() + # Recursively find the substate deltas. substates = self.substates for substate in self.dirty_substates.union(self._always_dirty_substates): @@ -1836,7 +1839,7 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow): } else: computed_vars = {} - variables = {**base_vars, **computed_vars} + variables = {"_scope": self._get_token(), **base_vars, **computed_vars} d = { self.get_full_name(): {k: variables[k] for k in sorted(variables)}, }