[REF-1988] API to Get instance of Arbitrary State class ()

* WiP get_state

* Refactor get_state fast path

Rudimentary protection for state instance access from a background task
(StateProxy)

* retain dirty substate marking per `_mark_dirty` call to avoid test changes

* Find common ancestor by part instead of by character

Fix StateProxy for substates and parent_state attributes (have to handle in
__getattr__, not property)

Fix type annotation for `get_state`

* test_state: workflow test for `get_state` functionality

* Do not reset _always_dirty_substates when adding vars

Reset the substate tracking only when the class is instantiated.

* test_state_tree: test substate access in a larger state tree

Ensure that `get_state` returns the proper "branch" of the state tree depending
on what substate is requested.

* test_format: fixup broken tests from adding substates of TestState

* Fix flaky integration tests with more polling

* AppHarness: reset _always_dirty_substates on rx.State

* RuntimeError unless State is instantiated with _reflex_internal_init=True

Avoid user errors trying to directly instantiate State classes

* Helper functions for _substate_key and _split_substate_key

Unify the implementation of generating and decoding the token + state name
format used for redis state sharding.

* StateManagerRedis: use create_task in get_state and set_state

read and write substates concurrently (allow redis to shine)

* test_state_inheritance: use polling cuz life too short for flaky tests

kthnxbai ❤️

* Move _is_testing_env to reflex.utils.exec.is_testing_env

Reuse the code in app.py

* Break up `BaseState.get_state` and friends into separate methods

* Add test case for pre-fetching cached var dependency

* Move on_load_internal and update_vars_internal to substates

Avoid loading the entire state tree to process these common internal events. If
the state tree is very large, this allow page navigation to occur more quickly.

Pre-fetch substates that contain cached vars, as they may need to be recomputed
if certain vars change.

* Do not copy ROUTER_DATA into all substates.

This is a waste of time and memory, and can be handled via a special case in
__getattribute__

* Track whether State instance _was_touched

Avoid wasting time serializing states that have no modifications

* Do not persist states in `StateManagerRedis.get_state`

Wait until the state is actually modified, and then persist it as part of `set_state`.

Factor out common logic into helper methods for readability and to reduce
duplication of common logic.

To avoid having to recursively call `get_state`, which would require persisting
the instance and then getting it again, some of the initialization logic
regarding parent_state and substates is duplicated when creating a new
instance. This is for performance reasons.

* Remove stray print()

* context.js.jinja2: fix check for empty local storage / cookie vars

* Add comments for onLoadInternalEvent and initialEvents

* nit: typo

* split _get_was_touched into _update_was_touched

Improve clarity in cases where _get_was_touched was being called for its side
effects only.

* Remove extraneous information from incorrect State instantiation error

* Update missing redis exception message
This commit is contained in:
Masen Furer 2024-02-27 13:02:08 -08:00 committed by GitHub
parent bf07315cb4
commit deae662e2a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
15 changed files with 1211 additions and 156 deletions

View File

@ -518,8 +518,8 @@ async def test_client_side_state(
set_sub("l6", "l6 value")
l5 = driver.find_element(By.ID, "l5")
l6 = driver.find_element(By.ID, "l6")
assert AppHarness._poll_for(lambda: l6.text == "l6 value")
assert l5.text == "l5 value"
assert l6.text == "l6 value"
# Switch back to main window.
driver.switch_to.window(main_tab)
@ -527,8 +527,8 @@ async def test_client_side_state(
# The values should have updated automatically.
l5 = driver.find_element(By.ID, "l5")
l6 = driver.find_element(By.ID, "l6")
assert AppHarness._poll_for(lambda: l6.text == "l6 value")
assert l5.text == "l5 value"
assert l6.text == "l6 value"
# clear the cookie jar and local storage, ensure state reset to default
driver.delete_all_cookies()

View File

@ -1,14 +1,29 @@
"""Test state inheritance."""
import time
from contextlib import suppress
from typing import Generator
import pytest
from selenium.common.exceptions import NoAlertPresentException
from selenium.webdriver.common.alert import Alert
from selenium.webdriver.common.by import By
from reflex.testing import DEFAULT_TIMEOUT, AppHarness, WebDriver
def get_alert_or_none(driver: WebDriver) -> Alert | None:
"""Switch to an alert if present.
Args:
driver: WebDriver instance.
Returns:
The alert if present, otherwise None.
"""
with suppress(NoAlertPresentException):
return driver.switch_to.alert
def raises_alert(driver: WebDriver, element: str) -> None:
"""Click an element and check that an alert is raised.
@ -18,8 +33,8 @@ def raises_alert(driver: WebDriver, element: str) -> None:
"""
btn = driver.find_element(By.ID, element)
btn.click()
time.sleep(0.2) # wait for the alert to appear
alert = driver.switch_to.alert
alert = AppHarness._poll_for(lambda: get_alert_or_none(driver))
assert isinstance(alert, Alert)
assert alert.text == "clicked"
alert.accept()
@ -355,7 +370,7 @@ def test_state_inheritance(
child3_other_mixin_btn = driver.find_element(By.ID, "child3-other-mixin-btn")
child3_other_mixin_btn.click()
child2_other_mixin_value = state_inheritance.poll_for_content(
child2_other_mixin, exp_not_equal="other_mixin"
child2_other_mixin, exp_not_equal="Child2.clicked.1"
)
child2_computed_mixin_value = state_inheritance.poll_for_content(
child2_computed_other_mixin, exp_not_equal="other_mixin"

View File

@ -25,11 +25,31 @@ export const clientStorage = {}
{% if state_name %}
export const state_name = "{{state_name}}"
export const onLoadInternalEvent = () => [
Event('{{state_name}}.{{const.update_vars_internal}}', {vars: hydrateClientStorage(clientStorage)}),
Event('{{state_name}}.{{const.on_load_internal}}')
]
// Theses events are triggered on initial load and each page navigation.
export const onLoadInternalEvent = () => {
const internal_events = [];
// Get tracked cookie and local storage vars to send to the backend.
const client_storage_vars = hydrateClientStorage(clientStorage);
// But only send the vars if any are actually set in the browser.
if (client_storage_vars && Object.keys(client_storage_vars).length !== 0) {
internal_events.push(
Event(
'{{state_name}}.{{const.update_vars_internal}}',
{vars: client_storage_vars},
),
);
}
// `on_load_internal` triggers the correct on_load event(s) for the current page.
// If the page does not define any on_load event, this will just set `is_hydrated = true`.
internal_events.push(Event('{{state_name}}.{{const.on_load_internal}}'));
return internal_events;
}
// The following events are sent when the websocket connects or reconnects.
export const initialEvents = () => [
Event('{{state_name}}.{{const.hydrate}}'),
...onLoadInternalEvent()

View File

@ -587,7 +587,7 @@ export const useEventLoop = (
if (storage_to_state_map[e.key]) {
const vars = {}
vars[storage_to_state_map[e.key]] = e.newValue
const event = Event(`${state_name}.update_vars_internal`, {vars: vars})
const event = Event(`${state_name}.update_vars_internal_state.update_vars_internal`, {vars: vars})
addEvents([event], e);
}
};

View File

@ -69,9 +69,11 @@ from reflex.state import (
State,
StateManager,
StateUpdate,
_substate_key,
code_uses_state_contexts,
)
from reflex.utils import console, exceptions, format, prerequisites, types
from reflex.utils.exec import is_testing_env
from reflex.utils.imports import ImportVar
# Define custom types.
@ -159,10 +161,9 @@ class App(Base):
)
super().__init__(*args, **kwargs)
state_subclasses = BaseState.__subclasses__()
is_testing_env = constants.PYTEST_CURRENT_TEST in os.environ
# Special case to allow test cases have multiple subclasses of rx.BaseState.
if not is_testing_env:
if not is_testing_env():
# Only one Base State class is allowed.
if len(state_subclasses) > 1:
raise ValueError(
@ -176,7 +177,8 @@ class App(Base):
deprecation_version="0.3.5",
removal_version="0.5.0",
)
if len(State.class_subclasses) > 0:
# 2 substates are built-in and not considered when determining if app is stateless.
if len(State.class_subclasses) > 2:
self.state = State
# Get the config
config = get_config()
@ -1002,7 +1004,7 @@ def upload(app: App):
)
# Get the state for the session.
substate_token = token + "_" + handler.rpartition(".")[0]
substate_token = _substate_key(token, handler.rpartition(".")[0])
state = await app.state_manager.get_state(substate_token)
# get the current session ID

View File

@ -138,12 +138,12 @@ def compile_state(state: Type[BaseState]) -> dict:
A dictionary of the compiled state.
"""
try:
initial_state = state().dict(initial=True)
initial_state = state(_reflex_internal_init=True).dict(initial=True)
except Exception as e:
console.warn(
f"Failed to compile initial state with computed vars, excluding them: {e}"
)
initial_state = state().dict(include_computed=False)
initial_state = state(_reflex_internal_init=True).dict(include_computed=False)
return format.format_state(initial_state)

View File

@ -59,9 +59,9 @@ class CompileVars(SimpleNamespace):
# The name of the function for converting a dict to an event.
TO_EVENT = "Event"
# The name of the internal on_load event.
ON_LOAD_INTERNAL = "on_load_internal"
ON_LOAD_INTERNAL = "on_load_internal_state.on_load_internal"
# The name of the internal event to update generic state vars.
UPDATE_VARS_INTERNAL = "update_vars_internal"
UPDATE_VARS_INTERNAL = "update_vars_internal_state.update_vars_internal"
class PageNames(SimpleNamespace):

View File

@ -8,7 +8,6 @@ import copy
import functools
import inspect
import json
import os
import traceback
import urllib.parse
import uuid
@ -45,6 +44,7 @@ from reflex.event import (
)
from reflex.utils import console, format, prerequisites, types
from reflex.utils.exceptions import ImmutableStateError, LockExpiredError
from reflex.utils.exec import is_testing_env
from reflex.utils.serializers import SerializedType, serialize, serializer
from reflex.vars import BaseVar, ComputedVar, Var, computed_var
@ -151,9 +151,45 @@ RESERVED_BACKEND_VAR_NAMES = {
"_substate_var_dependencies",
"_always_dirty_computed_vars",
"_always_dirty_substates",
"_was_touched",
}
def _substate_key(
token: str,
state_cls_or_name: BaseState | Type[BaseState] | str | list[str],
) -> str:
"""Get the substate key.
Args:
token: The token of the state.
state_cls_or_name: The state class/instance or name or sequence of name parts.
Returns:
The substate key.
"""
if isinstance(state_cls_or_name, BaseState) or (
isinstance(state_cls_or_name, type) and issubclass(state_cls_or_name, BaseState)
):
state_cls_or_name = state_cls_or_name.get_full_name()
elif isinstance(state_cls_or_name, (list, tuple)):
state_cls_or_name = ".".join(state_cls_or_name)
return f"{token}_{state_cls_or_name}"
def _split_substate_key(substate_key: str) -> tuple[str, str]:
"""Split the substate key into token and state name.
Args:
substate_key: The substate key.
Returns:
Tuple of token and state name.
"""
token, _, state_name = substate_key.partition("_")
return token, state_name
class BaseState(Base, ABC, extra=pydantic.Extra.allow):
"""The state of the app."""
@ -214,29 +250,46 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
# The router data for the current page
router: RouterData = RouterData()
# Whether the state has ever been touched since instantiation.
_was_touched: bool = False
def __init__(
self,
*args,
parent_state: BaseState | None = None,
init_substates: bool = True,
_reflex_internal_init: bool = False,
**kwargs,
):
"""Initialize the state.
DO NOT INSTANTIATE STATE CLASSES DIRECTLY! Use StateManager.get_state() instead.
Args:
*args: The args to pass to the Pydantic init method.
parent_state: The parent state.
init_substates: Whether to initialize the substates in this instance.
_reflex_internal_init: A flag to indicate that the state is being initialized by the framework.
**kwargs: The kwargs to pass to the Pydantic init method.
Raises:
RuntimeError: If the state is instantiated directly by end user.
"""
if not _reflex_internal_init and not is_testing_env():
raise RuntimeError(
"State classes should not be instantiated directly in a Reflex app. "
"See https://reflex.dev/docs/state for further information."
)
kwargs["parent_state"] = parent_state
super().__init__(*args, **kwargs)
# Setup the substates (for memory state manager only).
if init_substates:
for substate in self.get_substates():
self.substates[substate.get_name()] = substate(parent_state=self)
self.substates[substate.get_name()] = substate(
parent_state=self,
_reflex_internal_init=True,
)
# Convert the event handlers to functions.
self._init_event_handlers()
@ -287,7 +340,6 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
Raises:
ValueError: If a substate class shadows another.
"""
is_testing_env = constants.PYTEST_CURRENT_TEST in os.environ
super().__init_subclass__(**kwargs)
# Event handlers should not shadow builtin state methods.
cls._check_overridden_methods()
@ -295,6 +347,9 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
# Reset subclass tracking for this class.
cls.class_subclasses = set()
# Reset dirty substate tracking for this class.
cls._always_dirty_substates = set()
# Get the parent vars.
parent_state = cls.get_parent_state()
if parent_state is not None:
@ -303,7 +358,7 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
# Check if another substate class with the same name has already been defined.
if cls.__name__ in set(c.__name__ for c in parent_state.class_subclasses):
if is_testing_env:
if is_testing_env():
# Clear existing subclass with same name when app is reloaded via
# utils.prerequisites.get_app(reload=True)
parent_state.class_subclasses = set(
@ -325,6 +380,7 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
name: value
for name, value in cls.__dict__.items()
if types.is_backend_variable(name, cls)
and name not in RESERVED_BACKEND_VAR_NAMES
and name not in cls.inherited_backend_vars
and not isinstance(value, FunctionType)
}
@ -484,7 +540,6 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
)
# Any substate containing a ComputedVar with cache=False always needs to be recomputed
cls._always_dirty_substates = set()
if cls._always_dirty_computed_vars:
# Tell parent classes that this substate has always dirty computed vars
state_name = cls.get_name()
@ -923,8 +978,12 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
**super().__getattribute__("inherited_vars"),
**super().__getattribute__("inherited_backend_vars"),
}
if name in inherited_vars:
return getattr(super().__getattribute__("parent_state"), name)
# For now, handle router_data updates as a special case.
if name in inherited_vars or name == constants.ROUTER_DATA:
parent_state = super().__getattribute__("parent_state")
if parent_state is not None:
return getattr(parent_state, name)
backend_vars = super().__getattribute__("_backend_vars")
if name in backend_vars:
@ -980,9 +1039,6 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
if name == constants.ROUTER_DATA:
self.dirty_vars.add(name)
self._mark_dirty()
# propagate router_data updates down the state tree
for substate in self.substates.values():
setattr(substate, name, value)
def reset(self):
"""Reset all the base vars to their default values."""
@ -1036,6 +1092,170 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
raise ValueError(f"Invalid path: {path}")
return self.substates[path[0]].get_substate(path[1:])
@classmethod
def _get_common_ancestor(cls, other: Type[BaseState]) -> str:
"""Find the name of the nearest common ancestor shared by this and the other state.
Args:
other: The other state.
Returns:
Full name of the nearest common ancestor.
"""
common_ancestor_parts = []
for part1, part2 in zip(
cls.get_full_name().split("."),
other.get_full_name().split("."),
):
if part1 != part2:
break
common_ancestor_parts.append(part1)
return ".".join(common_ancestor_parts)
@classmethod
def _determine_missing_parent_states(
cls, target_state_cls: Type[BaseState]
) -> tuple[str, list[str]]:
"""Determine the missing parent states between the target_state_cls and common ancestor of this state.
Args:
target_state_cls: The class of the state to find missing parent states for.
Returns:
The name of the common ancestor and the list of missing parent states.
"""
common_ancestor_name = cls._get_common_ancestor(target_state_cls)
common_ancestor_parts = common_ancestor_name.split(".")
target_state_parts = tuple(target_state_cls.get_full_name().split("."))
relative_target_state_parts = target_state_parts[len(common_ancestor_parts) :]
# Determine which parent states to fetch from the common ancestor down to the target_state_cls.
fetch_parent_states = [common_ancestor_name]
for ix, relative_parent_state_name in enumerate(relative_target_state_parts):
fetch_parent_states.append(
".".join([*fetch_parent_states[: ix + 1], relative_parent_state_name])
)
return common_ancestor_name, fetch_parent_states[1:-1]
def _get_parent_states(self) -> list[tuple[str, BaseState]]:
"""Get all parent state instances up to the root of the state tree.
Returns:
A list of tuples containing the name and the instance of each parent state.
"""
parent_states_with_name = []
parent_state = self
while parent_state.parent_state is not None:
parent_state = parent_state.parent_state
parent_states_with_name.append((parent_state.get_full_name(), parent_state))
return parent_states_with_name
async def _populate_parent_states(self, target_state_cls: Type[BaseState]):
"""Populate substates in the tree between the target_state_cls and common ancestor of this state.
Args:
target_state_cls: The class of the state to populate parent states for.
Returns:
The parent state instance of target_state_cls.
Raises:
RuntimeError: If redis is not used in this backend process.
"""
state_manager = get_state_manager()
if not isinstance(state_manager, StateManagerRedis):
raise RuntimeError(
f"Cannot populate parent states of {target_state_cls.get_full_name()} without redis. "
"(All states should already be available -- this is likely a bug).",
)
# Find the missing parent states up to the common ancestor.
(
common_ancestor_name,
missing_parent_states,
) = self._determine_missing_parent_states(target_state_cls)
# Fetch all missing parent states and link them up to the common ancestor.
parent_states_by_name = dict(self._get_parent_states())
parent_state = parent_states_by_name[common_ancestor_name]
for parent_state_name in missing_parent_states:
parent_state = await state_manager.get_state(
token=_substate_key(
self.router.session.client_token, parent_state_name
),
top_level=False,
get_substates=False,
parent_state=parent_state,
)
# Return the direct parent of target_state_cls for subsequent linking.
return parent_state
def _get_state_from_cache(self, state_cls: Type[BaseState]) -> BaseState:
"""Get a state instance from the cache.
Args:
state_cls: The class of the state.
Returns:
The instance of state_cls associated with this state's client_token.
"""
if self.parent_state is None:
root_state = self
else:
root_state = self._get_parent_states()[-1][1]
return root_state.get_substate(state_cls.get_full_name().split("."))
async def _get_state_from_redis(self, state_cls: Type[BaseState]) -> BaseState:
"""Get a state instance from redis.
Args:
state_cls: The class of the state.
Returns:
The instance of state_cls associated with this state's client_token.
Raises:
RuntimeError: If redis is not used in this backend process.
"""
# Fetch all missing parent states from redis.
parent_state_of_state_cls = await self._populate_parent_states(state_cls)
# Then get the target state and all its substates.
state_manager = get_state_manager()
if not isinstance(state_manager, StateManagerRedis):
raise RuntimeError(
f"Requested state {state_cls.get_full_name()} is not cached and cannot be accessed without redis. "
"(All states should already be available -- this is likely a bug).",
)
return await state_manager.get_state(
token=_substate_key(self.router.session.client_token, state_cls),
top_level=False,
get_substates=True,
parent_state=parent_state_of_state_cls,
)
async def get_state(self, state_cls: Type[BaseState]) -> BaseState:
"""Get an instance of the state associated with this token.
Allows for arbitrary access to sibling states from within an event handler.
Args:
state_cls: The class of the state.
Returns:
The instance of state_cls associated with this state's client_token.
"""
# Fast case - if this state instance is already cached, get_substate from root state.
try:
return self._get_state_from_cache(state_cls)
except ValueError:
pass
# Slow case - fetch missing parent states from redis.
return await self._get_state_from_redis(state_cls)
def _get_event_handler(
self, event: Event
) -> tuple[BaseState | StateProxy, EventHandler]:
@ -1238,6 +1458,28 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
for cvar in self._computed_var_dependencies[dirty_var]
)
@classmethod
def _potentially_dirty_substates(cls) -> set[Type[BaseState]]:
"""Determine substates which could be affected by dirty vars in this state.
Returns:
Set of State classes that may need to be fetched to recalc computed vars.
"""
# _always_dirty_substates need to be fetched to recalc computed vars.
fetch_substates = set(
cls.get_class_substate(tuple(substate_name.split(".")))
for substate_name in cls._always_dirty_substates
)
# Substates with cached vars also need to be fetched.
for dependent_substates in cls._substate_var_dependencies.values():
fetch_substates.update(
set(
cls.get_class_substate(tuple(substate_name.split(".")))
for substate_name in dependent_substates
)
)
return fetch_substates
def get_delta(self) -> Delta:
"""Get the delta for the state.
@ -1269,8 +1511,6 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
# Recursively find the substate deltas.
substates = self.substates
for substate in self.dirty_substates.union(self._always_dirty_substates):
if substate not in substates:
continue # substate not loaded at this time, no delta
delta.update(substates[substate].get_delta())
# Format the delta.
@ -1292,20 +1532,45 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
# have to mark computed vars dirty to allow access to newly computed
# values within the same ComputedVar function
self._mark_dirty_computed_vars()
self._mark_dirty_substates()
# Propagate dirty var / computed var status into substates
def _mark_dirty_substates(self):
"""Propagate dirty var / computed var status into substates."""
substates = self.substates
for var in self.dirty_vars:
for substate_name in self._substate_var_dependencies[var]:
self.dirty_substates.add(substate_name)
if substate_name not in substates:
continue
substate = substates[substate_name]
substate.dirty_vars.add(var)
substate._mark_dirty()
def _update_was_touched(self):
"""Update the _was_touched flag based on dirty_vars."""
if self.dirty_vars and not self._was_touched:
for var in self.dirty_vars:
if var in self.base_vars or var in self._backend_vars:
self._was_touched = True
break
def _get_was_touched(self) -> bool:
"""Check current dirty_vars and flag to determine if state instance was modified.
If any dirty vars belong to this state, mark _was_touched.
This flag determines whether this state instance should be persisted to redis.
Returns:
Whether this state instance was ever modified.
"""
# Ensure the flag is up to date based on the current dirty_vars
self._update_was_touched()
return self._was_touched
def _clean(self):
"""Reset the dirty vars."""
# Update touched status before cleaning dirty_vars.
self._update_was_touched()
# Recursively clean the substates.
for substate in self.dirty_substates:
if substate not in self.substates:
@ -1422,6 +1687,7 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
state["__dict__"] = state["__dict__"].copy()
state["__dict__"]["parent_state"] = None
state["__dict__"]["substates"] = {}
state["__dict__"].pop("_was_touched", None)
return state
@ -1431,28 +1697,11 @@ class State(BaseState):
# The hydrated bool.
is_hydrated: bool = False
def on_load_internal(self) -> list[Event | EventSpec] | None:
"""Queue on_load handlers for the current page.
Returns:
The list of events to queue for on load handling.
"""
# Do not app.compile_()! It should be already compiled by now.
app = getattr(prerequisites.get_app(), constants.CompileVars.APP)
load_events = app.get_load_events(self.router.page.path)
if not load_events and self.is_hydrated:
return # Fast path for page-to-page navigation
self.is_hydrated = False
return [
*fix_events(
load_events,
self.router.session.client_token,
router_data=self.router_data,
),
type(self).set_is_hydrated(True), # type: ignore
]
class UpdateVarsInternalState(State):
"""Substate for handling internal state var updates."""
def update_vars_internal(self, vars: dict[str, Any]) -> None:
async def update_vars_internal(self, vars: dict[str, Any]) -> None:
"""Apply updates to fully qualified state vars.
The keys in `vars` should be in the form of `{state.get_full_name()}.{var_name}`,
@ -1466,10 +1715,42 @@ class State(BaseState):
"""
for var, value in vars.items():
state_name, _, var_name = var.rpartition(".")
var_state = self.get_substate(state_name.split("."))
var_state_cls = State.get_class_substate(tuple(state_name.split(".")))
var_state = await self.get_state(var_state_cls)
setattr(var_state, var_name, value)
class OnLoadInternalState(State):
"""Substate for handling on_load event enumeration.
This is a separate substate to avoid deserializing the entire state tree for every page navigation.
"""
def on_load_internal(self) -> list[Event | EventSpec] | None:
"""Queue on_load handlers for the current page.
Returns:
The list of events to queue for on load handling.
"""
# Do not app.compile_()! It should be already compiled by now.
app = getattr(prerequisites.get_app(), constants.CompileVars.APP)
load_events = app.get_load_events(self.router.page.path)
if not load_events and self.is_hydrated:
return # Fast path for page-to-page navigation
if not load_events:
self.is_hydrated = True
return # Fast path for initial hydrate with no on_load events defined.
self.is_hydrated = False
return [
*fix_events(
load_events,
self.router.session.client_token,
router_data=self.router_data,
),
State.set_is_hydrated(True), # type: ignore
]
class StateProxy(wrapt.ObjectProxy):
"""Proxy of a state instance to control mutability of vars for a background task.
@ -1522,9 +1803,10 @@ class StateProxy(wrapt.ObjectProxy):
This StateProxy instance in mutable mode.
"""
self._self_actx = self._self_app.modify_state(
self.__wrapped__.router.session.client_token
+ "_"
+ ".".join(self._self_substate_path)
token=_substate_key(
self.__wrapped__.router.session.client_token,
self._self_substate_path,
)
)
mutable_state = await self._self_actx.__aenter__()
super().__setattr__(
@ -1574,7 +1856,15 @@ class StateProxy(wrapt.ObjectProxy):
Returns:
The value of the attribute.
Raises:
ImmutableStateError: If the state is not in mutable mode.
"""
if name in ["substates", "parent_state"] and not self._self_mutable:
raise ImmutableStateError(
"Background task StateProxy is immutable outside of a context "
"manager. Use `async with self` to modify state."
)
value = super().__getattr__(name)
if not name.startswith("_self_") and isinstance(value, MutableProxy):
# ensure mutations to these containers are blocked unless proxy is _mutable
@ -1622,6 +1912,60 @@ class StateProxy(wrapt.ObjectProxy):
"manager. Use `async with self` to modify state."
)
def get_substate(self, path: Sequence[str]) -> BaseState:
"""Only allow substate access with lock held.
Args:
path: The path to the substate.
Returns:
The substate.
Raises:
ImmutableStateError: If the state is not in mutable mode.
"""
if not self._self_mutable:
raise ImmutableStateError(
"Background task StateProxy is immutable outside of a context "
"manager. Use `async with self` to modify state."
)
return self.__wrapped__.get_substate(path)
async def get_state(self, state_cls: Type[BaseState]) -> BaseState:
"""Get an instance of the state associated with this token.
Args:
state_cls: The class of the state.
Returns:
The state.
Raises:
ImmutableStateError: If the state is not in mutable mode.
"""
if not self._self_mutable:
raise ImmutableStateError(
"Background task StateProxy is immutable outside of a context "
"manager. Use `async with self` to modify state."
)
return await self.__wrapped__.get_state(state_cls)
def _as_state_update(self, *args, **kwargs) -> StateUpdate:
"""Temporarily allow mutability to access parent_state.
Args:
*args: The args to pass to the underlying state instance.
**kwargs: The kwargs to pass to the underlying state instance.
Returns:
The state update.
"""
self._self_mutable = True
try:
return self.__wrapped__._as_state_update(*args, **kwargs)
finally:
self._self_mutable = False
class StateUpdate(Base):
"""A state update sent to the frontend."""
@ -1722,9 +2066,9 @@ class StateManagerMemory(StateManager):
The state for the token.
"""
# Memory state manager ignores the substate suffix and always returns the top-level state.
token = token.partition("_")[0]
token = _split_substate_key(token)[0]
if token not in self.states:
self.states[token] = self.state()
self.states[token] = self.state(_reflex_internal_init=True)
return self.states[token]
async def set_state(self, token: str, state: BaseState):
@ -1747,7 +2091,7 @@ class StateManagerMemory(StateManager):
The state for the token.
"""
# Memory state manager ignores the substate suffix and always returns the top-level state.
token = token.partition("_")[0]
token = _split_substate_key(token)[0]
if token not in self._states_locks:
async with self._state_manager_lock:
if token not in self._states_locks:
@ -1787,6 +2131,81 @@ class StateManagerRedis(StateManager):
b"evicted",
}
def _get_root_state(self, state: BaseState) -> BaseState:
"""Chase parent_state pointers to find an instance of the top-level state.
Args:
state: The state to start from.
Returns:
An instance of the top-level state (self.state).
"""
while type(state) != self.state and state.parent_state is not None:
state = state.parent_state
return state
async def _get_parent_state(self, token: str) -> BaseState | None:
"""Get the parent state for the state requested in the token.
Args:
token: The token to get the state for (_substate_key).
Returns:
The parent state for the state requested by the token or None if there is no such parent.
"""
parent_state = None
client_token, state_path = _split_substate_key(token)
parent_state_name = state_path.rpartition(".")[0]
if parent_state_name:
# Retrieve the parent state to populate event handlers onto this substate.
parent_state = await self.get_state(
token=_substate_key(client_token, parent_state_name),
top_level=False,
get_substates=False,
)
return parent_state
async def _populate_substates(
self,
token: str,
state: BaseState,
all_substates: bool = False,
):
"""Fetch and link substates for the given state instance.
There is no return value; the side-effect is that `state` will have `substates` populated,
and each substate will have its `parent_state` set to `state`.
Args:
token: The token to get the state for.
state: The state instance to populate substates for.
all_substates: Whether to fetch all substates or just required substates.
"""
client_token, _ = _split_substate_key(token)
if all_substates:
# All substates are requested.
fetch_substates = state.get_substates()
else:
# Only _potentially_dirty_substates need to be fetched to recalc computed vars.
fetch_substates = state._potentially_dirty_substates()
tasks = {}
# Retrieve the necessary substates from redis.
for substate_cls in fetch_substates:
substate_name = substate_cls.get_name()
tasks[substate_name] = asyncio.create_task(
self.get_state(
token=_substate_key(client_token, substate_cls),
top_level=False,
get_substates=all_substates,
parent_state=state,
)
)
for substate_name, substate_task in tasks.items():
state.substates[substate_name] = await substate_task
async def get_state(
self,
token: str,
@ -1798,8 +2217,8 @@ class StateManagerRedis(StateManager):
Args:
token: The token to get the state for.
top_level: If true, return an instance of the top-level state.
get_substates: If true, also retrieve substates
top_level: If true, return an instance of the top-level state (self.state).
get_substates: If true, also retrieve substates.
parent_state: If provided, use this parent_state instead of getting it from redis.
Returns:
@ -1809,7 +2228,7 @@ class StateManagerRedis(StateManager):
RuntimeError: when the state_cls is not specified in the token
"""
# Split the actual token from the fully qualified substate name.
client_token, _, state_path = token.partition("_")
_, state_path = _split_substate_key(token)
if state_path:
# Get the State class associated with the given path.
state_cls = self.state.get_class_substate(tuple(state_path.split(".")))
@ -1825,66 +2244,49 @@ class StateManagerRedis(StateManager):
# Deserialize the substate.
state = cloudpickle.loads(redis_state)
# Populate parent and substates if requested.
# Populate parent state if missing and requested.
if parent_state is None:
# Retrieve the parent state from redis.
parent_state_name = state_path.rpartition(".")[0]
if parent_state_name:
parent_state_key = token.rpartition(".")[0]
parent_state = await self.get_state(
parent_state_key, top_level=False, get_substates=False
)
parent_state = await self._get_parent_state(token)
# Set up Bidirectional linkage between this state and its parent.
if parent_state is not None:
parent_state.substates[state.get_name()] = state
state.parent_state = parent_state
if get_substates:
# Retrieve all substates from redis.
for substate_cls in state_cls.get_substates():
substate_name = substate_cls.get_name()
substate_key = token + "." + substate_name
state.substates[substate_name] = await self.get_state(
substate_key, top_level=False, parent_state=state
)
# Populate substates if requested.
await self._populate_substates(token, state, all_substates=get_substates)
# To retain compatibility with previous implementation, by default, we return
# the top-level state by chasing `parent_state` pointers up the tree.
if top_level:
while type(state) != self.state and state.parent_state is not None:
state = state.parent_state
return self._get_root_state(state)
return state
# Key didn't exist so we have to create a new entry for this token.
# TODO: dedupe the following logic with the above block
# Key didn't exist so we have to create a new instance for this token.
if parent_state is None:
parent_state_name = state_path.rpartition(".")[0]
if parent_state_name:
# Retrieve the parent state to populate event handlers onto this substate.
parent_state_key = client_token + "_" + parent_state_name
parent_state = await self.get_state(
parent_state_key, top_level=False, get_substates=False
)
# Persist the new state class to redis.
await self.set_state(
token,
state_cls(
parent_state=parent_state,
init_substates=False,
),
)
# After creating the state key, recursively call `get_state` to populate substates.
return await self.get_state(
token,
top_level=top_level,
get_substates=get_substates,
parent_state = await self._get_parent_state(token)
# Instantiate the new state class (but don't persist it yet).
state = state_cls(
parent_state=parent_state,
init_substates=False,
_reflex_internal_init=True,
)
# Set up Bidirectional linkage between this state and its parent.
if parent_state is not None:
parent_state.substates[state.get_name()] = state
state.parent_state = parent_state
# Populate substates for the newly created state.
await self._populate_substates(token, state, all_substates=get_substates)
# To retain compatibility with previous implementation, by default, we return
# the top-level state by chasing `parent_state` pointers up the tree.
if top_level:
return self._get_root_state(state)
return state
async def set_state(
self,
token: str,
state: BaseState,
lock_id: bytes | None = None,
set_substates: bool = True,
set_parent_state: bool = True,
):
"""Set the state for a token.
@ -1892,11 +2294,10 @@ class StateManagerRedis(StateManager):
token: The token to set the state for.
state: The state to set.
lock_id: If provided, the lock_key must be set to this value to set the state.
set_substates: If True, write substates to redis
set_parent_state: If True, write parent state to redis
Raises:
LockExpiredError: If lock_id is provided and the lock for the token is not held by that ID.
RuntimeError: If the state instance doesn't match the state name in the token.
"""
# Check that we're holding the lock.
if (
@ -1908,28 +2309,36 @@ class StateManagerRedis(StateManager):
f"`app.state_manager.lock_expiration` (currently {self.lock_expiration}) "
"or use `@rx.background` decorator for long-running tasks."
)
# Find the substate associated with the token.
state_path = token.partition("_")[2]
if state_path and state.get_full_name() != state_path:
state = state.get_substate(tuple(state_path.split(".")))
# Persist the parent state separately, if requested.
if state.parent_state is not None and set_parent_state:
parent_state_key = token.rpartition(".")[0]
await self.set_state(
parent_state_key,
state.parent_state,
lock_id=lock_id,
set_substates=False,
client_token, substate_name = _split_substate_key(token)
# If the substate name on the token doesn't match the instance name, it cannot have a parent.
if state.parent_state is not None and state.get_full_name() != substate_name:
raise RuntimeError(
f"Cannot `set_state` with mismatching token {token} and substate {state.get_full_name()}."
)
# Persist the substates separately, if requested.
if set_substates:
for substate_name, substate in state.substates.items():
substate_key = token + "." + substate_name
await self.set_state(
substate_key, substate, lock_id=lock_id, set_parent_state=False
# Recursively set_state on all known substates.
tasks = []
for substate in state.substates.values():
tasks.append(
asyncio.create_task(
self.set_state(
token=_substate_key(client_token, substate),
state=substate,
lock_id=lock_id,
)
)
)
# Persist only the given state (parents or substates are excluded by BaseState.__getstate__).
await self.redis.set(token, cloudpickle.dumps(state), ex=self.token_expiration)
if state._get_was_touched():
await self.redis.set(
_substate_key(client_token, state),
cloudpickle.dumps(state),
ex=self.token_expiration,
)
# Wait for substates to be persisted.
for t in tasks:
await t
@contextlib.asynccontextmanager
async def modify_state(self, token: str) -> AsyncIterator[BaseState]:
@ -1957,7 +2366,7 @@ class StateManagerRedis(StateManager):
The redis lock key for the token.
"""
# All substates share the same lock domain, so ignore any substate path suffix.
client_token = token.partition("_")[0]
client_token = _split_substate_key(token)[0]
return f"{client_token}_lock".encode()
async def _try_get_lock(self, lock_key: bytes, lock_id: bytes) -> bool | None:
@ -2052,6 +2461,16 @@ class StateManagerRedis(StateManager):
await self.redis.close(close_connection_pool=True)
def get_state_manager() -> StateManager:
"""Get the state manager for the app that is currently running.
Returns:
The state manager.
"""
app = getattr(prerequisites.get_app(), constants.CompileVars.APP)
return app.state_manager
class ClientStorageBase:
"""Base class for client-side storage."""

View File

@ -70,6 +70,10 @@ else:
FRONTEND_POPEN_ARGS["start_new_session"] = True
# Save a copy of internal substates to reset after each test.
INTERNAL_STATES = State.class_subclasses.copy()
# borrowed from py3.11
class chdir(contextlib.AbstractContextManager):
"""Non thread-safe context manager to change the current working directory."""
@ -220,6 +224,8 @@ class AppHarness:
reflex.config.get_config(reload=True)
# reset rx.State subclasses
State.class_subclasses.clear()
State.class_subclasses.update(INTERNAL_STATES)
State._always_dirty_substates = set()
State.get_class_substate.cache_clear()
# Ensure the AppHarness test does not skip State assignment due to running via pytest
os.environ.pop(reflex.constants.PYTEST_CURRENT_TEST, None)

View File

@ -285,3 +285,12 @@ def output_system_info():
console.debug(f"Using package executer at: {prerequisites.get_package_manager()}") # type: ignore
if system != "Windows":
console.debug(f"Unzip path: {path_ops.which('unzip')}")
def is_testing_env() -> bool:
"""Whether the app is running in a testing environment.
Returns:
True if the app is running in under pytest.
"""
return constants.PYTEST_CURRENT_TEST in os.environ

View File

@ -1875,6 +1875,10 @@ class ComputedVar(Var, property):
Returns:
A set of variable names accessed by the given obj.
Raises:
ValueError: if the function references the get_state, parent_state, or substates attributes
(cannot track deps in a related state, only implicitly via parent state).
"""
d = set()
if obj is None:
@ -1898,6 +1902,8 @@ class ComputedVar(Var, property):
if self_name is None:
# cannot reference attributes on self if method takes no args
return set()
invalid_names = ["get_state", "parent_state", "substates", "get_substate"]
self_is_top_of_stack = False
for instruction in dis.get_instructions(obj):
if (
@ -1916,6 +1922,10 @@ class ComputedVar(Var, property):
ref_obj = getattr(objclass, instruction.argval)
except Exception:
ref_obj = None
if instruction.argval in invalid_names:
raise ValueError(
f"Cached var {self._var_full_name} cannot access arbitrary state via `{instruction.argval}`."
)
if callable(ref_obj):
# recurse into callable attributes
d.update(

View File

@ -29,7 +29,15 @@ from reflex.components.radix.themes.typography.text import Text
from reflex.event import Event
from reflex.middleware import HydrateMiddleware
from reflex.model import Model
from reflex.state import BaseState, RouterData, State, StateManagerRedis, StateUpdate
from reflex.state import (
BaseState,
OnLoadInternalState,
RouterData,
State,
StateManagerRedis,
StateUpdate,
_substate_key,
)
from reflex.style import Style
from reflex.utils import format
from reflex.vars import ComputedVar
@ -362,7 +370,7 @@ async def test_initialize_with_state(test_state: Type[ATestState], token: str):
assert app.state == test_state
# Get a state for a given token.
state = await app.state_manager.get_state(f"{token}_{test_state.get_full_name()}")
state = await app.state_manager.get_state(_substate_key(token, test_state))
assert isinstance(state, test_state)
assert state.var == 0 # type: ignore
@ -766,8 +774,7 @@ async def test_upload_file(tmp_path, state, delta, token: str, mocker):
# The App state must be the "root" of the state tree
app = App(state=State)
app.event_namespace.emit = AsyncMock() # type: ignore
substate_token = f"{token}_{state.get_full_name()}"
current_state = await app.state_manager.get_state(substate_token)
current_state = await app.state_manager.get_state(_substate_key(token, state))
data = b"This is binary data"
# Create a binary IO object and write data to it
@ -796,7 +803,7 @@ async def test_upload_file(tmp_path, state, delta, token: str, mocker):
== StateUpdate(delta=delta, events=[], final=True).json() + "\n"
)
current_state = await app.state_manager.get_state(substate_token)
current_state = await app.state_manager.get_state(_substate_key(token, state))
state_dict = current_state.dict()[state.get_full_name()]
assert state_dict["img_list"] == [
"image1.jpg",
@ -913,7 +920,7 @@ class DynamicState(BaseState):
# self.side_effect_counter = self.side_effect_counter + 1
return self.dynamic
on_load_internal = State.on_load_internal.fn
on_load_internal = OnLoadInternalState.on_load_internal.fn
@pytest.mark.asyncio
@ -950,7 +957,7 @@ async def test_dynamic_route_var_route_change_completed_on_load(
}
assert constants.ROUTER in app.state()._computed_var_dependencies
substate_token = f"{token}_{DynamicState.get_full_name()}"
substate_token = _substate_key(token, DynamicState)
sid = "mock_sid"
client_ip = "127.0.0.1"
state = await app.state_manager.get_state(substate_token)
@ -978,7 +985,7 @@ async def test_dynamic_route_var_route_change_completed_on_load(
prev_exp_val = ""
for exp_index, exp_val in enumerate(exp_vals):
on_load_internal = _event(
name=f"{state.get_full_name()}.{constants.CompileVars.ON_LOAD_INTERNAL}",
name=f"{state.get_full_name()}.{constants.CompileVars.ON_LOAD_INTERNAL.rpartition('.')[2]}",
val=exp_val,
)
exp_router_data = {
@ -1013,8 +1020,8 @@ async def test_dynamic_route_var_route_change_completed_on_load(
name="on_load",
val=exp_val,
),
_dynamic_state_event(
name="set_is_hydrated",
_event(
name="state.set_is_hydrated",
payload={"value": True},
val=exp_val,
router_data={},

View File

@ -23,6 +23,7 @@ from reflex.state import (
ImmutableStateError,
LockExpiredError,
MutableProxy,
OnLoadInternalState,
RouterData,
State,
StateManager,
@ -30,6 +31,7 @@ from reflex.state import (
StateManagerRedis,
StateProxy,
StateUpdate,
_substate_key,
)
from reflex.utils import prerequisites, types
from reflex.utils.format import json_dumps
@ -139,6 +141,12 @@ class ChildState2(TestState):
value: str
class ChildState3(TestState):
"""A child state fixture."""
value: str
class GrandchildState(ChildState):
"""A grandchild state fixture."""
@ -149,6 +157,32 @@ class GrandchildState(ChildState):
pass
class GrandchildState2(ChildState2):
"""A grandchild state fixture."""
@rx.cached_var
def cached(self) -> str:
"""A cached var.
Returns:
The value.
"""
return self.value
class GrandchildState3(ChildState3):
"""A great grandchild state fixture."""
@rx.var
def computed(self) -> str:
"""A computed var.
Returns:
The value.
"""
return self.value
class DateTimeState(BaseState):
"""A State with some datetime fields."""
@ -329,6 +363,9 @@ def test_dict(test_state):
"test_state.child_state",
"test_state.child_state.grandchild_state",
"test_state.child_state2",
"test_state.child_state2.grandchild_state2",
"test_state.child_state3",
"test_state.child_state3.grandchild_state3",
}
test_state_dict = test_state.dict()
assert set(test_state_dict) == substates
@ -380,10 +417,11 @@ def test_get_parent_state():
def test_get_substates():
"""Test getting the substates."""
assert TestState.get_substates() == {ChildState, ChildState2}
assert TestState.get_substates() == {ChildState, ChildState2, ChildState3}
assert ChildState.get_substates() == {GrandchildState}
assert ChildState2.get_substates() == set()
assert ChildState2.get_substates() == {GrandchildState2}
assert GrandchildState.get_substates() == set()
assert GrandchildState2.get_substates() == set()
def test_get_name():
@ -469,8 +507,8 @@ def test_set_parent_and_substates(test_state, child_state, grandchild_state):
child_state: A child state.
grandchild_state: A grandchild state.
"""
assert len(test_state.substates) == 2
assert set(test_state.substates) == {"child_state", "child_state2"}
assert len(test_state.substates) == 3
assert set(test_state.substates) == {"child_state", "child_state2", "child_state3"}
assert child_state.parent_state == test_state
assert len(child_state.substates) == 1
@ -655,7 +693,7 @@ def test_reset(test_state, child_state):
assert child_state.dirty_vars == {"count", "value"}
# The dirty substates should be reset.
assert test_state.dirty_substates == {"child_state", "child_state2"}
assert test_state.dirty_substates == {"child_state", "child_state2", "child_state3"}
@pytest.mark.asyncio
@ -675,7 +713,10 @@ async def test_process_event_simple(test_state):
# The delta should contain the changes, including computed vars.
# assert update.delta == {"test_state": {"num1": 69, "sum": 72.14}}
assert update.delta == {"test_state": {"num1": 69, "sum": 72.14, "upper": ""}}
assert update.delta == {
"test_state": {"num1": 69, "sum": 72.14, "upper": ""},
"test_state.child_state3.grandchild_state3": {"computed": ""},
}
assert update.events == []
@ -700,6 +741,7 @@ async def test_process_event_substate(test_state, child_state, grandchild_state)
assert update.delta == {
"test_state": {"sum": 3.14, "upper": ""},
"test_state.child_state": {"value": "HI", "count": 24},
"test_state.child_state3.grandchild_state3": {"computed": ""},
}
test_state._clean()
@ -715,6 +757,7 @@ async def test_process_event_substate(test_state, child_state, grandchild_state)
assert update.delta == {
"test_state": {"sum": 3.14, "upper": ""},
"test_state.child_state.grandchild_state": {"value2": "new"},
"test_state.child_state3.grandchild_state3": {"computed": ""},
}
@ -1443,7 +1486,7 @@ def substate_token(state_manager, token):
Returns:
Token concatenated with the state_manager's state full_name.
"""
return f"{token}_{state_manager.state.get_full_name()}"
return _substate_key(token, state_manager.state)
@pytest.mark.asyncio
@ -1545,7 +1588,7 @@ def substate_token_redis(state_manager_redis, token):
Returns:
Token concatenated with the state_manager's state full_name.
"""
return f"{token}_{state_manager_redis.state.get_full_name()}"
return _substate_key(token, state_manager_redis.state)
@pytest.mark.asyncio
@ -1670,6 +1713,22 @@ async def test_state_proxy(grandchild_state: GrandchildState, mock_app: rx.App):
# cannot directly modify state proxy outside of async context
sp.value2 = "16"
with pytest.raises(ImmutableStateError):
# Cannot get_state
await sp.get_state(ChildState)
with pytest.raises(ImmutableStateError):
# Cannot access get_substate
sp.get_substate([])
with pytest.raises(ImmutableStateError):
# Cannot access parent state
sp.parent_state.get_name()
with pytest.raises(ImmutableStateError):
# Cannot access substates
sp.substates[""]
async with sp:
assert sp._self_actx is not None
assert sp._self_mutable # proxy is mutable inside context
@ -1685,8 +1744,9 @@ async def test_state_proxy(grandchild_state: GrandchildState, mock_app: rx.App):
assert sp.value2 == "42"
# Get the state from the state manager directly and check that the value is updated
gc_token = f"{grandchild_state.get_token()}_{grandchild_state.get_full_name()}"
gotten_state = await mock_app.state_manager.get_state(gc_token)
gotten_state = await mock_app.state_manager.get_state(
_substate_key(grandchild_state.router.session.client_token, grandchild_state)
)
if isinstance(mock_app.state_manager, StateManagerMemory):
# For in-process store, only one instance of the state exists
assert gotten_state is parent_state
@ -1710,6 +1770,9 @@ async def test_state_proxy(grandchild_state: GrandchildState, mock_app: rx.App):
grandchild_state.get_full_name(): {
"value2": "42",
},
GrandchildState3.get_full_name(): {
"computed": "",
},
}
)
assert mcall.kwargs["to"] == grandchild_state.get_sid()
@ -1879,8 +1942,11 @@ async def test_background_task_no_block(mock_app: rx.App, token: str):
"private",
]
substate_token = f"{token}_{BackgroundTaskState.get_name()}"
assert (await mock_app.state_manager.get_state(substate_token)).order == exp_order
assert (
await mock_app.state_manager.get_state(
_substate_key(token, BackgroundTaskState)
)
).order == exp_order
assert mock_app.event_namespace is not None
emit_mock = mock_app.event_namespace.emit
@ -1957,8 +2023,11 @@ async def test_background_task_reset(mock_app: rx.App, token: str):
await task
assert not mock_app.background_tasks
substate_token = f"{token}_{BackgroundTaskState.get_name()}"
assert (await mock_app.state_manager.get_state(substate_token)).order == [
assert (
await mock_app.state_manager.get_state(
_substate_key(token, BackgroundTaskState)
)
).order == [
"reset",
]
@ -2246,7 +2315,7 @@ def test_mutable_copy_vars(mutable_state, copy_func):
def test_duplicate_substate_class(mocker):
mocker.patch("reflex.state.os.environ", {})
mocker.patch("reflex.state.is_testing_env", lambda: False)
with pytest.raises(ValueError):
class TestState(BaseState):
@ -2435,7 +2504,9 @@ async def test_preprocess(app_module_mock, token, test_state, expected, mocker):
expected: Expected delta.
mocker: pytest mock object.
"""
mocker.patch("reflex.state.State.class_subclasses", {test_state})
mocker.patch(
"reflex.state.State.class_subclasses", {test_state, OnLoadInternalState}
)
app = app_module_mock.app = App(
state=State, load_events={"index": [test_state.test_handler]}
)
@ -2476,7 +2547,9 @@ async def test_preprocess_multiple_load_events(app_module_mock, token, mocker):
token: A token.
mocker: pytest mock object.
"""
mocker.patch("reflex.state.State.class_subclasses", {OnLoadState})
mocker.patch(
"reflex.state.State.class_subclasses", {OnLoadState, OnLoadInternalState}
)
app = app_module_mock.app = App(
state=State,
load_events={"index": [OnLoadState.test_handler, OnLoadState.test_handler]},
@ -2510,3 +2583,120 @@ async def test_preprocess_multiple_load_events(app_module_mock, token, mocker):
OnLoadState.get_full_name(): {"num": 2}
}
assert (await state._process(events[2]).__anext__()).delta == exp_is_hydrated(state)
@pytest.mark.asyncio
async def test_get_state(mock_app: rx.App, token: str):
"""Test that a get_state populates the top level state and delta calculation is correct.
Args:
mock_app: An app that will be returned by `get_app()`
token: A token.
"""
mock_app.state_manager.state = mock_app.state = TestState
# Get instance of ChildState2.
test_state = await mock_app.state_manager.get_state(
_substate_key(token, ChildState2)
)
assert isinstance(test_state, TestState)
if isinstance(mock_app.state_manager, StateManagerMemory):
# All substates are available
assert tuple(sorted(test_state.substates)) == (
"child_state",
"child_state2",
"child_state3",
)
else:
# Sibling states are only populated if they have computed vars
assert tuple(sorted(test_state.substates)) == ("child_state2", "child_state3")
# Because ChildState3 has a computed var, it is always dirty, and always populated.
assert (
test_state.substates["child_state3"].substates["grandchild_state3"].computed
== ""
)
# Get the child_state2 directly.
child_state2_direct = test_state.get_substate(["child_state2"])
child_state2_get_state = await test_state.get_state(ChildState2)
# These should be the same object.
assert child_state2_direct is child_state2_get_state
# Get arbitrary GrandchildState.
grandchild_state = await child_state2_get_state.get_state(GrandchildState)
assert isinstance(grandchild_state, GrandchildState)
# Now the original root should have all substates populated.
assert tuple(sorted(test_state.substates)) == (
"child_state",
"child_state2",
"child_state3",
)
# ChildState should be retrievable
child_state_direct = test_state.get_substate(["child_state"])
child_state_get_state = await test_state.get_state(ChildState)
# These should be the same object.
assert child_state_direct is child_state_get_state
# GrandchildState instance should be the same as the one retrieved from the child_state2.
assert grandchild_state is child_state_direct.get_substate(["grandchild_state"])
grandchild_state.value2 = "set_value"
assert test_state.get_delta() == {
TestState.get_full_name(): {
"sum": 3.14,
"upper": "",
},
GrandchildState.get_full_name(): {
"value2": "set_value",
},
GrandchildState3.get_full_name(): {
"computed": "",
},
}
# Get a fresh instance
new_test_state = await mock_app.state_manager.get_state(
_substate_key(token, ChildState2)
)
assert isinstance(new_test_state, TestState)
if isinstance(mock_app.state_manager, StateManagerMemory):
# In memory, it's the same instance
assert new_test_state is test_state
test_state._clean()
# All substates are available
assert tuple(sorted(new_test_state.substates)) == (
"child_state",
"child_state2",
"child_state3",
)
else:
# With redis, we get a whole new instance
assert new_test_state is not test_state
# Sibling states are only populated if they have computed vars
assert tuple(sorted(new_test_state.substates)) == (
"child_state2",
"child_state3",
)
# Set a value on child_state2, should update cached var in grandchild_state2
child_state2 = new_test_state.get_substate(("child_state2",))
child_state2.value = "set_c2_value"
assert new_test_state.get_delta() == {
TestState.get_full_name(): {
"sum": 3.14,
"upper": "",
},
ChildState2.get_full_name(): {
"value": "set_c2_value",
},
GrandchildState2.get_full_name(): {
"cached": "set_c2_value",
},
GrandchildState3.get_full_name(): {
"computed": "",
},
}

371
tests/test_state_tree.py Normal file
View File

@ -0,0 +1,371 @@
"""Specialized test for a larger state tree."""
import asyncio
from typing import Generator
import pytest
import reflex as rx
from reflex.state import BaseState, StateManager, StateManagerRedis, _substate_key
class Root(BaseState):
"""Root of the state tree."""
root: int
class TreeA(Root):
"""TreeA is a child of Root."""
a: int
class SubA_A(TreeA):
"""SubA_A is a child of TreeA."""
sub_a_a: int
class SubA_A_A(SubA_A):
"""SubA_A_A is a child of SubA_A."""
sub_a_a_a: int
class SubA_A_A_A(SubA_A_A):
"""SubA_A_A_A is a child of SubA_A_A."""
sub_a_a_a_a: int
class SubA_A_A_B(SubA_A_A):
"""SubA_A_A_B is a child of SubA_A_A."""
@rx.cached_var
def sub_a_a_a_cached(self) -> int:
"""A cached var.
Returns:
The value of sub_a_a_a + 1
"""
return self.sub_a_a_a + 1
class SubA_A_A_C(SubA_A_A):
"""SubA_A_A_C is a child of SubA_A_A."""
sub_a_a_a_c: int
class SubA_A_B(SubA_A):
"""SubA_A_B is a child of SubA_A."""
sub_a_a_b: int
class SubA_B(TreeA):
"""SubA_B is a child of TreeA."""
sub_a_b: int
class TreeB(Root):
"""TreeB is a child of Root."""
b: int
class SubB_A(TreeB):
"""SubB_A is a child of TreeB."""
sub_b_a: int
class SubB_B(TreeB):
"""SubB_B is a child of TreeB."""
sub_b_b: int
class SubB_C(TreeB):
"""SubB_C is a child of TreeB."""
sub_b_c: int
class SubB_C_A(SubB_C):
"""SubB_C_A is a child of SubB_C."""
sub_b_c_a: int
class TreeC(Root):
"""TreeC is a child of Root."""
c: int
class SubC_A(TreeC):
"""SubC_A is a child of TreeC."""
sub_c_a: int
class TreeD(Root):
"""TreeD is a child of Root."""
d: int
@rx.var
def d_var(self) -> int:
"""A computed var.
Returns:
The value of d + 1
"""
return self.d + 1
class TreeE(Root):
"""TreeE is a child of Root."""
e: int
class SubE_A(TreeE):
"""SubE_A is a child of TreeE."""
sub_e_a: int
class SubE_A_A(SubE_A):
"""SubE_A_A is a child of SubE_A."""
sub_e_a_a: int
class SubE_A_A_A(SubE_A_A):
"""SubE_A_A_A is a child of SubE_A_A."""
sub_e_a_a_a: int
class SubE_A_A_A_A(SubE_A_A_A):
"""SubE_A_A_A_A is a child of SubE_A_A_A."""
sub_e_a_a_a_a: int
@rx.var
def sub_e_a_a_a_a_var(self) -> int:
"""A computed var.
Returns:
The value of sub_e_a_a_a_a + 1
"""
return self.sub_e_a_a_a + 1
class SubE_A_A_A_B(SubE_A_A_A):
"""SubE_A_A_A_B is a child of SubE_A_A_A."""
sub_e_a_a_a_b: int
class SubE_A_A_A_C(SubE_A_A_A):
"""SubE_A_A_A_C is a child of SubE_A_A_A."""
sub_e_a_a_a_c: int
class SubE_A_A_A_D(SubE_A_A_A):
"""SubE_A_A_A_D is a child of SubE_A_A_A."""
sub_e_a_a_a_d: int
@rx.cached_var
def sub_e_a_a_a_d_var(self) -> int:
"""A computed var.
Returns:
The value of sub_e_a_a_a_a + 1
"""
return self.sub_e_a_a_a + 1
ALWAYS_COMPUTED_VARS = {
TreeD.get_full_name(): {"d_var": 1},
SubE_A_A_A_A.get_full_name(): {"sub_e_a_a_a_a_var": 1},
}
ALWAYS_COMPUTED_DICT_KEYS = [
Root.get_full_name(),
TreeD.get_full_name(),
TreeE.get_full_name(),
SubE_A.get_full_name(),
SubE_A_A.get_full_name(),
SubE_A_A_A.get_full_name(),
SubE_A_A_A_A.get_full_name(),
SubE_A_A_A_D.get_full_name(),
]
@pytest.fixture(scope="function")
def state_manager_redis(app_module_mock) -> Generator[StateManager, None, None]:
"""Instance of state manager for redis only.
Args:
app_module_mock: The app module mock fixture.
Yields:
A state manager instance
"""
app_module_mock.app = rx.App(state=Root)
state_manager = app_module_mock.app.state_manager
if not isinstance(state_manager, StateManagerRedis):
pytest.skip("Test requires redis")
yield state_manager
asyncio.get_event_loop().run_until_complete(state_manager.close())
@pytest.mark.asyncio
@pytest.mark.parametrize(
("substate_cls", "exp_root_substates", "exp_root_dict_keys"),
[
(
Root,
["tree_a", "tree_b", "tree_c", "tree_d", "tree_e"],
[
TreeA.get_full_name(),
SubA_A.get_full_name(),
SubA_A_A.get_full_name(),
SubA_A_A_A.get_full_name(),
SubA_A_A_B.get_full_name(),
SubA_A_A_C.get_full_name(),
SubA_A_B.get_full_name(),
SubA_B.get_full_name(),
TreeB.get_full_name(),
SubB_A.get_full_name(),
SubB_B.get_full_name(),
SubB_C.get_full_name(),
SubB_C_A.get_full_name(),
TreeC.get_full_name(),
SubC_A.get_full_name(),
SubE_A_A_A_B.get_full_name(),
SubE_A_A_A_C.get_full_name(),
*ALWAYS_COMPUTED_DICT_KEYS,
],
),
(
TreeA,
("tree_a", "tree_d", "tree_e"),
[
TreeA.get_full_name(),
SubA_A.get_full_name(),
SubA_A_A.get_full_name(),
SubA_A_A_A.get_full_name(),
SubA_A_A_B.get_full_name(),
SubA_A_A_C.get_full_name(),
SubA_A_B.get_full_name(),
SubA_B.get_full_name(),
*ALWAYS_COMPUTED_DICT_KEYS,
],
),
(
SubA_A_A_A,
["tree_a", "tree_d", "tree_e"],
[
TreeA.get_full_name(),
SubA_A.get_full_name(),
SubA_A_A.get_full_name(),
SubA_A_A_A.get_full_name(),
SubA_A_A_B.get_full_name(), # Cached var dep
*ALWAYS_COMPUTED_DICT_KEYS,
],
),
(
TreeB,
["tree_b", "tree_d", "tree_e"],
[
TreeB.get_full_name(),
SubB_A.get_full_name(),
SubB_B.get_full_name(),
SubB_C.get_full_name(),
SubB_C_A.get_full_name(),
*ALWAYS_COMPUTED_DICT_KEYS,
],
),
(
SubB_B,
["tree_b", "tree_d", "tree_e"],
[
TreeB.get_full_name(),
SubB_B.get_full_name(),
*ALWAYS_COMPUTED_DICT_KEYS,
],
),
(
SubB_C_A,
["tree_b", "tree_d", "tree_e"],
[
TreeB.get_full_name(),
SubB_C.get_full_name(),
SubB_C_A.get_full_name(),
*ALWAYS_COMPUTED_DICT_KEYS,
],
),
(
TreeC,
["tree_c", "tree_d", "tree_e"],
[
TreeC.get_full_name(),
SubC_A.get_full_name(),
*ALWAYS_COMPUTED_DICT_KEYS,
],
),
(
TreeD,
["tree_d", "tree_e"],
[
*ALWAYS_COMPUTED_DICT_KEYS,
],
),
(
TreeE,
["tree_d", "tree_e"],
[
# Extra siblings of computed var included now.
SubE_A_A_A_B.get_full_name(),
SubE_A_A_A_C.get_full_name(),
*ALWAYS_COMPUTED_DICT_KEYS,
],
),
],
)
async def test_get_state_tree(
state_manager_redis,
token,
substate_cls,
exp_root_substates,
exp_root_dict_keys,
):
"""Test getting state trees and assert on which branches are retrieved.
Args:
state_manager_redis: The state manager redis fixture.
token: The token fixture.
substate_cls: The substate class to retrieve.
exp_root_substates: The expected substates of the root state.
exp_root_dict_keys: The expected keys of the root state dict.
"""
state = await state_manager_redis.get_state(_substate_key(token, substate_cls))
assert isinstance(state, Root)
assert sorted(state.substates) == sorted(exp_root_substates)
# Only computed vars should be returned
assert state.get_delta() == ALWAYS_COMPUTED_VARS
# All of TreeA, TreeD, and TreeE substates should be in the dict
assert sorted(state.dict()) == sorted(exp_root_dict_keys)

View File

@ -13,8 +13,11 @@ from reflex.vars import BaseVar, Var
from tests.test_state import (
ChildState,
ChildState2,
ChildState3,
DateTimeState,
GrandchildState,
GrandchildState2,
GrandchildState3,
TestState,
)
@ -649,7 +652,7 @@ formatted_router = {
"input, output",
[
(
TestState().dict(), # type: ignore
TestState(_reflex_internal_init=True).dict(), # type: ignore
{
TestState.get_full_name(): {
"array": [1, 2, 3.14],
@ -674,11 +677,14 @@ formatted_router = {
"value": "",
},
ChildState2.get_full_name(): {"value": ""},
ChildState3.get_full_name(): {"value": ""},
GrandchildState.get_full_name(): {"value2": ""},
GrandchildState2.get_full_name(): {"cached": ""},
GrandchildState3.get_full_name(): {"computed": ""},
},
),
(
DateTimeState().dict(),
DateTimeState(_reflex_internal_init=True).dict(), # type: ignore
{
DateTimeState.get_full_name(): {
"d": "1989-11-09",