diff --git a/reflex/state.py b/reflex/state.py index 9740bddaa..799fc1eed 100644 --- a/reflex/state.py +++ b/reflex/state.py @@ -2842,9 +2842,13 @@ class StateManagerDisk(StateManager): for substate in state.get_substates(): substate_token = _substate_key(client_token, substate) + fresh_instance = await root_state.get_state(substate) instance = await self.load_state(substate_token) - if instance is None: - instance = await root_state.get_state(substate) + if instance is not None: + # Ensure all substates exist, even if they weren't serialized previously. + instance.substates = fresh_instance.substates + else: + instance = fresh_instance state.substates[substate.get_name()] = instance instance.parent_state = state diff --git a/tests/units/test_state.py b/tests/units/test_state.py index 610d69110..ebfeeb72c 100644 --- a/tests/units/test_state.py +++ b/tests/units/test_state.py @@ -3313,3 +3313,36 @@ def test_assignment_to_undeclared_vars(): state.handle_supported_regular_vars() state.handle_non_var() + + +@pytest.mark.asyncio +async def test_deserialize_gc_state_disk(token): + """Test that a state can be deserialized from disk with a grandchild state. + + Args: + token: A token. + """ + + class Root(BaseState): + pass + + class State(Root): + num: int = 42 + + class Child(State): + foo: str = "bar" + + dsm = StateManagerDisk(state=Root) + async with dsm.modify_state(token) as root: + s = await root.get_state(State) + s.num += 1 + c = await root.get_state(Child) + assert s._get_was_touched() + assert not c._get_was_touched() + + dsm2 = StateManagerDisk(state=Root) + root = await dsm2.get_state(token) + s = await root.get_state(State) + assert s.num == 43 + c = await root.get_state(Child) + assert c.foo == "bar"