From ca3c0fd723cc898c5bd902fdea70039e2b32c1b9 Mon Sep 17 00:00:00 2001 From: Masen Furer Date: Tue, 28 Jan 2025 10:00:23 -0800 Subject: [PATCH] flatten StateManagerRedis.get_state algorithm simplify fetching of states and avoid repeatedly fetching the same state --- reflex/state.py | 392 +++++++++----------------------------- tests/units/test_state.py | 36 ++-- tests/units/test_var.py | 1 - 3 files changed, 111 insertions(+), 318 deletions(-) diff --git a/reflex/state.py b/reflex/state.py index 12c5534fb..1f162bf5d 100644 --- a/reflex/state.py +++ b/reflex/state.py @@ -1465,65 +1465,6 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow): } ) - @classmethod - def _get_common_ancestor(cls, other: Type[BaseState]) -> str: - """Find the name of the nearest common ancestor shared by this and the other state. - - Args: - other: The other state. - - Returns: - Full name of the nearest common ancestor. - """ - common_ancestor_parts = [] - for part1, part2 in zip( - cls.get_full_name().split("."), - other.get_full_name().split("."), - ): - if part1 != part2: - break - common_ancestor_parts.append(part1) - return ".".join(common_ancestor_parts) - - @classmethod - def _determine_missing_parent_states( - cls, target_state_cls: Type[BaseState] - ) -> tuple[str, list[str]]: - """Determine the missing parent states between the target_state_cls and common ancestor of this state. - - Args: - target_state_cls: The class of the state to find missing parent states for. - - Returns: - The name of the common ancestor and the list of missing parent states. - """ - common_ancestor_name = cls._get_common_ancestor(target_state_cls) - common_ancestor_parts = common_ancestor_name.split(".") - target_state_parts = tuple(target_state_cls.get_full_name().split(".")) - relative_target_state_parts = target_state_parts[len(common_ancestor_parts) :] - - # Determine which parent states to fetch from the common ancestor down to the target_state_cls. - fetch_parent_states = [common_ancestor_name] - for relative_parent_state_name in relative_target_state_parts: - fetch_parent_states.append( - ".".join((fetch_parent_states[-1], relative_parent_state_name)) - ) - - return common_ancestor_name, fetch_parent_states[1:-1] - - def _get_parent_states(self) -> list[tuple[str, BaseState]]: - """Get all parent state instances up to the root of the state tree. - - Returns: - A list of tuples containing the name and the instance of each parent state. - """ - parent_states_with_name = [] - parent_state = self - while parent_state.parent_state is not None: - parent_state = parent_state.parent_state - parent_states_with_name.append((parent_state.get_full_name(), parent_state)) - return parent_states_with_name - def _get_root_state(self) -> BaseState: """Get the root state of the state tree. @@ -1555,9 +1496,10 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow): f"Requested state {state_cls.get_full_name()} is not cached and cannot be accessed without redis. " "(All states should already be available -- this is likely a bug).", ) - state_in_redis = await state_manager._link_arbitrary_state( - self, - state_cls, + state_in_redis = await state_manager.get_state( + token=_substate_key(self.router.session.client_token, state_cls), + top_level=False, + for_state_instance=self, ) if not isinstance(state_in_redis, state_cls): @@ -1944,54 +1886,6 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow): if include_backend or not self.computed_vars[cvar]._backend } - async def _recursively_populate_dependent_substates( - self, - seen_classes: set[type[BaseState]] | None = None, - ) -> set[type[BaseState]]: - """Fetch all substates that have computed var dependencies on this state. - - Args: - seen_classes: set of classes that have already been seen to prevent infinite recursion. - - Returns: - The set of classes that were processed (mostly for testability). - """ - if seen_classes is None: - print( - f"\n\nTop-level _recursively_populate_dependent_substates from {type(self)}:" - ) - seen_classes = set() - if type(self) in seen_classes: - return seen_classes - seen_classes.add(type(self)) - populated_substate_instances = {} - for substate_cls in { - self.get_class_substate((self.get_name(), *substate_name.split("."))) - for substate_name in self._always_dirty_substates - }: - # _always_dirty_substates need to be fetched to recalc computed vars. - if substate_cls not in populated_substate_instances: - print(f"fetching always dirty {substate_cls}") - populated_substate_instances[substate_cls] = await self.get_state( - substate_cls - ) - for dep_set in self._var_dependencies.values(): - for substate_name, _ in dep_set: - if substate_name == self.get_full_name(): - # Do NOT fetch our own state instance. - continue - substate_cls = self.get_root_state().get_class_substate(substate_name) - if substate_cls not in populated_substate_instances: - print(f"fetching dependent {substate_cls}") - populated_substate_instances[substate_cls] = await self.get_state( - substate_cls - ) - for substate in populated_substate_instances.values(): - await substate._recursively_populate_dependent_substates( - seen_classes=seen_classes, - ) - return seen_classes - def get_delta(self) -> Delta: """Get the delta for the state. @@ -3316,179 +3210,74 @@ class StateManagerRedis(StateManager): b"evicted", } - async def _get_parent_state( - self, token: str, state: BaseState | None = None - ) -> BaseState | None: - """Get the parent state for the state requested in the token. - - Args: - token: The token to get the state for (_substate_key). - state: The state instance to get parent state for. - - Returns: - The parent state for the state requested by the token or None if there is no such parent. - """ - parent_state = None - client_token, state_path = _split_substate_key(token) - parent_state_name = state_path.rpartition(".")[0] - if parent_state_name: - cached_substates = None - if state is not None: - cached_substates = [state] - # Retrieve the parent state to populate event handlers onto this substate. - parent_state = await self.get_state( - token=_substate_key(client_token, parent_state_name), - top_level=False, - get_substates=False, - cached_substates=cached_substates, - ) - return parent_state - - async def _populate_parent_states( - self, calling_state: BaseState, target_state_cls: Type[BaseState] - ): - """Populate substates in the tree between the target_state_cls and common ancestor of calling_state. - - Args: - calling_state: The substate instance requesting subtree population. - target_state_cls: The class of the state to populate parent states for. - - Returns: - The parent state instance of target_state_cls. - """ - # Find the missing parent states up to the common ancestor. - ( - common_ancestor_name, - missing_parent_states, - ) = calling_state._determine_missing_parent_states(target_state_cls) - - # Fetch all missing parent states and link them up to the common ancestor. - parent_states_tuple = calling_state._get_parent_states() - root_state = parent_states_tuple[-1][1] - parent_states_by_name = dict(parent_states_tuple) - parent_state = parent_states_by_name[common_ancestor_name] - for parent_state_name in missing_parent_states: - try: - parent_state = root_state.get_substate(parent_state_name.split(".")) - # The requested state is already cached, do NOT fetch it again. - continue - except ValueError: - # The requested state is missing, fetch from redis. - pass - parent_state = await self.get_state( - token=_substate_key( - calling_state.router.session.client_token, parent_state_name - ), - top_level=False, - get_substates=False, - parent_state=parent_state, - ) - - # Return the direct parent of target_state_cls for subsequent linking. - return parent_state - - async def _link_arbitrary_state( - self, calling_state: BaseState, state_cls: Type[T_STATE] - ) -> T_STATE: - """Get a state instance from redis. - - Args: - calling_state: The state instance requesting the newly linked instance of state_cls. - state_cls: The class of the state to link into the tree. - - Returns: - The instance of state_cls associated with calling_state's client_token. - - Raises: - StateMismatchError: If the state instance is not of the expected type. - """ - # Fetch all missing parent states from redis. - parent_state_of_state_cls = await self._populate_parent_states( - calling_state, state_cls - ) - - # Then get the target state and all its substates. - state_in_redis = await self.get_state( - token=_substate_key(calling_state.router.session.client_token, state_cls), - top_level=False, - get_substates=True, - parent_state=parent_state_of_state_cls, - ) - - return state_in_redis - - async def _populate_substates( + def _get_required_state_classes( self, - token: str, - state: BaseState, - all_substates: bool = False, - ): - """Fetch and link substates for the given state instance. - - There is no return value; the side-effect is that `state` will have `substates` populated, - and each substate will have its `parent_state` set to `state`. - - Args: - token: The token to get the state for. - state: The state instance to populate substates for. - all_substates: Whether to fetch all substates or just required substates. - """ - client_token, _ = _split_substate_key(token) - - # Only _potentially_dirty_substates need to be fetched to recalc computed vars. - fetch_substates = state._get_potentially_dirty_states() - if all_substates: - # All substates are requested. - fetch_substates.update(state.get_substates()) - - tasks = {} - link_tasks = set() - # Retrieve the necessary substates from redis. - for substate_cls in fetch_substates: - if substate_cls.get_name() in state.substates: - continue - substate_name = substate_cls.get_name() - if substate_cls in state.get_substates(): - tasks[substate_name] = asyncio.create_task( - self.get_state( - token=_substate_key(client_token, substate_cls), - top_level=False, - get_substates=all_substates, - parent_state=state, - ) + target_state_cls: Type[BaseState], + subclasses: bool = False, + required_state_classes: set[Type[BaseState]] | None = None, + ) -> set[Type[BaseState]]: + if required_state_classes is None: + required_state_classes = set() + # Get the substates if requested. + if subclasses: + for substate in target_state_cls.get_substates(): + self._get_required_state_classes( + substate, + subclasses=True, + required_state_classes=required_state_classes, ) - else: - try: - state._get_root_state().get_substate(substate_name.split(".")) - except ValueError: - # The requested state is missing, so fetch and link it (and its parents). - link_tasks.add( - asyncio.create_task( - self._link_arbitrary_state(state, substate_cls) - ) - ) + if target_state_cls in required_state_classes: + return required_state_classes + required_state_classes.add(target_state_cls) - for substate_name, substate_task in tasks.items(): - state.substates[substate_name] = await substate_task - await asyncio.gather(*link_tasks) + # Get dependent substates. + for pd_substates in target_state_cls._get_potentially_dirty_states(): + self._get_required_state_classes( + pd_substates, + subclasses=False, + required_state_classes=required_state_classes, + ) + + # Get the parent state if it exists. + if parent_state := target_state_cls.get_parent_state(): + self._get_required_state_classes( + parent_state, + subclasses=False, + required_state_classes=required_state_classes, + ) + return required_state_classes + + def _get_populated_states( + self, + target_state: BaseState, + populated_states: dict[str, BaseState] | None = None, + ) -> dict[str, BaseState]: + if populated_states is None: + populated_states = {} + if target_state.get_full_name() in populated_states: + return populated_states + populated_states[target_state.get_full_name()] = target_state + for substate in target_state.substates.values(): + self._get_populated_states(substate, populated_states=populated_states) + if target_state.parent_state is not None: + self._get_populated_states( + target_state.parent_state, populated_states=populated_states + ) + return populated_states @override async def get_state( self, token: str, top_level: bool = True, - get_substates: bool = True, - parent_state: BaseState | None = None, - cached_substates: list[BaseState] | None = None, + for_state_instance: BaseState | None = None, ) -> BaseState: """Get the state for a token. Args: token: The token to get the state for. top_level: If true, return an instance of the top-level state (self.state). - get_substates: If true, also retrieve substates. - parent_state: If provided, use this parent_state instead of getting it from redis. - cached_substates: If provided, attach these substates to the state. + for_state_instance: If provided, attach the requested states to this existing state tree. Returns: The state for the token. @@ -3497,7 +3286,7 @@ class StateManagerRedis(StateManager): RuntimeError: when the state_cls is not specified in the token """ # Split the actual token from the fully qualified substate name. - _, state_path = _split_substate_key(token) + token, state_path = _split_substate_key(token) if state_path: # Get the State class associated with the given path. state_cls = self.state.get_class_substate(state_path) @@ -3506,37 +3295,44 @@ class StateManagerRedis(StateManager): f"StateManagerRedis requires token to be specified in the form of {{token}}_{{state_full_name}}, but got {token}" ) - # The deserialized or newly created (sub)state instance. - state = None + # Determine which states we already have. + flat_state_tree: dict[str, BaseState] = ( + self._get_populated_states(for_state_instance) if for_state_instance else {} + ) - # Fetch the serialized substate from redis. - redis_state = await self.redis.get(token) + # Determine which states from the tree need to be fetched. + required_state_classes = self._get_required_state_classes( + state_cls, subclasses=True + ) - {type(s) for s in flat_state_tree.values()} - if redis_state is not None: - # Deserialize the substate. - with contextlib.suppress(StateSchemaMismatchError): - state = BaseState._deserialize(data=redis_state) - if state is None: - # Key didn't exist or schema mismatch so create a new instance for this token. - state = state_cls( - init_substates=False, - _reflex_internal_init=True, - ) - # Populate parent state if missing and requested. - if parent_state is None: - parent_state = await self._get_parent_state(token, state) - # Set up Bidirectional linkage between this state and its parent. - if parent_state is not None: - parent_state.substates[state.get_name()] = state - state.parent_state = parent_state - # Avoid fetching substates multiple times. - if cached_substates: - for substate in cached_substates: - state.substates[substate.get_name()] = substate - if substate.parent_state is None: - substate.parent_state = state - # Populate substates if requested. - await self._populate_substates(token, state, all_substates=get_substates) + for state_cls in sorted( + required_state_classes, key=lambda x: x.get_full_name() + ): + state = None + redis_state = await self.redis.get(_substate_key(token, state_cls)) + + if redis_state is not None: + # Deserialize the substate. + with contextlib.suppress(StateSchemaMismatchError): + state = BaseState._deserialize(data=redis_state) + if state is None: + # Key didn't exist or schema mismatch so create a new instance for this token. + state = state_cls( + init_substates=False, + _reflex_internal_init=True, + ) + flat_state_tree[state.get_full_name()] = state + if state.get_parent_state() is not None: + parent_state_name, _dot, state_name = state.get_full_name().rpartition( + "." + ) + parent_state = flat_state_tree.get(parent_state_name) + if parent_state is None: + raise Exception( + f"Parent state should get fetched first... got {state.get_full_name()} instead" + ) + parent_state.substates[state_name] = state + state.parent_state = parent_state # To retain compatibility with previous implementation, by default, we return # the top-level state by chasing `parent_state` pointers up the tree. diff --git a/tests/units/test_state.py b/tests/units/test_state.py index 6086799e1..0d9d438ea 100644 --- a/tests/units/test_state.py +++ b/tests/units/test_state.py @@ -3212,8 +3212,13 @@ def test_potentially_dirty_substates(): @pytest.mark.asyncio -async def test_router_var_dep() -> None: - """Test that router var dependencies are correctly tracked.""" +async def test_router_var_dep(state_manager: StateManager, token: str) -> None: + """Test that router var dependencies are correctly tracked. + + Args: + state_manager: A state manager. + token: A token. + """ class RouterVarParentState(State): """A parent state for testing router var dependency.""" @@ -3233,24 +3238,17 @@ async def test_router_var_dep() -> None: assert foo._deps(objclass=RouterVarDepState) == { RouterVarDepState.get_full_name(): {"router"} } - assert State._var_dependencies == { - "router": {(RouterVarDepState.get_full_name(), "foo")} - } + assert (RouterVarDepState.get_full_name(), "foo") in State._var_dependencies[ + "router" + ] - rx_state = State() - parent_state = RouterVarParentState() - state = RouterVarDepState() - - # link states - rx_state.substates = {RouterVarParentState.get_name(): parent_state} - parent_state.parent_state = rx_state - state.parent_state = parent_state - parent_state.substates = {RouterVarDepState.get_name(): state} - - populated_substate_classes = ( - await rx_state._recursively_populate_dependent_substates() - ) - assert populated_substate_classes == {State, RouterVarDepState} + # Get state from state manager. + state_manager.state = State + rx_state = await state_manager.get_state(_substate_key(token, State)) + assert RouterVarParentState.get_name() in rx_state.substates + parent_state = rx_state.substates[RouterVarParentState.get_name()] + assert RouterVarDepState.get_name() in parent_state.substates + state = parent_state.substates[RouterVarDepState.get_name()] assert state.dirty_vars == set() diff --git a/tests/units/test_var.py b/tests/units/test_var.py index ab396b15e..30fbd4e9b 100644 --- a/tests/units/test_var.py +++ b/tests/units/test_var.py @@ -18,7 +18,6 @@ from reflex.utils.exceptions import ( from reflex.utils.imports import ImportVar from reflex.vars import VarData from reflex.vars.base import ( - AsyncComputedVar, ComputedVar, LiteralVar, Var,