Avoid fetching substates multiple times

In the presence of computed vars, substates may be cached more than once.
This commit is contained in:
Masen Furer 2024-03-20 02:43:59 -07:00
parent d77b900bd7
commit 6dba120e6d
No known key found for this signature in database
GPG Key ID: B0008AD22B3B3A95

View File

@ -2874,11 +2874,14 @@ class StateManagerRedis(StateManager):
# Only warn about each state class size once. # Only warn about each state class size once.
_warned_about_state_size: ClassVar[Set[str]] = set() _warned_about_state_size: ClassVar[Set[str]] = set()
async def _get_parent_state(self, token: str) -> BaseState | None: async def _get_parent_state(
self, token: str, state: BaseState | None = None
) -> BaseState | None:
"""Get the parent state for the state requested in the token. """Get the parent state for the state requested in the token.
Args: Args:
token: The token to get the state for (_substate_key). token: The token to get the state for (_substate_key).
state: The state instance to get parent state for.
Returns: Returns:
The parent state for the state requested by the token or None if there is no such parent. The parent state for the state requested by the token or None if there is no such parent.
@ -2887,11 +2890,15 @@ class StateManagerRedis(StateManager):
client_token, state_path = _split_substate_key(token) client_token, state_path = _split_substate_key(token)
parent_state_name = state_path.rpartition(".")[0] parent_state_name = state_path.rpartition(".")[0]
if parent_state_name: if parent_state_name:
cached_substates = None
if state is not None:
cached_substates = [state]
# Retrieve the parent state to populate event handlers onto this substate. # Retrieve the parent state to populate event handlers onto this substate.
parent_state = await self.get_state( parent_state = await self.get_state(
token=_substate_key(client_token, parent_state_name), token=_substate_key(client_token, parent_state_name),
top_level=False, top_level=False,
get_substates=False, get_substates=False,
cached_substates=cached_substates,
) )
return parent_state return parent_state
@ -2923,6 +2930,8 @@ class StateManagerRedis(StateManager):
tasks = {} tasks = {}
# Retrieve the necessary substates from redis. # Retrieve the necessary substates from redis.
for substate_cls in fetch_substates: for substate_cls in fetch_substates:
if substate_cls.get_name() in state.substates:
continue
substate_name = substate_cls.get_name() substate_name = substate_cls.get_name()
tasks[substate_name] = asyncio.create_task( tasks[substate_name] = asyncio.create_task(
self.get_state( self.get_state(
@ -2943,6 +2952,7 @@ class StateManagerRedis(StateManager):
top_level: bool = True, top_level: bool = True,
get_substates: bool = True, get_substates: bool = True,
parent_state: BaseState | None = None, parent_state: BaseState | None = None,
cached_substates: list[BaseState] | None = None,
) -> BaseState: ) -> BaseState:
"""Get the state for a token. """Get the state for a token.
@ -2951,6 +2961,7 @@ class StateManagerRedis(StateManager):
top_level: If true, return an instance of the top-level state (self.state). top_level: If true, return an instance of the top-level state (self.state).
get_substates: If true, also retrieve substates. get_substates: If true, also retrieve substates.
parent_state: If provided, use this parent_state instead of getting it from redis. parent_state: If provided, use this parent_state instead of getting it from redis.
cached_substates: If provided, attach these substates to the state.
Returns: Returns:
The state for the token. The state for the token.
@ -2977,11 +2988,17 @@ class StateManagerRedis(StateManager):
# Populate parent state if missing and requested. # Populate parent state if missing and requested.
if parent_state is None: if parent_state is None:
parent_state = await self._get_parent_state(token) parent_state = await self._get_parent_state(token, state)
# Set up Bidirectional linkage between this state and its parent. # Set up Bidirectional linkage between this state and its parent.
if parent_state is not None: if parent_state is not None:
parent_state.substates[state.get_name()] = state parent_state.substates[state.get_name()] = state
state.parent_state = parent_state state.parent_state = parent_state
# Avoid fetching substates multiple times.
if cached_substates:
for substate in cached_substates:
state.substates[substate.get_name()] = substate
if substate.parent_state is None:
substate.parent_state = state
# Populate substates if requested. # Populate substates if requested.
await self._populate_substates(token, state, all_substates=get_substates) await self._populate_substates(token, state, all_substates=get_substates)