From 82c82d9bd901bda89d8a5892bb7b0e1f037f39c3 Mon Sep 17 00:00:00 2001 From: Benedikt Bartscher Date: Wed, 27 Nov 2024 02:10:04 +0100 Subject: [PATCH] wip --- reflex/state.py | 288 ++++++++++++----------- tests/integration/test_client_storage.py | 18 +- tests/units/test_state_tree.py | 10 +- 3 files changed, 165 insertions(+), 151 deletions(-) diff --git a/reflex/state.py b/reflex/state.py index 55f29cf45..350cd7067 100644 --- a/reflex/state.py +++ b/reflex/state.py @@ -938,7 +938,20 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow): for substate in cls.get_substates(): if path[0] == substate.get_name(): return substate.get_class_substate(path[1:]) - raise ValueError(f"Invalid path: {path}") + raise ValueError(f"Invalid path: {cls.get_full_name()=} {path=}") + + @classmethod + # @functools.lru_cache() + def get_all_substate_classes(cls) -> set[Type[BaseState]]: + """Get all substate classes of the state. + + Returns: + The set of all substate classes. + """ + substates = set(cls.get_substates()) + for substate in cls.get_substates(): + substates.update(substate.get_all_substate_classes()) + return substates @classmethod def get_class_var(cls, path: Sequence[str]) -> Any: @@ -1393,7 +1406,9 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow): return self path = path[1:] if path[0] not in self.substates: - raise ValueError(f"Invalid path: {path}") + raise ValueError( + f"Invalid path: {path=} {self.get_full_name()=} {self.substates.keys()=}" + ) return self.substates[path[0]].get_substate(path[1:]) @classmethod @@ -1455,6 +1470,29 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow): parent_states_with_name.append((parent_state.get_full_name(), parent_state)) return parent_states_with_name + def _get_all_loaded_states(self) -> dict[str, BaseState]: + """Get all loaded states in the state tree. + + Returns: + A list of all loaded states in the state tree. + """ + root_state = self._get_root_state() + d = {root_state.get_full_name(): root_state} + d.update(root_state._get_loaded_substates()) + return d + + def _get_loaded_substates(self) -> dict[str, BaseState]: + """Get all loaded substates of this state. + + Returns: + A list of all loaded substates of this state. + """ + loaded_substates = {} + for substate in self.substates.values(): + loaded_substates[substate.get_full_name()] = substate + loaded_substates.update(substate._get_loaded_substates()) + return loaded_substates + def _get_root_state(self) -> BaseState: """Get the root state of the state tree. @@ -1861,6 +1899,7 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow): if include_backend or not self.computed_vars[cvar]._backend ) + # TODO: just return full name? cache? @classmethod def _potentially_dirty_substates(cls) -> set[Type[BaseState]]: """Determine substates which could be affected by dirty vars in this state. @@ -1882,6 +1921,22 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow): ) return fetch_substates + # TODO: just return full name? cache? + # this only needs to be computed once, and only for the root state? + @classmethod + def _recursive_potentially_dirty_substates(cls) -> set[Type[BaseState]]: + """Recursively determine substates which could be affected by dirty vars in this state. + + Returns: + Set of State classes that may need to be fetched to recalc computed vars. + """ + fetch_substates = cls._potentially_dirty_substates() + for substate_cls in cls.get_substates(): + fetch_substates.update( + substate_cls._recursive_potentially_dirty_substates() + ) + return fetch_substates + def get_delta(self) -> Delta: """Get the delta for the state. @@ -3190,77 +3245,6 @@ class StateManagerRedis(StateManager): b"evicted", } - async def _get_parent_state( - self, token: str, state: BaseState | None = None - ) -> BaseState | None: - """Get the parent state for the state requested in the token. - - Args: - token: The token to get the state for (_substate_key). - state: The state instance to get parent state for. - - Returns: - The parent state for the state requested by the token or None if there is no such parent. - """ - parent_state = None - client_token, state_path = _split_substate_key(token) - parent_state_name = state_path.rpartition(".")[0] - if parent_state_name: - cached_substates = None - if state is not None: - cached_substates = [state] - # Retrieve the parent state to populate event handlers onto this substate. - parent_state = await self.get_state( - token=_substate_key(client_token, parent_state_name), - top_level=False, - get_substates=False, - cached_substates=cached_substates, - ) - return parent_state - - async def _populate_substates( - self, - token: str, - state: BaseState, - all_substates: bool = False, - ): - """Fetch and link substates for the given state instance. - - There is no return value; the side-effect is that `state` will have `substates` populated, - and each substate will have its `parent_state` set to `state`. - - Args: - token: The token to get the state for. - state: The state instance to populate substates for. - all_substates: Whether to fetch all substates or just required substates. - """ - client_token, _ = _split_substate_key(token) - - if all_substates: - # All substates are requested. - fetch_substates = state.get_substates() - else: - # Only _potentially_dirty_substates need to be fetched to recalc computed vars. - fetch_substates = state._potentially_dirty_substates() - - tasks = {} - # Retrieve the necessary substates from redis. - for substate_cls in fetch_substates: - if substate_cls.get_name() in state.substates: - continue - substate_name = substate_cls.get_name() - tasks[substate_name] = asyncio.create_task( - self.get_state( - token=_substate_key(client_token, substate_cls), - top_level=False, - get_substates=all_substates, - parent_state=state, - ) - ) - - for substate_name, substate_task in tasks.items(): - state.substates[substate_name] = await substate_task - @override async def get_state( self, @@ -3268,7 +3252,6 @@ class StateManagerRedis(StateManager): top_level: bool = True, get_substates: bool = True, parent_state: BaseState | None = None, - cached_substates: list[BaseState] | None = None, ) -> BaseState: """Get the state for a token. @@ -3277,7 +3260,6 @@ class StateManagerRedis(StateManager): top_level: If true, return an instance of the top-level state (self.state). get_substates: If true, also retrieve substates. parent_state: If provided, use this parent_state instead of getting it from redis. - cached_substates: If provided, attach these substates to the state. Returns: The state for the token. @@ -3285,8 +3267,8 @@ class StateManagerRedis(StateManager): Raises: RuntimeError: when the state_cls is not specified in the token """ - # Split the actual token from the fully qualified substate name. - _, state_path = _split_substate_key(token) + # new impl from top to bottomA + client_token, state_path = _split_substate_key(token) if state_path: # Get the State class associated with the given path. state_cls = self.state.get_class_substate(state_path) @@ -3295,44 +3277,92 @@ class StateManagerRedis(StateManager): "StateManagerRedis requires token to be specified in the form of {token}_{state_full_name}" ) - # The deserialized or newly created (sub)state instance. - state = None + state_tokens = {state_path} - # Fetch the serialized substate from redis. - redis_state = await self.redis.get(token) + # walk up the state path + walk_state_path = state_path + while "." in walk_state_path: + walk_state_path = walk_state_path.rpartition(".")[0] + state_tokens.add(walk_state_path) - if redis_state is not None: - # Deserialize the substate. - with contextlib.suppress(StateSchemaMismatchError): - state = BaseState._deserialize(data=redis_state) - if state is None: - # Key didn't exist or schema mismatch so create a new instance for this token. - state = state_cls( - init_substates=False, - _reflex_internal_init=True, + state_tokens.update( + { + substate.get_full_name() + for substate in self.state._recursive_potentially_dirty_substates() + } + ) + if get_substates: + state_tokens.update( + { + substate.get_full_name() + for substate in state_cls.get_all_substate_classes() + } ) - # Populate parent state if missing and requested. - if parent_state is None: - parent_state = await self._get_parent_state(token, state) - # Set up Bidirectional linkage between this state and its parent. - if parent_state is not None: - parent_state.substates[state.get_name()] = state - state.parent_state = parent_state - # Avoid fetching substates multiple times. - if cached_substates: - for substate in cached_substates: - state.substates[substate.get_name()] = substate - if substate.parent_state is None: - substate.parent_state = state - # Populate substates if requested. - await self._populate_substates(token, state, all_substates=get_substates) - # To retain compatibility with previous implementation, by default, we return - # the top-level state by chasing `parent_state` pointers up the tree. + loaded_states = {} + if parent_state is not None: + loaded_states = parent_state._get_all_loaded_states() + # remove all states that are already loaded + state_tokens = state_tokens.difference(loaded_states.keys()) + + redis_states = await self.hmget(name=client_token, keys=list(state_tokens)) + redis_states.update(loaded_states) + root_state = redis_states[self.state.get_full_name()] + self.recursive_link_substates(state=root_state, substates=redis_states) + if top_level: - return state._get_root_state() + return root_state + + state = redis_states[state_path] return state + def recursive_link_substates( + self, + state: BaseState, + substates: dict[str, BaseState], + ): + """Recursively link substates to a state. + + Args: + state: The state to link substates to. + substates: The substates to link. + """ + for substate_cls in state.get_substates(): + if substate_cls.get_full_name() not in substates: + continue + substate = substates[substate_cls.get_full_name()] + state.substates[substate.get_name()] = substate + substate.parent_state = state + self.recursive_link_substates( + state=substate, + substates=substates, + ) + + async def hmget(self, name: str, keys: List[str]) -> dict[str, BaseState]: + """Get multiple values from a hash. + + Args: + name: The name of the hash. + keys: The keys to get. + + Returns: + The values. + """ + d = {} + for state in await self.redis.hmget(name=name, keys=keys): # type: ignore + key = keys.pop(0) + if state is not None: + with contextlib.suppress(StateSchemaMismatchError): + state = BaseState._deserialize(data=state) + if state is None: + state_cls = self.state.get_class_substate(key) + state = state_cls( + init_substates=False, + _reflex_internal_init=True, + ) + d[state.get_full_name()] = state + return d + @override async def set_state( self, @@ -3368,31 +3398,25 @@ class StateManagerRedis(StateManager): f"Cannot `set_state` with mismatching token {token} and substate {state.get_full_name()}." ) - # Recursively set_state on all known substates. - tasks = [] - for substate in state.substates.values(): - tasks.append( - asyncio.create_task( - self.set_state( - token=_substate_key(client_token, substate), - state=substate, - lock_id=lock_id, - ) - ) - ) - # Persist only the given state (parents or substates are excluded by BaseState.__getstate__). - if state._get_was_touched(): - pickle_state = state._serialize() - if pickle_state: - await self.redis.set( - _substate_key(client_token, state), - pickle_state, - ex=self.token_expiration, - ) + redis_hashset = {} - # Wait for substates to be persisted. - for t in tasks: - await t + for state_name, substate in state._get_all_loaded_states().items(): + if not substate._get_was_touched(): + continue + pickle_state = substate._serialize() + if not pickle_state: + continue + redis_hashset[state_name] = pickle_state + + if not redis_hashset: + return + + await self.redis.hmset(name=client_token, mapping=redis_hashset) # type: ignore + await self.redis.hexpire( + client_token, + self.token_expiration, + *redis_hashset.keys(), + ) @override @contextlib.asynccontextmanager diff --git a/tests/integration/test_client_storage.py b/tests/integration/test_client_storage.py index 236d3e14e..649b236a4 100644 --- a/tests/integration/test_client_storage.py +++ b/tests/integration/test_client_storage.py @@ -11,7 +11,6 @@ from selenium.webdriver.common.by import By from selenium.webdriver.remote.webdriver import WebDriver from reflex.state import ( - State, StateManagerDisk, StateManagerMemory, StateManagerRedis, @@ -278,6 +277,7 @@ async def test_client_side_state( set_sub_sub_state_button.click() token = poll_for_token() + assert token is not None # get a reference to all cookie and local storage elements c1 = driver.find_element(By.ID, "c1") @@ -613,16 +613,7 @@ async def test_client_side_state( # Simulate state expiration if isinstance(client_side.state_manager, StateManagerRedis): - await client_side.state_manager.redis.delete( - _substate_key(token, State.get_full_name()) - ) - await client_side.state_manager.redis.delete(_substate_key(token, state_name)) - await client_side.state_manager.redis.delete( - _substate_key(token, sub_state_name) - ) - await client_side.state_manager.redis.delete( - _substate_key(token, sub_sub_state_name) - ) + await client_side.state_manager.redis.delete(token) elif isinstance(client_side.state_manager, (StateManagerMemory, StateManagerDisk)): del client_side.state_manager.states[token] if isinstance(client_side.state_manager, StateManagerDisk): @@ -679,9 +670,8 @@ async def test_client_side_state( # Get the backend state and ensure the values are still set async def get_sub_state(): - root_state = await client_side.get_state( - _substate_key(token or "", sub_state_name) - ) + assert token is not None + root_state = await client_side.get_state(_substate_key(token, sub_state_name)) state = root_state.substates[client_side.get_state_name("_client_side_state")] sub_state = state.substates[ client_side.get_state_name("_client_side_sub_state") diff --git a/tests/units/test_state_tree.py b/tests/units/test_state_tree.py index ebdd877de..f1d100ff2 100644 --- a/tests/units/test_state_tree.py +++ b/tests/units/test_state_tree.py @@ -354,11 +354,11 @@ async def state_manager_redis( ], ) async def test_get_state_tree( - state_manager_redis, - token, - substate_cls, - exp_root_substates, - exp_root_dict_keys, + state_manager_redis: StateManagerRedis, + token: str, + substate_cls: type[BaseState], + exp_root_substates: list[str], + exp_root_dict_keys: list[str], ): """Test getting state trees and assert on which branches are retrieved.