flatten StateManagerRedis.get_state algorithm
simplify fetching of states and avoid repeatedly fetching the same state
This commit is contained in:
parent
14c8aa45a6
commit
ca3c0fd723
392
reflex/state.py
392
reflex/state.py
@ -1465,65 +1465,6 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
|
||||
}
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def _get_common_ancestor(cls, other: Type[BaseState]) -> str:
|
||||
"""Find the name of the nearest common ancestor shared by this and the other state.
|
||||
|
||||
Args:
|
||||
other: The other state.
|
||||
|
||||
Returns:
|
||||
Full name of the nearest common ancestor.
|
||||
"""
|
||||
common_ancestor_parts = []
|
||||
for part1, part2 in zip(
|
||||
cls.get_full_name().split("."),
|
||||
other.get_full_name().split("."),
|
||||
):
|
||||
if part1 != part2:
|
||||
break
|
||||
common_ancestor_parts.append(part1)
|
||||
return ".".join(common_ancestor_parts)
|
||||
|
||||
@classmethod
|
||||
def _determine_missing_parent_states(
|
||||
cls, target_state_cls: Type[BaseState]
|
||||
) -> tuple[str, list[str]]:
|
||||
"""Determine the missing parent states between the target_state_cls and common ancestor of this state.
|
||||
|
||||
Args:
|
||||
target_state_cls: The class of the state to find missing parent states for.
|
||||
|
||||
Returns:
|
||||
The name of the common ancestor and the list of missing parent states.
|
||||
"""
|
||||
common_ancestor_name = cls._get_common_ancestor(target_state_cls)
|
||||
common_ancestor_parts = common_ancestor_name.split(".")
|
||||
target_state_parts = tuple(target_state_cls.get_full_name().split("."))
|
||||
relative_target_state_parts = target_state_parts[len(common_ancestor_parts) :]
|
||||
|
||||
# Determine which parent states to fetch from the common ancestor down to the target_state_cls.
|
||||
fetch_parent_states = [common_ancestor_name]
|
||||
for relative_parent_state_name in relative_target_state_parts:
|
||||
fetch_parent_states.append(
|
||||
".".join((fetch_parent_states[-1], relative_parent_state_name))
|
||||
)
|
||||
|
||||
return common_ancestor_name, fetch_parent_states[1:-1]
|
||||
|
||||
def _get_parent_states(self) -> list[tuple[str, BaseState]]:
|
||||
"""Get all parent state instances up to the root of the state tree.
|
||||
|
||||
Returns:
|
||||
A list of tuples containing the name and the instance of each parent state.
|
||||
"""
|
||||
parent_states_with_name = []
|
||||
parent_state = self
|
||||
while parent_state.parent_state is not None:
|
||||
parent_state = parent_state.parent_state
|
||||
parent_states_with_name.append((parent_state.get_full_name(), parent_state))
|
||||
return parent_states_with_name
|
||||
|
||||
def _get_root_state(self) -> BaseState:
|
||||
"""Get the root state of the state tree.
|
||||
|
||||
@ -1555,9 +1496,10 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
|
||||
f"Requested state {state_cls.get_full_name()} is not cached and cannot be accessed without redis. "
|
||||
"(All states should already be available -- this is likely a bug).",
|
||||
)
|
||||
state_in_redis = await state_manager._link_arbitrary_state(
|
||||
self,
|
||||
state_cls,
|
||||
state_in_redis = await state_manager.get_state(
|
||||
token=_substate_key(self.router.session.client_token, state_cls),
|
||||
top_level=False,
|
||||
for_state_instance=self,
|
||||
)
|
||||
|
||||
if not isinstance(state_in_redis, state_cls):
|
||||
@ -1944,54 +1886,6 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
|
||||
if include_backend or not self.computed_vars[cvar]._backend
|
||||
}
|
||||
|
||||
async def _recursively_populate_dependent_substates(
|
||||
self,
|
||||
seen_classes: set[type[BaseState]] | None = None,
|
||||
) -> set[type[BaseState]]:
|
||||
"""Fetch all substates that have computed var dependencies on this state.
|
||||
|
||||
Args:
|
||||
seen_classes: set of classes that have already been seen to prevent infinite recursion.
|
||||
|
||||
Returns:
|
||||
The set of classes that were processed (mostly for testability).
|
||||
"""
|
||||
if seen_classes is None:
|
||||
print(
|
||||
f"\n\nTop-level _recursively_populate_dependent_substates from {type(self)}:"
|
||||
)
|
||||
seen_classes = set()
|
||||
if type(self) in seen_classes:
|
||||
return seen_classes
|
||||
seen_classes.add(type(self))
|
||||
populated_substate_instances = {}
|
||||
for substate_cls in {
|
||||
self.get_class_substate((self.get_name(), *substate_name.split(".")))
|
||||
for substate_name in self._always_dirty_substates
|
||||
}:
|
||||
# _always_dirty_substates need to be fetched to recalc computed vars.
|
||||
if substate_cls not in populated_substate_instances:
|
||||
print(f"fetching always dirty {substate_cls}")
|
||||
populated_substate_instances[substate_cls] = await self.get_state(
|
||||
substate_cls
|
||||
)
|
||||
for dep_set in self._var_dependencies.values():
|
||||
for substate_name, _ in dep_set:
|
||||
if substate_name == self.get_full_name():
|
||||
# Do NOT fetch our own state instance.
|
||||
continue
|
||||
substate_cls = self.get_root_state().get_class_substate(substate_name)
|
||||
if substate_cls not in populated_substate_instances:
|
||||
print(f"fetching dependent {substate_cls}")
|
||||
populated_substate_instances[substate_cls] = await self.get_state(
|
||||
substate_cls
|
||||
)
|
||||
for substate in populated_substate_instances.values():
|
||||
await substate._recursively_populate_dependent_substates(
|
||||
seen_classes=seen_classes,
|
||||
)
|
||||
return seen_classes
|
||||
|
||||
def get_delta(self) -> Delta:
|
||||
"""Get the delta for the state.
|
||||
|
||||
@ -3316,179 +3210,74 @@ class StateManagerRedis(StateManager):
|
||||
b"evicted",
|
||||
}
|
||||
|
||||
async def _get_parent_state(
|
||||
self, token: str, state: BaseState | None = None
|
||||
) -> BaseState | None:
|
||||
"""Get the parent state for the state requested in the token.
|
||||
|
||||
Args:
|
||||
token: The token to get the state for (_substate_key).
|
||||
state: The state instance to get parent state for.
|
||||
|
||||
Returns:
|
||||
The parent state for the state requested by the token or None if there is no such parent.
|
||||
"""
|
||||
parent_state = None
|
||||
client_token, state_path = _split_substate_key(token)
|
||||
parent_state_name = state_path.rpartition(".")[0]
|
||||
if parent_state_name:
|
||||
cached_substates = None
|
||||
if state is not None:
|
||||
cached_substates = [state]
|
||||
# Retrieve the parent state to populate event handlers onto this substate.
|
||||
parent_state = await self.get_state(
|
||||
token=_substate_key(client_token, parent_state_name),
|
||||
top_level=False,
|
||||
get_substates=False,
|
||||
cached_substates=cached_substates,
|
||||
)
|
||||
return parent_state
|
||||
|
||||
async def _populate_parent_states(
|
||||
self, calling_state: BaseState, target_state_cls: Type[BaseState]
|
||||
):
|
||||
"""Populate substates in the tree between the target_state_cls and common ancestor of calling_state.
|
||||
|
||||
Args:
|
||||
calling_state: The substate instance requesting subtree population.
|
||||
target_state_cls: The class of the state to populate parent states for.
|
||||
|
||||
Returns:
|
||||
The parent state instance of target_state_cls.
|
||||
"""
|
||||
# Find the missing parent states up to the common ancestor.
|
||||
(
|
||||
common_ancestor_name,
|
||||
missing_parent_states,
|
||||
) = calling_state._determine_missing_parent_states(target_state_cls)
|
||||
|
||||
# Fetch all missing parent states and link them up to the common ancestor.
|
||||
parent_states_tuple = calling_state._get_parent_states()
|
||||
root_state = parent_states_tuple[-1][1]
|
||||
parent_states_by_name = dict(parent_states_tuple)
|
||||
parent_state = parent_states_by_name[common_ancestor_name]
|
||||
for parent_state_name in missing_parent_states:
|
||||
try:
|
||||
parent_state = root_state.get_substate(parent_state_name.split("."))
|
||||
# The requested state is already cached, do NOT fetch it again.
|
||||
continue
|
||||
except ValueError:
|
||||
# The requested state is missing, fetch from redis.
|
||||
pass
|
||||
parent_state = await self.get_state(
|
||||
token=_substate_key(
|
||||
calling_state.router.session.client_token, parent_state_name
|
||||
),
|
||||
top_level=False,
|
||||
get_substates=False,
|
||||
parent_state=parent_state,
|
||||
)
|
||||
|
||||
# Return the direct parent of target_state_cls for subsequent linking.
|
||||
return parent_state
|
||||
|
||||
async def _link_arbitrary_state(
|
||||
self, calling_state: BaseState, state_cls: Type[T_STATE]
|
||||
) -> T_STATE:
|
||||
"""Get a state instance from redis.
|
||||
|
||||
Args:
|
||||
calling_state: The state instance requesting the newly linked instance of state_cls.
|
||||
state_cls: The class of the state to link into the tree.
|
||||
|
||||
Returns:
|
||||
The instance of state_cls associated with calling_state's client_token.
|
||||
|
||||
Raises:
|
||||
StateMismatchError: If the state instance is not of the expected type.
|
||||
"""
|
||||
# Fetch all missing parent states from redis.
|
||||
parent_state_of_state_cls = await self._populate_parent_states(
|
||||
calling_state, state_cls
|
||||
)
|
||||
|
||||
# Then get the target state and all its substates.
|
||||
state_in_redis = await self.get_state(
|
||||
token=_substate_key(calling_state.router.session.client_token, state_cls),
|
||||
top_level=False,
|
||||
get_substates=True,
|
||||
parent_state=parent_state_of_state_cls,
|
||||
)
|
||||
|
||||
return state_in_redis
|
||||
|
||||
async def _populate_substates(
|
||||
def _get_required_state_classes(
|
||||
self,
|
||||
token: str,
|
||||
state: BaseState,
|
||||
all_substates: bool = False,
|
||||
):
|
||||
"""Fetch and link substates for the given state instance.
|
||||
|
||||
There is no return value; the side-effect is that `state` will have `substates` populated,
|
||||
and each substate will have its `parent_state` set to `state`.
|
||||
|
||||
Args:
|
||||
token: The token to get the state for.
|
||||
state: The state instance to populate substates for.
|
||||
all_substates: Whether to fetch all substates or just required substates.
|
||||
"""
|
||||
client_token, _ = _split_substate_key(token)
|
||||
|
||||
# Only _potentially_dirty_substates need to be fetched to recalc computed vars.
|
||||
fetch_substates = state._get_potentially_dirty_states()
|
||||
if all_substates:
|
||||
# All substates are requested.
|
||||
fetch_substates.update(state.get_substates())
|
||||
|
||||
tasks = {}
|
||||
link_tasks = set()
|
||||
# Retrieve the necessary substates from redis.
|
||||
for substate_cls in fetch_substates:
|
||||
if substate_cls.get_name() in state.substates:
|
||||
continue
|
||||
substate_name = substate_cls.get_name()
|
||||
if substate_cls in state.get_substates():
|
||||
tasks[substate_name] = asyncio.create_task(
|
||||
self.get_state(
|
||||
token=_substate_key(client_token, substate_cls),
|
||||
top_level=False,
|
||||
get_substates=all_substates,
|
||||
parent_state=state,
|
||||
)
|
||||
target_state_cls: Type[BaseState],
|
||||
subclasses: bool = False,
|
||||
required_state_classes: set[Type[BaseState]] | None = None,
|
||||
) -> set[Type[BaseState]]:
|
||||
if required_state_classes is None:
|
||||
required_state_classes = set()
|
||||
# Get the substates if requested.
|
||||
if subclasses:
|
||||
for substate in target_state_cls.get_substates():
|
||||
self._get_required_state_classes(
|
||||
substate,
|
||||
subclasses=True,
|
||||
required_state_classes=required_state_classes,
|
||||
)
|
||||
else:
|
||||
try:
|
||||
state._get_root_state().get_substate(substate_name.split("."))
|
||||
except ValueError:
|
||||
# The requested state is missing, so fetch and link it (and its parents).
|
||||
link_tasks.add(
|
||||
asyncio.create_task(
|
||||
self._link_arbitrary_state(state, substate_cls)
|
||||
)
|
||||
)
|
||||
if target_state_cls in required_state_classes:
|
||||
return required_state_classes
|
||||
required_state_classes.add(target_state_cls)
|
||||
|
||||
for substate_name, substate_task in tasks.items():
|
||||
state.substates[substate_name] = await substate_task
|
||||
await asyncio.gather(*link_tasks)
|
||||
# Get dependent substates.
|
||||
for pd_substates in target_state_cls._get_potentially_dirty_states():
|
||||
self._get_required_state_classes(
|
||||
pd_substates,
|
||||
subclasses=False,
|
||||
required_state_classes=required_state_classes,
|
||||
)
|
||||
|
||||
# Get the parent state if it exists.
|
||||
if parent_state := target_state_cls.get_parent_state():
|
||||
self._get_required_state_classes(
|
||||
parent_state,
|
||||
subclasses=False,
|
||||
required_state_classes=required_state_classes,
|
||||
)
|
||||
return required_state_classes
|
||||
|
||||
def _get_populated_states(
|
||||
self,
|
||||
target_state: BaseState,
|
||||
populated_states: dict[str, BaseState] | None = None,
|
||||
) -> dict[str, BaseState]:
|
||||
if populated_states is None:
|
||||
populated_states = {}
|
||||
if target_state.get_full_name() in populated_states:
|
||||
return populated_states
|
||||
populated_states[target_state.get_full_name()] = target_state
|
||||
for substate in target_state.substates.values():
|
||||
self._get_populated_states(substate, populated_states=populated_states)
|
||||
if target_state.parent_state is not None:
|
||||
self._get_populated_states(
|
||||
target_state.parent_state, populated_states=populated_states
|
||||
)
|
||||
return populated_states
|
||||
|
||||
@override
|
||||
async def get_state(
|
||||
self,
|
||||
token: str,
|
||||
top_level: bool = True,
|
||||
get_substates: bool = True,
|
||||
parent_state: BaseState | None = None,
|
||||
cached_substates: list[BaseState] | None = None,
|
||||
for_state_instance: BaseState | None = None,
|
||||
) -> BaseState:
|
||||
"""Get the state for a token.
|
||||
|
||||
Args:
|
||||
token: The token to get the state for.
|
||||
top_level: If true, return an instance of the top-level state (self.state).
|
||||
get_substates: If true, also retrieve substates.
|
||||
parent_state: If provided, use this parent_state instead of getting it from redis.
|
||||
cached_substates: If provided, attach these substates to the state.
|
||||
for_state_instance: If provided, attach the requested states to this existing state tree.
|
||||
|
||||
Returns:
|
||||
The state for the token.
|
||||
@ -3497,7 +3286,7 @@ class StateManagerRedis(StateManager):
|
||||
RuntimeError: when the state_cls is not specified in the token
|
||||
"""
|
||||
# Split the actual token from the fully qualified substate name.
|
||||
_, state_path = _split_substate_key(token)
|
||||
token, state_path = _split_substate_key(token)
|
||||
if state_path:
|
||||
# Get the State class associated with the given path.
|
||||
state_cls = self.state.get_class_substate(state_path)
|
||||
@ -3506,37 +3295,44 @@ class StateManagerRedis(StateManager):
|
||||
f"StateManagerRedis requires token to be specified in the form of {{token}}_{{state_full_name}}, but got {token}"
|
||||
)
|
||||
|
||||
# The deserialized or newly created (sub)state instance.
|
||||
state = None
|
||||
# Determine which states we already have.
|
||||
flat_state_tree: dict[str, BaseState] = (
|
||||
self._get_populated_states(for_state_instance) if for_state_instance else {}
|
||||
)
|
||||
|
||||
# Fetch the serialized substate from redis.
|
||||
redis_state = await self.redis.get(token)
|
||||
# Determine which states from the tree need to be fetched.
|
||||
required_state_classes = self._get_required_state_classes(
|
||||
state_cls, subclasses=True
|
||||
) - {type(s) for s in flat_state_tree.values()}
|
||||
|
||||
if redis_state is not None:
|
||||
# Deserialize the substate.
|
||||
with contextlib.suppress(StateSchemaMismatchError):
|
||||
state = BaseState._deserialize(data=redis_state)
|
||||
if state is None:
|
||||
# Key didn't exist or schema mismatch so create a new instance for this token.
|
||||
state = state_cls(
|
||||
init_substates=False,
|
||||
_reflex_internal_init=True,
|
||||
)
|
||||
# Populate parent state if missing and requested.
|
||||
if parent_state is None:
|
||||
parent_state = await self._get_parent_state(token, state)
|
||||
# Set up Bidirectional linkage between this state and its parent.
|
||||
if parent_state is not None:
|
||||
parent_state.substates[state.get_name()] = state
|
||||
state.parent_state = parent_state
|
||||
# Avoid fetching substates multiple times.
|
||||
if cached_substates:
|
||||
for substate in cached_substates:
|
||||
state.substates[substate.get_name()] = substate
|
||||
if substate.parent_state is None:
|
||||
substate.parent_state = state
|
||||
# Populate substates if requested.
|
||||
await self._populate_substates(token, state, all_substates=get_substates)
|
||||
for state_cls in sorted(
|
||||
required_state_classes, key=lambda x: x.get_full_name()
|
||||
):
|
||||
state = None
|
||||
redis_state = await self.redis.get(_substate_key(token, state_cls))
|
||||
|
||||
if redis_state is not None:
|
||||
# Deserialize the substate.
|
||||
with contextlib.suppress(StateSchemaMismatchError):
|
||||
state = BaseState._deserialize(data=redis_state)
|
||||
if state is None:
|
||||
# Key didn't exist or schema mismatch so create a new instance for this token.
|
||||
state = state_cls(
|
||||
init_substates=False,
|
||||
_reflex_internal_init=True,
|
||||
)
|
||||
flat_state_tree[state.get_full_name()] = state
|
||||
if state.get_parent_state() is not None:
|
||||
parent_state_name, _dot, state_name = state.get_full_name().rpartition(
|
||||
"."
|
||||
)
|
||||
parent_state = flat_state_tree.get(parent_state_name)
|
||||
if parent_state is None:
|
||||
raise Exception(
|
||||
f"Parent state should get fetched first... got {state.get_full_name()} instead"
|
||||
)
|
||||
parent_state.substates[state_name] = state
|
||||
state.parent_state = parent_state
|
||||
|
||||
# To retain compatibility with previous implementation, by default, we return
|
||||
# the top-level state by chasing `parent_state` pointers up the tree.
|
||||
|
@ -3212,8 +3212,13 @@ def test_potentially_dirty_substates():
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_router_var_dep() -> None:
|
||||
"""Test that router var dependencies are correctly tracked."""
|
||||
async def test_router_var_dep(state_manager: StateManager, token: str) -> None:
|
||||
"""Test that router var dependencies are correctly tracked.
|
||||
|
||||
Args:
|
||||
state_manager: A state manager.
|
||||
token: A token.
|
||||
"""
|
||||
|
||||
class RouterVarParentState(State):
|
||||
"""A parent state for testing router var dependency."""
|
||||
@ -3233,24 +3238,17 @@ async def test_router_var_dep() -> None:
|
||||
assert foo._deps(objclass=RouterVarDepState) == {
|
||||
RouterVarDepState.get_full_name(): {"router"}
|
||||
}
|
||||
assert State._var_dependencies == {
|
||||
"router": {(RouterVarDepState.get_full_name(), "foo")}
|
||||
}
|
||||
assert (RouterVarDepState.get_full_name(), "foo") in State._var_dependencies[
|
||||
"router"
|
||||
]
|
||||
|
||||
rx_state = State()
|
||||
parent_state = RouterVarParentState()
|
||||
state = RouterVarDepState()
|
||||
|
||||
# link states
|
||||
rx_state.substates = {RouterVarParentState.get_name(): parent_state}
|
||||
parent_state.parent_state = rx_state
|
||||
state.parent_state = parent_state
|
||||
parent_state.substates = {RouterVarDepState.get_name(): state}
|
||||
|
||||
populated_substate_classes = (
|
||||
await rx_state._recursively_populate_dependent_substates()
|
||||
)
|
||||
assert populated_substate_classes == {State, RouterVarDepState}
|
||||
# Get state from state manager.
|
||||
state_manager.state = State
|
||||
rx_state = await state_manager.get_state(_substate_key(token, State))
|
||||
assert RouterVarParentState.get_name() in rx_state.substates
|
||||
parent_state = rx_state.substates[RouterVarParentState.get_name()]
|
||||
assert RouterVarDepState.get_name() in parent_state.substates
|
||||
state = parent_state.substates[RouterVarDepState.get_name()]
|
||||
|
||||
assert state.dirty_vars == set()
|
||||
|
||||
|
@ -18,7 +18,6 @@ from reflex.utils.exceptions import (
|
||||
from reflex.utils.imports import ImportVar
|
||||
from reflex.vars import VarData
|
||||
from reflex.vars.base import (
|
||||
AsyncComputedVar,
|
||||
ComputedVar,
|
||||
LiteralVar,
|
||||
Var,
|
||||
|
Loading…
Reference in New Issue
Block a user