From a2243190ff6818b6108ebf841e6dac1509e45569 Mon Sep 17 00:00:00 2001 From: Masen Furer Date: Fri, 31 Jan 2025 16:33:30 -0800 Subject: [PATCH] [ENG-4326] Async ComputedVar (#4711) * WiP * Save the var from get_var_name * flatten StateManagerRedis.get_state algorithm simplify fetching of states and avoid repeatedly fetching the same state * Get all the states in a single redis round-trip * update docstrings in StateManagerRedis * Move computed var dep tracking to separate module * Fix pre-commit issues * ComputedVar.add_dependency: explicitly dependency declaration Allow var dependencies to be added at runtime, for example, when defining a ComponentState that depends on vars that cannot be known statically. Fix more pyright issues. * Fix/ignore more pyright issues from recent merge * handle cleaning out _potentially_dirty_states on reload * ignore accessed attributes missing on state class these might be added dynamically later in which case we recompute the dependency tracking dicts... if not, they'll blow up anyway at runtime. * fix playwright tests, which insist on running an asyncio loop --------- Co-authored-by: Khaleel Al-Adhami --- reflex/app.py | 16 +- reflex/compiler/utils.py | 24 +- reflex/middleware/hydrate_middleware.py | 4 +- reflex/state.py | 570 ++++++++---------- reflex/utils/exec.py | 2 +- reflex/vars/base.py | 347 +++++++---- reflex/vars/dep_tracking.py | 344 +++++++++++ .../tests_playwright/test_table.py | 12 +- tests/units/test_app.py | 18 +- tests/units/test_state.py | 212 +++++-- tests/units/test_var.py | 28 +- 11 files changed, 1088 insertions(+), 489 deletions(-) create mode 100644 reflex/vars/dep_tracking.py diff --git a/reflex/app.py b/reflex/app.py index ad123a655..ce6808816 100644 --- a/reflex/app.py +++ b/reflex/app.py @@ -908,11 +908,17 @@ class App(MiddlewareMixin, LifespanMixin): if not var._cache: continue deps = var._deps(objclass=state) - for dep in deps: - if dep not in state.vars and dep not in state.backend_vars: - raise exceptions.VarDependencyError( - f"ComputedVar {var._js_expr} on state {state.__name__} has an invalid dependency {dep}" - ) + for state_name, dep_set in deps.items(): + state_cls = ( + state.get_root_state().get_class_substate(state_name) + if state_name != state.get_full_name() + else state + ) + for dep in dep_set: + if dep not in state_cls.vars and dep not in state_cls.backend_vars: + raise exceptions.VarDependencyError( + f"ComputedVar {var._js_expr} on state {state.__name__} has an invalid dependency {state_name}.{dep}" + ) for substate in state.class_subclasses: self._validate_var_dependencies(substate) diff --git a/reflex/compiler/utils.py b/reflex/compiler/utils.py index d145e6c0b..9b388400b 100644 --- a/reflex/compiler/utils.py +++ b/reflex/compiler/utils.py @@ -2,12 +2,15 @@ from __future__ import annotations +import asyncio +import concurrent.futures import traceback from datetime import datetime from pathlib import Path from typing import Any, Callable, Dict, Optional, Type, Union from urllib.parse import urlparse +from reflex.utils.exec import is_in_app_harness from reflex.utils.prerequisites import get_web_dir from reflex.vars.base import Var @@ -33,7 +36,7 @@ from reflex.components.base import ( ) from reflex.components.component import Component, ComponentStyle, CustomComponent from reflex.istate.storage import Cookie, LocalStorage, SessionStorage -from reflex.state import BaseState +from reflex.state import BaseState, _resolve_delta from reflex.style import Style from reflex.utils import console, format, imports, path_ops from reflex.utils.imports import ImportVar, ParsedImportDict @@ -177,7 +180,24 @@ def compile_state(state: Type[BaseState]) -> dict: initial_state = state(_reflex_internal_init=True).dict( initial=True, include_computed=False ) - return initial_state + try: + _ = asyncio.get_running_loop() + except RuntimeError: + pass + else: + if is_in_app_harness(): + # Playwright tests already have an event loop running, so we can't use asyncio.run. + with concurrent.futures.ThreadPoolExecutor() as pool: + resolved_initial_state = pool.submit( + asyncio.run, _resolve_delta(initial_state) + ).result() + console.warn( + f"Had to get initial state in a thread 🤮 {resolved_initial_state}", + ) + return resolved_initial_state + + # Normally the compile runs before any event loop starts, we asyncio.run is available for calling. + return asyncio.run(_resolve_delta(initial_state)) def _compile_client_storage_field( diff --git a/reflex/middleware/hydrate_middleware.py b/reflex/middleware/hydrate_middleware.py index 2198b82c2..2dea54e17 100644 --- a/reflex/middleware/hydrate_middleware.py +++ b/reflex/middleware/hydrate_middleware.py @@ -8,7 +8,7 @@ from typing import TYPE_CHECKING, Optional from reflex import constants from reflex.event import Event, get_hydrate_event from reflex.middleware.middleware import Middleware -from reflex.state import BaseState, StateUpdate +from reflex.state import BaseState, StateUpdate, _resolve_delta if TYPE_CHECKING: from reflex.app import App @@ -42,7 +42,7 @@ class HydrateMiddleware(Middleware): setattr(state, constants.CompileVars.IS_HYDRATED, False) # Get the initial state. - delta = state.dict() + delta = await _resolve_delta(state.dict()) # since a full dict was captured, clean any dirtiness state._clean() diff --git a/reflex/state.py b/reflex/state.py index 6c74d5e55..92aaa4710 100644 --- a/reflex/state.py +++ b/reflex/state.py @@ -15,7 +15,6 @@ import time import typing import uuid from abc import ABC, abstractmethod -from collections import defaultdict from hashlib import md5 from pathlib import Path from types import FunctionType, MethodType @@ -329,6 +328,25 @@ def get_var_for_field(cls: Type[BaseState], f: ModelField): ) +async def _resolve_delta(delta: Delta) -> Delta: + """Await all coroutines in the delta. + + Args: + delta: The delta to process. + + Returns: + The same delta dict with all coroutines resolved to their return value. + """ + tasks = {} + for state_name, state_delta in delta.items(): + for var_name, value in state_delta.items(): + if asyncio.iscoroutine(value): + tasks[state_name, var_name] = asyncio.create_task(value) + for (state_name, var_name), task in tasks.items(): + delta[state_name][var_name] = await task + return delta + + class BaseState(Base, ABC, extra=pydantic.Extra.allow): """The state of the app.""" @@ -356,11 +374,8 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow): # A set of subclassses of this class. class_subclasses: ClassVar[Set[Type[BaseState]]] = set() - # Mapping of var name to set of computed variables that depend on it - _computed_var_dependencies: ClassVar[Dict[str, Set[str]]] = {} - - # Mapping of var name to set of substates that depend on it - _substate_var_dependencies: ClassVar[Dict[str, Set[str]]] = {} + # Mapping of var name to set of (state_full_name, var_name) that depend on it. + _var_dependencies: ClassVar[Dict[str, Set[Tuple[str, str]]]] = {} # Set of vars which always need to be recomputed _always_dirty_computed_vars: ClassVar[Set[str]] = set() @@ -368,6 +383,9 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow): # Set of substates which always need to be recomputed _always_dirty_substates: ClassVar[Set[str]] = set() + # Set of states which might need to be recomputed if vars in this state change. + _potentially_dirty_states: ClassVar[Set[str]] = set() + # The parent state. parent_state: Optional[BaseState] = None @@ -519,6 +537,7 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow): # Reset dirty substate tracking for this class. cls._always_dirty_substates = set() + cls._potentially_dirty_states = set() # Get the parent vars. parent_state = cls.get_parent_state() @@ -622,8 +641,7 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow): setattr(cls, name, handler) # Initialize per-class var dependency tracking. - cls._computed_var_dependencies = defaultdict(set) - cls._substate_var_dependencies = defaultdict(set) + cls._var_dependencies = {} cls._init_var_dependency_dicts() @staticmethod @@ -768,26 +786,27 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow): Additional updates tracking dicts for vars and substates that always need to be recomputed. """ - inherited_vars = set(cls.inherited_vars).union( - set(cls.inherited_backend_vars), - ) for cvar_name, cvar in cls.computed_vars.items(): - # Add the dependencies. - for var in cvar._deps(objclass=cls): - cls._computed_var_dependencies[var].add(cvar_name) - if var in inherited_vars: - # track that this substate depends on its parent for this var - state_name = cls.get_name() - parent_state = cls.get_parent_state() - while parent_state is not None and var in { - **parent_state.vars, - **parent_state.backend_vars, + if not cvar._cache: + # Do not perform dep calculation when cache=False (these are always dirty). + continue + for state_name, dvar_set in cvar._deps(objclass=cls).items(): + state_cls = cls.get_root_state().get_class_substate(state_name) + for dvar in dvar_set: + defining_state_cls = state_cls + while dvar in { + *defining_state_cls.inherited_vars, + *defining_state_cls.inherited_backend_vars, }: - parent_state._substate_var_dependencies[var].add(state_name) - state_name, parent_state = ( - parent_state.get_name(), - parent_state.get_parent_state(), - ) + parent_state = defining_state_cls.get_parent_state() + if parent_state is not None: + defining_state_cls = parent_state + defining_state_cls._var_dependencies.setdefault(dvar, set()).add( + (cls.get_full_name(), cvar_name) + ) + defining_state_cls._potentially_dirty_states.add( + cls.get_full_name() + ) # ComputedVar with cache=False always need to be recomputed cls._always_dirty_computed_vars = { @@ -902,6 +921,17 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow): raise ValueError(f"Only one parent state is allowed {parent_states}.") return parent_states[0] if len(parent_states) == 1 else None + @classmethod + @functools.lru_cache() + def get_root_state(cls) -> Type[BaseState]: + """Get the root state. + + Returns: + The root state. + """ + parent_state = cls.get_parent_state() + return cls if parent_state is None else parent_state.get_root_state() + @classmethod def get_substates(cls) -> set[Type[BaseState]]: """Get the substates of the state. @@ -1351,7 +1381,7 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow): super().__setattr__(name, value) # Add the var to the dirty list. - if name in self.vars or name in self._computed_var_dependencies: + if name in self.base_vars: self.dirty_vars.add(name) self._mark_dirty() @@ -1422,64 +1452,21 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow): return self.substates[path[0]].get_substate(path[1:]) @classmethod - def _get_common_ancestor(cls, other: Type[BaseState]) -> str: - """Find the name of the nearest common ancestor shared by this and the other state. - - Args: - other: The other state. + def _get_potentially_dirty_states(cls) -> set[type[BaseState]]: + """Get substates which may have dirty vars due to dependencies. Returns: - Full name of the nearest common ancestor. + The set of potentially dirty substate classes. """ - common_ancestor_parts = [] - for part1, part2 in zip( - cls.get_full_name().split("."), - other.get_full_name().split("."), - strict=True, - ): - if part1 != part2: - break - common_ancestor_parts.append(part1) - return ".".join(common_ancestor_parts) - - @classmethod - def _determine_missing_parent_states( - cls, target_state_cls: Type[BaseState] - ) -> tuple[str, list[str]]: - """Determine the missing parent states between the target_state_cls and common ancestor of this state. - - Args: - target_state_cls: The class of the state to find missing parent states for. - - Returns: - The name of the common ancestor and the list of missing parent states. - """ - common_ancestor_name = cls._get_common_ancestor(target_state_cls) - common_ancestor_parts = common_ancestor_name.split(".") - target_state_parts = tuple(target_state_cls.get_full_name().split(".")) - relative_target_state_parts = target_state_parts[len(common_ancestor_parts) :] - - # Determine which parent states to fetch from the common ancestor down to the target_state_cls. - fetch_parent_states = [common_ancestor_name] - for relative_parent_state_name in relative_target_state_parts: - fetch_parent_states.append( - ".".join((fetch_parent_states[-1], relative_parent_state_name)) - ) - - return common_ancestor_name, fetch_parent_states[1:-1] - - def _get_parent_states(self) -> list[tuple[str, BaseState]]: - """Get all parent state instances up to the root of the state tree. - - Returns: - A list of tuples containing the name and the instance of each parent state. - """ - parent_states_with_name = [] - parent_state = self - while parent_state.parent_state is not None: - parent_state = parent_state.parent_state - parent_states_with_name.append((parent_state.get_full_name(), parent_state)) - return parent_states_with_name + return { + cls.get_class_substate(substate_name) + for substate_name in cls._always_dirty_substates + }.union( + { + cls.get_root_state().get_class_substate(substate_name) + for substate_name in cls._potentially_dirty_states + } + ) def _get_root_state(self) -> BaseState: """Get the root state of the state tree. @@ -1492,55 +1479,38 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow): parent_state = parent_state.parent_state return parent_state - async def _populate_parent_states(self, target_state_cls: Type[BaseState]): - """Populate substates in the tree between the target_state_cls and common ancestor of this state. + async def _get_state_from_redis(self, state_cls: Type[T_STATE]) -> T_STATE: + """Get a state instance from redis. Args: - target_state_cls: The class of the state to populate parent states for. + state_cls: The class of the state. Returns: - The parent state instance of target_state_cls. + The instance of state_cls associated with this state's client_token. Raises: RuntimeError: If redis is not used in this backend process. + StateMismatchError: If the state instance is not of the expected type. """ + # Then get the target state and all its substates. state_manager = get_state_manager() if not isinstance(state_manager, StateManagerRedis): raise RuntimeError( - f"Cannot populate parent states of {target_state_cls.get_full_name()} without redis. " + f"Requested state {state_cls.get_full_name()} is not cached and cannot be accessed without redis. " "(All states should already be available -- this is likely a bug).", ) + state_in_redis = await state_manager.get_state( + token=_substate_key(self.router.session.client_token, state_cls), + top_level=False, + for_state_instance=self, + ) - # Find the missing parent states up to the common ancestor. - ( - common_ancestor_name, - missing_parent_states, - ) = self._determine_missing_parent_states(target_state_cls) - - # Fetch all missing parent states and link them up to the common ancestor. - parent_states_tuple = self._get_parent_states() - root_state = parent_states_tuple[-1][1] - parent_states_by_name = dict(parent_states_tuple) - parent_state = parent_states_by_name[common_ancestor_name] - for parent_state_name in missing_parent_states: - try: - parent_state = root_state.get_substate(parent_state_name.split(".")) - # The requested state is already cached, do NOT fetch it again. - continue - except ValueError: - # The requested state is missing, fetch from redis. - pass - parent_state = await state_manager.get_state( - token=_substate_key( - self.router.session.client_token, parent_state_name - ), - top_level=False, - get_substates=False, - parent_state=parent_state, + if not isinstance(state_in_redis, state_cls): + raise StateMismatchError( + f"Searched for state {state_cls.get_full_name()} but found {state_in_redis}." ) - # Return the direct parent of target_state_cls for subsequent linking. - return parent_state + return state_in_redis def _get_state_from_cache(self, state_cls: Type[T_STATE]) -> T_STATE: """Get a state instance from the cache. @@ -1562,44 +1532,6 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow): ) return substate - async def _get_state_from_redis(self, state_cls: Type[T_STATE]) -> T_STATE: - """Get a state instance from redis. - - Args: - state_cls: The class of the state. - - Returns: - The instance of state_cls associated with this state's client_token. - - Raises: - RuntimeError: If redis is not used in this backend process. - StateMismatchError: If the state instance is not of the expected type. - """ - # Fetch all missing parent states from redis. - parent_state_of_state_cls = await self._populate_parent_states(state_cls) - - # Then get the target state and all its substates. - state_manager = get_state_manager() - if not isinstance(state_manager, StateManagerRedis): - raise RuntimeError( - f"Requested state {state_cls.get_full_name()} is not cached and cannot be accessed without redis. " - "(All states should already be available -- this is likely a bug).", - ) - - state_in_redis = await state_manager.get_state( - token=_substate_key(self.router.session.client_token, state_cls), - top_level=False, - get_substates=True, - parent_state=parent_state_of_state_cls, - ) - - if not isinstance(state_in_redis, state_cls): - raise StateMismatchError( - f"Searched for state {state_cls.get_full_name()} but found {state_in_redis}." - ) - - return state_in_redis - async def get_state(self, state_cls: Type[T_STATE]) -> T_STATE: """Get an instance of the state associated with this token. @@ -1738,7 +1670,7 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow): f"Your handler {handler.fn.__qualname__} must only return/yield: None, Events or other EventHandlers referenced by their class (not using `self`)" ) - def _as_state_update( + async def _as_state_update( self, handler: EventHandler, events: EventSpec | list[EventSpec] | None, @@ -1766,7 +1698,7 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow): try: # Get the delta after processing the event. - delta = state.get_delta() + delta = await _resolve_delta(state.get_delta()) state._clean() return StateUpdate( @@ -1866,24 +1798,28 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow): # Handle async generators. if inspect.isasyncgen(events): async for event in events: - yield state._as_state_update(handler, event, final=False) - yield state._as_state_update(handler, events=None, final=True) + yield await state._as_state_update(handler, event, final=False) + yield await state._as_state_update(handler, events=None, final=True) # Handle regular generators. elif inspect.isgenerator(events): try: while True: - yield state._as_state_update(handler, next(events), final=False) + yield await state._as_state_update( + handler, next(events), final=False + ) except StopIteration as si: # the "return" value of the generator is not available # in the loop, we must catch StopIteration to access it if si.value is not None: - yield state._as_state_update(handler, si.value, final=False) - yield state._as_state_update(handler, events=None, final=True) + yield await state._as_state_update( + handler, si.value, final=False + ) + yield await state._as_state_update(handler, events=None, final=True) # Handle regular event chains. else: - yield state._as_state_update(handler, events, final=True) + yield await state._as_state_update(handler, events, final=True) # If an error occurs, throw a window alert. except Exception as ex: @@ -1893,7 +1829,7 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow): prerequisites.get_and_validate_app().app.backend_exception_handler(ex) ) - yield state._as_state_update( + yield await state._as_state_update( handler, event_specs, final=True, @@ -1901,15 +1837,28 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow): def _mark_dirty_computed_vars(self) -> None: """Mark ComputedVars that need to be recalculated based on dirty_vars.""" + # Append expired computed vars to dirty_vars to trigger recalculation + self.dirty_vars.update(self._expired_computed_vars()) + # Append always dirty computed vars to dirty_vars to trigger recalculation + self.dirty_vars.update(self._always_dirty_computed_vars) + dirty_vars = self.dirty_vars while dirty_vars: calc_vars, dirty_vars = dirty_vars, set() - for cvar in self._dirty_computed_vars(from_vars=calc_vars): - self.dirty_vars.add(cvar) + for state_name, cvar in self._dirty_computed_vars(from_vars=calc_vars): + if state_name == self.get_full_name(): + defining_state = self + else: + defining_state = self._get_root_state().get_substate( + tuple(state_name.split(".")) + ) + defining_state.dirty_vars.add(cvar) dirty_vars.add(cvar) - actual_var = self.computed_vars.get(cvar) + actual_var = defining_state.computed_vars.get(cvar) if actual_var is not None: - actual_var.mark_dirty(instance=self) + actual_var.mark_dirty(instance=defining_state) + if defining_state is not self: + defining_state._mark_dirty() def _expired_computed_vars(self) -> set[str]: """Determine ComputedVars that need to be recalculated based on the expiration time. @@ -1925,7 +1874,7 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow): def _dirty_computed_vars( self, from_vars: set[str] | None = None, include_backend: bool = True - ) -> set[str]: + ) -> set[tuple[str, str]]: """Determine ComputedVars that need to be recalculated based on the given vars. Args: @@ -1936,33 +1885,12 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow): Set of computed vars to include in the delta. """ return { - cvar + (state_name, cvar) for dirty_var in from_vars or self.dirty_vars - for cvar in self._computed_var_dependencies[dirty_var] + for state_name, cvar in self._var_dependencies.get(dirty_var, set()) if include_backend or not self.computed_vars[cvar]._backend } - @classmethod - def _potentially_dirty_substates(cls) -> set[Type[BaseState]]: - """Determine substates which could be affected by dirty vars in this state. - - Returns: - Set of State classes that may need to be fetched to recalc computed vars. - """ - # _always_dirty_substates need to be fetched to recalc computed vars. - fetch_substates = { - cls.get_class_substate((cls.get_name(), *substate_name.split("."))) - for substate_name in cls._always_dirty_substates - } - for dependent_substates in cls._substate_var_dependencies.values(): - fetch_substates.update( - { - cls.get_class_substate((cls.get_name(), *substate_name.split("."))) - for substate_name in dependent_substates - } - ) - return fetch_substates - def get_delta(self) -> Delta: """Get the delta for the state. @@ -1971,21 +1899,15 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow): """ delta = {} - # Apply dirty variables down into substates - self.dirty_vars.update(self._always_dirty_computed_vars) - self._mark_dirty() - + self._mark_dirty_computed_vars() frontend_computed_vars: set[str] = { name for name, cv in self.computed_vars.items() if not cv._backend } # Return the dirty vars for this instance, any cached/dependent computed vars, # and always dirty computed vars (cache=False) - delta_vars = ( - self.dirty_vars.intersection(self.base_vars) - .union(self.dirty_vars.intersection(frontend_computed_vars)) - .union(self._dirty_computed_vars(include_backend=False)) - .union(self._always_dirty_computed_vars) + delta_vars = self.dirty_vars.intersection(self.base_vars).union( + self.dirty_vars.intersection(frontend_computed_vars) ) subdelta: Dict[str, Any] = { @@ -2015,23 +1937,9 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow): self.parent_state.dirty_substates.add(self.get_name()) self.parent_state._mark_dirty() - # Append expired computed vars to dirty_vars to trigger recalculation - self.dirty_vars.update(self._expired_computed_vars()) - # have to mark computed vars dirty to allow access to newly computed # values within the same ComputedVar function self._mark_dirty_computed_vars() - self._mark_dirty_substates() - - def _mark_dirty_substates(self): - """Propagate dirty var / computed var status into substates.""" - substates = self.substates - for var in self.dirty_vars: - for substate_name in self._substate_var_dependencies[var]: - self.dirty_substates.add(substate_name) - substate = substates[substate_name] - substate.dirty_vars.add(var) - substate._mark_dirty() def _update_was_touched(self): """Update the _was_touched flag based on dirty_vars.""" @@ -2103,11 +2011,7 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow): The object as a dictionary. """ if include_computed: - # Apply dirty variables down into substates to allow never-cached ComputedVar to - # trigger recalculation of dependent vars - self.dirty_vars.update(self._always_dirty_computed_vars) - self._mark_dirty() - + self._mark_dirty_computed_vars() base_vars = { prop_name: self.get_value(prop_name) for prop_name in self.base_vars } @@ -2824,7 +2728,7 @@ class StateProxy(wrapt.ObjectProxy): await self.__wrapped__.get_state(state_cls), parent_state_proxy=self ) - def _as_state_update(self, *args, **kwargs) -> StateUpdate: + async def _as_state_update(self, *args, **kwargs) -> StateUpdate: """Temporarily allow mutability to access parent_state. Args: @@ -2837,7 +2741,7 @@ class StateProxy(wrapt.ObjectProxy): original_mutable = self._self_mutable self._self_mutable = True try: - return self.__wrapped__._as_state_update(*args, **kwargs) + return await self.__wrapped__._as_state_update(*args, **kwargs) finally: self._self_mutable = original_mutable @@ -3313,103 +3217,106 @@ class StateManagerRedis(StateManager): b"evicted", } - async def _get_parent_state( - self, token: str, state: BaseState | None = None - ) -> BaseState | None: - """Get the parent state for the state requested in the token. + def _get_required_state_classes( + self, + target_state_cls: Type[BaseState], + subclasses: bool = False, + required_state_classes: set[Type[BaseState]] | None = None, + ) -> set[Type[BaseState]]: + """Recursively determine which states are required to fetch the target state. + + This will always include potentially dirty substates that depend on vars + in the target_state_cls. Args: - token: The token to get the state for (_substate_key). - state: The state instance to get parent state for. + target_state_cls: The target state class being fetched. + subclasses: Whether to include subclasses of the target state. + required_state_classes: Recursive argument tracking state classes that have already been seen. Returns: - The parent state for the state requested by the token or None if there is no such parent. + The set of state classes required to fetch the target state. """ - parent_state = None - client_token, state_path = _split_substate_key(token) - parent_state_name = state_path.rpartition(".")[0] - if parent_state_name: - cached_substates = None - if state is not None: - cached_substates = [state] - # Retrieve the parent state to populate event handlers onto this substate. - parent_state = await self.get_state( - token=_substate_key(client_token, parent_state_name), - top_level=False, - get_substates=False, - cached_substates=cached_substates, + if required_state_classes is None: + required_state_classes = set() + # Get the substates if requested. + if subclasses: + for substate in target_state_cls.get_substates(): + self._get_required_state_classes( + substate, + subclasses=True, + required_state_classes=required_state_classes, + ) + if target_state_cls in required_state_classes: + return required_state_classes + required_state_classes.add(target_state_cls) + + # Get dependent substates. + for pd_substates in target_state_cls._get_potentially_dirty_states(): + self._get_required_state_classes( + pd_substates, + subclasses=False, + required_state_classes=required_state_classes, ) - return parent_state - async def _populate_substates( + # Get the parent state if it exists. + if parent_state := target_state_cls.get_parent_state(): + self._get_required_state_classes( + parent_state, + subclasses=False, + required_state_classes=required_state_classes, + ) + return required_state_classes + + def _get_populated_states( self, - token: str, - state: BaseState, - all_substates: bool = False, - ): - """Fetch and link substates for the given state instance. - - There is no return value; the side-effect is that `state` will have `substates` populated, - and each substate will have its `parent_state` set to `state`. + target_state: BaseState, + populated_states: dict[str, BaseState] | None = None, + ) -> dict[str, BaseState]: + """Recursively determine which states from target_state are already fetched. Args: - token: The token to get the state for. - state: The state instance to populate substates for. - all_substates: Whether to fetch all substates or just required substates. + target_state: The state to check for populated states. + populated_states: Recursive argument tracking states seen in previous calls. + + Returns: + A dictionary of state full name to state instance. """ - client_token, _ = _split_substate_key(token) - - if all_substates: - # All substates are requested. - fetch_substates = state.get_substates() - else: - # Only _potentially_dirty_substates need to be fetched to recalc computed vars. - fetch_substates = state._potentially_dirty_substates() - - tasks = {} - # Retrieve the necessary substates from redis. - for substate_cls in fetch_substates: - if substate_cls.get_name() in state.substates: - continue - substate_name = substate_cls.get_name() - tasks[substate_name] = asyncio.create_task( - self.get_state( - token=_substate_key(client_token, substate_cls), - top_level=False, - get_substates=all_substates, - parent_state=state, - ) + if populated_states is None: + populated_states = {} + if target_state.get_full_name() in populated_states: + return populated_states + populated_states[target_state.get_full_name()] = target_state + for substate in target_state.substates.values(): + self._get_populated_states(substate, populated_states=populated_states) + if target_state.parent_state is not None: + self._get_populated_states( + target_state.parent_state, populated_states=populated_states ) - - for substate_name, substate_task in tasks.items(): - state.substates[substate_name] = await substate_task + return populated_states @override async def get_state( self, token: str, top_level: bool = True, - get_substates: bool = True, - parent_state: BaseState | None = None, - cached_substates: list[BaseState] | None = None, + for_state_instance: BaseState | None = None, ) -> BaseState: """Get the state for a token. Args: token: The token to get the state for. top_level: If true, return an instance of the top-level state (self.state). - get_substates: If true, also retrieve substates. - parent_state: If provided, use this parent_state instead of getting it from redis. - cached_substates: If provided, attach these substates to the state. + for_state_instance: If provided, attach the requested states to this existing state tree. Returns: The state for the token. Raises: - RuntimeError: when the state_cls is not specified in the token + RuntimeError: when the state_cls is not specified in the token, or when the parent state for a + requested state was not fetched. """ # Split the actual token from the fully qualified substate name. - _, state_path = _split_substate_key(token) + token, state_path = _split_substate_key(token) if state_path: # Get the State class associated with the given path. state_cls = self.state.get_class_substate(state_path) @@ -3418,43 +3325,59 @@ class StateManagerRedis(StateManager): f"StateManagerRedis requires token to be specified in the form of {{token}}_{{state_full_name}}, but got {token}" ) - # The deserialized or newly created (sub)state instance. - state = None + # Determine which states we already have. + flat_state_tree: dict[str, BaseState] = ( + self._get_populated_states(for_state_instance) if for_state_instance else {} + ) - # Fetch the serialized substate from redis. - redis_state = await self.redis.get(token) + # Determine which states from the tree need to be fetched. + required_state_classes = sorted( + self._get_required_state_classes(state_cls, subclasses=True) + - {type(s) for s in flat_state_tree.values()}, + key=lambda x: x.get_full_name(), + ) - if redis_state is not None: - # Deserialize the substate. - with contextlib.suppress(StateSchemaMismatchError): - state = BaseState._deserialize(data=redis_state) - if state is None: - # Key didn't exist or schema mismatch so create a new instance for this token. - state = state_cls( - init_substates=False, - _reflex_internal_init=True, - ) - # Populate parent state if missing and requested. - if parent_state is None: - parent_state = await self._get_parent_state(token, state) - # Set up Bidirectional linkage between this state and its parent. - if parent_state is not None: - parent_state.substates[state.get_name()] = state - state.parent_state = parent_state - # Avoid fetching substates multiple times. - if cached_substates: - for substate in cached_substates: - state.substates[substate.get_name()] = substate - if substate.parent_state is None: - substate.parent_state = state - # Populate substates if requested. - await self._populate_substates(token, state, all_substates=get_substates) + redis_pipeline = self.redis.pipeline() + for state_cls in required_state_classes: + redis_pipeline.get(_substate_key(token, state_cls)) + + for state_cls, redis_state in zip( + required_state_classes, + await redis_pipeline.execute(), + strict=False, + ): + state = None + + if redis_state is not None: + # Deserialize the substate. + with contextlib.suppress(StateSchemaMismatchError): + state = BaseState._deserialize(data=redis_state) + if state is None: + # Key didn't exist or schema mismatch so create a new instance for this token. + state = state_cls( + init_substates=False, + _reflex_internal_init=True, + ) + flat_state_tree[state.get_full_name()] = state + if state.get_parent_state() is not None: + parent_state_name, _dot, state_name = state.get_full_name().rpartition( + "." + ) + parent_state = flat_state_tree.get(parent_state_name) + if parent_state is None: + raise RuntimeError( + f"Parent state for {state.get_full_name()} was not found " + "in the state tree, but should have already been fetched. " + "This is a bug", + ) + parent_state.substates[state_name] = state + state.parent_state = parent_state # To retain compatibility with previous implementation, by default, we return - # the top-level state by chasing `parent_state` pointers up the tree. + # the top-level state which should always be fetched or already cached. if top_level: - return state._get_root_state() - return state + return flat_state_tree[self.state.get_full_name()] + return flat_state_tree[state_cls.get_full_name()] @override async def set_state( @@ -4154,12 +4077,19 @@ def reload_state_module( state: Recursive argument for the state class to reload. """ + # Clean out all potentially dirty states of reloaded modules. + for pd_state in tuple(state._potentially_dirty_states): + with contextlib.suppress(ValueError): + if ( + state.get_root_state().get_class_substate(pd_state).__module__ == module + and module is not None + ): + state._potentially_dirty_states.remove(pd_state) for subclass in tuple(state.class_subclasses): reload_state_module(module=module, state=subclass) if subclass.__module__ == module and module is not None: state.class_subclasses.remove(subclass) state._always_dirty_substates.discard(subclass.get_name()) - state._computed_var_dependencies = defaultdict(set) - state._substate_var_dependencies = defaultdict(set) + state._var_dependencies = {} state._init_var_dependency_dicts() state.get_class_substate.cache_clear() diff --git a/reflex/utils/exec.py b/reflex/utils/exec.py index 479ff816a..67df7ea91 100644 --- a/reflex/utils/exec.py +++ b/reflex/utils/exec.py @@ -488,7 +488,7 @@ def output_system_info(): dependencies.append(fnm_info) if system == "Linux": - import distro + import distro # pyright: ignore[reportMissingImports] os_version = distro.name(pretty=True) else: diff --git a/reflex/vars/base.py b/reflex/vars/base.py index 8a76f250d..ec65c3711 100644 --- a/reflex/vars/base.py +++ b/reflex/vars/base.py @@ -5,7 +5,6 @@ from __future__ import annotations import contextlib import dataclasses import datetime -import dis import functools import inspect import json @@ -20,6 +19,7 @@ from typing import ( Any, Callable, ClassVar, + Coroutine, Dict, FrozenSet, Generic, @@ -51,7 +51,6 @@ from reflex.utils.exceptions import ( VarAttributeError, VarDependencyError, VarTypeError, - VarValueError, ) from reflex.utils.format import format_state_name from reflex.utils.imports import ( @@ -1983,7 +1982,7 @@ class ComputedVar(Var[RETURN_TYPE]): _initial_value: RETURN_TYPE | types.Unset = dataclasses.field(default=types.Unset()) # Explicit var dependencies to track - _static_deps: set[str] = dataclasses.field(default_factory=set) + _static_deps: dict[str, set[str]] = dataclasses.field(default_factory=dict) # Whether var dependencies should be auto-determined _auto_deps: bool = dataclasses.field(default=True) @@ -2053,21 +2052,34 @@ class ComputedVar(Var[RETURN_TYPE]): object.__setattr__(self, "_update_interval", interval) - if deps is None: - deps = [] - else: + _static_deps = {} + if isinstance(deps, dict): + # Assume a dict is coming from _replace, so no special processing. + _static_deps = deps + elif deps is not None: for dep in deps: if isinstance(dep, Var): - continue - if isinstance(dep, str) and dep != "": - continue - raise TypeError( - "ComputedVar dependencies must be Var instances or var names (non-empty strings)." - ) + state_name = ( + all_var_data.state + if (all_var_data := dep._get_all_var_data()) + and all_var_data.state + else None + ) + if all_var_data is not None: + var_name = all_var_data.field_name + else: + var_name = dep._js_expr + _static_deps.setdefault(state_name, set()).add(var_name) + elif isinstance(dep, str) and dep != "": + _static_deps.setdefault(None, set()).add(dep) + else: + raise TypeError( + "ComputedVar dependencies must be Var instances or var names (non-empty strings)." + ) object.__setattr__( self, "_static_deps", - {dep._js_expr if isinstance(dep, Var) else dep for dep in deps}, + _static_deps, ) object.__setattr__(self, "_auto_deps", auto_deps) @@ -2149,6 +2161,13 @@ class ComputedVar(Var[RETURN_TYPE]): return True return datetime.datetime.now() - last_updated > self._update_interval + @overload + def __get__( + self: ComputedVar[bool], + instance: None, + owner: Type, + ) -> BooleanVar: ... + @overload def __get__( self: ComputedVar[int] | ComputedVar[float], @@ -2233,125 +2252,67 @@ class ComputedVar(Var[RETURN_TYPE]): setattr(instance, self._last_updated_attr, datetime.datetime.now()) value = getattr(instance, self._cache_attr) + self._check_deprecated_return_type(instance, value) + + return value + + def _check_deprecated_return_type(self, instance: BaseState, value: Any) -> None: if not _isinstance(value, self._var_type): console.error( f"Computed var '{type(instance).__name__}.{self._js_expr}' must return" f" type '{self._var_type}', got '{type(value)}'." ) - return value - def _deps( self, - objclass: Type, + objclass: Type[BaseState], obj: FunctionType | CodeType | None = None, - self_name: Optional[str] = None, - ) -> set[str]: + ) -> dict[str, set[str]]: """Determine var dependencies of this ComputedVar. - Save references to attributes accessed on "self". Recursively called - when the function makes a method call on "self" or define comprehensions - or nested functions that may reference "self". + Save references to attributes accessed on "self" or other fetched states. + + Recursively called when the function makes a method call on "self" or + define comprehensions or nested functions that may reference "self". Args: objclass: the class obj this ComputedVar is attached to. obj: the object to disassemble (defaults to the fget function). - self_name: if specified, look for this name in LOAD_FAST and LOAD_DEREF instructions. Returns: - A set of variable names accessed by the given obj. - - Raises: - VarValueError: if the function references the get_state, parent_state, or substates attributes - (cannot track deps in a related state, only implicitly via parent state). + A dictionary mapping state names to the set of variable names + accessed by the given obj. """ + from .dep_tracking import DependencyTracker + + d = {} + if self._static_deps: + d.update(self._static_deps) + # None is a placeholder for the current state class. + if None in d: + d[objclass.get_full_name()] = d.pop(None) + if not self._auto_deps: - return self._static_deps - d = self._static_deps.copy() + return d + if obj is None: fget = self._fget if fget is not None: obj = cast(FunctionType, fget) else: - return set() - with contextlib.suppress(AttributeError): - # unbox functools.partial - obj = cast(FunctionType, obj.func) # pyright: ignore [reportAttributeAccessIssue] - with contextlib.suppress(AttributeError): - # unbox EventHandler - obj = cast(FunctionType, obj.fn) # pyright: ignore [reportAttributeAccessIssue] + return d - if self_name is None and isinstance(obj, FunctionType): - try: - # the first argument to the function is the name of "self" arg - self_name = obj.__code__.co_varnames[0] - except (AttributeError, IndexError): - self_name = None - if self_name is None: - # cannot reference attributes on self if method takes no args - return set() - - invalid_names = ["get_state", "parent_state", "substates", "get_substate"] - self_is_top_of_stack = False - for instruction in dis.get_instructions(obj): - if ( - instruction.opname in ("LOAD_FAST", "LOAD_DEREF") - and instruction.argval == self_name - ): - # bytecode loaded the class instance to the top of stack, next load instruction - # is referencing an attribute on self - self_is_top_of_stack = True - continue - if self_is_top_of_stack and instruction.opname in ( - "LOAD_ATTR", - "LOAD_METHOD", - ): - try: - ref_obj = getattr(objclass, instruction.argval) - except Exception: - ref_obj = None - if instruction.argval in invalid_names: - raise VarValueError( - f"Cached var {self!s} cannot access arbitrary state via `{instruction.argval}`." - ) - if callable(ref_obj): - # recurse into callable attributes - d.update( - self._deps( - objclass=objclass, - obj=ref_obj, # pyright: ignore [reportArgumentType] - ) - ) - # recurse into property fget functions - elif isinstance(ref_obj, property) and not isinstance( - ref_obj, ComputedVar - ): - d.update( - self._deps( - objclass=objclass, - obj=ref_obj.fget, # pyright: ignore [reportArgumentType] - ) - ) - elif ( - instruction.argval in objclass.backend_vars - or instruction.argval in objclass.vars - ): - # var access - d.add(instruction.argval) - elif instruction.opname == "LOAD_CONST" and isinstance( - instruction.argval, CodeType - ): - # recurse into nested functions / comprehensions, which can reference - # instance attributes from the outer scope - d.update( - self._deps( - objclass=objclass, - obj=instruction.argval, - self_name=self_name, - ) - ) - self_is_top_of_stack = False - return d + try: + return DependencyTracker( + func=obj, state_cls=objclass, dependencies=d + ).dependencies + except Exception as e: + console.warn( + "Failed to automatically determine dependencies for computed var " + f"{objclass.__name__}.{self._js_expr}: {e}. " + "Provide static_deps and set auto_deps=False to suppress this warning." + ) + return d def mark_dirty(self, instance: BaseState) -> None: """Mark this ComputedVar as dirty. @@ -2362,6 +2323,37 @@ class ComputedVar(Var[RETURN_TYPE]): with contextlib.suppress(AttributeError): delattr(instance, self._cache_attr) + def add_dependency(self, objclass: Type[BaseState], dep: Var): + """Explicitly add a dependency to the ComputedVar. + + After adding the dependency, when the `dep` changes, this computed var + will be marked dirty. + + Args: + objclass: The class obj this ComputedVar is attached to. + dep: The dependency to add. + + Raises: + VarDependencyError: If the dependency is not a Var instance with a + state and field name + """ + if all_var_data := dep._get_all_var_data(): + state_name = all_var_data.state + if state_name: + var_name = all_var_data.field_name + if var_name: + self._static_deps.setdefault(state_name, set()).add(var_name) + objclass.get_root_state().get_class_substate( + state_name + )._var_dependencies.setdefault(var_name, set()).add( + (objclass.get_full_name(), self._js_expr) + ) + return + raise VarDependencyError( + "ComputedVar dependencies must be Var instances with a state and " + f"field name, got {dep!r}." + ) + def _determine_var_type(self) -> Type: """Get the type of the var. @@ -2398,6 +2390,126 @@ class DynamicRouteVar(ComputedVar[Union[str, List[str]]]): pass +async def _default_async_computed_var(_self: BaseState) -> Any: + return None + + +@dataclasses.dataclass( + eq=False, + frozen=True, + init=False, + slots=True, +) +class AsyncComputedVar(ComputedVar[RETURN_TYPE]): + """A computed var that wraps a coroutinefunction.""" + + _fget: Callable[[BaseState], Coroutine[None, None, RETURN_TYPE]] = ( + dataclasses.field(default=_default_async_computed_var) + ) + + @overload + def __get__( + self: AsyncComputedVar[bool], + instance: None, + owner: Type, + ) -> BooleanVar: ... + + @overload + def __get__( + self: AsyncComputedVar[int] | ComputedVar[float], + instance: None, + owner: Type, + ) -> NumberVar: ... + + @overload + def __get__( + self: AsyncComputedVar[str], + instance: None, + owner: Type, + ) -> StringVar: ... + + @overload + def __get__( + self: AsyncComputedVar[Mapping[DICT_KEY, DICT_VAL]], + instance: None, + owner: Type, + ) -> ObjectVar[Mapping[DICT_KEY, DICT_VAL]]: ... + + @overload + def __get__( + self: AsyncComputedVar[list[LIST_INSIDE]], + instance: None, + owner: Type, + ) -> ArrayVar[list[LIST_INSIDE]]: ... + + @overload + def __get__( + self: AsyncComputedVar[tuple[LIST_INSIDE, ...]], + instance: None, + owner: Type, + ) -> ArrayVar[tuple[LIST_INSIDE, ...]]: ... + + @overload + def __get__(self, instance: None, owner: Type) -> AsyncComputedVar[RETURN_TYPE]: ... + + @overload + def __get__( + self, instance: BaseState, owner: Type + ) -> Coroutine[None, None, RETURN_TYPE]: ... + + def __get__( + self, instance: BaseState | None, owner + ) -> Var | Coroutine[None, None, RETURN_TYPE]: + """Get the ComputedVar value. + + If the value is already cached on the instance, return the cached value. + + Args: + instance: the instance of the class accessing this computed var. + owner: the class that this descriptor is attached to. + + Returns: + The value of the var for the given instance. + """ + if instance is None: + return super(AsyncComputedVar, self).__get__(instance, owner) + + if not self._cache: + + async def _awaitable_result(instance: BaseState = instance) -> RETURN_TYPE: + value = await self.fget(instance) + self._check_deprecated_return_type(instance, value) + return value + + return _awaitable_result() + else: + # handle caching + async def _awaitable_result(instance: BaseState = instance) -> RETURN_TYPE: + if not hasattr(instance, self._cache_attr) or self.needs_update( + instance + ): + # Set cache attr on state instance. + setattr(instance, self._cache_attr, await self.fget(instance)) + # Ensure the computed var gets serialized to redis. + instance._was_touched = True + # Set the last updated timestamp on the state instance. + setattr(instance, self._last_updated_attr, datetime.datetime.now()) + value = getattr(instance, self._cache_attr) + self._check_deprecated_return_type(instance, value) + return value + + return _awaitable_result() + + @property + def fget(self) -> Callable[[BaseState], Coroutine[None, None, RETURN_TYPE]]: + """Get the getter function. + + Returns: + The getter function. + """ + return self._fget + + if TYPE_CHECKING: BASE_STATE = TypeVar("BASE_STATE", bound=BaseState) @@ -2464,10 +2576,27 @@ def computed_var( raise VarDependencyError("Cannot track dependencies without caching.") if fget is not None: - return ComputedVar(fget, cache=cache) + if inspect.iscoroutinefunction(fget): + computed_var_cls = AsyncComputedVar + else: + computed_var_cls = ComputedVar + return computed_var_cls( + fget, + initial_value=initial_value, + cache=cache, + deps=deps, + auto_deps=auto_deps, + interval=interval, + backend=backend, + **kwargs, + ) def wrapper(fget: Callable[[BASE_STATE], Any]) -> ComputedVar: - return ComputedVar( + if inspect.iscoroutinefunction(fget): + computed_var_cls = AsyncComputedVar + else: + computed_var_cls = ComputedVar + return computed_var_cls( fget, initial_value=initial_value, cache=cache, diff --git a/reflex/vars/dep_tracking.py b/reflex/vars/dep_tracking.py new file mode 100644 index 000000000..0b2367799 --- /dev/null +++ b/reflex/vars/dep_tracking.py @@ -0,0 +1,344 @@ +"""Collection of base classes.""" + +from __future__ import annotations + +import contextlib +import dataclasses +import dis +import enum +import inspect +from types import CellType, CodeType, FunctionType +from typing import TYPE_CHECKING, Any, ClassVar, Type, cast + +from reflex.utils.exceptions import VarValueError + +if TYPE_CHECKING: + from reflex.state import BaseState + + from .base import Var + + +CellEmpty = object() + + +def get_cell_value(cell: CellType) -> Any: + """Get the value of a cell object. + + Args: + cell: The cell object to get the value from. (func.__closure__ objects) + + Returns: + The value from the cell or CellEmpty if a ValueError is raised. + """ + try: + return cell.cell_contents + except ValueError: + return CellEmpty + + +class ScanStatus(enum.Enum): + """State of the dis instruction scanning loop.""" + + SCANNING = enum.auto() + GETTING_ATTR = enum.auto() + GETTING_STATE = enum.auto() + GETTING_VAR = enum.auto() + + +@dataclasses.dataclass +class DependencyTracker: + """State machine for identifying state attributes that are accessed by a function.""" + + func: FunctionType | CodeType = dataclasses.field() + state_cls: Type[BaseState] = dataclasses.field() + + dependencies: dict[str, set[str]] = dataclasses.field(default_factory=dict) + + scan_status: ScanStatus = dataclasses.field(default=ScanStatus.SCANNING) + top_of_stack: str | None = dataclasses.field(default=None) + + tracked_locals: dict[str, Type[BaseState]] = dataclasses.field(default_factory=dict) + + _getting_state_class: Type[BaseState] | None = dataclasses.field(default=None) + _getting_var_instructions: list[dis.Instruction] = dataclasses.field( + default_factory=list + ) + + INVALID_NAMES: ClassVar[list[str]] = ["parent_state", "substates", "get_substate"] + + def __post_init__(self): + """After initializing, populate the dependencies dict.""" + with contextlib.suppress(AttributeError): + # unbox functools.partial + self.func = cast(FunctionType, self.func.func) # pyright: ignore[reportAttributeAccessIssue] + with contextlib.suppress(AttributeError): + # unbox EventHandler + self.func = cast(FunctionType, self.func.fn) # pyright: ignore[reportAttributeAccessIssue] + + if isinstance(self.func, FunctionType): + with contextlib.suppress(AttributeError, IndexError): + # the first argument to the function is the name of "self" arg + self.tracked_locals[self.func.__code__.co_varnames[0]] = self.state_cls + + self._populate_dependencies() + + def _merge_deps(self, tracker: DependencyTracker) -> None: + """Merge dependencies from another DependencyTracker. + + Args: + tracker: The DependencyTracker to merge dependencies from. + """ + for state_name, dep_name in tracker.dependencies.items(): + self.dependencies.setdefault(state_name, set()).update(dep_name) + + def load_attr_or_method(self, instruction: dis.Instruction) -> None: + """Handle loading an attribute or method from the object on top of the stack. + + This method directly tracks attributes and recursively merges + dependencies from analyzing the dependencies of any methods called. + + Args: + instruction: The dis instruction to process. + + Raises: + VarValueError: if the attribute is an disallowed name. + """ + from .base import ComputedVar + + if instruction.argval in self.INVALID_NAMES: + raise VarValueError( + f"Cached var {self!s} cannot access arbitrary state via `{instruction.argval}`." + ) + if instruction.argval == "get_state": + # Special case: arbitrary state access requested. + self.scan_status = ScanStatus.GETTING_STATE + return + if instruction.argval == "get_var_value": + # Special case: arbitrary var access requested. + self.scan_status = ScanStatus.GETTING_VAR + return + + # Reset status back to SCANNING after attribute is accessed. + self.scan_status = ScanStatus.SCANNING + if not self.top_of_stack: + return + target_state = self.tracked_locals[self.top_of_stack] + try: + ref_obj = getattr(target_state, instruction.argval) + except AttributeError: + # Not found on this state class, maybe it is a dynamic attribute that will be picked up later. + ref_obj = None + + if isinstance(ref_obj, property) and not isinstance(ref_obj, ComputedVar): + # recurse into property fget functions + ref_obj = ref_obj.fget + if callable(ref_obj): + # recurse into callable attributes + self._merge_deps( + type(self)(func=cast(FunctionType, ref_obj), state_cls=target_state) + ) + elif ( + instruction.argval in target_state.backend_vars + or instruction.argval in target_state.vars + ): + # var access + self.dependencies.setdefault(target_state.get_full_name(), set()).add( + instruction.argval + ) + + def _get_globals(self) -> dict[str, Any]: + """Get the globals of the function. + + Returns: + The var names and values in the globals of the function. + """ + if isinstance(self.func, CodeType): + return {} + return self.func.__globals__ # pyright: ignore[reportAttributeAccessIssue] + + def _get_closure(self) -> dict[str, Any]: + """Get the closure of the function, with unbound values omitted. + + Returns: + The var names and values in the closure of the function. + """ + if isinstance(self.func, CodeType): + return {} + return { + var_name: get_cell_value(cell) + for var_name, cell in zip( + self.func.__code__.co_freevars, # pyright: ignore[reportAttributeAccessIssue] + self.func.__closure__ or (), + strict=False, + ) + if get_cell_value(cell) is not CellEmpty + } + + def handle_getting_state(self, instruction: dis.Instruction) -> None: + """Handle bytecode analysis when `get_state` was called in the function. + + If the wrapped function is getting an arbitrary state and saving it to a + local variable, this method associates the local variable name with the + state class in self.tracked_locals. + + When an attribute/method is accessed on a tracked local, it will be + associated with this state. + + Args: + instruction: The dis instruction to process. + + Raises: + VarValueError: if the state class cannot be determined from the instruction. + """ + from reflex.state import BaseState + + if instruction.opname == "LOAD_FAST": + raise VarValueError( + f"Dependency detection cannot identify get_state class from local var {instruction.argval}." + ) + if isinstance(self.func, CodeType): + raise VarValueError( + "Dependency detection cannot identify get_state class from a code object." + ) + if instruction.opname == "LOAD_GLOBAL": + # Special case: referencing state class from global scope. + try: + self._getting_state_class = self._get_globals()[instruction.argval] + except (ValueError, KeyError) as ve: + raise VarValueError( + f"Cached var {self!s} cannot access arbitrary state `{instruction.argval}`, not found in globals." + ) from ve + elif instruction.opname == "LOAD_DEREF": + # Special case: referencing state class from closure. + try: + self._getting_state_class = self._get_closure()[instruction.argval] + except (ValueError, KeyError) as ve: + raise VarValueError( + f"Cached var {self!s} cannot access arbitrary state `{instruction.argval}`, is it defined yet?" + ) from ve + elif instruction.opname == "STORE_FAST": + # Storing the result of get_state in a local variable. + if not isinstance(self._getting_state_class, type) or not issubclass( + self._getting_state_class, BaseState + ): + raise VarValueError( + f"Cached var {self!s} cannot determine dependencies in fetched state `{instruction.argval}`." + ) + self.tracked_locals[instruction.argval] = self._getting_state_class + self.scan_status = ScanStatus.SCANNING + self._getting_state_class = None + + def _eval_var(self) -> Var: + """Evaluate instructions from the wrapped function to get the Var object. + + Returns: + The Var object. + + Raises: + VarValueError: if the source code for the var cannot be determined. + """ + # Get the original source code and eval it to get the Var. + module = inspect.getmodule(self.func) + positions0 = self._getting_var_instructions[0].positions + positions1 = self._getting_var_instructions[-1].positions + if module is None or positions0 is None or positions1 is None: + raise VarValueError( + f"Cannot determine the source code for the var in {self.func!r}." + ) + start_line = positions0.lineno + start_column = positions0.col_offset + end_line = positions1.end_lineno + end_column = positions1.end_col_offset + if ( + start_line is None + or start_column is None + or end_line is None + or end_column is None + ): + raise VarValueError( + f"Cannot determine the source code for the var in {self.func!r}." + ) + source = inspect.getsource(module).splitlines(True)[start_line - 1 : end_line] + # Create a python source string snippet. + if len(source) > 1: + snipped_source = "".join( + [ + *source[0][start_column:], + *(source[1:-2] if len(source) > 2 else []), + *source[-1][: end_column - 1], + ] + ) + else: + snipped_source = source[0][start_column : end_column - 1] + # Evaluate the string in the context of the function's globals and closure. + return eval(f"({snipped_source})", self._get_globals(), self._get_closure()) + + def handle_getting_var(self, instruction: dis.Instruction) -> None: + """Handle bytecode analysis when `get_var_value` was called in the function. + + This only really works if the expression passed to `get_var_value` is + evaluable in the function's global scope or closure, so getting the var + value from a var saved in a local variable or in the class instance is + not possible. + + Args: + instruction: The dis instruction to process. + + Raises: + VarValueError: if the source code for the var cannot be determined. + """ + if instruction.opname == "CALL" and self._getting_var_instructions: + if self._getting_var_instructions: + the_var = self._eval_var() + the_var_data = the_var._get_all_var_data() + if the_var_data is None: + raise VarValueError( + f"Cannot determine the source code for the var in {self.func!r}." + ) + self.dependencies.setdefault(the_var_data.state, set()).add( + the_var_data.field_name + ) + self._getting_var_instructions.clear() + self.scan_status = ScanStatus.SCANNING + else: + self._getting_var_instructions.append(instruction) + + def _populate_dependencies(self) -> None: + """Update self.dependencies based on the disassembly of self.func. + + Save references to attributes accessed on "self" or other fetched states. + + Recursively called when the function makes a method call on "self" or + define comprehensions or nested functions that may reference "self". + """ + for instruction in dis.get_instructions(self.func): + if self.scan_status == ScanStatus.GETTING_STATE: + self.handle_getting_state(instruction) + elif self.scan_status == ScanStatus.GETTING_VAR: + self.handle_getting_var(instruction) + elif ( + instruction.opname in ("LOAD_FAST", "LOAD_DEREF") + and instruction.argval in self.tracked_locals + ): + # bytecode loaded the class instance to the top of stack, next load instruction + # is referencing an attribute on self + self.top_of_stack = instruction.argval + self.scan_status = ScanStatus.GETTING_ATTR + elif self.scan_status == ScanStatus.GETTING_ATTR and instruction.opname in ( + "LOAD_ATTR", + "LOAD_METHOD", + ): + self.load_attr_or_method(instruction) + self.top_of_stack = None + elif instruction.opname == "LOAD_CONST" and isinstance( + instruction.argval, CodeType + ): + # recurse into nested functions / comprehensions, which can reference + # instance attributes from the outer scope + self._merge_deps( + type(self)( + func=instruction.argval, + state_cls=self.state_cls, + tracked_locals=self.tracked_locals, + ) + ) diff --git a/tests/integration/tests_playwright/test_table.py b/tests/integration/tests_playwright/test_table.py index bd399a840..a88c4a621 100644 --- a/tests/integration/tests_playwright/test_table.py +++ b/tests/integration/tests_playwright/test_table.py @@ -3,7 +3,7 @@ from typing import Generator import pytest -from playwright.sync_api import Page +from playwright.sync_api import Page, expect from reflex.testing import AppHarness @@ -87,12 +87,14 @@ def test_table(page: Page, table_app: AppHarness): table = page.get_by_role("table") # Check column headers - headers = table.get_by_role("columnheader").all_inner_texts() - assert headers == expected_col_headers + headers = table.get_by_role("columnheader") + for header, exp_value in zip(headers.all(), expected_col_headers, strict=True): + expect(header).to_have_text(exp_value) # Check rows headers - rows = table.get_by_role("rowheader").all_inner_texts() - assert rows == expected_row_headers + rows = table.get_by_role("rowheader") + for row, expected_row in zip(rows.all(), expected_row_headers, strict=True): + expect(row).to_have_text(expected_row) # Check cells rows = table.get_by_role("cell").all_inner_texts() diff --git a/tests/units/test_app.py b/tests/units/test_app.py index 4a6c16d6e..bf1a8a313 100644 --- a/tests/units/test_app.py +++ b/tests/units/test_app.py @@ -277,9 +277,9 @@ def test_add_page_set_route_dynamic(index_page, windows_platform: bool): assert app._pages.keys() == {"test/[dynamic]"} assert "dynamic" in app._state.computed_vars assert app._state.computed_vars["dynamic"]._deps(objclass=EmptyState) == { - constants.ROUTER + EmptyState.get_full_name(): {constants.ROUTER}, } - assert constants.ROUTER in app._state()._computed_var_dependencies + assert constants.ROUTER in app._state()._var_dependencies def test_add_page_set_route_nested(app: App, index_page, windows_platform: bool): @@ -995,9 +995,9 @@ async def test_dynamic_route_var_route_change_completed_on_load( assert arg_name in app._state.vars assert arg_name in app._state.computed_vars assert app._state.computed_vars[arg_name]._deps(objclass=DynamicState) == { - constants.ROUTER + DynamicState.get_full_name(): {constants.ROUTER}, } - assert constants.ROUTER in app._state()._computed_var_dependencies + assert constants.ROUTER in app._state()._var_dependencies substate_token = _substate_key(token, DynamicState) sid = "mock_sid" @@ -1555,6 +1555,16 @@ def test_app_with_valid_var_dependencies(compilable_app: tuple[App, Path]): def bar(self) -> str: return "bar" + class Child1(ValidDepState): + @computed_var(deps=["base", ValidDepState.bar]) + def other(self) -> str: + return "other" + + class Child2(ValidDepState): + @computed_var(deps=["base", Child1.other]) + def other(self) -> str: + return "other" + app._state = ValidDepState app._compile() diff --git a/tests/units/test_state.py b/tests/units/test_state.py index 9e1932305..44c3f60b7 100644 --- a/tests/units/test_state.py +++ b/tests/units/test_state.py @@ -14,6 +14,7 @@ from typing import ( Any, AsyncGenerator, Callable, + ClassVar, Dict, List, Optional, @@ -1169,13 +1170,17 @@ def test_conditional_computed_vars(): ms = MainState() # Initially there are no dirty computed vars. - assert ms._dirty_computed_vars(from_vars={"flag"}) == {"rendered_var"} - assert ms._dirty_computed_vars(from_vars={"t2"}) == {"rendered_var"} - assert ms._dirty_computed_vars(from_vars={"t1"}) == {"rendered_var"} + assert ms._dirty_computed_vars(from_vars={"flag"}) == { + (MainState.get_full_name(), "rendered_var") + } + assert ms._dirty_computed_vars(from_vars={"t2"}) == { + (MainState.get_full_name(), "rendered_var") + } + assert ms._dirty_computed_vars(from_vars={"t1"}) == { + (MainState.get_full_name(), "rendered_var") + } assert ms.computed_vars["rendered_var"]._deps(objclass=MainState) == { - "flag", - "t1", - "t2", + MainState.get_full_name(): {"flag", "t1", "t2"} } @@ -1370,7 +1375,10 @@ def test_cached_var_depends_on_event_handler(use_partial: bool): assert isinstance(HandlerState.handler, EventHandler) s = HandlerState() - assert "cached_x_side_effect" in s._computed_var_dependencies["x"] + assert ( + HandlerState.get_full_name(), + "cached_x_side_effect", + ) in s._var_dependencies["x"] assert s.cached_x_side_effect == 1 assert s.x == 43 s.handler() @@ -1460,15 +1468,15 @@ def test_computed_var_dependencies(): return [z in self._z for z in range(5)] cs = ComputedState() - assert cs._computed_var_dependencies["v"] == { - "comp_v", - "comp_v_backend", - "comp_v_via_property", + assert cs._var_dependencies["v"] == { + (ComputedState.get_full_name(), "comp_v"), + (ComputedState.get_full_name(), "comp_v_backend"), + (ComputedState.get_full_name(), "comp_v_via_property"), } - assert cs._computed_var_dependencies["w"] == {"comp_w"} - assert cs._computed_var_dependencies["x"] == {"comp_x"} - assert cs._computed_var_dependencies["y"] == {"comp_y"} - assert cs._computed_var_dependencies["_z"] == {"comp_z"} + assert cs._var_dependencies["w"] == {(ComputedState.get_full_name(), "comp_w")} + assert cs._var_dependencies["x"] == {(ComputedState.get_full_name(), "comp_x")} + assert cs._var_dependencies["y"] == {(ComputedState.get_full_name(), "comp_y")} + assert cs._var_dependencies["_z"] == {(ComputedState.get_full_name(), "comp_z")} def test_backend_method(): @@ -3180,7 +3188,7 @@ async def test_get_state_from_sibling_not_cached(mock_app: rx.App, token: str): RxState = State -def test_potentially_dirty_substates(): +def test_potentially_dirty_states(): """Test that potentially_dirty_substates returns the correct substates. Even if the name "State" is shadowed, it should still work correctly. @@ -3196,13 +3204,19 @@ def test_potentially_dirty_substates(): def bar(self) -> str: return "" - assert RxState._potentially_dirty_substates() == set() - assert State._potentially_dirty_substates() == set() - assert C1._potentially_dirty_substates() == set() + assert RxState._get_potentially_dirty_states() == set() + assert State._get_potentially_dirty_states() == set() + assert C1._get_potentially_dirty_states() == set() -def test_router_var_dep() -> None: - """Test that router var dependencies are correctly tracked.""" +@pytest.mark.asyncio +async def test_router_var_dep(state_manager: StateManager, token: str) -> None: + """Test that router var dependencies are correctly tracked. + + Args: + state_manager: A state manager. + token: A token. + """ class RouterVarParentState(State): """A parent state for testing router var dependency.""" @@ -3219,30 +3233,27 @@ def test_router_var_dep() -> None: foo = RouterVarDepState.computed_vars["foo"] State._init_var_dependency_dicts() - assert foo._deps(objclass=RouterVarDepState) == {"router"} - assert RouterVarParentState._potentially_dirty_substates() == {RouterVarDepState} - assert RouterVarParentState._substate_var_dependencies == { - "router": {RouterVarDepState.get_name()} - } - assert RouterVarDepState._computed_var_dependencies == { - "router": {"foo"}, + assert foo._deps(objclass=RouterVarDepState) == { + RouterVarDepState.get_full_name(): {"router"} } + assert (RouterVarDepState.get_full_name(), "foo") in State._var_dependencies[ + "router" + ] - rx_state = State() - parent_state = RouterVarParentState() - state = RouterVarDepState() - - # link states - rx_state.substates = {RouterVarParentState.get_name(): parent_state} - parent_state.parent_state = rx_state - state.parent_state = parent_state - parent_state.substates = {RouterVarDepState.get_name(): state} + # Get state from state manager. + state_manager.state = State + rx_state = await state_manager.get_state(_substate_key(token, State)) + assert RouterVarParentState.get_name() in rx_state.substates + parent_state = rx_state.substates[RouterVarParentState.get_name()] + assert RouterVarDepState.get_name() in parent_state.substates + state = parent_state.substates[RouterVarDepState.get_name()] assert state.dirty_vars == set() # Reassign router var state.router = state.router - assert state.dirty_vars == {"foo", "router"} + assert rx_state.dirty_vars == {"router"} + assert state.dirty_vars == {"foo"} assert parent_state.dirty_substates == {RouterVarDepState.get_name()} @@ -3801,3 +3812,128 @@ async def test_get_var_value(state_manager: StateManager, substate_token: str): # Generic Var with no state with pytest.raises(UnretrievableVarValueError): await state.get_var_value(rx.Var("undefined")) + + +@pytest.mark.asyncio +async def test_async_computed_var_get_state(mock_app: rx.App, token: str): + """A test where an async computed var depends on a var in another state. + + Args: + mock_app: An app that will be returned by `get_app()` + token: A token. + """ + + class Parent(BaseState): + """A root state like rx.State.""" + + parent_var: int = 0 + + class Child2(Parent): + """An unconnected child state.""" + + pass + + class Child3(Parent): + """A child state with a computed var causing it to be pre-fetched. + + If child3_var gets set to a value, and `get_state` erroneously + re-fetches it from redis, the value will be lost. + """ + + child3_var: int = 0 + + @rx.var(cache=True) + def v(self) -> int: + return self.child3_var + + class Child(Parent): + """A state simulating UpdateVarsInternalState.""" + + @rx.var(cache=True) + async def v(self) -> int: + p = await self.get_state(Parent) + child3 = await self.get_state(Child3) + return child3.child3_var + p.parent_var + + mock_app.state_manager.state = mock_app._state = Parent + + # Get the top level state via unconnected sibling. + root = await mock_app.state_manager.get_state(_substate_key(token, Child)) + # Set value in parent_var to assert it does not get refetched later. + root.parent_var = 1 + + if isinstance(mock_app.state_manager, StateManagerRedis): + # When redis is used, only states with uncached computed vars are pre-fetched. + assert Child2.get_name() not in root.substates + assert Child3.get_name() not in root.substates + + # Get the unconnected sibling state, which will be used to `get_state` other instances. + child = root.get_substate(Child.get_full_name().split(".")) + + # Get an uncached child state. + child2 = await child.get_state(Child2) + assert child2.parent_var == 1 + + # Set value on already-cached Child3 state (prefetched because it has a Computed Var). + child3 = await child.get_state(Child3) + child3.child3_var = 1 + + assert await child.v == 2 + assert await child.v == 2 + root.parent_var = 2 + assert await child.v == 3 + + +class Table(rx.ComponentState): + """A table state.""" + + data: ClassVar[Var] + + @rx.var(cache=True, auto_deps=False) + async def rows(self) -> List[Dict[str, Any]]: + """Computed var over the given rows. + + Returns: + The data rows. + """ + return await self.get_var_value(self.data) + + @classmethod + def get_component(cls, data: Var) -> rx.Component: + """Get the component for the table. + + Args: + data: The data var. + + Returns: + The component. + """ + cls.data = data + cls.computed_vars["rows"].add_dependency(cls, data) + return rx.foreach(data, lambda d: rx.text(d.to_string())) + + +@pytest.mark.asyncio +async def test_async_computed_var_get_var_value(mock_app: rx.App, token: str): + """A test where an async computed var depends on a var in another state. + + Args: + mock_app: An app that will be returned by `get_app()` + token: A token. + """ + + class OtherState(rx.State): + """A state with a var.""" + + data: List[Dict[str, Any]] = [{"foo": "bar"}] + + mock_app.state_manager.state = mock_app._state = rx.State + comp = Table.create(data=OtherState.data) + state = await mock_app.state_manager.get_state(_substate_key(token, OtherState)) + other_state = await state.get_state(OtherState) + assert comp.State is not None + comp_state = await state.get_state(comp.State) + assert comp_state.dirty_vars == set() + + other_state.data.append({"foo": "baz"}) + assert "rows" in comp_state.dirty_vars diff --git a/tests/units/test_var.py b/tests/units/test_var.py index ef19e86e8..a72242814 100644 --- a/tests/units/test_var.py +++ b/tests/units/test_var.py @@ -1807,9 +1807,9 @@ def cv_fget(state: BaseState) -> int: @pytest.mark.parametrize( "deps,expected", [ - (["a"], {"a"}), - (["b"], {"b"}), - ([ComputedVar(fget=cv_fget)], {"cv_fget"}), + (["a"], {None: {"a"}}), + (["b"], {None: {"b"}}), + ([ComputedVar(fget=cv_fget)], {None: {"cv_fget"}}), ], ) def test_computed_var_deps(deps: List[Union[str, Var]], expected: Set[str]): @@ -1857,6 +1857,28 @@ def test_to_string_operation(): assert single_var._var_type == Email +@pytest.mark.asyncio +async def test_async_computed_var(): + side_effect_counter = 0 + + class AsyncComputedVarState(BaseState): + v: int = 1 + + @computed_var(cache=True) + async def async_computed_var(self) -> int: + nonlocal side_effect_counter + side_effect_counter += 1 + return self.v + 1 + + my_state = AsyncComputedVarState() + assert await my_state.async_computed_var == 2 + assert await my_state.async_computed_var == 2 + my_state.v = 2 + assert await my_state.async_computed_var == 3 + assert await my_state.async_computed_var == 3 + assert side_effect_counter == 2 + + def test_var_data_hooks(): var_data_str = VarData(hooks="what") var_data_list = VarData(hooks=["what"])