diff --git a/reflex/state.py b/reflex/state.py index 7a7d7f43e..21acee67b 100644 --- a/reflex/state.py +++ b/reflex/state.py @@ -39,6 +39,7 @@ from typing import ( get_type_hints, ) +from redis.asyncio.client import PubSub from sqlalchemy.orm import DeclarativeBase from typing_extensions import Self @@ -135,7 +136,7 @@ HANDLED_PICKLE_ERRORS = ( def _no_chain_background_task( - state_cls: Type["BaseState"], name: str, fn: Callable + state_cls: Type[BaseState], name: str, fn: Callable ) -> Callable: """Protect against directly chaining a background task from another event handler. @@ -172,9 +173,10 @@ def _no_chain_background_task( raise TypeError(f"{fn} is marked as a background task, but is not async.") +@functools.lru_cache() def _substate_key( token: str, - state_cls_or_name: BaseState | Type[BaseState] | str | Sequence[str], + state_cls_or_name: Type[BaseState] | str | Sequence[str], ) -> str: """Get the substate key. @@ -185,9 +187,7 @@ def _substate_key( Returns: The substate key. """ - if isinstance(state_cls_or_name, BaseState) or ( - isinstance(state_cls_or_name, type) and issubclass(state_cls_or_name, BaseState) - ): + if isinstance(state_cls_or_name, type) and issubclass(state_cls_or_name, BaseState): state_cls_or_name = state_cls_or_name.get_full_name() elif isinstance(state_cls_or_name, (list, tuple)): state_cls_or_name = ".".join(state_cls_or_name) @@ -301,7 +301,16 @@ def get_var_for_field(cls: Type[BaseState], f: ModelField): ) -class BaseState(Base, ABC, extra=pydantic.Extra.allow): +class HashableModelMetaclass(type(Base)): + def __hash__(self): + return id(self) + # return hash(f"{self.__module__}.{self.__name__}") + # return hash(self.get_full_name()) + + +class BaseState( + Base, ABC, extra=pydantic.Extra.allow, metaclass=HashableModelMetaclass +): """The state of the app.""" # A map from the var name to the var. @@ -3066,17 +3075,17 @@ class StateManagerDisk(StateManager): state: The state object to populate. root_state: The root state object. """ - for substate in state.get_substates(): - substate_token = _substate_key(client_token, substate) + for substate_cls in state.get_substates(): + substate_token = _substate_key(client_token, substate_cls) - fresh_instance = await root_state.get_state(substate) + fresh_instance = await root_state.get_state(substate_cls) instance = await self.load_state(substate_token) if instance is not None: # Ensure all substates exist, even if they weren't serialized previously. instance.substates = fresh_instance.substates else: instance = fresh_instance - state.substates[substate.get_name()] = instance + state.substates[substate_cls.get_name()] = instance instance.parent_state = state await self.populate_substates(client_token, instance, root_state) @@ -3120,7 +3129,7 @@ class StateManagerDisk(StateManager): client_token: The client token. substate: The substate to set. """ - substate_token = _substate_key(client_token, substate) + substate_token = _substate_key(client_token, type(substate)) if substate._get_was_touched(): substate._was_touched = False # Reset the touched flag after serializing. @@ -3177,6 +3186,18 @@ def _default_lock_expiration() -> int: return get_config().redis_lock_expiration +PUBSUB_CLIENTS: Dict[str, PubSub] = {} + + +async def cached_pubsub(redis: Redis, lock_key_channel: str) -> PubSub: + if lock_key_channel in PUBSUB_CLIENTS: + return PUBSUB_CLIENTS[lock_key_channel] + pubsub = redis.pubsub() + await pubsub.psubscribe(lock_key_channel) + PUBSUB_CLIENTS[lock_key_channel] = pubsub + return pubsub + + class StateManagerRedis(StateManager): """A state manager that stores states in redis.""" @@ -3392,7 +3413,7 @@ class StateManagerRedis(StateManager): tasks.append( asyncio.create_task( self.set_state( - token=_substate_key(client_token, substate), + token=_substate_key(client_token, type(substate)), state=substate, lock_id=lock_id, ) @@ -3403,7 +3424,7 @@ class StateManagerRedis(StateManager): pickle_state = state._serialize() if pickle_state: await self.redis.set( - _substate_key(client_token, state), + _substate_key(client_token, type(state)), pickle_state, ex=self.token_expiration, ) @@ -3485,8 +3506,8 @@ class StateManagerRedis(StateManager): raise if lock_key not in self._pubsub_locks: self._pubsub_locks[lock_key] = asyncio.Lock() - async with self._pubsub_locks[lock_key], self.redis.pubsub() as pubsub: - await pubsub.psubscribe(lock_key_channel) + async with self._pubsub_locks[lock_key]: + pubsub = await cached_pubsub(self.redis, lock_key_channel) while not state_is_locked: # wait for the lock to be released while True: diff --git a/tests/integration/test_background_task.py b/tests/integration/test_background_task.py index 70e2202a6..fa8b55d03 100644 --- a/tests/integration/test_background_task.py +++ b/tests/integration/test_background_task.py @@ -44,7 +44,7 @@ def BackgroundTask(): @rx.event(background=True) async def fast_yielding(self): - for _ in range(1000): + for _ in range(10000): yield State.increment() @rx.event @@ -409,4 +409,4 @@ def test_fast_yielding( assert background_task._poll_for(lambda: counter.text == "0", timeout=5) fast_yielding_button.click() - assert background_task._poll_for(lambda: counter.text == "1000", timeout=50) + assert background_task._poll_for(lambda: counter.text == "10000", timeout=50) diff --git a/tests/units/test_state.py b/tests/units/test_state.py index c8a52e6c0..2c18699ba 100644 --- a/tests/units/test_state.py +++ b/tests/units/test_state.py @@ -1759,7 +1759,7 @@ def substate_token_redis(state_manager_redis, token): Returns: Token concatenated with the state_manager's state full_name. """ - return _substate_key(token, state_manager_redis.state) + return _substate_key(token, type(state_manager_redis.state)) @pytest.mark.asyncio @@ -1918,7 +1918,9 @@ async def test_state_proxy(grandchild_state: GrandchildState, mock_app: rx.App): # Get the state from the state manager directly and check that the value is updated gotten_state = await mock_app.state_manager.get_state( - _substate_key(grandchild_state.router.session.client_token, grandchild_state) + _substate_key( + grandchild_state.router.session.client_token, type(grandchild_state) + ) ) if isinstance(mock_app.state_manager, (StateManagerMemory, StateManagerDisk)): # For in-process store, only one instance of the state exists