diff --git a/reflex/app.py b/reflex/app.py index d432925ab..a208d35ff 100644 --- a/reflex/app.py +++ b/reflex/app.py @@ -1580,7 +1580,7 @@ class EventNamespace(AsyncNamespace): self.sid_to_token = {} self.app = app - def on_connect(self, sid, environ): + async def on_connect(self, sid, environ): """Event for when the websocket is connected. Args: @@ -1593,7 +1593,7 @@ class EventNamespace(AsyncNamespace): f"Frontend version {subprotocol} for session {sid} does not match the backend version {constants.Reflex.VERSION}." ) - def on_disconnect(self, sid): + async def on_disconnect(self, sid): """Event for when the websocket disconnects. Args: @@ -1602,6 +1602,7 @@ class EventNamespace(AsyncNamespace): disconnect_token = self.sid_to_token.pop(sid, None) if disconnect_token: self.token_to_sid.pop(disconnect_token, None) + await self.app.state_manager.disconnect(sid) async def emit_update(self, update: StateUpdate, sid: str) -> None: """Emit an update to the client. diff --git a/reflex/state.py b/reflex/state.py index 40afcbc79..5f36ed455 100644 --- a/reflex/state.py +++ b/reflex/state.py @@ -2936,6 +2936,14 @@ class StateManager(Base, ABC): """ yield self.state() + async def disconnect(self, token: str) -> None: + """Disconnect the client with the given token. + + Args: + token: The token to disconnect. + """ + pass + class StateManagerMemory(StateManager): """A state manager that stores states in memory.""" @@ -3005,6 +3013,20 @@ class StateManagerMemory(StateManager): yield state await self.set_state(token, state) + @override + async def disconnect(self, token: str) -> None: + """Disconnect the client with the given token. + + Args: + token: The token to disconnect. + """ + if token in self.states: + del self.states[token] + if lock := self._states_locks.get(token): + if lock.locked(): + lock.release() + del self._states_locks[token] + def _default_token_expiration() -> int: """Get the default token expiration time. @@ -3307,6 +3329,9 @@ class StateManagerRedis(StateManager): b"evicted", } + # This lock is used to ensure we only subscribe to keyspace events once per token and worker + _pubsub_locks: Dict[bytes, asyncio.Lock] = pydantic.PrivateAttr({}) + async def _get_parent_state( self, token: str, state: BaseState | None = None ) -> BaseState | None: @@ -3641,7 +3666,9 @@ class StateManagerRedis(StateManager): # Some redis servers only allow out-of-band configuration, so ignore errors here. if not environment.REFLEX_IGNORE_REDIS_CONFIG_ERROR.get(): raise - async with self.redis.pubsub() as pubsub: + if lock_key not in self._pubsub_locks: + self._pubsub_locks[lock_key] = asyncio.Lock() + async with self._pubsub_locks[lock_key], self.redis.pubsub() as pubsub: await pubsub.psubscribe(lock_key_channel) # wait for the lock to be released while True: @@ -3651,6 +3678,19 @@ class StateManagerRedis(StateManager): # wait for lock events await self._get_pubsub_message(pubsub) + @override + async def disconnect(self, token: str): + """Disconnect the token from the redis client. + + Args: + token: The token to disconnect. + """ + lock_key = self._lock_key(token) + if lock := self._pubsub_locks.get(lock_key): + if lock.locked(): + lock.release() + del self._pubsub_locks[lock_key] + @contextlib.asynccontextmanager async def _lock(self, token: str): """Obtain a redis lock for a token. diff --git a/tests/integration/test_background_task.py b/tests/integration/test_background_task.py index cb8fda019..672254f44 100644 --- a/tests/integration/test_background_task.py +++ b/tests/integration/test_background_task.py @@ -42,6 +42,11 @@ def BackgroundTask(): yield State.increment() # type: ignore await asyncio.sleep(0.005) + @rx.event(background=True) + async def fast_yielding(self): + for _ in range(1000): + yield State.increment() + @rx.event def increment(self): self.counter += 1 @@ -169,6 +174,11 @@ def BackgroundTask(): on_click=State.yield_in_async_with_self, id="yield-in-async-with-self", ), + rx.button( + "Fast Yielding", + on_click=State.fast_yielding, + id="fast-yielding", + ), rx.button("Reset", on_click=State.reset_counter, id="reset"), ) @@ -375,3 +385,28 @@ def test_yield_in_async_with_self( yield_in_async_with_self_button.click() assert background_task._poll_for(lambda: counter.text == "2", timeout=5) + + +def test_fast_yielding( + background_task: AppHarness, + driver: WebDriver, + token: str, +) -> None: + """Test that fast yielding works as expected. + + Args: + background_task: harness for BackgroundTask app. + driver: WebDriver instance. + token: The token for the connected client. + """ + assert background_task.app_instance is not None + + # get a reference to all buttons + fast_yielding_button = driver.find_element(By.ID, "fast-yielding") + + # get a reference to the counter + counter = driver.find_element(By.ID, "counter") + assert background_task._poll_for(lambda: counter.text == "0", timeout=5) + + fast_yielding_button.click() + assert background_task._poll_for(lambda: counter.text == "1000", timeout=50)