This commit is contained in:
abulvenz 2025-01-11 08:43:41 -08:00 committed by GitHub
commit acf0f06fa7
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 168 additions and 27 deletions

View File

@ -1108,7 +1108,9 @@ class App(MiddlewareMixin, LifespanMixin):
# When the state is modified reset dirty status and emit the delta to the frontend.
state._clean()
await self.event_namespace.emit_update(
update=StateUpdate(delta=delta),
update=StateUpdate(
delta=delta, scopes=state.scopes_and_subscopes()
),
sid=state.router.session.session_id,
)
@ -1273,16 +1275,23 @@ async def process(
from reflex.utils import telemetry
try:
# Add request data to the state.
router_data = event.router_data
router_data.update(
{
constants.RouteVar.QUERY: format.format_query_params(event.router_data),
constants.RouteVar.CLIENT_TOKEN: event.token,
constants.RouteVar.SESSION_ID: sid,
constants.RouteVar.HEADERS: headers,
constants.RouteVar.CLIENT_IP: client_ip,
}
router_data = {}
if event.router_data:
# Add request data to the state.
router_data = event.router_data
router_data.update(
{
constants.RouteVar.QUERY: format.format_query_params(
event.router_data
),
constants.RouteVar.CLIENT_TOKEN: event.token,
constants.RouteVar.SESSION_ID: sid,
constants.RouteVar.HEADERS: headers,
constants.RouteVar.CLIENT_IP: client_ip,
}
)
print(
f"Processing event: {event.name} with payload: {event.payload} {event.substate_token}"
)
# Get the state for the session exclusively.
async with app.state_manager.modify_state(event.substate_token) as state:
@ -1319,7 +1328,7 @@ async def process(
else:
if app._process_background(state, event) is not None:
# `final=True` allows the frontend send more events immediately.
yield StateUpdate(final=True)
yield StateUpdate(final=True, scopes=state.scopes_and_subscopes())
return
# Process the event synchronously.
@ -1542,13 +1551,74 @@ class EventNamespace(AsyncNamespace):
if disconnect_token:
self.token_to_sid.pop(disconnect_token, None)
async def emit_update(self, update: StateUpdate, sid: str) -> None:
async def emit_update(
self, update: StateUpdate, sid: str, room: str | None = None
) -> None:
"""Emit an update to the client.
Args:
update: The state update to send.
sid: The Socket.IO session id.
room: The room to send the update to.
"""
# TODO We don't know when to leave a room yet.
for receiver in update.scopes:
if self.sid_to_token[sid] != receiver:
room = receiver
if room not in self.rooms(sid):
print(f"Entering room `{room}`")
await self.enter_room(sid, room)
for room in self.rooms(sid):
if room not in update.scopes and room != sid:
print(f"Leaving room `{room}`")
await self.leave_room(sid, room)
delta_by_scope = {}
events_by_scope = {}
for event in update.events:
scope = self.token_to_sid.get(event.token, event.token)
events = events_by_scope.get(scope, [])
events.append(event)
events_by_scope[scope] = events
for state, delta in update.delta.items():
key = delta.get("_scope", sid)
d = delta_by_scope.get(key, {})
d.update({state: delta})
delta_by_scope[key] = d
for scope, deltas in delta_by_scope.items():
events = events_by_scope.get(scope, [])
print(f"Sending update to {scope} {events}")
single_update = StateUpdate(
delta=deltas, scopes=[scope], events=events, final=update.final
)
await asyncio.create_task(
self.emit(
str(constants.SocketEvent.EVENT), single_update.json(), to=scope
)
)
for key in events_by_scope:
if key not in delta_by_scope:
single_update = StateUpdate(
delta={},
scopes=[key],
events=events_by_scope.get(key, []),
final=update.final,
)
print(f"Sending event to {key}")
await asyncio.create_task(
self.emit(
str(constants.SocketEvent.EVENT), single_update.json(), to=key
)
)
update.scopes = []
# Creating a task prevents the update from being blocked behind other coroutines.
await asyncio.create_task(
self.emit(str(constants.SocketEvent.EVENT), update, to=sid)
@ -1590,6 +1660,8 @@ class EventNamespace(AsyncNamespace):
except (KeyError, IndexError):
client_ip = environ.get("REMOTE_ADDR", "0.0.0.0")
print(f"Received event {event.name} {event.token} from {client_ip}")
# Process the events.
async for update in process(self.app, event, sid, headers, client_ip):
# Emit the update from processing the event.

View File

@ -47,4 +47,4 @@ class HydrateMiddleware(Middleware):
state._clean()
# Return the state update.
return StateUpdate(delta=delta, events=[])
return StateUpdate(delta=delta, events=[], scopes=state.scopes_and_subscopes())

View File

@ -8,6 +8,8 @@ import copy
import dataclasses
import functools
import inspect
import os
import traceback
import json
import pickle
import sys
@ -135,6 +137,23 @@ if environment.REFLEX_PERF_MODE.get() != PerformanceMode.OFF:
# Only warn about each state class size once.
_WARNED_ABOUT_STATE_SIZE: Set[str] = set()
def print_stack(depth: int = 3):
"""Print the current stacktrace to the console.
Args:
depth: Depth of the stack-trace to print
"""
stack = traceback.extract_stack()
stack.reverse()
print("stacktrace")
for idx in range(1, depth + 1):
stack_info = stack[idx]
print(
f" {stack_info.name} {os.path.basename(stack_info.filename)}:{stack_info.lineno}"
)
# Errors caught during pickling of state
HANDLED_PICKLE_ERRORS = (
pickle.PicklingError,
@ -481,11 +500,14 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
)
@classmethod
def __init_subclass__(cls, mixin: bool = False, **kwargs):
def __init_subclass__(
cls, mixin: bool = False, scope: Union[Var, str, None] = None, **kwargs
):
"""Do some magic for the subclass initialization.
Args:
mixin: Whether the subclass is a mixin and should not be initialized.
scope: A var or string to set the scope of the state. The state will be shared across all states with the same scope value.
**kwargs: The kwargs to pass to the pydantic init_subclass method.
Raises:
@ -499,6 +521,8 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
if mixin:
return
# Set the scope of the state.
cls._scope = scope
# Handle locally-defined states for pickling.
if "<locals>" in cls.__qualname__:
cls._handle_local_def()
@ -795,7 +819,7 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
}
# Any substate containing a ComputedVar with cache=False always needs to be recomputed
if cls._always_dirty_computed_vars:
if cls._always_dirty_computed_vars: # or cls._scope is not None:
# Tell parent classes that this substate has always dirty computed vars
state_name = cls.get_name()
parent_state = cls.get_parent_state()
@ -1532,9 +1556,7 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
# The requested state is missing, fetch from redis.
pass
parent_state = await state_manager.get_state(
token=_substate_key(
self.router.session.client_token, parent_state_name
),
token=_substate_key(self._get_token(), parent_state_name),
top_level=False,
get_substates=False,
parent_state=parent_state,
@ -1577,8 +1599,9 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
f"Requested state {state_cls.get_full_name()} is not cached and cannot be accessed without redis. "
"(All states should already be available -- this is likely a bug).",
)
return await state_manager.get_state(
token=_substate_key(self.router.session.client_token, state_cls),
token=_substate_key(self._get_token(), state_cls),
top_level=False,
get_substates=True,
parent_state=parent_state_of_state_cls,
@ -1720,6 +1743,34 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
f"Your handler {handler.fn.__qualname__} must only return/yield: None, Events or other EventHandlers referenced by their class (not using `self`)"
)
def scopes_and_subscopes(self) -> list[str]:
"""Recursively gathers all scopes of self and substates.
Returns:
A unique list of the scopes/token
"""
result = [self._get_token()]
for substate in self.substates.values():
subscopes = substate.scopes_and_subscopes()
for subscope in subscopes:
if subscope not in result:
result.append(subscope)
return result
def _get_token(self, other: type[BaseState] | None = None) -> str:
token = self.router.session.client_token
cls = other or self.__class__
if cls._scope is not None:
scope = None
if isinstance(cls._scope, str):
scope = f"static{cls._scope}"
else:
scope = f"shared{getattr(self, cls._scope._var_name)}"
token = scope
return token
def _as_state_update(
self,
handler: EventHandler,
@ -1741,7 +1792,7 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
# get the delta from the root of the state tree
state = self._get_root_state()
token = self.router.session.client_token
token = self._get_token()
# Convert valid EventHandler and EventSpec into Event
fixed_events = fix_events(self._check_valid(handler, events), token)
@ -1755,6 +1806,7 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
delta=delta,
events=fixed_events,
final=final if not handler.is_background else True,
scopes=state.scopes_and_subscopes(),
)
except Exception as ex:
state._clean()
@ -1970,6 +2022,9 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
.union(self._always_dirty_computed_vars)
)
if len(self.scopes_and_subscopes()) > 1 and "router" in delta_vars:
delta_vars.remove("router")
subdelta: Dict[str, Any] = {
prop: self.get_value(prop)
for prop in delta_vars
@ -1979,6 +2034,9 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
if len(subdelta) > 0:
delta[self.get_full_name()] = subdelta
if self.__class__._scope is not None:
subdelta["_scope"] = self._get_token()
# Recursively find the substate deltas.
substates = self.substates
for substate in self.dirty_substates.union(self._always_dirty_substates):
@ -2114,7 +2172,7 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
}
else:
computed_vars = {}
variables = {**base_vars, **computed_vars}
variables = {"_scope": self._get_token(), **base_vars, **computed_vars}
d = {
self.get_full_name(): {k: variables[k] for k in sorted(variables)},
}
@ -2431,7 +2489,7 @@ class OnLoadInternalState(State):
return [
*fix_events(
load_events,
self.router.session.client_token,
self._get_token(),
router_data=self.router_data,
),
State.set_is_hydrated(True), # type: ignore
@ -2833,6 +2891,8 @@ class StateUpdate:
# Whether this is the final state update for the event.
final: bool = True
scopes: list[str] = []
def json(self) -> str:
"""Convert the state update to a JSON string.
@ -2948,10 +3008,15 @@ class StateManagerMemory(StateManager):
Returns:
The state for the token.
"""
# Memory state manager ignores the substate suffix and always returns the top-level state.
token = _split_substate_key(token)[0]
if token not in self.states:
self.states[token] = self.state(_reflex_internal_init=True)
# TODO This is a bit madness maybe.
token = self.states[token]._get_token()
if token not in self.states:
self.states[token] = self.state(_reflex_internal_init=True)
return self.states[token]
@override
@ -2975,8 +3040,6 @@ class StateManagerMemory(StateManager):
Yields:
The state for the token.
"""
# Memory state manager ignores the substate suffix and always returns the top-level state.
token = _split_substate_key(token)[0]
if token not in self._states_locks:
async with self._state_manager_lock:
if token not in self._states_locks:
@ -3394,6 +3457,10 @@ class StateManagerRedis(StateManager):
f"StateManagerRedis requires token to be specified in the form of {{token}}_{{state_full_name}}, but got {token}"
)
if parent_state is None:
parent_state = await self._get_parent_state(token)
if parent_state is not None:
token = f"{parent_state._get_token(state_cls)}_{state_path}"
# The deserialized or newly created (sub)state instance.
state = None
@ -3472,6 +3539,8 @@ class StateManagerRedis(StateManager):
)
client_token, substate_name = _split_substate_key(token)
client_token = state._get_token()
# If the substate name on the token doesn't match the instance name, it cannot have a parent.
if state.parent_state is not None and state.get_full_name() != substate_name:
raise RuntimeError(