Handle root_state deserialized from disk

In this case, we need to initialize the whole state tree, so any non-persistent
states will still get default values, whereas on-disk states will overwrite the
defaults.
This commit is contained in:
Masen Furer 2024-10-03 19:16:01 -07:00
parent ae24d72201
commit 20d8192360
No known key found for this signature in database
GPG Key ID: B0008AD22B3B3A95

View File

@ -2743,11 +2743,10 @@ class StateManagerDisk(StateManager):
substate_token = _substate_key(client_token, substate) substate_token = _substate_key(client_token, substate)
instance = await self.load_state(substate_token) instance = await self.load_state(substate_token)
if instance is not None: if instance is None:
state.substates[substate.get_name()] = instance
substate.parent_state = state
else:
instance = await root_state.get_state(substate) instance = await root_state.get_state(substate)
state.substates[substate.get_name()] = instance
instance.parent_state = state
await self.populate_substates(client_token, instance, root_state) await self.populate_substates(client_token, instance, root_state)
@ -2765,12 +2764,23 @@ class StateManagerDisk(StateManager):
The state for the token. The state for the token.
""" """
client_token = _split_substate_key(token)[0] client_token = _split_substate_key(token)[0]
root_state = self.states.get(client_token) root_state_token = _substate_key(client_token, self.state)
root_state = self.states.get(root_state_token)
if root_state is not None:
# Retrieved state from memory.
return root_state
# Deserialize root state from disk.
root_state = await self.load_state(root_state_token)
# Create a new root state tree with all substates instantiated.
fresh_root_state = self.state(_reflex_internal_init=True)
if root_state is None: if root_state is None:
# Create a new root state which will be persisted in the next set_state call. root_state = fresh_root_state
root_state = self.state(_reflex_internal_init=True) else:
self.states[client_token] = root_state # Ensure all substates exist, even if they were not serialized previously.
await self.populate_substates(client_token, root_state, root_state) root_state.substates = fresh_root_state.substates
self.states[root_state_token] = root_state
await self.populate_substates(client_token, root_state, root_state)
return root_state return root_state
async def set_state_for_substate(self, client_token: str, substate: BaseState): async def set_state_for_substate(self, client_token: str, substate: BaseState):