diff --git a/reflex/state.py b/reflex/state.py index 5798564fa..5bf7749aa 100644 --- a/reflex/state.py +++ b/reflex/state.py @@ -2711,34 +2711,24 @@ class StateManagerDisk(StateManager): self.states_directory / f"{md5(token.encode()).hexdigest()}.pkl" ).absolute() - async def load_state(self, token: str, root_state: BaseState) -> BaseState: + async def load_state(self, token: str) -> BaseState | None: """Load a state object based on the provided token. Args: token: The token used to identify the state object. - root_state: The root state object. Returns: - The loaded state object. + The loaded state object or None. """ - if token in self.states: - return self.states[token] - - client_token, substate_address = _split_substate_key(token) - token_path = self.token_path(token) if token_path.exists(): try: with token_path.open(mode="rb") as file: - substate = BaseState._deserialize(fp=file) - await self.populate_substates(client_token, substate, root_state) - return substate + return BaseState._deserialize(fp=file) except Exception: pass - return root_state.get_substate(substate_address.split(".")[1:]) - async def populate_substates( self, client_token: str, state: BaseState, root_state: BaseState ): @@ -2752,10 +2742,14 @@ class StateManagerDisk(StateManager): for substate in state.get_substates(): substate_token = _substate_key(client_token, substate) - substate = await self.load_state(substate_token, root_state) + instance = await self.load_state(substate_token) + if instance is not None: + state.substates[substate.get_name()] = instance + substate.parent_state = state + else: + instance = await root_state.get_state(substate) - state.substates[substate.get_name()] = substate - substate.parent_state = state + await self.populate_substates(client_token, instance, root_state) @override async def get_state( @@ -2770,15 +2764,14 @@ class StateManagerDisk(StateManager): Returns: The state for the token. """ - client_token, substate_address = _split_substate_key(token) - - root_state_token = _substate_key(client_token, substate_address.split(".")[0]) - root_state = self.states.get(root_state_token) + client_token = _split_substate_key(token)[0] + root_state = self.states.get(client_token) if root_state is None: # Create a new root state which will be persisted in the next set_state call. root_state = self.state(_reflex_internal_init=True) - - return await self.load_state(root_state_token, root_state) + self.states[client_token] = root_state + await self.populate_substates(client_token, root_state, root_state) + return root_state async def set_state_for_substate(self, client_token: str, substate: BaseState): """Set the state for a substate. @@ -2789,12 +2782,12 @@ class StateManagerDisk(StateManager): """ substate_token = _substate_key(client_token, substate) - self.states[substate_token] = substate - - state_dilled = substate._serialize() - if not self.states_directory.exists(): - self.states_directory.mkdir(parents=True, exist_ok=True) - self.token_path(substate_token).write_bytes(state_dilled) + if substate._get_was_touched(): + substate._was_touched = False # Reset the touched flag after serializing. + pickle_state = substate._serialize() + if not self.states_directory.exists(): + self.states_directory.mkdir(parents=True, exist_ok=True) + self.token_path(substate_token).write_bytes(pickle_state) for substate_substate in substate.substates.values(): await self.set_state_for_substate(client_token, substate_substate)