fix: only open one connection/sub for each token per worker

bonus: properly cleanup StateManager connections on disconnect
This commit is contained in:
Benedikt Bartscher 2024-11-23 20:37:39 +01:00
parent 03eda2e90e
commit 389f4c7196
No known key found for this signature in database
2 changed files with 44 additions and 3 deletions

View File

@ -1477,7 +1477,7 @@ class EventNamespace(AsyncNamespace):
super().__init__(namespace)
self.app = app
def on_connect(self, sid, environ):
async def on_connect(self, sid, environ):
"""Event for when the websocket is connected.
Args:
@ -1486,7 +1486,7 @@ class EventNamespace(AsyncNamespace):
"""
pass
def on_disconnect(self, sid):
async def on_disconnect(self, sid):
"""Event for when the websocket disconnects.
Args:
@ -1495,6 +1495,7 @@ class EventNamespace(AsyncNamespace):
disconnect_token = self.sid_to_token.pop(sid, None)
if disconnect_token:
self.token_to_sid.pop(disconnect_token, None)
await self.app.state_manager.disconnect(sid)
async def emit_update(self, update: StateUpdate, sid: str) -> None:
"""Emit an update to the client.

View File

@ -2826,6 +2826,14 @@ class StateManager(Base, ABC):
"""
yield self.state()
async def disconnect(self, token: str) -> None:
"""Disconnect the client with the given token.
Args:
token: The token to disconnect.
"""
pass
class StateManagerMemory(StateManager):
"""A state manager that stores states in memory."""
@ -2895,6 +2903,20 @@ class StateManagerMemory(StateManager):
yield state
await self.set_state(token, state)
@override
async def disconnect(self, token: str) -> None:
"""Disconnect the client with the given token.
Args:
token: The token to disconnect.
"""
if token in self.states:
del self.states[token]
if lock := self._states_locks.get(token):
if lock.locked():
lock.release()
del self._states_locks[token]
def _default_token_expiration() -> int:
"""Get the default token expiration time.
@ -3183,6 +3205,9 @@ class StateManagerRedis(StateManager):
b"evicted",
}
# This lock is used to ensure we only subscribe to keyspace events once per token and worker
_pubsub_locks: Dict[bytes, asyncio.Lock] = pydantic.PrivateAttr({})
async def _get_parent_state(
self, token: str, state: BaseState | None = None
) -> BaseState | None:
@ -3458,7 +3483,9 @@ class StateManagerRedis(StateManager):
# Some redis servers only allow out-of-band configuration, so ignore errors here.
if not environment.REFLEX_IGNORE_REDIS_CONFIG_ERROR.get():
raise
async with self.redis.pubsub() as pubsub:
if lock_key not in self._pubsub_locks:
self._pubsub_locks[lock_key] = asyncio.Lock()
async with self._pubsub_locks[lock_key], self.redis.pubsub() as pubsub:
await pubsub.psubscribe(lock_key_channel)
while not state_is_locked:
# wait for the lock to be released
@ -3475,6 +3502,19 @@ class StateManagerRedis(StateManager):
break
state_is_locked = await self._try_get_lock(lock_key, lock_id)
@override
async def disconnect(self, token: str):
"""Disconnect the token from the redis client.
Args:
token: The token to disconnect.
"""
lock_key = self._lock_key(token)
if lock := self._pubsub_locks.get(lock_key):
if lock.locked():
lock.release()
del self._pubsub_locks[lock_key]
@contextlib.asynccontextmanager
async def _lock(self, token: str):
"""Obtain a redis lock for a token.