Merge 7fe33c9bf1
into 709c6dedf2
This commit is contained in:
commit
0a394f0a54
@ -1580,7 +1580,7 @@ class EventNamespace(AsyncNamespace):
|
|||||||
self.sid_to_token = {}
|
self.sid_to_token = {}
|
||||||
self.app = app
|
self.app = app
|
||||||
|
|
||||||
def on_connect(self, sid, environ):
|
async def on_connect(self, sid, environ):
|
||||||
"""Event for when the websocket is connected.
|
"""Event for when the websocket is connected.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@ -1593,7 +1593,7 @@ class EventNamespace(AsyncNamespace):
|
|||||||
f"Frontend version {subprotocol} for session {sid} does not match the backend version {constants.Reflex.VERSION}."
|
f"Frontend version {subprotocol} for session {sid} does not match the backend version {constants.Reflex.VERSION}."
|
||||||
)
|
)
|
||||||
|
|
||||||
def on_disconnect(self, sid):
|
async def on_disconnect(self, sid):
|
||||||
"""Event for when the websocket disconnects.
|
"""Event for when the websocket disconnects.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@ -1602,6 +1602,7 @@ class EventNamespace(AsyncNamespace):
|
|||||||
disconnect_token = self.sid_to_token.pop(sid, None)
|
disconnect_token = self.sid_to_token.pop(sid, None)
|
||||||
if disconnect_token:
|
if disconnect_token:
|
||||||
self.token_to_sid.pop(disconnect_token, None)
|
self.token_to_sid.pop(disconnect_token, None)
|
||||||
|
await self.app.state_manager.disconnect(sid)
|
||||||
|
|
||||||
async def emit_update(self, update: StateUpdate, sid: str) -> None:
|
async def emit_update(self, update: StateUpdate, sid: str) -> None:
|
||||||
"""Emit an update to the client.
|
"""Emit an update to the client.
|
||||||
|
@ -2936,6 +2936,14 @@ class StateManager(Base, ABC):
|
|||||||
"""
|
"""
|
||||||
yield self.state()
|
yield self.state()
|
||||||
|
|
||||||
|
async def disconnect(self, token: str) -> None:
|
||||||
|
"""Disconnect the client with the given token.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
token: The token to disconnect.
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
class StateManagerMemory(StateManager):
|
class StateManagerMemory(StateManager):
|
||||||
"""A state manager that stores states in memory."""
|
"""A state manager that stores states in memory."""
|
||||||
@ -3005,6 +3013,20 @@ class StateManagerMemory(StateManager):
|
|||||||
yield state
|
yield state
|
||||||
await self.set_state(token, state)
|
await self.set_state(token, state)
|
||||||
|
|
||||||
|
@override
|
||||||
|
async def disconnect(self, token: str) -> None:
|
||||||
|
"""Disconnect the client with the given token.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
token: The token to disconnect.
|
||||||
|
"""
|
||||||
|
if token in self.states:
|
||||||
|
del self.states[token]
|
||||||
|
if lock := self._states_locks.get(token):
|
||||||
|
if lock.locked():
|
||||||
|
lock.release()
|
||||||
|
del self._states_locks[token]
|
||||||
|
|
||||||
|
|
||||||
def _default_token_expiration() -> int:
|
def _default_token_expiration() -> int:
|
||||||
"""Get the default token expiration time.
|
"""Get the default token expiration time.
|
||||||
@ -3307,6 +3329,9 @@ class StateManagerRedis(StateManager):
|
|||||||
b"evicted",
|
b"evicted",
|
||||||
}
|
}
|
||||||
|
|
||||||
|
# This lock is used to ensure we only subscribe to keyspace events once per token and worker
|
||||||
|
_pubsub_locks: Dict[bytes, asyncio.Lock] = pydantic.PrivateAttr({})
|
||||||
|
|
||||||
async def _get_parent_state(
|
async def _get_parent_state(
|
||||||
self, token: str, state: BaseState | None = None
|
self, token: str, state: BaseState | None = None
|
||||||
) -> BaseState | None:
|
) -> BaseState | None:
|
||||||
@ -3641,7 +3666,9 @@ class StateManagerRedis(StateManager):
|
|||||||
# Some redis servers only allow out-of-band configuration, so ignore errors here.
|
# Some redis servers only allow out-of-band configuration, so ignore errors here.
|
||||||
if not environment.REFLEX_IGNORE_REDIS_CONFIG_ERROR.get():
|
if not environment.REFLEX_IGNORE_REDIS_CONFIG_ERROR.get():
|
||||||
raise
|
raise
|
||||||
async with self.redis.pubsub() as pubsub:
|
if lock_key not in self._pubsub_locks:
|
||||||
|
self._pubsub_locks[lock_key] = asyncio.Lock()
|
||||||
|
async with self._pubsub_locks[lock_key], self.redis.pubsub() as pubsub:
|
||||||
await pubsub.psubscribe(lock_key_channel)
|
await pubsub.psubscribe(lock_key_channel)
|
||||||
# wait for the lock to be released
|
# wait for the lock to be released
|
||||||
while True:
|
while True:
|
||||||
@ -3651,6 +3678,19 @@ class StateManagerRedis(StateManager):
|
|||||||
# wait for lock events
|
# wait for lock events
|
||||||
await self._get_pubsub_message(pubsub)
|
await self._get_pubsub_message(pubsub)
|
||||||
|
|
||||||
|
@override
|
||||||
|
async def disconnect(self, token: str):
|
||||||
|
"""Disconnect the token from the redis client.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
token: The token to disconnect.
|
||||||
|
"""
|
||||||
|
lock_key = self._lock_key(token)
|
||||||
|
if lock := self._pubsub_locks.get(lock_key):
|
||||||
|
if lock.locked():
|
||||||
|
lock.release()
|
||||||
|
del self._pubsub_locks[lock_key]
|
||||||
|
|
||||||
@contextlib.asynccontextmanager
|
@contextlib.asynccontextmanager
|
||||||
async def _lock(self, token: str):
|
async def _lock(self, token: str):
|
||||||
"""Obtain a redis lock for a token.
|
"""Obtain a redis lock for a token.
|
||||||
|
@ -42,6 +42,11 @@ def BackgroundTask():
|
|||||||
yield State.increment() # type: ignore
|
yield State.increment() # type: ignore
|
||||||
await asyncio.sleep(0.005)
|
await asyncio.sleep(0.005)
|
||||||
|
|
||||||
|
@rx.event(background=True)
|
||||||
|
async def fast_yielding(self):
|
||||||
|
for _ in range(1000):
|
||||||
|
yield State.increment()
|
||||||
|
|
||||||
@rx.event
|
@rx.event
|
||||||
def increment(self):
|
def increment(self):
|
||||||
self.counter += 1
|
self.counter += 1
|
||||||
@ -169,6 +174,11 @@ def BackgroundTask():
|
|||||||
on_click=State.yield_in_async_with_self,
|
on_click=State.yield_in_async_with_self,
|
||||||
id="yield-in-async-with-self",
|
id="yield-in-async-with-self",
|
||||||
),
|
),
|
||||||
|
rx.button(
|
||||||
|
"Fast Yielding",
|
||||||
|
on_click=State.fast_yielding,
|
||||||
|
id="fast-yielding",
|
||||||
|
),
|
||||||
rx.button("Reset", on_click=State.reset_counter, id="reset"),
|
rx.button("Reset", on_click=State.reset_counter, id="reset"),
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -375,3 +385,28 @@ def test_yield_in_async_with_self(
|
|||||||
|
|
||||||
yield_in_async_with_self_button.click()
|
yield_in_async_with_self_button.click()
|
||||||
assert background_task._poll_for(lambda: counter.text == "2", timeout=5)
|
assert background_task._poll_for(lambda: counter.text == "2", timeout=5)
|
||||||
|
|
||||||
|
|
||||||
|
def test_fast_yielding(
|
||||||
|
background_task: AppHarness,
|
||||||
|
driver: WebDriver,
|
||||||
|
token: str,
|
||||||
|
) -> None:
|
||||||
|
"""Test that fast yielding works as expected.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
background_task: harness for BackgroundTask app.
|
||||||
|
driver: WebDriver instance.
|
||||||
|
token: The token for the connected client.
|
||||||
|
"""
|
||||||
|
assert background_task.app_instance is not None
|
||||||
|
|
||||||
|
# get a reference to all buttons
|
||||||
|
fast_yielding_button = driver.find_element(By.ID, "fast-yielding")
|
||||||
|
|
||||||
|
# get a reference to the counter
|
||||||
|
counter = driver.find_element(By.ID, "counter")
|
||||||
|
assert background_task._poll_for(lambda: counter.text == "0", timeout=5)
|
||||||
|
|
||||||
|
fast_yielding_button.click()
|
||||||
|
assert background_task._poll_for(lambda: counter.text == "1000", timeout=50)
|
||||||
|
Loading…
Reference in New Issue
Block a user