Merge 198d02cb9b
into d75a708e6b
This commit is contained in:
commit
5dff4dc216
@ -1495,7 +1495,7 @@ class EventNamespace(AsyncNamespace):
|
||||
self.sid_to_token = {}
|
||||
self.app = app
|
||||
|
||||
def on_connect(self, sid, environ):
|
||||
async def on_connect(self, sid, environ):
|
||||
"""Event for when the websocket is connected.
|
||||
|
||||
Args:
|
||||
@ -1504,7 +1504,7 @@ class EventNamespace(AsyncNamespace):
|
||||
"""
|
||||
pass
|
||||
|
||||
def on_disconnect(self, sid):
|
||||
async def on_disconnect(self, sid):
|
||||
"""Event for when the websocket disconnects.
|
||||
|
||||
Args:
|
||||
@ -1513,6 +1513,7 @@ class EventNamespace(AsyncNamespace):
|
||||
disconnect_token = self.sid_to_token.pop(sid, None)
|
||||
if disconnect_token:
|
||||
self.token_to_sid.pop(disconnect_token, None)
|
||||
await self.app.state_manager.disconnect(sid)
|
||||
|
||||
async def emit_update(self, update: StateUpdate, sid: str) -> None:
|
||||
"""Emit an update to the client.
|
||||
|
@ -39,6 +39,7 @@ from typing import (
|
||||
get_type_hints,
|
||||
)
|
||||
|
||||
from redis.asyncio.client import PubSub
|
||||
from sqlalchemy.orm import DeclarativeBase
|
||||
from typing_extensions import Self
|
||||
|
||||
@ -137,7 +138,7 @@ HANDLED_PICKLE_ERRORS = (
|
||||
|
||||
|
||||
def _no_chain_background_task(
|
||||
state_cls: Type["BaseState"], name: str, fn: Callable
|
||||
state_cls: Type[BaseState], name: str, fn: Callable
|
||||
) -> Callable:
|
||||
"""Protect against directly chaining a background task from another event handler.
|
||||
|
||||
@ -174,9 +175,10 @@ def _no_chain_background_task(
|
||||
raise TypeError(f"{fn} is marked as a background task, but is not async.")
|
||||
|
||||
|
||||
@functools.lru_cache()
|
||||
def _substate_key(
|
||||
token: str,
|
||||
state_cls_or_name: BaseState | Type[BaseState] | str | Sequence[str],
|
||||
state_cls_or_name: Type[BaseState] | str | Sequence[str],
|
||||
) -> str:
|
||||
"""Get the substate key.
|
||||
|
||||
@ -187,9 +189,7 @@ def _substate_key(
|
||||
Returns:
|
||||
The substate key.
|
||||
"""
|
||||
if isinstance(state_cls_or_name, BaseState) or (
|
||||
isinstance(state_cls_or_name, type) and issubclass(state_cls_or_name, BaseState)
|
||||
):
|
||||
if isinstance(state_cls_or_name, type) and issubclass(state_cls_or_name, BaseState):
|
||||
state_cls_or_name = state_cls_or_name.get_full_name()
|
||||
elif isinstance(state_cls_or_name, (list, tuple)):
|
||||
state_cls_or_name = ".".join(state_cls_or_name)
|
||||
@ -315,7 +315,16 @@ def get_var_for_field(cls: Type[BaseState], f: ModelField):
|
||||
)
|
||||
|
||||
|
||||
class BaseState(Base, ABC, extra=pydantic.Extra.allow):
|
||||
class HashableModelMetaclass(type(Base)):
|
||||
def __hash__(self):
|
||||
return id(self)
|
||||
# return hash(f"{self.__module__}.{self.__name__}")
|
||||
# return hash(self.get_full_name())
|
||||
|
||||
|
||||
class BaseState(
|
||||
Base, ABC, extra=pydantic.Extra.allow, metaclass=HashableModelMetaclass
|
||||
):
|
||||
"""The state of the app."""
|
||||
|
||||
# A map from the var name to the var.
|
||||
@ -2859,6 +2868,14 @@ class StateManager(Base, ABC):
|
||||
"""
|
||||
yield self.state()
|
||||
|
||||
async def disconnect(self, token: str) -> None:
|
||||
"""Disconnect the client with the given token.
|
||||
|
||||
Args:
|
||||
token: The token to disconnect.
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
class StateManagerMemory(StateManager):
|
||||
"""A state manager that stores states in memory."""
|
||||
@ -2928,6 +2945,20 @@ class StateManagerMemory(StateManager):
|
||||
yield state
|
||||
await self.set_state(token, state)
|
||||
|
||||
@override
|
||||
async def disconnect(self, token: str) -> None:
|
||||
"""Disconnect the client with the given token.
|
||||
|
||||
Args:
|
||||
token: The token to disconnect.
|
||||
"""
|
||||
if token in self.states:
|
||||
del self.states[token]
|
||||
if lock := self._states_locks.get(token):
|
||||
if lock.locked():
|
||||
lock.release()
|
||||
del self._states_locks[token]
|
||||
|
||||
|
||||
def _default_token_expiration() -> int:
|
||||
"""Get the default token expiration time.
|
||||
@ -3077,17 +3108,17 @@ class StateManagerDisk(StateManager):
|
||||
state: The state object to populate.
|
||||
root_state: The root state object.
|
||||
"""
|
||||
for substate in state.get_substates():
|
||||
substate_token = _substate_key(client_token, substate)
|
||||
for substate_cls in state.get_substates():
|
||||
substate_token = _substate_key(client_token, substate_cls)
|
||||
|
||||
fresh_instance = await root_state.get_state(substate)
|
||||
fresh_instance = await root_state.get_state(substate_cls)
|
||||
instance = await self.load_state(substate_token)
|
||||
if instance is not None:
|
||||
# Ensure all substates exist, even if they weren't serialized previously.
|
||||
instance.substates = fresh_instance.substates
|
||||
else:
|
||||
instance = fresh_instance
|
||||
state.substates[substate.get_name()] = instance
|
||||
state.substates[substate_cls.get_name()] = instance
|
||||
instance.parent_state = state
|
||||
|
||||
await self.populate_substates(client_token, instance, root_state)
|
||||
@ -3131,7 +3162,7 @@ class StateManagerDisk(StateManager):
|
||||
client_token: The client token.
|
||||
substate: The substate to set.
|
||||
"""
|
||||
substate_token = _substate_key(client_token, substate)
|
||||
substate_token = _substate_key(client_token, type(substate))
|
||||
|
||||
if substate._get_was_touched():
|
||||
substate._was_touched = False # Reset the touched flag after serializing.
|
||||
@ -3188,6 +3219,18 @@ def _default_lock_expiration() -> int:
|
||||
return get_config().redis_lock_expiration
|
||||
|
||||
|
||||
PUBSUB_CLIENTS: Dict[str, PubSub] = {}
|
||||
|
||||
|
||||
async def cached_pubsub(redis: Redis, lock_key_channel: str) -> PubSub:
|
||||
if lock_key_channel in PUBSUB_CLIENTS:
|
||||
return PUBSUB_CLIENTS[lock_key_channel]
|
||||
pubsub = redis.pubsub()
|
||||
await pubsub.psubscribe(lock_key_channel)
|
||||
PUBSUB_CLIENTS[lock_key_channel] = pubsub
|
||||
return pubsub
|
||||
|
||||
|
||||
class StateManagerRedis(StateManager):
|
||||
"""A state manager that stores states in redis."""
|
||||
|
||||
@ -3216,6 +3259,9 @@ class StateManagerRedis(StateManager):
|
||||
b"evicted",
|
||||
}
|
||||
|
||||
# This lock is used to ensure we only subscribe to keyspace events once per token and worker
|
||||
_pubsub_locks: Dict[bytes, asyncio.Lock] = pydantic.PrivateAttr({})
|
||||
|
||||
async def _get_parent_state(
|
||||
self, token: str, state: BaseState | None = None
|
||||
) -> BaseState | None:
|
||||
@ -3400,7 +3446,7 @@ class StateManagerRedis(StateManager):
|
||||
tasks.append(
|
||||
asyncio.create_task(
|
||||
self.set_state(
|
||||
token=_substate_key(client_token, substate),
|
||||
token=_substate_key(client_token, type(substate)),
|
||||
state=substate,
|
||||
lock_id=lock_id,
|
||||
)
|
||||
@ -3411,7 +3457,7 @@ class StateManagerRedis(StateManager):
|
||||
pickle_state = state._serialize()
|
||||
if pickle_state:
|
||||
await self.redis.set(
|
||||
_substate_key(client_token, state),
|
||||
_substate_key(client_token, type(state)),
|
||||
pickle_state,
|
||||
ex=self.token_expiration,
|
||||
)
|
||||
@ -3491,8 +3537,10 @@ class StateManagerRedis(StateManager):
|
||||
# Some redis servers only allow out-of-band configuration, so ignore errors here.
|
||||
if not environment.REFLEX_IGNORE_REDIS_CONFIG_ERROR.get():
|
||||
raise
|
||||
async with self.redis.pubsub() as pubsub:
|
||||
await pubsub.psubscribe(lock_key_channel)
|
||||
if lock_key not in self._pubsub_locks:
|
||||
self._pubsub_locks[lock_key] = asyncio.Lock()
|
||||
async with self._pubsub_locks[lock_key]:
|
||||
pubsub = await cached_pubsub(self.redis, lock_key_channel)
|
||||
while not state_is_locked:
|
||||
# wait for the lock to be released
|
||||
while True:
|
||||
@ -3508,6 +3556,19 @@ class StateManagerRedis(StateManager):
|
||||
break
|
||||
state_is_locked = await self._try_get_lock(lock_key, lock_id)
|
||||
|
||||
@override
|
||||
async def disconnect(self, token: str):
|
||||
"""Disconnect the token from the redis client.
|
||||
|
||||
Args:
|
||||
token: The token to disconnect.
|
||||
"""
|
||||
lock_key = self._lock_key(token)
|
||||
if lock := self._pubsub_locks.get(lock_key):
|
||||
if lock.locked():
|
||||
lock.release()
|
||||
del self._pubsub_locks[lock_key]
|
||||
|
||||
@contextlib.asynccontextmanager
|
||||
async def _lock(self, token: str):
|
||||
"""Obtain a redis lock for a token.
|
||||
|
@ -42,6 +42,11 @@ def BackgroundTask():
|
||||
yield State.increment() # type: ignore
|
||||
await asyncio.sleep(0.005)
|
||||
|
||||
@rx.event(background=True)
|
||||
async def fast_yielding(self):
|
||||
for _ in range(10000):
|
||||
yield State.increment()
|
||||
|
||||
@rx.event
|
||||
def increment(self):
|
||||
self.counter += 1
|
||||
@ -169,6 +174,11 @@ def BackgroundTask():
|
||||
on_click=State.yield_in_async_with_self,
|
||||
id="yield-in-async-with-self",
|
||||
),
|
||||
rx.button(
|
||||
"Fast Yielding",
|
||||
on_click=State.fast_yielding,
|
||||
id="fast-yielding",
|
||||
),
|
||||
rx.button("Reset", on_click=State.reset_counter, id="reset"),
|
||||
)
|
||||
|
||||
@ -375,3 +385,28 @@ def test_yield_in_async_with_self(
|
||||
|
||||
yield_in_async_with_self_button.click()
|
||||
assert background_task._poll_for(lambda: counter.text == "2", timeout=5)
|
||||
|
||||
|
||||
def test_fast_yielding(
|
||||
background_task: AppHarness,
|
||||
driver: WebDriver,
|
||||
token: str,
|
||||
) -> None:
|
||||
"""Test that fast yielding works as expected.
|
||||
|
||||
Args:
|
||||
background_task: harness for BackgroundTask app.
|
||||
driver: WebDriver instance.
|
||||
token: The token for the connected client.
|
||||
"""
|
||||
assert background_task.app_instance is not None
|
||||
|
||||
# get a reference to all buttons
|
||||
fast_yielding_button = driver.find_element(By.ID, "fast-yielding")
|
||||
|
||||
# get a reference to the counter
|
||||
counter = driver.find_element(By.ID, "counter")
|
||||
assert background_task._poll_for(lambda: counter.text == "0", timeout=5)
|
||||
|
||||
fast_yielding_button.click()
|
||||
assert background_task._poll_for(lambda: counter.text == "10000", timeout=50)
|
||||
|
@ -1773,7 +1773,7 @@ def substate_token_redis(state_manager_redis, token):
|
||||
Returns:
|
||||
Token concatenated with the state_manager's state full_name.
|
||||
"""
|
||||
return _substate_key(token, state_manager_redis.state)
|
||||
return _substate_key(token, type(state_manager_redis.state))
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@ -1945,7 +1945,9 @@ async def test_state_proxy(grandchild_state: GrandchildState, mock_app: rx.App):
|
||||
|
||||
# Get the state from the state manager directly and check that the value is updated
|
||||
gotten_state = await mock_app.state_manager.get_state(
|
||||
_substate_key(grandchild_state.router.session.client_token, grandchild_state)
|
||||
_substate_key(
|
||||
grandchild_state.router.session.client_token, type(grandchild_state)
|
||||
)
|
||||
)
|
||||
if isinstance(mock_app.state_manager, (StateManagerMemory, StateManagerDisk)):
|
||||
# For in-process store, only one instance of the state exists
|
||||
|
Loading…
Reference in New Issue
Block a user