diff --git a/integration/test_client_storage.py b/integration/test_client_storage.py index 3f7ff33f8..1be40d32b 100644 --- a/integration/test_client_storage.py +++ b/integration/test_client_storage.py @@ -518,8 +518,8 @@ async def test_client_side_state( set_sub("l6", "l6 value") l5 = driver.find_element(By.ID, "l5") l6 = driver.find_element(By.ID, "l6") + assert AppHarness._poll_for(lambda: l6.text == "l6 value") assert l5.text == "l5 value" - assert l6.text == "l6 value" # Switch back to main window. driver.switch_to.window(main_tab) @@ -527,8 +527,8 @@ async def test_client_side_state( # The values should have updated automatically. l5 = driver.find_element(By.ID, "l5") l6 = driver.find_element(By.ID, "l6") + assert AppHarness._poll_for(lambda: l6.text == "l6 value") assert l5.text == "l5 value" - assert l6.text == "l6 value" # clear the cookie jar and local storage, ensure state reset to default driver.delete_all_cookies() diff --git a/integration/test_state_inheritance.py b/integration/test_state_inheritance.py index addc0c654..24b26e523 100644 --- a/integration/test_state_inheritance.py +++ b/integration/test_state_inheritance.py @@ -1,14 +1,29 @@ """Test state inheritance.""" -import time +from contextlib import suppress from typing import Generator import pytest +from selenium.common.exceptions import NoAlertPresentException +from selenium.webdriver.common.alert import Alert from selenium.webdriver.common.by import By from reflex.testing import DEFAULT_TIMEOUT, AppHarness, WebDriver +def get_alert_or_none(driver: WebDriver) -> Alert | None: + """Switch to an alert if present. + + Args: + driver: WebDriver instance. + + Returns: + The alert if present, otherwise None. + """ + with suppress(NoAlertPresentException): + return driver.switch_to.alert + + def raises_alert(driver: WebDriver, element: str) -> None: """Click an element and check that an alert is raised. @@ -18,8 +33,8 @@ def raises_alert(driver: WebDriver, element: str) -> None: """ btn = driver.find_element(By.ID, element) btn.click() - time.sleep(0.2) # wait for the alert to appear - alert = driver.switch_to.alert + alert = AppHarness._poll_for(lambda: get_alert_or_none(driver)) + assert isinstance(alert, Alert) assert alert.text == "clicked" alert.accept() @@ -355,7 +370,7 @@ def test_state_inheritance( child3_other_mixin_btn = driver.find_element(By.ID, "child3-other-mixin-btn") child3_other_mixin_btn.click() child2_other_mixin_value = state_inheritance.poll_for_content( - child2_other_mixin, exp_not_equal="other_mixin" + child2_other_mixin, exp_not_equal="Child2.clicked.1" ) child2_computed_mixin_value = state_inheritance.poll_for_content( child2_computed_other_mixin, exp_not_equal="other_mixin" diff --git a/reflex/.templates/jinja/web/utils/context.js.jinja2 b/reflex/.templates/jinja/web/utils/context.js.jinja2 index 2c029fcd5..369b65136 100644 --- a/reflex/.templates/jinja/web/utils/context.js.jinja2 +++ b/reflex/.templates/jinja/web/utils/context.js.jinja2 @@ -25,11 +25,31 @@ export const clientStorage = {} {% if state_name %} export const state_name = "{{state_name}}" -export const onLoadInternalEvent = () => [ - Event('{{state_name}}.{{const.update_vars_internal}}', {vars: hydrateClientStorage(clientStorage)}), - Event('{{state_name}}.{{const.on_load_internal}}') -] +// Theses events are triggered on initial load and each page navigation. +export const onLoadInternalEvent = () => { + const internal_events = []; + + // Get tracked cookie and local storage vars to send to the backend. + const client_storage_vars = hydrateClientStorage(clientStorage); + // But only send the vars if any are actually set in the browser. + if (client_storage_vars && Object.keys(client_storage_vars).length !== 0) { + internal_events.push( + Event( + '{{state_name}}.{{const.update_vars_internal}}', + {vars: client_storage_vars}, + ), + ); + } + + // `on_load_internal` triggers the correct on_load event(s) for the current page. + // If the page does not define any on_load event, this will just set `is_hydrated = true`. + internal_events.push(Event('{{state_name}}.{{const.on_load_internal}}')); + + return internal_events; +} + +// The following events are sent when the websocket connects or reconnects. export const initialEvents = () => [ Event('{{state_name}}.{{const.hydrate}}'), ...onLoadInternalEvent() diff --git a/reflex/.templates/web/utils/state.js b/reflex/.templates/web/utils/state.js index 8dcba5dc4..13da2c510 100644 --- a/reflex/.templates/web/utils/state.js +++ b/reflex/.templates/web/utils/state.js @@ -587,7 +587,7 @@ export const useEventLoop = ( if (storage_to_state_map[e.key]) { const vars = {} vars[storage_to_state_map[e.key]] = e.newValue - const event = Event(`${state_name}.update_vars_internal`, {vars: vars}) + const event = Event(`${state_name}.update_vars_internal_state.update_vars_internal`, {vars: vars}) addEvents([event], e); } }; diff --git a/reflex/app.py b/reflex/app.py index 43def40ed..fca37d9ef 100644 --- a/reflex/app.py +++ b/reflex/app.py @@ -69,9 +69,11 @@ from reflex.state import ( State, StateManager, StateUpdate, + _substate_key, code_uses_state_contexts, ) from reflex.utils import console, exceptions, format, prerequisites, types +from reflex.utils.exec import is_testing_env from reflex.utils.imports import ImportVar # Define custom types. @@ -159,10 +161,9 @@ class App(Base): ) super().__init__(*args, **kwargs) state_subclasses = BaseState.__subclasses__() - is_testing_env = constants.PYTEST_CURRENT_TEST in os.environ # Special case to allow test cases have multiple subclasses of rx.BaseState. - if not is_testing_env: + if not is_testing_env(): # Only one Base State class is allowed. if len(state_subclasses) > 1: raise ValueError( @@ -176,7 +177,8 @@ class App(Base): deprecation_version="0.3.5", removal_version="0.5.0", ) - if len(State.class_subclasses) > 0: + # 2 substates are built-in and not considered when determining if app is stateless. + if len(State.class_subclasses) > 2: self.state = State # Get the config config = get_config() @@ -1002,7 +1004,7 @@ def upload(app: App): ) # Get the state for the session. - substate_token = token + "_" + handler.rpartition(".")[0] + substate_token = _substate_key(token, handler.rpartition(".")[0]) state = await app.state_manager.get_state(substate_token) # get the current session ID diff --git a/reflex/compiler/utils.py b/reflex/compiler/utils.py index ac53e8b02..c86e890d0 100644 --- a/reflex/compiler/utils.py +++ b/reflex/compiler/utils.py @@ -138,12 +138,12 @@ def compile_state(state: Type[BaseState]) -> dict: A dictionary of the compiled state. """ try: - initial_state = state().dict(initial=True) + initial_state = state(_reflex_internal_init=True).dict(initial=True) except Exception as e: console.warn( f"Failed to compile initial state with computed vars, excluding them: {e}" ) - initial_state = state().dict(include_computed=False) + initial_state = state(_reflex_internal_init=True).dict(include_computed=False) return format.format_state(initial_state) diff --git a/reflex/constants/compiler.py b/reflex/constants/compiler.py index 4efb68e22..2ec23f5e9 100644 --- a/reflex/constants/compiler.py +++ b/reflex/constants/compiler.py @@ -59,9 +59,9 @@ class CompileVars(SimpleNamespace): # The name of the function for converting a dict to an event. TO_EVENT = "Event" # The name of the internal on_load event. - ON_LOAD_INTERNAL = "on_load_internal" + ON_LOAD_INTERNAL = "on_load_internal_state.on_load_internal" # The name of the internal event to update generic state vars. - UPDATE_VARS_INTERNAL = "update_vars_internal" + UPDATE_VARS_INTERNAL = "update_vars_internal_state.update_vars_internal" class PageNames(SimpleNamespace): diff --git a/reflex/state.py b/reflex/state.py index 11b1f300d..781fecdfe 100644 --- a/reflex/state.py +++ b/reflex/state.py @@ -8,7 +8,6 @@ import copy import functools import inspect import json -import os import traceback import urllib.parse import uuid @@ -45,6 +44,7 @@ from reflex.event import ( ) from reflex.utils import console, format, prerequisites, types from reflex.utils.exceptions import ImmutableStateError, LockExpiredError +from reflex.utils.exec import is_testing_env from reflex.utils.serializers import SerializedType, serialize, serializer from reflex.vars import BaseVar, ComputedVar, Var, computed_var @@ -151,9 +151,45 @@ RESERVED_BACKEND_VAR_NAMES = { "_substate_var_dependencies", "_always_dirty_computed_vars", "_always_dirty_substates", + "_was_touched", } +def _substate_key( + token: str, + state_cls_or_name: BaseState | Type[BaseState] | str | list[str], +) -> str: + """Get the substate key. + + Args: + token: The token of the state. + state_cls_or_name: The state class/instance or name or sequence of name parts. + + Returns: + The substate key. + """ + if isinstance(state_cls_or_name, BaseState) or ( + isinstance(state_cls_or_name, type) and issubclass(state_cls_or_name, BaseState) + ): + state_cls_or_name = state_cls_or_name.get_full_name() + elif isinstance(state_cls_or_name, (list, tuple)): + state_cls_or_name = ".".join(state_cls_or_name) + return f"{token}_{state_cls_or_name}" + + +def _split_substate_key(substate_key: str) -> tuple[str, str]: + """Split the substate key into token and state name. + + Args: + substate_key: The substate key. + + Returns: + Tuple of token and state name. + """ + token, _, state_name = substate_key.partition("_") + return token, state_name + + class BaseState(Base, ABC, extra=pydantic.Extra.allow): """The state of the app.""" @@ -214,29 +250,46 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow): # The router data for the current page router: RouterData = RouterData() + # Whether the state has ever been touched since instantiation. + _was_touched: bool = False + def __init__( self, *args, parent_state: BaseState | None = None, init_substates: bool = True, + _reflex_internal_init: bool = False, **kwargs, ): """Initialize the state. + DO NOT INSTANTIATE STATE CLASSES DIRECTLY! Use StateManager.get_state() instead. + Args: *args: The args to pass to the Pydantic init method. parent_state: The parent state. init_substates: Whether to initialize the substates in this instance. + _reflex_internal_init: A flag to indicate that the state is being initialized by the framework. **kwargs: The kwargs to pass to the Pydantic init method. + Raises: + RuntimeError: If the state is instantiated directly by end user. """ + if not _reflex_internal_init and not is_testing_env(): + raise RuntimeError( + "State classes should not be instantiated directly in a Reflex app. " + "See https://reflex.dev/docs/state for further information." + ) kwargs["parent_state"] = parent_state super().__init__(*args, **kwargs) # Setup the substates (for memory state manager only). if init_substates: for substate in self.get_substates(): - self.substates[substate.get_name()] = substate(parent_state=self) + self.substates[substate.get_name()] = substate( + parent_state=self, + _reflex_internal_init=True, + ) # Convert the event handlers to functions. self._init_event_handlers() @@ -287,7 +340,6 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow): Raises: ValueError: If a substate class shadows another. """ - is_testing_env = constants.PYTEST_CURRENT_TEST in os.environ super().__init_subclass__(**kwargs) # Event handlers should not shadow builtin state methods. cls._check_overridden_methods() @@ -295,6 +347,9 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow): # Reset subclass tracking for this class. cls.class_subclasses = set() + # Reset dirty substate tracking for this class. + cls._always_dirty_substates = set() + # Get the parent vars. parent_state = cls.get_parent_state() if parent_state is not None: @@ -303,7 +358,7 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow): # Check if another substate class with the same name has already been defined. if cls.__name__ in set(c.__name__ for c in parent_state.class_subclasses): - if is_testing_env: + if is_testing_env(): # Clear existing subclass with same name when app is reloaded via # utils.prerequisites.get_app(reload=True) parent_state.class_subclasses = set( @@ -325,6 +380,7 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow): name: value for name, value in cls.__dict__.items() if types.is_backend_variable(name, cls) + and name not in RESERVED_BACKEND_VAR_NAMES and name not in cls.inherited_backend_vars and not isinstance(value, FunctionType) } @@ -484,7 +540,6 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow): ) # Any substate containing a ComputedVar with cache=False always needs to be recomputed - cls._always_dirty_substates = set() if cls._always_dirty_computed_vars: # Tell parent classes that this substate has always dirty computed vars state_name = cls.get_name() @@ -923,8 +978,12 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow): **super().__getattribute__("inherited_vars"), **super().__getattribute__("inherited_backend_vars"), } - if name in inherited_vars: - return getattr(super().__getattribute__("parent_state"), name) + + # For now, handle router_data updates as a special case. + if name in inherited_vars or name == constants.ROUTER_DATA: + parent_state = super().__getattribute__("parent_state") + if parent_state is not None: + return getattr(parent_state, name) backend_vars = super().__getattribute__("_backend_vars") if name in backend_vars: @@ -980,9 +1039,6 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow): if name == constants.ROUTER_DATA: self.dirty_vars.add(name) self._mark_dirty() - # propagate router_data updates down the state tree - for substate in self.substates.values(): - setattr(substate, name, value) def reset(self): """Reset all the base vars to their default values.""" @@ -1036,6 +1092,170 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow): raise ValueError(f"Invalid path: {path}") return self.substates[path[0]].get_substate(path[1:]) + @classmethod + def _get_common_ancestor(cls, other: Type[BaseState]) -> str: + """Find the name of the nearest common ancestor shared by this and the other state. + + Args: + other: The other state. + + Returns: + Full name of the nearest common ancestor. + """ + common_ancestor_parts = [] + for part1, part2 in zip( + cls.get_full_name().split("."), + other.get_full_name().split("."), + ): + if part1 != part2: + break + common_ancestor_parts.append(part1) + return ".".join(common_ancestor_parts) + + @classmethod + def _determine_missing_parent_states( + cls, target_state_cls: Type[BaseState] + ) -> tuple[str, list[str]]: + """Determine the missing parent states between the target_state_cls and common ancestor of this state. + + Args: + target_state_cls: The class of the state to find missing parent states for. + + Returns: + The name of the common ancestor and the list of missing parent states. + """ + common_ancestor_name = cls._get_common_ancestor(target_state_cls) + common_ancestor_parts = common_ancestor_name.split(".") + target_state_parts = tuple(target_state_cls.get_full_name().split(".")) + relative_target_state_parts = target_state_parts[len(common_ancestor_parts) :] + + # Determine which parent states to fetch from the common ancestor down to the target_state_cls. + fetch_parent_states = [common_ancestor_name] + for ix, relative_parent_state_name in enumerate(relative_target_state_parts): + fetch_parent_states.append( + ".".join([*fetch_parent_states[: ix + 1], relative_parent_state_name]) + ) + + return common_ancestor_name, fetch_parent_states[1:-1] + + def _get_parent_states(self) -> list[tuple[str, BaseState]]: + """Get all parent state instances up to the root of the state tree. + + Returns: + A list of tuples containing the name and the instance of each parent state. + """ + parent_states_with_name = [] + parent_state = self + while parent_state.parent_state is not None: + parent_state = parent_state.parent_state + parent_states_with_name.append((parent_state.get_full_name(), parent_state)) + return parent_states_with_name + + async def _populate_parent_states(self, target_state_cls: Type[BaseState]): + """Populate substates in the tree between the target_state_cls and common ancestor of this state. + + Args: + target_state_cls: The class of the state to populate parent states for. + + Returns: + The parent state instance of target_state_cls. + + Raises: + RuntimeError: If redis is not used in this backend process. + """ + state_manager = get_state_manager() + if not isinstance(state_manager, StateManagerRedis): + raise RuntimeError( + f"Cannot populate parent states of {target_state_cls.get_full_name()} without redis. " + "(All states should already be available -- this is likely a bug).", + ) + + # Find the missing parent states up to the common ancestor. + ( + common_ancestor_name, + missing_parent_states, + ) = self._determine_missing_parent_states(target_state_cls) + + # Fetch all missing parent states and link them up to the common ancestor. + parent_states_by_name = dict(self._get_parent_states()) + parent_state = parent_states_by_name[common_ancestor_name] + for parent_state_name in missing_parent_states: + parent_state = await state_manager.get_state( + token=_substate_key( + self.router.session.client_token, parent_state_name + ), + top_level=False, + get_substates=False, + parent_state=parent_state, + ) + + # Return the direct parent of target_state_cls for subsequent linking. + return parent_state + + def _get_state_from_cache(self, state_cls: Type[BaseState]) -> BaseState: + """Get a state instance from the cache. + + Args: + state_cls: The class of the state. + + Returns: + The instance of state_cls associated with this state's client_token. + """ + if self.parent_state is None: + root_state = self + else: + root_state = self._get_parent_states()[-1][1] + return root_state.get_substate(state_cls.get_full_name().split(".")) + + async def _get_state_from_redis(self, state_cls: Type[BaseState]) -> BaseState: + """Get a state instance from redis. + + Args: + state_cls: The class of the state. + + Returns: + The instance of state_cls associated with this state's client_token. + + Raises: + RuntimeError: If redis is not used in this backend process. + """ + # Fetch all missing parent states from redis. + parent_state_of_state_cls = await self._populate_parent_states(state_cls) + + # Then get the target state and all its substates. + state_manager = get_state_manager() + if not isinstance(state_manager, StateManagerRedis): + raise RuntimeError( + f"Requested state {state_cls.get_full_name()} is not cached and cannot be accessed without redis. " + "(All states should already be available -- this is likely a bug).", + ) + return await state_manager.get_state( + token=_substate_key(self.router.session.client_token, state_cls), + top_level=False, + get_substates=True, + parent_state=parent_state_of_state_cls, + ) + + async def get_state(self, state_cls: Type[BaseState]) -> BaseState: + """Get an instance of the state associated with this token. + + Allows for arbitrary access to sibling states from within an event handler. + + Args: + state_cls: The class of the state. + + Returns: + The instance of state_cls associated with this state's client_token. + """ + # Fast case - if this state instance is already cached, get_substate from root state. + try: + return self._get_state_from_cache(state_cls) + except ValueError: + pass + + # Slow case - fetch missing parent states from redis. + return await self._get_state_from_redis(state_cls) + def _get_event_handler( self, event: Event ) -> tuple[BaseState | StateProxy, EventHandler]: @@ -1238,6 +1458,28 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow): for cvar in self._computed_var_dependencies[dirty_var] ) + @classmethod + def _potentially_dirty_substates(cls) -> set[Type[BaseState]]: + """Determine substates which could be affected by dirty vars in this state. + + Returns: + Set of State classes that may need to be fetched to recalc computed vars. + """ + # _always_dirty_substates need to be fetched to recalc computed vars. + fetch_substates = set( + cls.get_class_substate(tuple(substate_name.split("."))) + for substate_name in cls._always_dirty_substates + ) + # Substates with cached vars also need to be fetched. + for dependent_substates in cls._substate_var_dependencies.values(): + fetch_substates.update( + set( + cls.get_class_substate(tuple(substate_name.split("."))) + for substate_name in dependent_substates + ) + ) + return fetch_substates + def get_delta(self) -> Delta: """Get the delta for the state. @@ -1269,8 +1511,6 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow): # Recursively find the substate deltas. substates = self.substates for substate in self.dirty_substates.union(self._always_dirty_substates): - if substate not in substates: - continue # substate not loaded at this time, no delta delta.update(substates[substate].get_delta()) # Format the delta. @@ -1292,20 +1532,45 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow): # have to mark computed vars dirty to allow access to newly computed # values within the same ComputedVar function self._mark_dirty_computed_vars() + self._mark_dirty_substates() - # Propagate dirty var / computed var status into substates + def _mark_dirty_substates(self): + """Propagate dirty var / computed var status into substates.""" substates = self.substates for var in self.dirty_vars: for substate_name in self._substate_var_dependencies[var]: self.dirty_substates.add(substate_name) - if substate_name not in substates: - continue substate = substates[substate_name] substate.dirty_vars.add(var) substate._mark_dirty() + def _update_was_touched(self): + """Update the _was_touched flag based on dirty_vars.""" + if self.dirty_vars and not self._was_touched: + for var in self.dirty_vars: + if var in self.base_vars or var in self._backend_vars: + self._was_touched = True + break + + def _get_was_touched(self) -> bool: + """Check current dirty_vars and flag to determine if state instance was modified. + + If any dirty vars belong to this state, mark _was_touched. + + This flag determines whether this state instance should be persisted to redis. + + Returns: + Whether this state instance was ever modified. + """ + # Ensure the flag is up to date based on the current dirty_vars + self._update_was_touched() + return self._was_touched + def _clean(self): """Reset the dirty vars.""" + # Update touched status before cleaning dirty_vars. + self._update_was_touched() + # Recursively clean the substates. for substate in self.dirty_substates: if substate not in self.substates: @@ -1422,6 +1687,7 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow): state["__dict__"] = state["__dict__"].copy() state["__dict__"]["parent_state"] = None state["__dict__"]["substates"] = {} + state["__dict__"].pop("_was_touched", None) return state @@ -1431,28 +1697,11 @@ class State(BaseState): # The hydrated bool. is_hydrated: bool = False - def on_load_internal(self) -> list[Event | EventSpec] | None: - """Queue on_load handlers for the current page. - Returns: - The list of events to queue for on load handling. - """ - # Do not app.compile_()! It should be already compiled by now. - app = getattr(prerequisites.get_app(), constants.CompileVars.APP) - load_events = app.get_load_events(self.router.page.path) - if not load_events and self.is_hydrated: - return # Fast path for page-to-page navigation - self.is_hydrated = False - return [ - *fix_events( - load_events, - self.router.session.client_token, - router_data=self.router_data, - ), - type(self).set_is_hydrated(True), # type: ignore - ] +class UpdateVarsInternalState(State): + """Substate for handling internal state var updates.""" - def update_vars_internal(self, vars: dict[str, Any]) -> None: + async def update_vars_internal(self, vars: dict[str, Any]) -> None: """Apply updates to fully qualified state vars. The keys in `vars` should be in the form of `{state.get_full_name()}.{var_name}`, @@ -1466,10 +1715,42 @@ class State(BaseState): """ for var, value in vars.items(): state_name, _, var_name = var.rpartition(".") - var_state = self.get_substate(state_name.split(".")) + var_state_cls = State.get_class_substate(tuple(state_name.split("."))) + var_state = await self.get_state(var_state_cls) setattr(var_state, var_name, value) +class OnLoadInternalState(State): + """Substate for handling on_load event enumeration. + + This is a separate substate to avoid deserializing the entire state tree for every page navigation. + """ + + def on_load_internal(self) -> list[Event | EventSpec] | None: + """Queue on_load handlers for the current page. + + Returns: + The list of events to queue for on load handling. + """ + # Do not app.compile_()! It should be already compiled by now. + app = getattr(prerequisites.get_app(), constants.CompileVars.APP) + load_events = app.get_load_events(self.router.page.path) + if not load_events and self.is_hydrated: + return # Fast path for page-to-page navigation + if not load_events: + self.is_hydrated = True + return # Fast path for initial hydrate with no on_load events defined. + self.is_hydrated = False + return [ + *fix_events( + load_events, + self.router.session.client_token, + router_data=self.router_data, + ), + State.set_is_hydrated(True), # type: ignore + ] + + class StateProxy(wrapt.ObjectProxy): """Proxy of a state instance to control mutability of vars for a background task. @@ -1522,9 +1803,10 @@ class StateProxy(wrapt.ObjectProxy): This StateProxy instance in mutable mode. """ self._self_actx = self._self_app.modify_state( - self.__wrapped__.router.session.client_token - + "_" - + ".".join(self._self_substate_path) + token=_substate_key( + self.__wrapped__.router.session.client_token, + self._self_substate_path, + ) ) mutable_state = await self._self_actx.__aenter__() super().__setattr__( @@ -1574,7 +1856,15 @@ class StateProxy(wrapt.ObjectProxy): Returns: The value of the attribute. + + Raises: + ImmutableStateError: If the state is not in mutable mode. """ + if name in ["substates", "parent_state"] and not self._self_mutable: + raise ImmutableStateError( + "Background task StateProxy is immutable outside of a context " + "manager. Use `async with self` to modify state." + ) value = super().__getattr__(name) if not name.startswith("_self_") and isinstance(value, MutableProxy): # ensure mutations to these containers are blocked unless proxy is _mutable @@ -1622,6 +1912,60 @@ class StateProxy(wrapt.ObjectProxy): "manager. Use `async with self` to modify state." ) + def get_substate(self, path: Sequence[str]) -> BaseState: + """Only allow substate access with lock held. + + Args: + path: The path to the substate. + + Returns: + The substate. + + Raises: + ImmutableStateError: If the state is not in mutable mode. + """ + if not self._self_mutable: + raise ImmutableStateError( + "Background task StateProxy is immutable outside of a context " + "manager. Use `async with self` to modify state." + ) + return self.__wrapped__.get_substate(path) + + async def get_state(self, state_cls: Type[BaseState]) -> BaseState: + """Get an instance of the state associated with this token. + + Args: + state_cls: The class of the state. + + Returns: + The state. + + Raises: + ImmutableStateError: If the state is not in mutable mode. + """ + if not self._self_mutable: + raise ImmutableStateError( + "Background task StateProxy is immutable outside of a context " + "manager. Use `async with self` to modify state." + ) + return await self.__wrapped__.get_state(state_cls) + + def _as_state_update(self, *args, **kwargs) -> StateUpdate: + """Temporarily allow mutability to access parent_state. + + Args: + *args: The args to pass to the underlying state instance. + **kwargs: The kwargs to pass to the underlying state instance. + + Returns: + The state update. + """ + self._self_mutable = True + try: + return self.__wrapped__._as_state_update(*args, **kwargs) + finally: + self._self_mutable = False + class StateUpdate(Base): """A state update sent to the frontend.""" @@ -1722,9 +2066,9 @@ class StateManagerMemory(StateManager): The state for the token. """ # Memory state manager ignores the substate suffix and always returns the top-level state. - token = token.partition("_")[0] + token = _split_substate_key(token)[0] if token not in self.states: - self.states[token] = self.state() + self.states[token] = self.state(_reflex_internal_init=True) return self.states[token] async def set_state(self, token: str, state: BaseState): @@ -1747,7 +2091,7 @@ class StateManagerMemory(StateManager): The state for the token. """ # Memory state manager ignores the substate suffix and always returns the top-level state. - token = token.partition("_")[0] + token = _split_substate_key(token)[0] if token not in self._states_locks: async with self._state_manager_lock: if token not in self._states_locks: @@ -1787,6 +2131,81 @@ class StateManagerRedis(StateManager): b"evicted", } + def _get_root_state(self, state: BaseState) -> BaseState: + """Chase parent_state pointers to find an instance of the top-level state. + + Args: + state: The state to start from. + + Returns: + An instance of the top-level state (self.state). + """ + while type(state) != self.state and state.parent_state is not None: + state = state.parent_state + return state + + async def _get_parent_state(self, token: str) -> BaseState | None: + """Get the parent state for the state requested in the token. + + Args: + token: The token to get the state for (_substate_key). + + Returns: + The parent state for the state requested by the token or None if there is no such parent. + """ + parent_state = None + client_token, state_path = _split_substate_key(token) + parent_state_name = state_path.rpartition(".")[0] + if parent_state_name: + # Retrieve the parent state to populate event handlers onto this substate. + parent_state = await self.get_state( + token=_substate_key(client_token, parent_state_name), + top_level=False, + get_substates=False, + ) + return parent_state + + async def _populate_substates( + self, + token: str, + state: BaseState, + all_substates: bool = False, + ): + """Fetch and link substates for the given state instance. + + There is no return value; the side-effect is that `state` will have `substates` populated, + and each substate will have its `parent_state` set to `state`. + + Args: + token: The token to get the state for. + state: The state instance to populate substates for. + all_substates: Whether to fetch all substates or just required substates. + """ + client_token, _ = _split_substate_key(token) + + if all_substates: + # All substates are requested. + fetch_substates = state.get_substates() + else: + # Only _potentially_dirty_substates need to be fetched to recalc computed vars. + fetch_substates = state._potentially_dirty_substates() + + tasks = {} + # Retrieve the necessary substates from redis. + for substate_cls in fetch_substates: + substate_name = substate_cls.get_name() + tasks[substate_name] = asyncio.create_task( + self.get_state( + token=_substate_key(client_token, substate_cls), + top_level=False, + get_substates=all_substates, + parent_state=state, + ) + ) + + for substate_name, substate_task in tasks.items(): + state.substates[substate_name] = await substate_task + async def get_state( self, token: str, @@ -1798,8 +2217,8 @@ class StateManagerRedis(StateManager): Args: token: The token to get the state for. - top_level: If true, return an instance of the top-level state. - get_substates: If true, also retrieve substates + top_level: If true, return an instance of the top-level state (self.state). + get_substates: If true, also retrieve substates. parent_state: If provided, use this parent_state instead of getting it from redis. Returns: @@ -1809,7 +2228,7 @@ class StateManagerRedis(StateManager): RuntimeError: when the state_cls is not specified in the token """ # Split the actual token from the fully qualified substate name. - client_token, _, state_path = token.partition("_") + _, state_path = _split_substate_key(token) if state_path: # Get the State class associated with the given path. state_cls = self.state.get_class_substate(tuple(state_path.split("."))) @@ -1825,66 +2244,49 @@ class StateManagerRedis(StateManager): # Deserialize the substate. state = cloudpickle.loads(redis_state) - # Populate parent and substates if requested. + # Populate parent state if missing and requested. if parent_state is None: - # Retrieve the parent state from redis. - parent_state_name = state_path.rpartition(".")[0] - if parent_state_name: - parent_state_key = token.rpartition(".")[0] - parent_state = await self.get_state( - parent_state_key, top_level=False, get_substates=False - ) + parent_state = await self._get_parent_state(token) # Set up Bidirectional linkage between this state and its parent. if parent_state is not None: parent_state.substates[state.get_name()] = state state.parent_state = parent_state - if get_substates: - # Retrieve all substates from redis. - for substate_cls in state_cls.get_substates(): - substate_name = substate_cls.get_name() - substate_key = token + "." + substate_name - state.substates[substate_name] = await self.get_state( - substate_key, top_level=False, parent_state=state - ) + # Populate substates if requested. + await self._populate_substates(token, state, all_substates=get_substates) + # To retain compatibility with previous implementation, by default, we return # the top-level state by chasing `parent_state` pointers up the tree. if top_level: - while type(state) != self.state and state.parent_state is not None: - state = state.parent_state + return self._get_root_state(state) return state - # Key didn't exist so we have to create a new entry for this token. + # TODO: dedupe the following logic with the above block + # Key didn't exist so we have to create a new instance for this token. if parent_state is None: - parent_state_name = state_path.rpartition(".")[0] - if parent_state_name: - # Retrieve the parent state to populate event handlers onto this substate. - parent_state_key = client_token + "_" + parent_state_name - parent_state = await self.get_state( - parent_state_key, top_level=False, get_substates=False - ) - # Persist the new state class to redis. - await self.set_state( - token, - state_cls( - parent_state=parent_state, - init_substates=False, - ), - ) - # After creating the state key, recursively call `get_state` to populate substates. - return await self.get_state( - token, - top_level=top_level, - get_substates=get_substates, + parent_state = await self._get_parent_state(token) + # Instantiate the new state class (but don't persist it yet). + state = state_cls( parent_state=parent_state, + init_substates=False, + _reflex_internal_init=True, ) + # Set up Bidirectional linkage between this state and its parent. + if parent_state is not None: + parent_state.substates[state.get_name()] = state + state.parent_state = parent_state + # Populate substates for the newly created state. + await self._populate_substates(token, state, all_substates=get_substates) + # To retain compatibility with previous implementation, by default, we return + # the top-level state by chasing `parent_state` pointers up the tree. + if top_level: + return self._get_root_state(state) + return state async def set_state( self, token: str, state: BaseState, lock_id: bytes | None = None, - set_substates: bool = True, - set_parent_state: bool = True, ): """Set the state for a token. @@ -1892,11 +2294,10 @@ class StateManagerRedis(StateManager): token: The token to set the state for. state: The state to set. lock_id: If provided, the lock_key must be set to this value to set the state. - set_substates: If True, write substates to redis - set_parent_state: If True, write parent state to redis Raises: LockExpiredError: If lock_id is provided and the lock for the token is not held by that ID. + RuntimeError: If the state instance doesn't match the state name in the token. """ # Check that we're holding the lock. if ( @@ -1908,28 +2309,36 @@ class StateManagerRedis(StateManager): f"`app.state_manager.lock_expiration` (currently {self.lock_expiration}) " "or use `@rx.background` decorator for long-running tasks." ) - # Find the substate associated with the token. - state_path = token.partition("_")[2] - if state_path and state.get_full_name() != state_path: - state = state.get_substate(tuple(state_path.split("."))) - # Persist the parent state separately, if requested. - if state.parent_state is not None and set_parent_state: - parent_state_key = token.rpartition(".")[0] - await self.set_state( - parent_state_key, - state.parent_state, - lock_id=lock_id, - set_substates=False, + client_token, substate_name = _split_substate_key(token) + # If the substate name on the token doesn't match the instance name, it cannot have a parent. + if state.parent_state is not None and state.get_full_name() != substate_name: + raise RuntimeError( + f"Cannot `set_state` with mismatching token {token} and substate {state.get_full_name()}." ) - # Persist the substates separately, if requested. - if set_substates: - for substate_name, substate in state.substates.items(): - substate_key = token + "." + substate_name - await self.set_state( - substate_key, substate, lock_id=lock_id, set_parent_state=False + + # Recursively set_state on all known substates. + tasks = [] + for substate in state.substates.values(): + tasks.append( + asyncio.create_task( + self.set_state( + token=_substate_key(client_token, substate), + state=substate, + lock_id=lock_id, + ) ) + ) # Persist only the given state (parents or substates are excluded by BaseState.__getstate__). - await self.redis.set(token, cloudpickle.dumps(state), ex=self.token_expiration) + if state._get_was_touched(): + await self.redis.set( + _substate_key(client_token, state), + cloudpickle.dumps(state), + ex=self.token_expiration, + ) + + # Wait for substates to be persisted. + for t in tasks: + await t @contextlib.asynccontextmanager async def modify_state(self, token: str) -> AsyncIterator[BaseState]: @@ -1957,7 +2366,7 @@ class StateManagerRedis(StateManager): The redis lock key for the token. """ # All substates share the same lock domain, so ignore any substate path suffix. - client_token = token.partition("_")[0] + client_token = _split_substate_key(token)[0] return f"{client_token}_lock".encode() async def _try_get_lock(self, lock_key: bytes, lock_id: bytes) -> bool | None: @@ -2052,6 +2461,16 @@ class StateManagerRedis(StateManager): await self.redis.close(close_connection_pool=True) +def get_state_manager() -> StateManager: + """Get the state manager for the app that is currently running. + + Returns: + The state manager. + """ + app = getattr(prerequisites.get_app(), constants.CompileVars.APP) + return app.state_manager + + class ClientStorageBase: """Base class for client-side storage.""" diff --git a/reflex/testing.py b/reflex/testing.py index 78d220847..dadce40a5 100644 --- a/reflex/testing.py +++ b/reflex/testing.py @@ -70,6 +70,10 @@ else: FRONTEND_POPEN_ARGS["start_new_session"] = True +# Save a copy of internal substates to reset after each test. +INTERNAL_STATES = State.class_subclasses.copy() + + # borrowed from py3.11 class chdir(contextlib.AbstractContextManager): """Non thread-safe context manager to change the current working directory.""" @@ -220,6 +224,8 @@ class AppHarness: reflex.config.get_config(reload=True) # reset rx.State subclasses State.class_subclasses.clear() + State.class_subclasses.update(INTERNAL_STATES) + State._always_dirty_substates = set() State.get_class_substate.cache_clear() # Ensure the AppHarness test does not skip State assignment due to running via pytest os.environ.pop(reflex.constants.PYTEST_CURRENT_TEST, None) diff --git a/reflex/utils/exec.py b/reflex/utils/exec.py index a992534b4..e3c7eb586 100644 --- a/reflex/utils/exec.py +++ b/reflex/utils/exec.py @@ -285,3 +285,12 @@ def output_system_info(): console.debug(f"Using package executer at: {prerequisites.get_package_manager()}") # type: ignore if system != "Windows": console.debug(f"Unzip path: {path_ops.which('unzip')}") + + +def is_testing_env() -> bool: + """Whether the app is running in a testing environment. + + Returns: + True if the app is running in under pytest. + """ + return constants.PYTEST_CURRENT_TEST in os.environ diff --git a/reflex/vars.py b/reflex/vars.py index ed01e8d33..eed60a946 100644 --- a/reflex/vars.py +++ b/reflex/vars.py @@ -1875,6 +1875,10 @@ class ComputedVar(Var, property): Returns: A set of variable names accessed by the given obj. + + Raises: + ValueError: if the function references the get_state, parent_state, or substates attributes + (cannot track deps in a related state, only implicitly via parent state). """ d = set() if obj is None: @@ -1898,6 +1902,8 @@ class ComputedVar(Var, property): if self_name is None: # cannot reference attributes on self if method takes no args return set() + + invalid_names = ["get_state", "parent_state", "substates", "get_substate"] self_is_top_of_stack = False for instruction in dis.get_instructions(obj): if ( @@ -1916,6 +1922,10 @@ class ComputedVar(Var, property): ref_obj = getattr(objclass, instruction.argval) except Exception: ref_obj = None + if instruction.argval in invalid_names: + raise ValueError( + f"Cached var {self._var_full_name} cannot access arbitrary state via `{instruction.argval}`." + ) if callable(ref_obj): # recurse into callable attributes d.update( diff --git a/tests/test_app.py b/tests/test_app.py index b3153078a..0da6c11f9 100644 --- a/tests/test_app.py +++ b/tests/test_app.py @@ -29,7 +29,15 @@ from reflex.components.radix.themes.typography.text import Text from reflex.event import Event from reflex.middleware import HydrateMiddleware from reflex.model import Model -from reflex.state import BaseState, RouterData, State, StateManagerRedis, StateUpdate +from reflex.state import ( + BaseState, + OnLoadInternalState, + RouterData, + State, + StateManagerRedis, + StateUpdate, + _substate_key, +) from reflex.style import Style from reflex.utils import format from reflex.vars import ComputedVar @@ -362,7 +370,7 @@ async def test_initialize_with_state(test_state: Type[ATestState], token: str): assert app.state == test_state # Get a state for a given token. - state = await app.state_manager.get_state(f"{token}_{test_state.get_full_name()}") + state = await app.state_manager.get_state(_substate_key(token, test_state)) assert isinstance(state, test_state) assert state.var == 0 # type: ignore @@ -766,8 +774,7 @@ async def test_upload_file(tmp_path, state, delta, token: str, mocker): # The App state must be the "root" of the state tree app = App(state=State) app.event_namespace.emit = AsyncMock() # type: ignore - substate_token = f"{token}_{state.get_full_name()}" - current_state = await app.state_manager.get_state(substate_token) + current_state = await app.state_manager.get_state(_substate_key(token, state)) data = b"This is binary data" # Create a binary IO object and write data to it @@ -796,7 +803,7 @@ async def test_upload_file(tmp_path, state, delta, token: str, mocker): == StateUpdate(delta=delta, events=[], final=True).json() + "\n" ) - current_state = await app.state_manager.get_state(substate_token) + current_state = await app.state_manager.get_state(_substate_key(token, state)) state_dict = current_state.dict()[state.get_full_name()] assert state_dict["img_list"] == [ "image1.jpg", @@ -913,7 +920,7 @@ class DynamicState(BaseState): # self.side_effect_counter = self.side_effect_counter + 1 return self.dynamic - on_load_internal = State.on_load_internal.fn + on_load_internal = OnLoadInternalState.on_load_internal.fn @pytest.mark.asyncio @@ -950,7 +957,7 @@ async def test_dynamic_route_var_route_change_completed_on_load( } assert constants.ROUTER in app.state()._computed_var_dependencies - substate_token = f"{token}_{DynamicState.get_full_name()}" + substate_token = _substate_key(token, DynamicState) sid = "mock_sid" client_ip = "127.0.0.1" state = await app.state_manager.get_state(substate_token) @@ -978,7 +985,7 @@ async def test_dynamic_route_var_route_change_completed_on_load( prev_exp_val = "" for exp_index, exp_val in enumerate(exp_vals): on_load_internal = _event( - name=f"{state.get_full_name()}.{constants.CompileVars.ON_LOAD_INTERNAL}", + name=f"{state.get_full_name()}.{constants.CompileVars.ON_LOAD_INTERNAL.rpartition('.')[2]}", val=exp_val, ) exp_router_data = { @@ -1013,8 +1020,8 @@ async def test_dynamic_route_var_route_change_completed_on_load( name="on_load", val=exp_val, ), - _dynamic_state_event( - name="set_is_hydrated", + _event( + name="state.set_is_hydrated", payload={"value": True}, val=exp_val, router_data={}, diff --git a/tests/test_state.py b/tests/test_state.py index b00434c42..c1c692a31 100644 --- a/tests/test_state.py +++ b/tests/test_state.py @@ -23,6 +23,7 @@ from reflex.state import ( ImmutableStateError, LockExpiredError, MutableProxy, + OnLoadInternalState, RouterData, State, StateManager, @@ -30,6 +31,7 @@ from reflex.state import ( StateManagerRedis, StateProxy, StateUpdate, + _substate_key, ) from reflex.utils import prerequisites, types from reflex.utils.format import json_dumps @@ -139,6 +141,12 @@ class ChildState2(TestState): value: str +class ChildState3(TestState): + """A child state fixture.""" + + value: str + + class GrandchildState(ChildState): """A grandchild state fixture.""" @@ -149,6 +157,32 @@ class GrandchildState(ChildState): pass +class GrandchildState2(ChildState2): + """A grandchild state fixture.""" + + @rx.cached_var + def cached(self) -> str: + """A cached var. + + Returns: + The value. + """ + return self.value + + +class GrandchildState3(ChildState3): + """A great grandchild state fixture.""" + + @rx.var + def computed(self) -> str: + """A computed var. + + Returns: + The value. + """ + return self.value + + class DateTimeState(BaseState): """A State with some datetime fields.""" @@ -329,6 +363,9 @@ def test_dict(test_state): "test_state.child_state", "test_state.child_state.grandchild_state", "test_state.child_state2", + "test_state.child_state2.grandchild_state2", + "test_state.child_state3", + "test_state.child_state3.grandchild_state3", } test_state_dict = test_state.dict() assert set(test_state_dict) == substates @@ -380,10 +417,11 @@ def test_get_parent_state(): def test_get_substates(): """Test getting the substates.""" - assert TestState.get_substates() == {ChildState, ChildState2} + assert TestState.get_substates() == {ChildState, ChildState2, ChildState3} assert ChildState.get_substates() == {GrandchildState} - assert ChildState2.get_substates() == set() + assert ChildState2.get_substates() == {GrandchildState2} assert GrandchildState.get_substates() == set() + assert GrandchildState2.get_substates() == set() def test_get_name(): @@ -469,8 +507,8 @@ def test_set_parent_and_substates(test_state, child_state, grandchild_state): child_state: A child state. grandchild_state: A grandchild state. """ - assert len(test_state.substates) == 2 - assert set(test_state.substates) == {"child_state", "child_state2"} + assert len(test_state.substates) == 3 + assert set(test_state.substates) == {"child_state", "child_state2", "child_state3"} assert child_state.parent_state == test_state assert len(child_state.substates) == 1 @@ -655,7 +693,7 @@ def test_reset(test_state, child_state): assert child_state.dirty_vars == {"count", "value"} # The dirty substates should be reset. - assert test_state.dirty_substates == {"child_state", "child_state2"} + assert test_state.dirty_substates == {"child_state", "child_state2", "child_state3"} @pytest.mark.asyncio @@ -675,7 +713,10 @@ async def test_process_event_simple(test_state): # The delta should contain the changes, including computed vars. # assert update.delta == {"test_state": {"num1": 69, "sum": 72.14}} - assert update.delta == {"test_state": {"num1": 69, "sum": 72.14, "upper": ""}} + assert update.delta == { + "test_state": {"num1": 69, "sum": 72.14, "upper": ""}, + "test_state.child_state3.grandchild_state3": {"computed": ""}, + } assert update.events == [] @@ -700,6 +741,7 @@ async def test_process_event_substate(test_state, child_state, grandchild_state) assert update.delta == { "test_state": {"sum": 3.14, "upper": ""}, "test_state.child_state": {"value": "HI", "count": 24}, + "test_state.child_state3.grandchild_state3": {"computed": ""}, } test_state._clean() @@ -715,6 +757,7 @@ async def test_process_event_substate(test_state, child_state, grandchild_state) assert update.delta == { "test_state": {"sum": 3.14, "upper": ""}, "test_state.child_state.grandchild_state": {"value2": "new"}, + "test_state.child_state3.grandchild_state3": {"computed": ""}, } @@ -1443,7 +1486,7 @@ def substate_token(state_manager, token): Returns: Token concatenated with the state_manager's state full_name. """ - return f"{token}_{state_manager.state.get_full_name()}" + return _substate_key(token, state_manager.state) @pytest.mark.asyncio @@ -1545,7 +1588,7 @@ def substate_token_redis(state_manager_redis, token): Returns: Token concatenated with the state_manager's state full_name. """ - return f"{token}_{state_manager_redis.state.get_full_name()}" + return _substate_key(token, state_manager_redis.state) @pytest.mark.asyncio @@ -1670,6 +1713,22 @@ async def test_state_proxy(grandchild_state: GrandchildState, mock_app: rx.App): # cannot directly modify state proxy outside of async context sp.value2 = "16" + with pytest.raises(ImmutableStateError): + # Cannot get_state + await sp.get_state(ChildState) + + with pytest.raises(ImmutableStateError): + # Cannot access get_substate + sp.get_substate([]) + + with pytest.raises(ImmutableStateError): + # Cannot access parent state + sp.parent_state.get_name() + + with pytest.raises(ImmutableStateError): + # Cannot access substates + sp.substates[""] + async with sp: assert sp._self_actx is not None assert sp._self_mutable # proxy is mutable inside context @@ -1685,8 +1744,9 @@ async def test_state_proxy(grandchild_state: GrandchildState, mock_app: rx.App): assert sp.value2 == "42" # Get the state from the state manager directly and check that the value is updated - gc_token = f"{grandchild_state.get_token()}_{grandchild_state.get_full_name()}" - gotten_state = await mock_app.state_manager.get_state(gc_token) + gotten_state = await mock_app.state_manager.get_state( + _substate_key(grandchild_state.router.session.client_token, grandchild_state) + ) if isinstance(mock_app.state_manager, StateManagerMemory): # For in-process store, only one instance of the state exists assert gotten_state is parent_state @@ -1710,6 +1770,9 @@ async def test_state_proxy(grandchild_state: GrandchildState, mock_app: rx.App): grandchild_state.get_full_name(): { "value2": "42", }, + GrandchildState3.get_full_name(): { + "computed": "", + }, } ) assert mcall.kwargs["to"] == grandchild_state.get_sid() @@ -1879,8 +1942,11 @@ async def test_background_task_no_block(mock_app: rx.App, token: str): "private", ] - substate_token = f"{token}_{BackgroundTaskState.get_name()}" - assert (await mock_app.state_manager.get_state(substate_token)).order == exp_order + assert ( + await mock_app.state_manager.get_state( + _substate_key(token, BackgroundTaskState) + ) + ).order == exp_order assert mock_app.event_namespace is not None emit_mock = mock_app.event_namespace.emit @@ -1957,8 +2023,11 @@ async def test_background_task_reset(mock_app: rx.App, token: str): await task assert not mock_app.background_tasks - substate_token = f"{token}_{BackgroundTaskState.get_name()}" - assert (await mock_app.state_manager.get_state(substate_token)).order == [ + assert ( + await mock_app.state_manager.get_state( + _substate_key(token, BackgroundTaskState) + ) + ).order == [ "reset", ] @@ -2246,7 +2315,7 @@ def test_mutable_copy_vars(mutable_state, copy_func): def test_duplicate_substate_class(mocker): - mocker.patch("reflex.state.os.environ", {}) + mocker.patch("reflex.state.is_testing_env", lambda: False) with pytest.raises(ValueError): class TestState(BaseState): @@ -2435,7 +2504,9 @@ async def test_preprocess(app_module_mock, token, test_state, expected, mocker): expected: Expected delta. mocker: pytest mock object. """ - mocker.patch("reflex.state.State.class_subclasses", {test_state}) + mocker.patch( + "reflex.state.State.class_subclasses", {test_state, OnLoadInternalState} + ) app = app_module_mock.app = App( state=State, load_events={"index": [test_state.test_handler]} ) @@ -2476,7 +2547,9 @@ async def test_preprocess_multiple_load_events(app_module_mock, token, mocker): token: A token. mocker: pytest mock object. """ - mocker.patch("reflex.state.State.class_subclasses", {OnLoadState}) + mocker.patch( + "reflex.state.State.class_subclasses", {OnLoadState, OnLoadInternalState} + ) app = app_module_mock.app = App( state=State, load_events={"index": [OnLoadState.test_handler, OnLoadState.test_handler]}, @@ -2510,3 +2583,120 @@ async def test_preprocess_multiple_load_events(app_module_mock, token, mocker): OnLoadState.get_full_name(): {"num": 2} } assert (await state._process(events[2]).__anext__()).delta == exp_is_hydrated(state) + + +@pytest.mark.asyncio +async def test_get_state(mock_app: rx.App, token: str): + """Test that a get_state populates the top level state and delta calculation is correct. + + Args: + mock_app: An app that will be returned by `get_app()` + token: A token. + """ + mock_app.state_manager.state = mock_app.state = TestState + + # Get instance of ChildState2. + test_state = await mock_app.state_manager.get_state( + _substate_key(token, ChildState2) + ) + assert isinstance(test_state, TestState) + if isinstance(mock_app.state_manager, StateManagerMemory): + # All substates are available + assert tuple(sorted(test_state.substates)) == ( + "child_state", + "child_state2", + "child_state3", + ) + else: + # Sibling states are only populated if they have computed vars + assert tuple(sorted(test_state.substates)) == ("child_state2", "child_state3") + + # Because ChildState3 has a computed var, it is always dirty, and always populated. + assert ( + test_state.substates["child_state3"].substates["grandchild_state3"].computed + == "" + ) + + # Get the child_state2 directly. + child_state2_direct = test_state.get_substate(["child_state2"]) + child_state2_get_state = await test_state.get_state(ChildState2) + # These should be the same object. + assert child_state2_direct is child_state2_get_state + + # Get arbitrary GrandchildState. + grandchild_state = await child_state2_get_state.get_state(GrandchildState) + assert isinstance(grandchild_state, GrandchildState) + + # Now the original root should have all substates populated. + assert tuple(sorted(test_state.substates)) == ( + "child_state", + "child_state2", + "child_state3", + ) + + # ChildState should be retrievable + child_state_direct = test_state.get_substate(["child_state"]) + child_state_get_state = await test_state.get_state(ChildState) + # These should be the same object. + assert child_state_direct is child_state_get_state + + # GrandchildState instance should be the same as the one retrieved from the child_state2. + assert grandchild_state is child_state_direct.get_substate(["grandchild_state"]) + grandchild_state.value2 = "set_value" + + assert test_state.get_delta() == { + TestState.get_full_name(): { + "sum": 3.14, + "upper": "", + }, + GrandchildState.get_full_name(): { + "value2": "set_value", + }, + GrandchildState3.get_full_name(): { + "computed": "", + }, + } + + # Get a fresh instance + new_test_state = await mock_app.state_manager.get_state( + _substate_key(token, ChildState2) + ) + assert isinstance(new_test_state, TestState) + if isinstance(mock_app.state_manager, StateManagerMemory): + # In memory, it's the same instance + assert new_test_state is test_state + test_state._clean() + # All substates are available + assert tuple(sorted(new_test_state.substates)) == ( + "child_state", + "child_state2", + "child_state3", + ) + else: + # With redis, we get a whole new instance + assert new_test_state is not test_state + # Sibling states are only populated if they have computed vars + assert tuple(sorted(new_test_state.substates)) == ( + "child_state2", + "child_state3", + ) + + # Set a value on child_state2, should update cached var in grandchild_state2 + child_state2 = new_test_state.get_substate(("child_state2",)) + child_state2.value = "set_c2_value" + + assert new_test_state.get_delta() == { + TestState.get_full_name(): { + "sum": 3.14, + "upper": "", + }, + ChildState2.get_full_name(): { + "value": "set_c2_value", + }, + GrandchildState2.get_full_name(): { + "cached": "set_c2_value", + }, + GrandchildState3.get_full_name(): { + "computed": "", + }, + } diff --git a/tests/test_state_tree.py b/tests/test_state_tree.py new file mode 100644 index 000000000..0747f900c --- /dev/null +++ b/tests/test_state_tree.py @@ -0,0 +1,371 @@ +"""Specialized test for a larger state tree.""" +import asyncio +from typing import Generator + +import pytest + +import reflex as rx +from reflex.state import BaseState, StateManager, StateManagerRedis, _substate_key + + +class Root(BaseState): + """Root of the state tree.""" + + root: int + + +class TreeA(Root): + """TreeA is a child of Root.""" + + a: int + + +class SubA_A(TreeA): + """SubA_A is a child of TreeA.""" + + sub_a_a: int + + +class SubA_A_A(SubA_A): + """SubA_A_A is a child of SubA_A.""" + + sub_a_a_a: int + + +class SubA_A_A_A(SubA_A_A): + """SubA_A_A_A is a child of SubA_A_A.""" + + sub_a_a_a_a: int + + +class SubA_A_A_B(SubA_A_A): + """SubA_A_A_B is a child of SubA_A_A.""" + + @rx.cached_var + def sub_a_a_a_cached(self) -> int: + """A cached var. + + Returns: + The value of sub_a_a_a + 1 + """ + return self.sub_a_a_a + 1 + + +class SubA_A_A_C(SubA_A_A): + """SubA_A_A_C is a child of SubA_A_A.""" + + sub_a_a_a_c: int + + +class SubA_A_B(SubA_A): + """SubA_A_B is a child of SubA_A.""" + + sub_a_a_b: int + + +class SubA_B(TreeA): + """SubA_B is a child of TreeA.""" + + sub_a_b: int + + +class TreeB(Root): + """TreeB is a child of Root.""" + + b: int + + +class SubB_A(TreeB): + """SubB_A is a child of TreeB.""" + + sub_b_a: int + + +class SubB_B(TreeB): + """SubB_B is a child of TreeB.""" + + sub_b_b: int + + +class SubB_C(TreeB): + """SubB_C is a child of TreeB.""" + + sub_b_c: int + + +class SubB_C_A(SubB_C): + """SubB_C_A is a child of SubB_C.""" + + sub_b_c_a: int + + +class TreeC(Root): + """TreeC is a child of Root.""" + + c: int + + +class SubC_A(TreeC): + """SubC_A is a child of TreeC.""" + + sub_c_a: int + + +class TreeD(Root): + """TreeD is a child of Root.""" + + d: int + + @rx.var + def d_var(self) -> int: + """A computed var. + + Returns: + The value of d + 1 + """ + return self.d + 1 + + +class TreeE(Root): + """TreeE is a child of Root.""" + + e: int + + +class SubE_A(TreeE): + """SubE_A is a child of TreeE.""" + + sub_e_a: int + + +class SubE_A_A(SubE_A): + """SubE_A_A is a child of SubE_A.""" + + sub_e_a_a: int + + +class SubE_A_A_A(SubE_A_A): + """SubE_A_A_A is a child of SubE_A_A.""" + + sub_e_a_a_a: int + + +class SubE_A_A_A_A(SubE_A_A_A): + """SubE_A_A_A_A is a child of SubE_A_A_A.""" + + sub_e_a_a_a_a: int + + @rx.var + def sub_e_a_a_a_a_var(self) -> int: + """A computed var. + + Returns: + The value of sub_e_a_a_a_a + 1 + """ + return self.sub_e_a_a_a + 1 + + +class SubE_A_A_A_B(SubE_A_A_A): + """SubE_A_A_A_B is a child of SubE_A_A_A.""" + + sub_e_a_a_a_b: int + + +class SubE_A_A_A_C(SubE_A_A_A): + """SubE_A_A_A_C is a child of SubE_A_A_A.""" + + sub_e_a_a_a_c: int + + +class SubE_A_A_A_D(SubE_A_A_A): + """SubE_A_A_A_D is a child of SubE_A_A_A.""" + + sub_e_a_a_a_d: int + + @rx.cached_var + def sub_e_a_a_a_d_var(self) -> int: + """A computed var. + + Returns: + The value of sub_e_a_a_a_a + 1 + """ + return self.sub_e_a_a_a + 1 + + +ALWAYS_COMPUTED_VARS = { + TreeD.get_full_name(): {"d_var": 1}, + SubE_A_A_A_A.get_full_name(): {"sub_e_a_a_a_a_var": 1}, +} + +ALWAYS_COMPUTED_DICT_KEYS = [ + Root.get_full_name(), + TreeD.get_full_name(), + TreeE.get_full_name(), + SubE_A.get_full_name(), + SubE_A_A.get_full_name(), + SubE_A_A_A.get_full_name(), + SubE_A_A_A_A.get_full_name(), + SubE_A_A_A_D.get_full_name(), +] + + +@pytest.fixture(scope="function") +def state_manager_redis(app_module_mock) -> Generator[StateManager, None, None]: + """Instance of state manager for redis only. + + Args: + app_module_mock: The app module mock fixture. + + Yields: + A state manager instance + """ + app_module_mock.app = rx.App(state=Root) + state_manager = app_module_mock.app.state_manager + + if not isinstance(state_manager, StateManagerRedis): + pytest.skip("Test requires redis") + + yield state_manager + + asyncio.get_event_loop().run_until_complete(state_manager.close()) + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + ("substate_cls", "exp_root_substates", "exp_root_dict_keys"), + [ + ( + Root, + ["tree_a", "tree_b", "tree_c", "tree_d", "tree_e"], + [ + TreeA.get_full_name(), + SubA_A.get_full_name(), + SubA_A_A.get_full_name(), + SubA_A_A_A.get_full_name(), + SubA_A_A_B.get_full_name(), + SubA_A_A_C.get_full_name(), + SubA_A_B.get_full_name(), + SubA_B.get_full_name(), + TreeB.get_full_name(), + SubB_A.get_full_name(), + SubB_B.get_full_name(), + SubB_C.get_full_name(), + SubB_C_A.get_full_name(), + TreeC.get_full_name(), + SubC_A.get_full_name(), + SubE_A_A_A_B.get_full_name(), + SubE_A_A_A_C.get_full_name(), + *ALWAYS_COMPUTED_DICT_KEYS, + ], + ), + ( + TreeA, + ("tree_a", "tree_d", "tree_e"), + [ + TreeA.get_full_name(), + SubA_A.get_full_name(), + SubA_A_A.get_full_name(), + SubA_A_A_A.get_full_name(), + SubA_A_A_B.get_full_name(), + SubA_A_A_C.get_full_name(), + SubA_A_B.get_full_name(), + SubA_B.get_full_name(), + *ALWAYS_COMPUTED_DICT_KEYS, + ], + ), + ( + SubA_A_A_A, + ["tree_a", "tree_d", "tree_e"], + [ + TreeA.get_full_name(), + SubA_A.get_full_name(), + SubA_A_A.get_full_name(), + SubA_A_A_A.get_full_name(), + SubA_A_A_B.get_full_name(), # Cached var dep + *ALWAYS_COMPUTED_DICT_KEYS, + ], + ), + ( + TreeB, + ["tree_b", "tree_d", "tree_e"], + [ + TreeB.get_full_name(), + SubB_A.get_full_name(), + SubB_B.get_full_name(), + SubB_C.get_full_name(), + SubB_C_A.get_full_name(), + *ALWAYS_COMPUTED_DICT_KEYS, + ], + ), + ( + SubB_B, + ["tree_b", "tree_d", "tree_e"], + [ + TreeB.get_full_name(), + SubB_B.get_full_name(), + *ALWAYS_COMPUTED_DICT_KEYS, + ], + ), + ( + SubB_C_A, + ["tree_b", "tree_d", "tree_e"], + [ + TreeB.get_full_name(), + SubB_C.get_full_name(), + SubB_C_A.get_full_name(), + *ALWAYS_COMPUTED_DICT_KEYS, + ], + ), + ( + TreeC, + ["tree_c", "tree_d", "tree_e"], + [ + TreeC.get_full_name(), + SubC_A.get_full_name(), + *ALWAYS_COMPUTED_DICT_KEYS, + ], + ), + ( + TreeD, + ["tree_d", "tree_e"], + [ + *ALWAYS_COMPUTED_DICT_KEYS, + ], + ), + ( + TreeE, + ["tree_d", "tree_e"], + [ + # Extra siblings of computed var included now. + SubE_A_A_A_B.get_full_name(), + SubE_A_A_A_C.get_full_name(), + *ALWAYS_COMPUTED_DICT_KEYS, + ], + ), + ], +) +async def test_get_state_tree( + state_manager_redis, + token, + substate_cls, + exp_root_substates, + exp_root_dict_keys, +): + """Test getting state trees and assert on which branches are retrieved. + + Args: + state_manager_redis: The state manager redis fixture. + token: The token fixture. + substate_cls: The substate class to retrieve. + exp_root_substates: The expected substates of the root state. + exp_root_dict_keys: The expected keys of the root state dict. + """ + state = await state_manager_redis.get_state(_substate_key(token, substate_cls)) + assert isinstance(state, Root) + assert sorted(state.substates) == sorted(exp_root_substates) + + # Only computed vars should be returned + assert state.get_delta() == ALWAYS_COMPUTED_VARS + + # All of TreeA, TreeD, and TreeE substates should be in the dict + assert sorted(state.dict()) == sorted(exp_root_dict_keys) diff --git a/tests/utils/test_format.py b/tests/utils/test_format.py index dee55886a..19f385175 100644 --- a/tests/utils/test_format.py +++ b/tests/utils/test_format.py @@ -13,8 +13,11 @@ from reflex.vars import BaseVar, Var from tests.test_state import ( ChildState, ChildState2, + ChildState3, DateTimeState, GrandchildState, + GrandchildState2, + GrandchildState3, TestState, ) @@ -649,7 +652,7 @@ formatted_router = { "input, output", [ ( - TestState().dict(), # type: ignore + TestState(_reflex_internal_init=True).dict(), # type: ignore { TestState.get_full_name(): { "array": [1, 2, 3.14], @@ -674,11 +677,14 @@ formatted_router = { "value": "", }, ChildState2.get_full_name(): {"value": ""}, + ChildState3.get_full_name(): {"value": ""}, GrandchildState.get_full_name(): {"value2": ""}, + GrandchildState2.get_full_name(): {"cached": ""}, + GrandchildState3.get_full_name(): {"computed": ""}, }, ), ( - DateTimeState().dict(), + DateTimeState(_reflex_internal_init=True).dict(), # type: ignore { DateTimeState.get_full_name(): { "d": "1989-11-09",