wip shared pubsub and cached _substate_key, garbage collect pubsubs tbd

This commit is contained in:
Benedikt Bartscher 2024-11-25 13:54:53 +01:00
parent 7fe33c9bf1
commit 198d02cb9b
No known key found for this signature in database
3 changed files with 42 additions and 19 deletions

View File

@ -39,6 +39,7 @@ from typing import (
get_type_hints, get_type_hints,
) )
from redis.asyncio.client import PubSub
from sqlalchemy.orm import DeclarativeBase from sqlalchemy.orm import DeclarativeBase
from typing_extensions import Self from typing_extensions import Self
@ -135,7 +136,7 @@ HANDLED_PICKLE_ERRORS = (
def _no_chain_background_task( def _no_chain_background_task(
state_cls: Type["BaseState"], name: str, fn: Callable state_cls: Type[BaseState], name: str, fn: Callable
) -> Callable: ) -> Callable:
"""Protect against directly chaining a background task from another event handler. """Protect against directly chaining a background task from another event handler.
@ -172,9 +173,10 @@ def _no_chain_background_task(
raise TypeError(f"{fn} is marked as a background task, but is not async.") raise TypeError(f"{fn} is marked as a background task, but is not async.")
@functools.lru_cache()
def _substate_key( def _substate_key(
token: str, token: str,
state_cls_or_name: BaseState | Type[BaseState] | str | Sequence[str], state_cls_or_name: Type[BaseState] | str | Sequence[str],
) -> str: ) -> str:
"""Get the substate key. """Get the substate key.
@ -185,9 +187,7 @@ def _substate_key(
Returns: Returns:
The substate key. The substate key.
""" """
if isinstance(state_cls_or_name, BaseState) or ( if isinstance(state_cls_or_name, type) and issubclass(state_cls_or_name, BaseState):
isinstance(state_cls_or_name, type) and issubclass(state_cls_or_name, BaseState)
):
state_cls_or_name = state_cls_or_name.get_full_name() state_cls_or_name = state_cls_or_name.get_full_name()
elif isinstance(state_cls_or_name, (list, tuple)): elif isinstance(state_cls_or_name, (list, tuple)):
state_cls_or_name = ".".join(state_cls_or_name) state_cls_or_name = ".".join(state_cls_or_name)
@ -301,7 +301,16 @@ def get_var_for_field(cls: Type[BaseState], f: ModelField):
) )
class BaseState(Base, ABC, extra=pydantic.Extra.allow): class HashableModelMetaclass(type(Base)):
def __hash__(self):
return id(self)
# return hash(f"{self.__module__}.{self.__name__}")
# return hash(self.get_full_name())
class BaseState(
Base, ABC, extra=pydantic.Extra.allow, metaclass=HashableModelMetaclass
):
"""The state of the app.""" """The state of the app."""
# A map from the var name to the var. # A map from the var name to the var.
@ -3066,17 +3075,17 @@ class StateManagerDisk(StateManager):
state: The state object to populate. state: The state object to populate.
root_state: The root state object. root_state: The root state object.
""" """
for substate in state.get_substates(): for substate_cls in state.get_substates():
substate_token = _substate_key(client_token, substate) substate_token = _substate_key(client_token, substate_cls)
fresh_instance = await root_state.get_state(substate) fresh_instance = await root_state.get_state(substate_cls)
instance = await self.load_state(substate_token) instance = await self.load_state(substate_token)
if instance is not None: if instance is not None:
# Ensure all substates exist, even if they weren't serialized previously. # Ensure all substates exist, even if they weren't serialized previously.
instance.substates = fresh_instance.substates instance.substates = fresh_instance.substates
else: else:
instance = fresh_instance instance = fresh_instance
state.substates[substate.get_name()] = instance state.substates[substate_cls.get_name()] = instance
instance.parent_state = state instance.parent_state = state
await self.populate_substates(client_token, instance, root_state) await self.populate_substates(client_token, instance, root_state)
@ -3120,7 +3129,7 @@ class StateManagerDisk(StateManager):
client_token: The client token. client_token: The client token.
substate: The substate to set. substate: The substate to set.
""" """
substate_token = _substate_key(client_token, substate) substate_token = _substate_key(client_token, type(substate))
if substate._get_was_touched(): if substate._get_was_touched():
substate._was_touched = False # Reset the touched flag after serializing. substate._was_touched = False # Reset the touched flag after serializing.
@ -3177,6 +3186,18 @@ def _default_lock_expiration() -> int:
return get_config().redis_lock_expiration return get_config().redis_lock_expiration
PUBSUB_CLIENTS: Dict[str, PubSub] = {}
async def cached_pubsub(redis: Redis, lock_key_channel: str) -> PubSub:
if lock_key_channel in PUBSUB_CLIENTS:
return PUBSUB_CLIENTS[lock_key_channel]
pubsub = redis.pubsub()
await pubsub.psubscribe(lock_key_channel)
PUBSUB_CLIENTS[lock_key_channel] = pubsub
return pubsub
class StateManagerRedis(StateManager): class StateManagerRedis(StateManager):
"""A state manager that stores states in redis.""" """A state manager that stores states in redis."""
@ -3392,7 +3413,7 @@ class StateManagerRedis(StateManager):
tasks.append( tasks.append(
asyncio.create_task( asyncio.create_task(
self.set_state( self.set_state(
token=_substate_key(client_token, substate), token=_substate_key(client_token, type(substate)),
state=substate, state=substate,
lock_id=lock_id, lock_id=lock_id,
) )
@ -3403,7 +3424,7 @@ class StateManagerRedis(StateManager):
pickle_state = state._serialize() pickle_state = state._serialize()
if pickle_state: if pickle_state:
await self.redis.set( await self.redis.set(
_substate_key(client_token, state), _substate_key(client_token, type(state)),
pickle_state, pickle_state,
ex=self.token_expiration, ex=self.token_expiration,
) )
@ -3485,8 +3506,8 @@ class StateManagerRedis(StateManager):
raise raise
if lock_key not in self._pubsub_locks: if lock_key not in self._pubsub_locks:
self._pubsub_locks[lock_key] = asyncio.Lock() self._pubsub_locks[lock_key] = asyncio.Lock()
async with self._pubsub_locks[lock_key], self.redis.pubsub() as pubsub: async with self._pubsub_locks[lock_key]:
await pubsub.psubscribe(lock_key_channel) pubsub = await cached_pubsub(self.redis, lock_key_channel)
while not state_is_locked: while not state_is_locked:
# wait for the lock to be released # wait for the lock to be released
while True: while True:

View File

@ -44,7 +44,7 @@ def BackgroundTask():
@rx.event(background=True) @rx.event(background=True)
async def fast_yielding(self): async def fast_yielding(self):
for _ in range(1000): for _ in range(10000):
yield State.increment() yield State.increment()
@rx.event @rx.event
@ -409,4 +409,4 @@ def test_fast_yielding(
assert background_task._poll_for(lambda: counter.text == "0", timeout=5) assert background_task._poll_for(lambda: counter.text == "0", timeout=5)
fast_yielding_button.click() fast_yielding_button.click()
assert background_task._poll_for(lambda: counter.text == "1000", timeout=50) assert background_task._poll_for(lambda: counter.text == "10000", timeout=50)

View File

@ -1759,7 +1759,7 @@ def substate_token_redis(state_manager_redis, token):
Returns: Returns:
Token concatenated with the state_manager's state full_name. Token concatenated with the state_manager's state full_name.
""" """
return _substate_key(token, state_manager_redis.state) return _substate_key(token, type(state_manager_redis.state))
@pytest.mark.asyncio @pytest.mark.asyncio
@ -1918,7 +1918,9 @@ async def test_state_proxy(grandchild_state: GrandchildState, mock_app: rx.App):
# Get the state from the state manager directly and check that the value is updated # Get the state from the state manager directly and check that the value is updated
gotten_state = await mock_app.state_manager.get_state( gotten_state = await mock_app.state_manager.get_state(
_substate_key(grandchild_state.router.session.client_token, grandchild_state) _substate_key(
grandchild_state.router.session.client_token, type(grandchild_state)
)
) )
if isinstance(mock_app.state_manager, (StateManagerMemory, StateManagerDisk)): if isinstance(mock_app.state_manager, (StateManagerMemory, StateManagerDisk)):
# For in-process store, only one instance of the state exists # For in-process store, only one instance of the state exists