diff --git a/reflex/state.py b/reflex/state.py index d46e61439..7545aed54 100644 --- a/reflex/state.py +++ b/reflex/state.py @@ -1232,9 +1232,9 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow): # Determine which parent states to fetch from the common ancestor down to the target_state_cls. fetch_parent_states = [common_ancestor_name] - for ix, relative_parent_state_name in enumerate(relative_target_state_parts): + for relative_parent_state_name in relative_target_state_parts: fetch_parent_states.append( - ".".join([*fetch_parent_states[: ix + 1], relative_parent_state_name]) + ".".join((fetch_parent_states[-1], relative_parent_state_name)) ) return common_ancestor_name, fetch_parent_states[1:-1] @@ -1278,9 +1278,18 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow): ) = self._determine_missing_parent_states(target_state_cls) # Fetch all missing parent states and link them up to the common ancestor. - parent_states_by_name = dict(self._get_parent_states()) + parent_states_tuple = self._get_parent_states() + root_state = parent_states_tuple[-1][1] + parent_states_by_name = dict(parent_states_tuple) parent_state = parent_states_by_name[common_ancestor_name] for parent_state_name in missing_parent_states: + try: + parent_state = root_state.get_substate(parent_state_name.split(".")) + # The requested state is already cached, do NOT fetch it again. + continue + except ValueError: + # The requested state is missing, fetch from redis. + pass parent_state = await state_manager.get_state( token=_substate_key( self.router.session.client_token, parent_state_name diff --git a/tests/test_state.py b/tests/test_state.py index a770b5ed4..23fa1fa75 100644 --- a/tests/test_state.py +++ b/tests/test_state.py @@ -2729,6 +2729,99 @@ async def test_get_state(mock_app: rx.App, token: str): } +@pytest.mark.asyncio +async def test_get_state_from_sibling_not_cached(mock_app: rx.App, token: str): + """A test simulating update_vars_internal when setting cookies with computed vars. + + In that case, a sibling state, UpdateVarsInternalState handles the fetching + of states that need to have values set. Only the states that have a computed + var are pre-fetched (like Child3 in this test), so `get_state` needs to + avoid refetching those already-cached states when getting substates, + otherwise the set values will be overridden by the freshly deserialized + version and lost. + + Explicit regression test for https://github.com/reflex-dev/reflex/issues/2851. + + Args: + mock_app: An app that will be returned by `get_app()` + token: A token. + """ + + class Parent(BaseState): + """A root state like rx.State.""" + + parent_var: int = 0 + + class Child(Parent): + """A state simulating UpdateVarsInternalState.""" + + pass + + class Child2(Parent): + """An unconnected child state.""" + + pass + + class Child3(Parent): + """A child state with a computed var causing it to be pre-fetched. + + If child3_var gets set to a value, and `get_state` erroneously + re-fetches it from redis, the value will be lost. + """ + + child3_var: int = 0 + + @rx.var + def v(self): + pass + + class Grandchild3(Child3): + """An extra layer of substate to catch an issue discovered in + _determine_missing_parent_states while writing the regression test where + invalid parent state names were being constructed. + """ + + pass + + class GreatGrandchild3(Grandchild3): + """Fetching this state wants to also fetch Child3 as a missing parent. + However, Child3 should already be cached in the state tree because it + has a computed var. + """ + + pass + + mock_app.state_manager.state = mock_app.state = Parent + + # Get the top level state via unconnected sibling. + root = await mock_app.state_manager.get_state(_substate_key(token, Child)) + # Set value in parent_var to assert it does not get refetched later. + root.parent_var = 1 + + if isinstance(mock_app.state_manager, StateManagerRedis): + # When redis is used, only states with computed vars are pre-fetched. + assert "child2" not in root.substates + assert "child3" in root.substates # (due to @rx.var) + + # Get the unconnected sibling state, which will be used to `get_state` other instances. + child = root.get_substate(Child.get_full_name().split(".")) + + # Get an uncached child state. + child2 = await child.get_state(Child2) + assert child2.parent_var == 1 + + # Set value on already-cached Child3 state (prefetched because it has a Computed Var). + child3 = await child.get_state(Child3) + child3.child3_var = 1 + + # Get uncached great_grandchild3 state. + great_grandchild3 = await child.get_state(GreatGrandchild3) + + # Assert that we didn't re-fetch the parent and child3 state from redis + assert great_grandchild3.parent_var == 1 + assert great_grandchild3.child3_var == 1 + + # Save a reference to the rx.State to shadow the name State for testing. RxState = State