fix: only open one connection/sub for each token per worker
bonus: properly cleanup StateManager connections on disconnect
This commit is contained in:
parent
0bf8ffefee
commit
f3e393e621
@ -1479,7 +1479,7 @@ class EventNamespace(AsyncNamespace):
|
||||
self.sid_to_token = {}
|
||||
self.app = app
|
||||
|
||||
def on_connect(self, sid, environ):
|
||||
async def on_connect(self, sid, environ):
|
||||
"""Event for when the websocket is connected.
|
||||
|
||||
Args:
|
||||
@ -1488,7 +1488,7 @@ class EventNamespace(AsyncNamespace):
|
||||
"""
|
||||
pass
|
||||
|
||||
def on_disconnect(self, sid):
|
||||
async def on_disconnect(self, sid):
|
||||
"""Event for when the websocket disconnects.
|
||||
|
||||
Args:
|
||||
@ -1497,6 +1497,7 @@ class EventNamespace(AsyncNamespace):
|
||||
disconnect_token = self.sid_to_token.pop(sid, None)
|
||||
if disconnect_token:
|
||||
self.token_to_sid.pop(disconnect_token, None)
|
||||
await self.app.state_manager.disconnect(sid)
|
||||
|
||||
async def emit_update(self, update: StateUpdate, sid: str) -> None:
|
||||
"""Emit an update to the client.
|
||||
|
@ -2830,6 +2830,14 @@ class StateManager(Base, ABC):
|
||||
"""
|
||||
yield self.state()
|
||||
|
||||
async def disconnect(self, token: str) -> None:
|
||||
"""Disconnect the client with the given token.
|
||||
|
||||
Args:
|
||||
token: The token to disconnect.
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
class StateManagerMemory(StateManager):
|
||||
"""A state manager that stores states in memory."""
|
||||
@ -2899,6 +2907,20 @@ class StateManagerMemory(StateManager):
|
||||
yield state
|
||||
await self.set_state(token, state)
|
||||
|
||||
@override
|
||||
async def disconnect(self, token: str) -> None:
|
||||
"""Disconnect the client with the given token.
|
||||
|
||||
Args:
|
||||
token: The token to disconnect.
|
||||
"""
|
||||
if token in self.states:
|
||||
del self.states[token]
|
||||
if lock := self._states_locks.get(token):
|
||||
if lock.locked():
|
||||
lock.release()
|
||||
del self._states_locks[token]
|
||||
|
||||
|
||||
def _default_token_expiration() -> int:
|
||||
"""Get the default token expiration time.
|
||||
@ -3187,6 +3209,9 @@ class StateManagerRedis(StateManager):
|
||||
b"evicted",
|
||||
}
|
||||
|
||||
# This lock is used to ensure we only subscribe to keyspace events once per token and worker
|
||||
_pubsub_locks: Dict[bytes, asyncio.Lock] = pydantic.PrivateAttr({})
|
||||
|
||||
async def _get_parent_state(
|
||||
self, token: str, state: BaseState | None = None
|
||||
) -> BaseState | None:
|
||||
@ -3462,7 +3487,9 @@ class StateManagerRedis(StateManager):
|
||||
# Some redis servers only allow out-of-band configuration, so ignore errors here.
|
||||
if not environment.REFLEX_IGNORE_REDIS_CONFIG_ERROR.get():
|
||||
raise
|
||||
async with self.redis.pubsub() as pubsub:
|
||||
if lock_key not in self._pubsub_locks:
|
||||
self._pubsub_locks[lock_key] = asyncio.Lock()
|
||||
async with self._pubsub_locks[lock_key], self.redis.pubsub() as pubsub:
|
||||
await pubsub.psubscribe(lock_key_channel)
|
||||
while not state_is_locked:
|
||||
# wait for the lock to be released
|
||||
@ -3479,6 +3506,19 @@ class StateManagerRedis(StateManager):
|
||||
break
|
||||
state_is_locked = await self._try_get_lock(lock_key, lock_id)
|
||||
|
||||
@override
|
||||
async def disconnect(self, token: str):
|
||||
"""Disconnect the token from the redis client.
|
||||
|
||||
Args:
|
||||
token: The token to disconnect.
|
||||
"""
|
||||
lock_key = self._lock_key(token)
|
||||
if lock := self._pubsub_locks.get(lock_key):
|
||||
if lock.locked():
|
||||
lock.release()
|
||||
del self._pubsub_locks[lock_key]
|
||||
|
||||
@contextlib.asynccontextmanager
|
||||
async def _lock(self, token: str):
|
||||
"""Obtain a redis lock for a token.
|
||||
|
Loading…
Reference in New Issue
Block a user