Merge 5ca563c117
into 1e7a37bcf9
This commit is contained in:
commit
acf0f06fa7
@ -1108,7 +1108,9 @@ class App(MiddlewareMixin, LifespanMixin):
|
|||||||
# When the state is modified reset dirty status and emit the delta to the frontend.
|
# When the state is modified reset dirty status and emit the delta to the frontend.
|
||||||
state._clean()
|
state._clean()
|
||||||
await self.event_namespace.emit_update(
|
await self.event_namespace.emit_update(
|
||||||
update=StateUpdate(delta=delta),
|
update=StateUpdate(
|
||||||
|
delta=delta, scopes=state.scopes_and_subscopes()
|
||||||
|
),
|
||||||
sid=state.router.session.session_id,
|
sid=state.router.session.session_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -1273,16 +1275,23 @@ async def process(
|
|||||||
from reflex.utils import telemetry
|
from reflex.utils import telemetry
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# Add request data to the state.
|
router_data = {}
|
||||||
router_data = event.router_data
|
if event.router_data:
|
||||||
router_data.update(
|
# Add request data to the state.
|
||||||
{
|
router_data = event.router_data
|
||||||
constants.RouteVar.QUERY: format.format_query_params(event.router_data),
|
router_data.update(
|
||||||
constants.RouteVar.CLIENT_TOKEN: event.token,
|
{
|
||||||
constants.RouteVar.SESSION_ID: sid,
|
constants.RouteVar.QUERY: format.format_query_params(
|
||||||
constants.RouteVar.HEADERS: headers,
|
event.router_data
|
||||||
constants.RouteVar.CLIENT_IP: client_ip,
|
),
|
||||||
}
|
constants.RouteVar.CLIENT_TOKEN: event.token,
|
||||||
|
constants.RouteVar.SESSION_ID: sid,
|
||||||
|
constants.RouteVar.HEADERS: headers,
|
||||||
|
constants.RouteVar.CLIENT_IP: client_ip,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
print(
|
||||||
|
f"Processing event: {event.name} with payload: {event.payload} {event.substate_token}"
|
||||||
)
|
)
|
||||||
# Get the state for the session exclusively.
|
# Get the state for the session exclusively.
|
||||||
async with app.state_manager.modify_state(event.substate_token) as state:
|
async with app.state_manager.modify_state(event.substate_token) as state:
|
||||||
@ -1319,7 +1328,7 @@ async def process(
|
|||||||
else:
|
else:
|
||||||
if app._process_background(state, event) is not None:
|
if app._process_background(state, event) is not None:
|
||||||
# `final=True` allows the frontend send more events immediately.
|
# `final=True` allows the frontend send more events immediately.
|
||||||
yield StateUpdate(final=True)
|
yield StateUpdate(final=True, scopes=state.scopes_and_subscopes())
|
||||||
return
|
return
|
||||||
|
|
||||||
# Process the event synchronously.
|
# Process the event synchronously.
|
||||||
@ -1542,13 +1551,74 @@ class EventNamespace(AsyncNamespace):
|
|||||||
if disconnect_token:
|
if disconnect_token:
|
||||||
self.token_to_sid.pop(disconnect_token, None)
|
self.token_to_sid.pop(disconnect_token, None)
|
||||||
|
|
||||||
async def emit_update(self, update: StateUpdate, sid: str) -> None:
|
async def emit_update(
|
||||||
|
self, update: StateUpdate, sid: str, room: str | None = None
|
||||||
|
) -> None:
|
||||||
"""Emit an update to the client.
|
"""Emit an update to the client.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
update: The state update to send.
|
update: The state update to send.
|
||||||
sid: The Socket.IO session id.
|
sid: The Socket.IO session id.
|
||||||
|
room: The room to send the update to.
|
||||||
"""
|
"""
|
||||||
|
# TODO We don't know when to leave a room yet.
|
||||||
|
for receiver in update.scopes:
|
||||||
|
if self.sid_to_token[sid] != receiver:
|
||||||
|
room = receiver
|
||||||
|
if room not in self.rooms(sid):
|
||||||
|
print(f"Entering room `{room}`")
|
||||||
|
await self.enter_room(sid, room)
|
||||||
|
|
||||||
|
for room in self.rooms(sid):
|
||||||
|
if room not in update.scopes and room != sid:
|
||||||
|
print(f"Leaving room `{room}`")
|
||||||
|
await self.leave_room(sid, room)
|
||||||
|
|
||||||
|
delta_by_scope = {}
|
||||||
|
events_by_scope = {}
|
||||||
|
|
||||||
|
for event in update.events:
|
||||||
|
scope = self.token_to_sid.get(event.token, event.token)
|
||||||
|
events = events_by_scope.get(scope, [])
|
||||||
|
events.append(event)
|
||||||
|
events_by_scope[scope] = events
|
||||||
|
|
||||||
|
for state, delta in update.delta.items():
|
||||||
|
key = delta.get("_scope", sid)
|
||||||
|
d = delta_by_scope.get(key, {})
|
||||||
|
d.update({state: delta})
|
||||||
|
delta_by_scope[key] = d
|
||||||
|
|
||||||
|
for scope, deltas in delta_by_scope.items():
|
||||||
|
events = events_by_scope.get(scope, [])
|
||||||
|
print(f"Sending update to {scope} {events}")
|
||||||
|
single_update = StateUpdate(
|
||||||
|
delta=deltas, scopes=[scope], events=events, final=update.final
|
||||||
|
)
|
||||||
|
|
||||||
|
await asyncio.create_task(
|
||||||
|
self.emit(
|
||||||
|
str(constants.SocketEvent.EVENT), single_update.json(), to=scope
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
for key in events_by_scope:
|
||||||
|
if key not in delta_by_scope:
|
||||||
|
single_update = StateUpdate(
|
||||||
|
delta={},
|
||||||
|
scopes=[key],
|
||||||
|
events=events_by_scope.get(key, []),
|
||||||
|
final=update.final,
|
||||||
|
)
|
||||||
|
print(f"Sending event to {key}")
|
||||||
|
await asyncio.create_task(
|
||||||
|
self.emit(
|
||||||
|
str(constants.SocketEvent.EVENT), single_update.json(), to=key
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
update.scopes = []
|
||||||
|
|
||||||
# Creating a task prevents the update from being blocked behind other coroutines.
|
# Creating a task prevents the update from being blocked behind other coroutines.
|
||||||
await asyncio.create_task(
|
await asyncio.create_task(
|
||||||
self.emit(str(constants.SocketEvent.EVENT), update, to=sid)
|
self.emit(str(constants.SocketEvent.EVENT), update, to=sid)
|
||||||
@ -1590,6 +1660,8 @@ class EventNamespace(AsyncNamespace):
|
|||||||
except (KeyError, IndexError):
|
except (KeyError, IndexError):
|
||||||
client_ip = environ.get("REMOTE_ADDR", "0.0.0.0")
|
client_ip = environ.get("REMOTE_ADDR", "0.0.0.0")
|
||||||
|
|
||||||
|
print(f"Received event {event.name} {event.token} from {client_ip}")
|
||||||
|
|
||||||
# Process the events.
|
# Process the events.
|
||||||
async for update in process(self.app, event, sid, headers, client_ip):
|
async for update in process(self.app, event, sid, headers, client_ip):
|
||||||
# Emit the update from processing the event.
|
# Emit the update from processing the event.
|
||||||
|
@ -47,4 +47,4 @@ class HydrateMiddleware(Middleware):
|
|||||||
state._clean()
|
state._clean()
|
||||||
|
|
||||||
# Return the state update.
|
# Return the state update.
|
||||||
return StateUpdate(delta=delta, events=[])
|
return StateUpdate(delta=delta, events=[], scopes=state.scopes_and_subscopes())
|
||||||
|
@ -8,6 +8,8 @@ import copy
|
|||||||
import dataclasses
|
import dataclasses
|
||||||
import functools
|
import functools
|
||||||
import inspect
|
import inspect
|
||||||
|
import os
|
||||||
|
import traceback
|
||||||
import json
|
import json
|
||||||
import pickle
|
import pickle
|
||||||
import sys
|
import sys
|
||||||
@ -135,6 +137,23 @@ if environment.REFLEX_PERF_MODE.get() != PerformanceMode.OFF:
|
|||||||
# Only warn about each state class size once.
|
# Only warn about each state class size once.
|
||||||
_WARNED_ABOUT_STATE_SIZE: Set[str] = set()
|
_WARNED_ABOUT_STATE_SIZE: Set[str] = set()
|
||||||
|
|
||||||
|
|
||||||
|
def print_stack(depth: int = 3):
|
||||||
|
"""Print the current stacktrace to the console.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
depth: Depth of the stack-trace to print
|
||||||
|
"""
|
||||||
|
stack = traceback.extract_stack()
|
||||||
|
stack.reverse()
|
||||||
|
print("stacktrace")
|
||||||
|
for idx in range(1, depth + 1):
|
||||||
|
stack_info = stack[idx]
|
||||||
|
print(
|
||||||
|
f" {stack_info.name} {os.path.basename(stack_info.filename)}:{stack_info.lineno}"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
# Errors caught during pickling of state
|
# Errors caught during pickling of state
|
||||||
HANDLED_PICKLE_ERRORS = (
|
HANDLED_PICKLE_ERRORS = (
|
||||||
pickle.PicklingError,
|
pickle.PicklingError,
|
||||||
@ -481,11 +500,14 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
|
|||||||
)
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def __init_subclass__(cls, mixin: bool = False, **kwargs):
|
def __init_subclass__(
|
||||||
|
cls, mixin: bool = False, scope: Union[Var, str, None] = None, **kwargs
|
||||||
|
):
|
||||||
"""Do some magic for the subclass initialization.
|
"""Do some magic for the subclass initialization.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
mixin: Whether the subclass is a mixin and should not be initialized.
|
mixin: Whether the subclass is a mixin and should not be initialized.
|
||||||
|
scope: A var or string to set the scope of the state. The state will be shared across all states with the same scope value.
|
||||||
**kwargs: The kwargs to pass to the pydantic init_subclass method.
|
**kwargs: The kwargs to pass to the pydantic init_subclass method.
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
@ -499,6 +521,8 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
|
|||||||
if mixin:
|
if mixin:
|
||||||
return
|
return
|
||||||
|
|
||||||
|
# Set the scope of the state.
|
||||||
|
cls._scope = scope
|
||||||
# Handle locally-defined states for pickling.
|
# Handle locally-defined states for pickling.
|
||||||
if "<locals>" in cls.__qualname__:
|
if "<locals>" in cls.__qualname__:
|
||||||
cls._handle_local_def()
|
cls._handle_local_def()
|
||||||
@ -795,7 +819,7 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
|
|||||||
}
|
}
|
||||||
|
|
||||||
# Any substate containing a ComputedVar with cache=False always needs to be recomputed
|
# Any substate containing a ComputedVar with cache=False always needs to be recomputed
|
||||||
if cls._always_dirty_computed_vars:
|
if cls._always_dirty_computed_vars: # or cls._scope is not None:
|
||||||
# Tell parent classes that this substate has always dirty computed vars
|
# Tell parent classes that this substate has always dirty computed vars
|
||||||
state_name = cls.get_name()
|
state_name = cls.get_name()
|
||||||
parent_state = cls.get_parent_state()
|
parent_state = cls.get_parent_state()
|
||||||
@ -1532,9 +1556,7 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
|
|||||||
# The requested state is missing, fetch from redis.
|
# The requested state is missing, fetch from redis.
|
||||||
pass
|
pass
|
||||||
parent_state = await state_manager.get_state(
|
parent_state = await state_manager.get_state(
|
||||||
token=_substate_key(
|
token=_substate_key(self._get_token(), parent_state_name),
|
||||||
self.router.session.client_token, parent_state_name
|
|
||||||
),
|
|
||||||
top_level=False,
|
top_level=False,
|
||||||
get_substates=False,
|
get_substates=False,
|
||||||
parent_state=parent_state,
|
parent_state=parent_state,
|
||||||
@ -1577,8 +1599,9 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
|
|||||||
f"Requested state {state_cls.get_full_name()} is not cached and cannot be accessed without redis. "
|
f"Requested state {state_cls.get_full_name()} is not cached and cannot be accessed without redis. "
|
||||||
"(All states should already be available -- this is likely a bug).",
|
"(All states should already be available -- this is likely a bug).",
|
||||||
)
|
)
|
||||||
|
|
||||||
return await state_manager.get_state(
|
return await state_manager.get_state(
|
||||||
token=_substate_key(self.router.session.client_token, state_cls),
|
token=_substate_key(self._get_token(), state_cls),
|
||||||
top_level=False,
|
top_level=False,
|
||||||
get_substates=True,
|
get_substates=True,
|
||||||
parent_state=parent_state_of_state_cls,
|
parent_state=parent_state_of_state_cls,
|
||||||
@ -1720,6 +1743,34 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
|
|||||||
f"Your handler {handler.fn.__qualname__} must only return/yield: None, Events or other EventHandlers referenced by their class (not using `self`)"
|
f"Your handler {handler.fn.__qualname__} must only return/yield: None, Events or other EventHandlers referenced by their class (not using `self`)"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def scopes_and_subscopes(self) -> list[str]:
|
||||||
|
"""Recursively gathers all scopes of self and substates.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A unique list of the scopes/token
|
||||||
|
"""
|
||||||
|
result = [self._get_token()]
|
||||||
|
for substate in self.substates.values():
|
||||||
|
subscopes = substate.scopes_and_subscopes()
|
||||||
|
for subscope in subscopes:
|
||||||
|
if subscope not in result:
|
||||||
|
result.append(subscope)
|
||||||
|
return result
|
||||||
|
|
||||||
|
def _get_token(self, other: type[BaseState] | None = None) -> str:
|
||||||
|
token = self.router.session.client_token
|
||||||
|
cls = other or self.__class__
|
||||||
|
if cls._scope is not None:
|
||||||
|
scope = None
|
||||||
|
if isinstance(cls._scope, str):
|
||||||
|
scope = f"static{cls._scope}"
|
||||||
|
else:
|
||||||
|
scope = f"shared{getattr(self, cls._scope._var_name)}"
|
||||||
|
|
||||||
|
token = scope
|
||||||
|
|
||||||
|
return token
|
||||||
|
|
||||||
def _as_state_update(
|
def _as_state_update(
|
||||||
self,
|
self,
|
||||||
handler: EventHandler,
|
handler: EventHandler,
|
||||||
@ -1741,7 +1792,7 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
|
|||||||
# get the delta from the root of the state tree
|
# get the delta from the root of the state tree
|
||||||
state = self._get_root_state()
|
state = self._get_root_state()
|
||||||
|
|
||||||
token = self.router.session.client_token
|
token = self._get_token()
|
||||||
|
|
||||||
# Convert valid EventHandler and EventSpec into Event
|
# Convert valid EventHandler and EventSpec into Event
|
||||||
fixed_events = fix_events(self._check_valid(handler, events), token)
|
fixed_events = fix_events(self._check_valid(handler, events), token)
|
||||||
@ -1755,6 +1806,7 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
|
|||||||
delta=delta,
|
delta=delta,
|
||||||
events=fixed_events,
|
events=fixed_events,
|
||||||
final=final if not handler.is_background else True,
|
final=final if not handler.is_background else True,
|
||||||
|
scopes=state.scopes_and_subscopes(),
|
||||||
)
|
)
|
||||||
except Exception as ex:
|
except Exception as ex:
|
||||||
state._clean()
|
state._clean()
|
||||||
@ -1970,6 +2022,9 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
|
|||||||
.union(self._always_dirty_computed_vars)
|
.union(self._always_dirty_computed_vars)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if len(self.scopes_and_subscopes()) > 1 and "router" in delta_vars:
|
||||||
|
delta_vars.remove("router")
|
||||||
|
|
||||||
subdelta: Dict[str, Any] = {
|
subdelta: Dict[str, Any] = {
|
||||||
prop: self.get_value(prop)
|
prop: self.get_value(prop)
|
||||||
for prop in delta_vars
|
for prop in delta_vars
|
||||||
@ -1979,6 +2034,9 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
|
|||||||
if len(subdelta) > 0:
|
if len(subdelta) > 0:
|
||||||
delta[self.get_full_name()] = subdelta
|
delta[self.get_full_name()] = subdelta
|
||||||
|
|
||||||
|
if self.__class__._scope is not None:
|
||||||
|
subdelta["_scope"] = self._get_token()
|
||||||
|
|
||||||
# Recursively find the substate deltas.
|
# Recursively find the substate deltas.
|
||||||
substates = self.substates
|
substates = self.substates
|
||||||
for substate in self.dirty_substates.union(self._always_dirty_substates):
|
for substate in self.dirty_substates.union(self._always_dirty_substates):
|
||||||
@ -2114,7 +2172,7 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
|
|||||||
}
|
}
|
||||||
else:
|
else:
|
||||||
computed_vars = {}
|
computed_vars = {}
|
||||||
variables = {**base_vars, **computed_vars}
|
variables = {"_scope": self._get_token(), **base_vars, **computed_vars}
|
||||||
d = {
|
d = {
|
||||||
self.get_full_name(): {k: variables[k] for k in sorted(variables)},
|
self.get_full_name(): {k: variables[k] for k in sorted(variables)},
|
||||||
}
|
}
|
||||||
@ -2431,7 +2489,7 @@ class OnLoadInternalState(State):
|
|||||||
return [
|
return [
|
||||||
*fix_events(
|
*fix_events(
|
||||||
load_events,
|
load_events,
|
||||||
self.router.session.client_token,
|
self._get_token(),
|
||||||
router_data=self.router_data,
|
router_data=self.router_data,
|
||||||
),
|
),
|
||||||
State.set_is_hydrated(True), # type: ignore
|
State.set_is_hydrated(True), # type: ignore
|
||||||
@ -2833,6 +2891,8 @@ class StateUpdate:
|
|||||||
# Whether this is the final state update for the event.
|
# Whether this is the final state update for the event.
|
||||||
final: bool = True
|
final: bool = True
|
||||||
|
|
||||||
|
scopes: list[str] = []
|
||||||
|
|
||||||
def json(self) -> str:
|
def json(self) -> str:
|
||||||
"""Convert the state update to a JSON string.
|
"""Convert the state update to a JSON string.
|
||||||
|
|
||||||
@ -2948,10 +3008,15 @@ class StateManagerMemory(StateManager):
|
|||||||
Returns:
|
Returns:
|
||||||
The state for the token.
|
The state for the token.
|
||||||
"""
|
"""
|
||||||
# Memory state manager ignores the substate suffix and always returns the top-level state.
|
|
||||||
token = _split_substate_key(token)[0]
|
|
||||||
if token not in self.states:
|
if token not in self.states:
|
||||||
self.states[token] = self.state(_reflex_internal_init=True)
|
self.states[token] = self.state(_reflex_internal_init=True)
|
||||||
|
|
||||||
|
# TODO This is a bit madness maybe.
|
||||||
|
token = self.states[token]._get_token()
|
||||||
|
|
||||||
|
if token not in self.states:
|
||||||
|
self.states[token] = self.state(_reflex_internal_init=True)
|
||||||
|
|
||||||
return self.states[token]
|
return self.states[token]
|
||||||
|
|
||||||
@override
|
@override
|
||||||
@ -2975,8 +3040,6 @@ class StateManagerMemory(StateManager):
|
|||||||
Yields:
|
Yields:
|
||||||
The state for the token.
|
The state for the token.
|
||||||
"""
|
"""
|
||||||
# Memory state manager ignores the substate suffix and always returns the top-level state.
|
|
||||||
token = _split_substate_key(token)[0]
|
|
||||||
if token not in self._states_locks:
|
if token not in self._states_locks:
|
||||||
async with self._state_manager_lock:
|
async with self._state_manager_lock:
|
||||||
if token not in self._states_locks:
|
if token not in self._states_locks:
|
||||||
@ -3394,6 +3457,10 @@ class StateManagerRedis(StateManager):
|
|||||||
f"StateManagerRedis requires token to be specified in the form of {{token}}_{{state_full_name}}, but got {token}"
|
f"StateManagerRedis requires token to be specified in the form of {{token}}_{{state_full_name}}, but got {token}"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if parent_state is None:
|
||||||
|
parent_state = await self._get_parent_state(token)
|
||||||
|
if parent_state is not None:
|
||||||
|
token = f"{parent_state._get_token(state_cls)}_{state_path}"
|
||||||
# The deserialized or newly created (sub)state instance.
|
# The deserialized or newly created (sub)state instance.
|
||||||
state = None
|
state = None
|
||||||
|
|
||||||
@ -3472,6 +3539,8 @@ class StateManagerRedis(StateManager):
|
|||||||
)
|
)
|
||||||
|
|
||||||
client_token, substate_name = _split_substate_key(token)
|
client_token, substate_name = _split_substate_key(token)
|
||||||
|
client_token = state._get_token()
|
||||||
|
|
||||||
# If the substate name on the token doesn't match the instance name, it cannot have a parent.
|
# If the substate name on the token doesn't match the instance name, it cannot have a parent.
|
||||||
if state.parent_state is not None and state.get_full_name() != substate_name:
|
if state.parent_state is not None and state.get_full_name() != substate_name:
|
||||||
raise RuntimeError(
|
raise RuntimeError(
|
||||||
|
Loading…
Reference in New Issue
Block a user