feat: PoC for shared states.
This commit is contained in:
parent
93231f8168
commit
3bcdc12c26
@ -1460,16 +1460,19 @@ class EventNamespace(AsyncNamespace):
|
||||
if disconnect_token:
|
||||
self.token_to_sid.pop(disconnect_token, None)
|
||||
|
||||
async def emit_update(self, update: StateUpdate, sid: str) -> None:
|
||||
async def emit_update(
|
||||
self, update: StateUpdate, sid: str, room: str | None = None
|
||||
) -> None:
|
||||
"""Emit an update to the client.
|
||||
|
||||
Args:
|
||||
update: The state update to send.
|
||||
sid: The Socket.IO session id.
|
||||
room: The room to send the update to.
|
||||
"""
|
||||
# Creating a task prevents the update from being blocked behind other coroutines.
|
||||
await asyncio.create_task(
|
||||
self.emit(str(constants.SocketEvent.EVENT), update.json(), to=sid)
|
||||
self.emit(str(constants.SocketEvent.EVENT), update.json(), to=room or sid)
|
||||
)
|
||||
|
||||
async def on_event(self, sid, data):
|
||||
|
@ -442,11 +442,14 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def __init_subclass__(cls, mixin: bool = False, **kwargs):
|
||||
def __init_subclass__(
|
||||
cls, mixin: bool = False, scope: Union[Var, str, None] = None, **kwargs
|
||||
):
|
||||
"""Do some magic for the subclass initialization.
|
||||
|
||||
Args:
|
||||
mixin: Whether the subclass is a mixin and should not be initialized.
|
||||
scope: A var or string to set the scope of the state. The state will be shared across all states with the same scope value.
|
||||
**kwargs: The kwargs to pass to the pydantic init_subclass method.
|
||||
|
||||
Raises:
|
||||
@ -460,6 +463,9 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
|
||||
if mixin:
|
||||
return
|
||||
|
||||
# Set the scope of the state.
|
||||
cls._scope = scope
|
||||
|
||||
# Validate the module name.
|
||||
cls._validate_module_name()
|
||||
|
||||
@ -1270,9 +1276,7 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
|
||||
# The requested state is missing, fetch from redis.
|
||||
pass
|
||||
parent_state = await state_manager.get_state(
|
||||
token=_substate_key(
|
||||
self.router.session.client_token, parent_state_name
|
||||
),
|
||||
token=_substate_key(self._get_token(), parent_state_name),
|
||||
top_level=False,
|
||||
get_substates=False,
|
||||
parent_state=parent_state,
|
||||
@ -1318,8 +1322,9 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
|
||||
f"Requested state {state_cls.get_full_name()} is not cached and cannot be accessed without redis. "
|
||||
"(All states should already be available -- this is likely a bug).",
|
||||
)
|
||||
|
||||
return await state_manager.get_state(
|
||||
token=_substate_key(self.router.session.client_token, state_cls),
|
||||
token=_substate_key(self._get_token(), state_cls),
|
||||
top_level=False,
|
||||
get_substates=True,
|
||||
parent_state=parent_state_of_state_cls,
|
||||
@ -1425,6 +1430,19 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
|
||||
f"Your handler {handler.fn.__qualname__} must only return/yield: None, Events or other EventHandlers referenced by their class (not using `self`)"
|
||||
)
|
||||
|
||||
def _get_token(self) -> str:
|
||||
token = self.router.session.client_token
|
||||
if self.__class__._scope is not None:
|
||||
scope = None
|
||||
if isinstance(self.__class__._scope, str):
|
||||
scope = self.__class__._scope
|
||||
else:
|
||||
scope = getattr(self, self.__class__._scope._var_name)
|
||||
|
||||
token = scope
|
||||
|
||||
return token
|
||||
|
||||
def _as_state_update(
|
||||
self,
|
||||
handler: EventHandler,
|
||||
@ -1448,7 +1466,7 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
|
||||
while state.parent_state is not None:
|
||||
state = state.parent_state
|
||||
|
||||
token = self.router.session.client_token
|
||||
token = self._get_token()
|
||||
|
||||
# Convert valid EventHandler and EventSpec into Event
|
||||
fixed_events = fix_events(self._check_valid(handler, events), token)
|
||||
@ -1909,7 +1927,7 @@ class OnLoadInternalState(State):
|
||||
return [
|
||||
*fix_events(
|
||||
load_events,
|
||||
self.router.session.client_token,
|
||||
self._get_token(),
|
||||
router_data=self.router_data,
|
||||
),
|
||||
State.set_is_hydrated(True), # type: ignore
|
||||
@ -2328,10 +2346,15 @@ class StateManagerMemory(StateManager):
|
||||
Returns:
|
||||
The state for the token.
|
||||
"""
|
||||
# Memory state manager ignores the substate suffix and always returns the top-level state.
|
||||
token = _split_substate_key(token)[0]
|
||||
if token not in self.states:
|
||||
self.states[token] = self.state(_reflex_internal_init=True)
|
||||
|
||||
# TODO This is a bit madness maybe.
|
||||
token = self.states[token]._get_token()
|
||||
|
||||
if token not in self.states:
|
||||
self.states[token] = self.state(_reflex_internal_init=True)
|
||||
|
||||
return self.states[token]
|
||||
|
||||
async def set_state(self, token: str, state: BaseState):
|
||||
@ -2353,8 +2376,6 @@ class StateManagerMemory(StateManager):
|
||||
Yields:
|
||||
The state for the token.
|
||||
"""
|
||||
# Memory state manager ignores the substate suffix and always returns the top-level state.
|
||||
token = _split_substate_key(token)[0]
|
||||
if token not in self._states_locks:
|
||||
async with self._state_manager_lock:
|
||||
if token not in self._states_locks:
|
||||
|
@ -383,6 +383,8 @@ class Var:
|
||||
# Extra metadata associated with the Var
|
||||
_var_data: Optional[VarData]
|
||||
|
||||
_var_is_scope: bool = False
|
||||
|
||||
@classmethod
|
||||
def create(
|
||||
cls,
|
||||
@ -390,6 +392,7 @@ class Var:
|
||||
_var_is_local: bool = True,
|
||||
_var_is_string: bool | None = None,
|
||||
_var_data: Optional[VarData] = None,
|
||||
_var_is_scope: bool = False,
|
||||
) -> Var | None:
|
||||
"""Create a var from a value.
|
||||
|
||||
@ -455,6 +458,7 @@ class Var:
|
||||
_var_is_local=_var_is_local,
|
||||
_var_is_string=_var_is_string if _var_is_string is not None else False,
|
||||
_var_data=_var_data,
|
||||
_var_is_scope=_var_is_scope,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
@ -1866,6 +1870,8 @@ class BaseVar(Var):
|
||||
# Extra metadata associated with the Var
|
||||
_var_data: Optional[VarData] = dataclasses.field(default=None)
|
||||
|
||||
_var_is_scope: bool = dataclasses.field(default=False)
|
||||
|
||||
def __hash__(self) -> int:
|
||||
"""Define a hash function for a var.
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user