[REF-2219] Avoid refetching states that are already cached (#2953)

* Add test_get_state_from_sibling_not_cached

A better unit test to catch issues with refetching parent states
and calculating the wrong parent state names to fetch.

* _determine_missing_parent_states: correctly generate state names

Prepend only the previous state name to the current relative_parent_state_name
instead of joining all of the previous state names together.

* [REF-2219] Avoid refetching states that are already cached

The already cached states may have unsaved changes which can be wiped out if
they are refetched from redis in the middle of handling an event.

If the root state already knows about one of the potentially missing states,
then use the instance that is already cached.

Fix #2851
This commit is contained in:
Masen Furer 2024-03-29 09:42:25 -07:00 committed by GitHub
parent 628c865530
commit 55b0fb36e8
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 105 additions and 3 deletions

View File

@ -1232,9 +1232,9 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
# Determine which parent states to fetch from the common ancestor down to the target_state_cls.
fetch_parent_states = [common_ancestor_name]
for ix, relative_parent_state_name in enumerate(relative_target_state_parts):
for relative_parent_state_name in relative_target_state_parts:
fetch_parent_states.append(
".".join([*fetch_parent_states[: ix + 1], relative_parent_state_name])
".".join((fetch_parent_states[-1], relative_parent_state_name))
)
return common_ancestor_name, fetch_parent_states[1:-1]
@ -1278,9 +1278,18 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
) = self._determine_missing_parent_states(target_state_cls)
# Fetch all missing parent states and link them up to the common ancestor.
parent_states_by_name = dict(self._get_parent_states())
parent_states_tuple = self._get_parent_states()
root_state = parent_states_tuple[-1][1]
parent_states_by_name = dict(parent_states_tuple)
parent_state = parent_states_by_name[common_ancestor_name]
for parent_state_name in missing_parent_states:
try:
parent_state = root_state.get_substate(parent_state_name.split("."))
# The requested state is already cached, do NOT fetch it again.
continue
except ValueError:
# The requested state is missing, fetch from redis.
pass
parent_state = await state_manager.get_state(
token=_substate_key(
self.router.session.client_token, parent_state_name

View File

@ -2729,6 +2729,99 @@ async def test_get_state(mock_app: rx.App, token: str):
}
@pytest.mark.asyncio
async def test_get_state_from_sibling_not_cached(mock_app: rx.App, token: str):
"""A test simulating update_vars_internal when setting cookies with computed vars.
In that case, a sibling state, UpdateVarsInternalState handles the fetching
of states that need to have values set. Only the states that have a computed
var are pre-fetched (like Child3 in this test), so `get_state` needs to
avoid refetching those already-cached states when getting substates,
otherwise the set values will be overridden by the freshly deserialized
version and lost.
Explicit regression test for https://github.com/reflex-dev/reflex/issues/2851.
Args:
mock_app: An app that will be returned by `get_app()`
token: A token.
"""
class Parent(BaseState):
"""A root state like rx.State."""
parent_var: int = 0
class Child(Parent):
"""A state simulating UpdateVarsInternalState."""
pass
class Child2(Parent):
"""An unconnected child state."""
pass
class Child3(Parent):
"""A child state with a computed var causing it to be pre-fetched.
If child3_var gets set to a value, and `get_state` erroneously
re-fetches it from redis, the value will be lost.
"""
child3_var: int = 0
@rx.var
def v(self):
pass
class Grandchild3(Child3):
"""An extra layer of substate to catch an issue discovered in
_determine_missing_parent_states while writing the regression test where
invalid parent state names were being constructed.
"""
pass
class GreatGrandchild3(Grandchild3):
"""Fetching this state wants to also fetch Child3 as a missing parent.
However, Child3 should already be cached in the state tree because it
has a computed var.
"""
pass
mock_app.state_manager.state = mock_app.state = Parent
# Get the top level state via unconnected sibling.
root = await mock_app.state_manager.get_state(_substate_key(token, Child))
# Set value in parent_var to assert it does not get refetched later.
root.parent_var = 1
if isinstance(mock_app.state_manager, StateManagerRedis):
# When redis is used, only states with computed vars are pre-fetched.
assert "child2" not in root.substates
assert "child3" in root.substates # (due to @rx.var)
# Get the unconnected sibling state, which will be used to `get_state` other instances.
child = root.get_substate(Child.get_full_name().split("."))
# Get an uncached child state.
child2 = await child.get_state(Child2)
assert child2.parent_var == 1
# Set value on already-cached Child3 state (prefetched because it has a Computed Var).
child3 = await child.get_state(Child3)
child3.child3_var = 1
# Get uncached great_grandchild3 state.
great_grandchild3 = await child.get_state(GreatGrandchild3)
# Assert that we didn't re-fetch the parent and child3 state from redis
assert great_grandchild3.parent_var == 1
assert great_grandchild3.child3_var == 1
# Save a reference to the rx.State to shadow the name State for testing.
RxState = State