From 6dba120e6dfd877d5fd88597e472fa9f53acda29 Mon Sep 17 00:00:00 2001 From: Masen Furer Date: Wed, 20 Mar 2024 02:43:59 -0700 Subject: [PATCH] Avoid fetching substates multiple times In the presence of computed vars, substates may be cached more than once. --- reflex/state.py | 21 +++++++++++++++++++-- 1 file changed, 19 insertions(+), 2 deletions(-) diff --git a/reflex/state.py b/reflex/state.py index 5798564fa..64f5369bd 100644 --- a/reflex/state.py +++ b/reflex/state.py @@ -2874,11 +2874,14 @@ class StateManagerRedis(StateManager): # Only warn about each state class size once. _warned_about_state_size: ClassVar[Set[str]] = set() - async def _get_parent_state(self, token: str) -> BaseState | None: + async def _get_parent_state( + self, token: str, state: BaseState | None = None + ) -> BaseState | None: """Get the parent state for the state requested in the token. Args: token: The token to get the state for (_substate_key). + state: The state instance to get parent state for. Returns: The parent state for the state requested by the token or None if there is no such parent. @@ -2887,11 +2890,15 @@ class StateManagerRedis(StateManager): client_token, state_path = _split_substate_key(token) parent_state_name = state_path.rpartition(".")[0] if parent_state_name: + cached_substates = None + if state is not None: + cached_substates = [state] # Retrieve the parent state to populate event handlers onto this substate. parent_state = await self.get_state( token=_substate_key(client_token, parent_state_name), top_level=False, get_substates=False, + cached_substates=cached_substates, ) return parent_state @@ -2923,6 +2930,8 @@ class StateManagerRedis(StateManager): tasks = {} # Retrieve the necessary substates from redis. for substate_cls in fetch_substates: + if substate_cls.get_name() in state.substates: + continue substate_name = substate_cls.get_name() tasks[substate_name] = asyncio.create_task( self.get_state( @@ -2943,6 +2952,7 @@ class StateManagerRedis(StateManager): top_level: bool = True, get_substates: bool = True, parent_state: BaseState | None = None, + cached_substates: list[BaseState] | None = None, ) -> BaseState: """Get the state for a token. @@ -2951,6 +2961,7 @@ class StateManagerRedis(StateManager): top_level: If true, return an instance of the top-level state (self.state). get_substates: If true, also retrieve substates. parent_state: If provided, use this parent_state instead of getting it from redis. + cached_substates: If provided, attach these substates to the state. Returns: The state for the token. @@ -2977,11 +2988,17 @@ class StateManagerRedis(StateManager): # Populate parent state if missing and requested. if parent_state is None: - parent_state = await self._get_parent_state(token) + parent_state = await self._get_parent_state(token, state) # Set up Bidirectional linkage between this state and its parent. if parent_state is not None: parent_state.substates[state.get_name()] = state state.parent_state = parent_state + # Avoid fetching substates multiple times. + if cached_substates: + for substate in cached_substates: + state.substates[substate.get_name()] = substate + if substate.parent_state is None: + substate.parent_state = state # Populate substates if requested. await self._populate_substates(token, state, all_substates=get_substates)