From 68de3f41c465c845b4c27b98559e6e74c1901efe Mon Sep 17 00:00:00 2001 From: Masen Furer Date: Wed, 23 Oct 2024 16:09:02 -0700 Subject: [PATCH] [ENG-3989] Ensure non-serialized states are present in StateManagerDisk (#4230) If a non-root state was serialized, but its substates were not, then these would not be populated when reloading the pickled state, because only substates from the root were being populated with fresh versions. Now, capture the substates from all fresh states and apply them to the deserialized state for each substate to ensure that the entire state tree has all substates instantiated after deserializing, even substates that were never serialized originally. --- reflex/state.py | 8 ++++++-- tests/units/test_state.py | 33 +++++++++++++++++++++++++++++++++ 2 files changed, 39 insertions(+), 2 deletions(-) diff --git a/reflex/state.py b/reflex/state.py index 9740bddaa..799fc1eed 100644 --- a/reflex/state.py +++ b/reflex/state.py @@ -2842,9 +2842,13 @@ class StateManagerDisk(StateManager): for substate in state.get_substates(): substate_token = _substate_key(client_token, substate) + fresh_instance = await root_state.get_state(substate) instance = await self.load_state(substate_token) - if instance is None: - instance = await root_state.get_state(substate) + if instance is not None: + # Ensure all substates exist, even if they weren't serialized previously. + instance.substates = fresh_instance.substates + else: + instance = fresh_instance state.substates[substate.get_name()] = instance instance.parent_state = state diff --git a/tests/units/test_state.py b/tests/units/test_state.py index 610d69110..ebfeeb72c 100644 --- a/tests/units/test_state.py +++ b/tests/units/test_state.py @@ -3313,3 +3313,36 @@ def test_assignment_to_undeclared_vars(): state.handle_supported_regular_vars() state.handle_non_var() + + +@pytest.mark.asyncio +async def test_deserialize_gc_state_disk(token): + """Test that a state can be deserialized from disk with a grandchild state. + + Args: + token: A token. + """ + + class Root(BaseState): + pass + + class State(Root): + num: int = 42 + + class Child(State): + foo: str = "bar" + + dsm = StateManagerDisk(state=Root) + async with dsm.modify_state(token) as root: + s = await root.get_state(State) + s.num += 1 + c = await root.get_state(Child) + assert s._get_was_touched() + assert not c._get_was_touched() + + dsm2 = StateManagerDisk(state=Root) + root = await dsm2.get_state(token) + s = await root.get_state(State) + assert s.num == 43 + c = await root.get_state(Child) + assert c.foo == "bar"