From 20d8192360752323ae60fefc953f72f640d62e5e Mon Sep 17 00:00:00 2001 From: Masen Furer Date: Thu, 3 Oct 2024 19:16:01 -0700 Subject: [PATCH] Handle root_state deserialized from disk In this case, we need to initialize the whole state tree, so any non-persistent states will still get default values, whereas on-disk states will overwrite the defaults. --- reflex/state.py | 28 +++++++++++++++++++--------- 1 file changed, 19 insertions(+), 9 deletions(-) diff --git a/reflex/state.py b/reflex/state.py index 5bf7749aa..4a7076f6e 100644 --- a/reflex/state.py +++ b/reflex/state.py @@ -2743,11 +2743,10 @@ class StateManagerDisk(StateManager): substate_token = _substate_key(client_token, substate) instance = await self.load_state(substate_token) - if instance is not None: - state.substates[substate.get_name()] = instance - substate.parent_state = state - else: + if instance is None: instance = await root_state.get_state(substate) + state.substates[substate.get_name()] = instance + instance.parent_state = state await self.populate_substates(client_token, instance, root_state) @@ -2765,12 +2764,23 @@ class StateManagerDisk(StateManager): The state for the token. """ client_token = _split_substate_key(token)[0] - root_state = self.states.get(client_token) + root_state_token = _substate_key(client_token, self.state) + root_state = self.states.get(root_state_token) + if root_state is not None: + # Retrieved state from memory. + return root_state + + # Deserialize root state from disk. + root_state = await self.load_state(root_state_token) + # Create a new root state tree with all substates instantiated. + fresh_root_state = self.state(_reflex_internal_init=True) if root_state is None: - # Create a new root state which will be persisted in the next set_state call. - root_state = self.state(_reflex_internal_init=True) - self.states[client_token] = root_state - await self.populate_substates(client_token, root_state, root_state) + root_state = fresh_root_state + else: + # Ensure all substates exist, even if they were not serialized previously. + root_state.substates = fresh_root_state.substates + self.states[root_state_token] = root_state + await self.populate_substates(client_token, root_state, root_state) return root_state async def set_state_for_substate(self, client_token: str, substate: BaseState):