feat: PoC for shared states.

This commit is contained in:
Andreas Eismann 2024-07-16 22:51:11 +02:00
parent 93231f8168
commit 3bcdc12c26
No known key found for this signature in database
3 changed files with 43 additions and 13 deletions

View File

@ -1460,16 +1460,19 @@ class EventNamespace(AsyncNamespace):
if disconnect_token:
self.token_to_sid.pop(disconnect_token, None)
async def emit_update(self, update: StateUpdate, sid: str) -> None:
async def emit_update(
self, update: StateUpdate, sid: str, room: str | None = None
) -> None:
"""Emit an update to the client.
Args:
update: The state update to send.
sid: The Socket.IO session id.
room: The room to send the update to.
"""
# Creating a task prevents the update from being blocked behind other coroutines.
await asyncio.create_task(
self.emit(str(constants.SocketEvent.EVENT), update.json(), to=sid)
self.emit(str(constants.SocketEvent.EVENT), update.json(), to=room or sid)
)
async def on_event(self, sid, data):

View File

@ -442,11 +442,14 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
)
@classmethod
def __init_subclass__(cls, mixin: bool = False, **kwargs):
def __init_subclass__(
cls, mixin: bool = False, scope: Union[Var, str, None] = None, **kwargs
):
"""Do some magic for the subclass initialization.
Args:
mixin: Whether the subclass is a mixin and should not be initialized.
scope: A var or string to set the scope of the state. The state will be shared across all states with the same scope value.
**kwargs: The kwargs to pass to the pydantic init_subclass method.
Raises:
@ -460,6 +463,9 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
if mixin:
return
# Set the scope of the state.
cls._scope = scope
# Validate the module name.
cls._validate_module_name()
@ -1270,9 +1276,7 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
# The requested state is missing, fetch from redis.
pass
parent_state = await state_manager.get_state(
token=_substate_key(
self.router.session.client_token, parent_state_name
),
token=_substate_key(self._get_token(), parent_state_name),
top_level=False,
get_substates=False,
parent_state=parent_state,
@ -1318,8 +1322,9 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
f"Requested state {state_cls.get_full_name()} is not cached and cannot be accessed without redis. "
"(All states should already be available -- this is likely a bug).",
)
return await state_manager.get_state(
token=_substate_key(self.router.session.client_token, state_cls),
token=_substate_key(self._get_token(), state_cls),
top_level=False,
get_substates=True,
parent_state=parent_state_of_state_cls,
@ -1425,6 +1430,19 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
f"Your handler {handler.fn.__qualname__} must only return/yield: None, Events or other EventHandlers referenced by their class (not using `self`)"
)
def _get_token(self) -> str:
token = self.router.session.client_token
if self.__class__._scope is not None:
scope = None
if isinstance(self.__class__._scope, str):
scope = self.__class__._scope
else:
scope = getattr(self, self.__class__._scope._var_name)
token = scope
return token
def _as_state_update(
self,
handler: EventHandler,
@ -1448,7 +1466,7 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
while state.parent_state is not None:
state = state.parent_state
token = self.router.session.client_token
token = self._get_token()
# Convert valid EventHandler and EventSpec into Event
fixed_events = fix_events(self._check_valid(handler, events), token)
@ -1909,7 +1927,7 @@ class OnLoadInternalState(State):
return [
*fix_events(
load_events,
self.router.session.client_token,
self._get_token(),
router_data=self.router_data,
),
State.set_is_hydrated(True), # type: ignore
@ -2328,10 +2346,15 @@ class StateManagerMemory(StateManager):
Returns:
The state for the token.
"""
# Memory state manager ignores the substate suffix and always returns the top-level state.
token = _split_substate_key(token)[0]
if token not in self.states:
self.states[token] = self.state(_reflex_internal_init=True)
# TODO This is a bit madness maybe.
token = self.states[token]._get_token()
if token not in self.states:
self.states[token] = self.state(_reflex_internal_init=True)
return self.states[token]
async def set_state(self, token: str, state: BaseState):
@ -2353,8 +2376,6 @@ class StateManagerMemory(StateManager):
Yields:
The state for the token.
"""
# Memory state manager ignores the substate suffix and always returns the top-level state.
token = _split_substate_key(token)[0]
if token not in self._states_locks:
async with self._state_manager_lock:
if token not in self._states_locks:

View File

@ -383,6 +383,8 @@ class Var:
# Extra metadata associated with the Var
_var_data: Optional[VarData]
_var_is_scope: bool = False
@classmethod
def create(
cls,
@ -390,6 +392,7 @@ class Var:
_var_is_local: bool = True,
_var_is_string: bool | None = None,
_var_data: Optional[VarData] = None,
_var_is_scope: bool = False,
) -> Var | None:
"""Create a var from a value.
@ -455,6 +458,7 @@ class Var:
_var_is_local=_var_is_local,
_var_is_string=_var_is_string if _var_is_string is not None else False,
_var_data=_var_data,
_var_is_scope=_var_is_scope,
)
@classmethod
@ -1866,6 +1870,8 @@ class BaseVar(Var):
# Extra metadata associated with the Var
_var_data: Optional[VarData] = dataclasses.field(default=None)
_var_is_scope: bool = dataclasses.field(default=False)
def __hash__(self) -> int:
"""Define a hash function for a var.