Consolidate logic in StateManagerRedis.get_state

This commit is contained in:
Masen Furer 2024-10-03 18:32:15 -07:00
parent 6dba120e6d
commit 42f8fcf8e9
No known key found for this signature in database
GPG Key ID: B0008AD22B3B3A95

View File

@ -2985,45 +2985,29 @@ class StateManagerRedis(StateManager):
if redis_state is not None:
# Deserialize the substate.
state = BaseState._deserialize(data=redis_state)
# Populate parent state if missing and requested.
if parent_state is None:
parent_state = await self._get_parent_state(token, state)
# Set up Bidirectional linkage between this state and its parent.
if parent_state is not None:
parent_state.substates[state.get_name()] = state
state.parent_state = parent_state
# Avoid fetching substates multiple times.
if cached_substates:
for substate in cached_substates:
state.substates[substate.get_name()] = substate
if substate.parent_state is None:
substate.parent_state = state
# Populate substates if requested.
await self._populate_substates(token, state, all_substates=get_substates)
# To retain compatibility with previous implementation, by default, we return
# the top-level state by chasing `parent_state` pointers up the tree.
if top_level:
return state._get_root_state()
return state
# TODO: dedupe the following logic with the above block
# Key didn't exist so we have to create a new instance for this token.
else:
# Key didn't exist so we have to create a new instance for this token.
# Instantiate the new state class (but don't persist it yet).
state = state_cls(
init_substates=False,
_reflex_internal_init=True,
)
# Populate parent state if missing and requested.
if parent_state is None:
parent_state = await self._get_parent_state(token)
# Instantiate the new state class (but don't persist it yet).
state = state_cls(
parent_state=parent_state,
init_substates=False,
_reflex_internal_init=True,
)
parent_state = await self._get_parent_state(token, state)
# Set up Bidirectional linkage between this state and its parent.
if parent_state is not None:
parent_state.substates[state.get_name()] = state
state.parent_state = parent_state
# Populate substates for the newly created state.
# Avoid fetching substates multiple times.
if cached_substates:
for substate in cached_substates:
state.substates[substate.get_name()] = substate
if substate.parent_state is None:
substate.parent_state = state
# Populate substates if requested.
await self._populate_substates(token, state, all_substates=get_substates)
# To retain compatibility with previous implementation, by default, we return
# the top-level state by chasing `parent_state` pointers up the tree.
if top_level: