From 3d50c1b623c70a0542d5f5edcb778342902baf39 Mon Sep 17 00:00:00 2001 From: Masen Furer Date: Wed, 22 Jan 2025 05:00:49 -0800 Subject: [PATCH] WiP --- reflex/app.py | 16 +- reflex/compiler/utils.py | 5 +- reflex/middleware/hydrate_middleware.py | 4 +- reflex/state.py | 433 ++++++++++++++---------- reflex/vars/base.py | 303 +++++++++++++---- tests/units/test_app.py | 18 +- tests/units/test_state.py | 119 +++++-- tests/units/test_var.py | 29 +- 8 files changed, 654 insertions(+), 273 deletions(-) diff --git a/reflex/app.py b/reflex/app.py index 7e868e730..6b9a64ca7 100644 --- a/reflex/app.py +++ b/reflex/app.py @@ -833,11 +833,17 @@ class App(MiddlewareMixin, LifespanMixin): if not var._cache: continue deps = var._deps(objclass=state) - for dep in deps: - if dep not in state.vars and dep not in state.backend_vars: - raise exceptions.VarDependencyError( - f"ComputedVar {var._js_expr} on state {state.__name__} has an invalid dependency {dep}" - ) + for state_name, dep_set in deps.items(): + state_cls = ( + state.get_root_state().get_class_substate(state_name) + if state_name != state.get_full_name() + else state + ) + for dep in dep_set: + if dep not in state_cls.vars and dep not in state_cls.backend_vars: + raise exceptions.VarDependencyError( + f"ComputedVar {var._js_expr} on state {state.__name__} has an invalid dependency {state_name}.{dep}" + ) for substate in state.class_subclasses: self._validate_var_dependencies(substate) diff --git a/reflex/compiler/utils.py b/reflex/compiler/utils.py index c0ba28f4b..f5e79a796 100644 --- a/reflex/compiler/utils.py +++ b/reflex/compiler/utils.py @@ -2,6 +2,7 @@ from __future__ import annotations +import asyncio from pathlib import Path from typing import Any, Callable, Dict, Optional, Type, Union from urllib.parse import urlparse @@ -29,7 +30,7 @@ from reflex.components.base import ( ) from reflex.components.component import Component, ComponentStyle, CustomComponent from reflex.istate.storage import Cookie, LocalStorage, SessionStorage -from reflex.state import BaseState +from reflex.state import BaseState, _resolve_delta from reflex.style import Style from reflex.utils import console, format, imports, path_ops from reflex.utils.imports import ImportVar, ParsedImportDict @@ -169,7 +170,7 @@ def compile_state(state: Type[BaseState]) -> dict: initial_state = state(_reflex_internal_init=True).dict( initial=True, include_computed=False ) - return initial_state + return asyncio.run(_resolve_delta(initial_state)) def _compile_client_storage_field( diff --git a/reflex/middleware/hydrate_middleware.py b/reflex/middleware/hydrate_middleware.py index 2198b82c2..2dea54e17 100644 --- a/reflex/middleware/hydrate_middleware.py +++ b/reflex/middleware/hydrate_middleware.py @@ -8,7 +8,7 @@ from typing import TYPE_CHECKING, Optional from reflex import constants from reflex.event import Event, get_hydrate_event from reflex.middleware.middleware import Middleware -from reflex.state import BaseState, StateUpdate +from reflex.state import BaseState, StateUpdate, _resolve_delta if TYPE_CHECKING: from reflex.app import App @@ -42,7 +42,7 @@ class HydrateMiddleware(Middleware): setattr(state, constants.CompileVars.IS_HYDRATED, False) # Get the initial state. - delta = state.dict() + delta = await _resolve_delta(state.dict()) # since a full dict was captured, clean any dirtiness state._clean() diff --git a/reflex/state.py b/reflex/state.py index 66098d232..6ef3ff3e8 100644 --- a/reflex/state.py +++ b/reflex/state.py @@ -328,6 +328,25 @@ def get_var_for_field(cls: Type[BaseState], f: ModelField): ) +async def _resolve_delta(delta: Delta) -> Delta: + """Await all coroutines in the delta. + + Args: + delta: The delta to process. + + Returns: + The same delta dict with all coroutines resolved to their return value. + """ + tasks = {} + for state_name, state_delta in delta.items(): + for var_name, value in state_delta.items(): + if asyncio.iscoroutine(value): + tasks[state_name, var_name] = asyncio.create_task(value) + for (state_name, var_name), task in tasks.items(): + delta[state_name][var_name] = await task + return delta + + class BaseState(Base, ABC, extra=pydantic.Extra.allow): """The state of the app.""" @@ -355,11 +374,8 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow): # A set of subclassses of this class. class_subclasses: ClassVar[Set[Type[BaseState]]] = set() - # Mapping of var name to set of computed variables that depend on it - _computed_var_dependencies: ClassVar[Dict[str, Set[str]]] = {} - - # Mapping of var name to set of substates that depend on it - _substate_var_dependencies: ClassVar[Dict[str, Set[str]]] = {} + # Mapping of var name to set of (state_full_name, var_name) that depend on it. + _var_dependencies: ClassVar[Dict[str, Set[Tuple[str, str]]]] = {} # Set of vars which always need to be recomputed _always_dirty_computed_vars: ClassVar[Set[str]] = set() @@ -367,6 +383,9 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow): # Set of substates which always need to be recomputed _always_dirty_substates: ClassVar[Set[str]] = set() + # Set of states which might need to be recomputed if vars in this state change. + _potentially_dirty_states: ClassVar[Set[str]] = set() + # The parent state. parent_state: Optional[BaseState] = None @@ -518,6 +537,7 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow): # Reset dirty substate tracking for this class. cls._always_dirty_substates = set() + cls._potentially_dirty_states = set() # Get the parent vars. parent_state = cls.get_parent_state() @@ -621,8 +641,7 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow): setattr(cls, name, handler) # Initialize per-class var dependency tracking. - cls._computed_var_dependencies = defaultdict(set) - cls._substate_var_dependencies = defaultdict(set) + cls._var_dependencies = {} cls._init_var_dependency_dicts() @staticmethod @@ -767,26 +786,25 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow): Additional updates tracking dicts for vars and substates that always need to be recomputed. """ - inherited_vars = set(cls.inherited_vars).union( - set(cls.inherited_backend_vars), - ) for cvar_name, cvar in cls.computed_vars.items(): - # Add the dependencies. - for var in cvar._deps(objclass=cls): - cls._computed_var_dependencies[var].add(cvar_name) - if var in inherited_vars: - # track that this substate depends on its parent for this var - state_name = cls.get_name() - parent_state = cls.get_parent_state() - while parent_state is not None and var in { - **parent_state.vars, - **parent_state.backend_vars, + if not cvar._cache: + # Do not perform dep calculation when cache=False (these are always dirty). + continue + for state_name, dvar_set in cvar._deps(objclass=cls).items(): + state_cls = cls.get_root_state().get_class_substate(state_name) + for dvar in dvar_set: + defining_state_cls = state_cls + while dvar in { + *defining_state_cls.inherited_vars, + *defining_state_cls.inherited_backend_vars, }: - parent_state._substate_var_dependencies[var].add(state_name) - state_name, parent_state = ( - parent_state.get_name(), - parent_state.get_parent_state(), - ) + defining_state_cls = defining_state_cls.get_parent_state() + defining_state_cls._var_dependencies.setdefault(dvar, set()).add( + (cls.get_full_name(), cvar_name) + ) + defining_state_cls._potentially_dirty_states.add( + cls.get_full_name() + ) # ComputedVar with cache=False always need to be recomputed cls._always_dirty_computed_vars = { @@ -901,6 +919,17 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow): raise ValueError(f"Only one parent state is allowed {parent_states}.") return parent_states[0] if len(parent_states) == 1 else None # type: ignore + @classmethod + @functools.lru_cache() + def get_root_state(cls) -> Type[BaseState]: + """Get the root state. + + Returns: + The root state. + """ + parent_state = cls.get_parent_state() + return cls if parent_state is None else parent_state.get_root_state() + @classmethod def get_substates(cls) -> set[Type[BaseState]]: """Get the substates of the state. @@ -1353,7 +1382,7 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow): super().__setattr__(name, value) # Add the var to the dirty list. - if name in self.vars or name in self._computed_var_dependencies: + if name in self.base_vars: self.dirty_vars.add(name) self._mark_dirty() @@ -1423,6 +1452,23 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow): raise ValueError(f"Invalid path: {path}") return self.substates[path[0]].get_substate(path[1:]) + @classmethod + def _get_potentially_dirty_states(cls) -> set[type[BaseState]]: + """Get substates which may have dirty vars due to dependencies. + + Returns: + The set of potentially dirty substate classes. + """ + return { + cls.get_class_substate(substate_name) + for substate_name in cls._always_dirty_substates + }.union( + { + cls.get_root_state().get_class_substate(substate_name) + for substate_name in cls._potentially_dirty_states + } + ) + @classmethod def _get_common_ancestor(cls, other: Type[BaseState]) -> str: """Find the name of the nearest common ancestor shared by this and the other state. @@ -1493,55 +1539,37 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow): parent_state = parent_state.parent_state return parent_state - async def _populate_parent_states(self, target_state_cls: Type[BaseState]): - """Populate substates in the tree between the target_state_cls and common ancestor of this state. + async def _get_state_from_redis(self, state_cls: Type[T_STATE]) -> T_STATE: + """Get a state instance from redis. Args: - target_state_cls: The class of the state to populate parent states for. + state_cls: The class of the state. Returns: - The parent state instance of target_state_cls. + The instance of state_cls associated with this state's client_token. Raises: RuntimeError: If redis is not used in this backend process. + StateMismatchError: If the state instance is not of the expected type. """ + # Then get the target state and all its substates. state_manager = get_state_manager() if not isinstance(state_manager, StateManagerRedis): raise RuntimeError( - f"Cannot populate parent states of {target_state_cls.get_full_name()} without redis. " + f"Requested state {state_cls.get_full_name()} is not cached and cannot be accessed without redis. " "(All states should already be available -- this is likely a bug).", ) + state_in_redis = await state_manager._link_arbitrary_state( + self, + state_cls, + ) - # Find the missing parent states up to the common ancestor. - ( - common_ancestor_name, - missing_parent_states, - ) = self._determine_missing_parent_states(target_state_cls) - - # Fetch all missing parent states and link them up to the common ancestor. - parent_states_tuple = self._get_parent_states() - root_state = parent_states_tuple[-1][1] - parent_states_by_name = dict(parent_states_tuple) - parent_state = parent_states_by_name[common_ancestor_name] - for parent_state_name in missing_parent_states: - try: - parent_state = root_state.get_substate(parent_state_name.split(".")) - # The requested state is already cached, do NOT fetch it again. - continue - except ValueError: - # The requested state is missing, fetch from redis. - pass - parent_state = await state_manager.get_state( - token=_substate_key( - self.router.session.client_token, parent_state_name - ), - top_level=False, - get_substates=False, - parent_state=parent_state, + if not isinstance(state_in_redis, state_cls): + raise StateMismatchError( + f"Searched for state {state_cls.get_full_name()} but found {state_in_redis}." ) - # Return the direct parent of target_state_cls for subsequent linking. - return parent_state + return state_in_redis def _get_state_from_cache(self, state_cls: Type[T_STATE]) -> T_STATE: """Get a state instance from the cache. @@ -1563,44 +1591,6 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow): ) return substate - async def _get_state_from_redis(self, state_cls: Type[T_STATE]) -> T_STATE: - """Get a state instance from redis. - - Args: - state_cls: The class of the state. - - Returns: - The instance of state_cls associated with this state's client_token. - - Raises: - RuntimeError: If redis is not used in this backend process. - StateMismatchError: If the state instance is not of the expected type. - """ - # Fetch all missing parent states from redis. - parent_state_of_state_cls = await self._populate_parent_states(state_cls) - - # Then get the target state and all its substates. - state_manager = get_state_manager() - if not isinstance(state_manager, StateManagerRedis): - raise RuntimeError( - f"Requested state {state_cls.get_full_name()} is not cached and cannot be accessed without redis. " - "(All states should already be available -- this is likely a bug).", - ) - - state_in_redis = await state_manager.get_state( - token=_substate_key(self.router.session.client_token, state_cls), - top_level=False, - get_substates=True, - parent_state=parent_state_of_state_cls, - ) - - if not isinstance(state_in_redis, state_cls): - raise StateMismatchError( - f"Searched for state {state_cls.get_full_name()} but found {state_in_redis}." - ) - - return state_in_redis - async def get_state(self, state_cls: Type[T_STATE]) -> T_STATE: """Get an instance of the state associated with this token. @@ -1737,7 +1727,7 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow): f"Your handler {handler.fn.__qualname__} must only return/yield: None, Events or other EventHandlers referenced by their class (not using `self`)" ) - def _as_state_update( + async def _as_state_update( self, handler: EventHandler, events: EventSpec | list[EventSpec] | None, @@ -1765,7 +1755,7 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow): try: # Get the delta after processing the event. - delta = state.get_delta() + delta = await _resolve_delta(state.get_delta()) state._clean() return StateUpdate( @@ -1865,24 +1855,28 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow): # Handle async generators. if inspect.isasyncgen(events): async for event in events: - yield state._as_state_update(handler, event, final=False) - yield state._as_state_update(handler, events=None, final=True) + yield await state._as_state_update(handler, event, final=False) + yield await state._as_state_update(handler, events=None, final=True) # Handle regular generators. elif inspect.isgenerator(events): try: while True: - yield state._as_state_update(handler, next(events), final=False) + yield await state._as_state_update( + handler, next(events), final=False + ) except StopIteration as si: # the "return" value of the generator is not available # in the loop, we must catch StopIteration to access it if si.value is not None: - yield state._as_state_update(handler, si.value, final=False) - yield state._as_state_update(handler, events=None, final=True) + yield await state._as_state_update( + handler, si.value, final=False + ) + yield await state._as_state_update(handler, events=None, final=True) # Handle regular event chains. else: - yield state._as_state_update(handler, events, final=True) + yield await state._as_state_update(handler, events, final=True) # If an error occurs, throw a window alert. except Exception as ex: @@ -1892,7 +1886,7 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow): prerequisites.get_and_validate_app().app.backend_exception_handler(ex) ) - yield state._as_state_update( + yield await state._as_state_update( handler, event_specs, final=True, @@ -1900,15 +1894,28 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow): def _mark_dirty_computed_vars(self) -> None: """Mark ComputedVars that need to be recalculated based on dirty_vars.""" + # Append expired computed vars to dirty_vars to trigger recalculation + self.dirty_vars.update(self._expired_computed_vars()) + # Append always dirty computed vars to dirty_vars to trigger recalculation + self.dirty_vars.update(self._always_dirty_computed_vars) + dirty_vars = self.dirty_vars while dirty_vars: calc_vars, dirty_vars = dirty_vars, set() - for cvar in self._dirty_computed_vars(from_vars=calc_vars): - self.dirty_vars.add(cvar) + for state_name, cvar in self._dirty_computed_vars(from_vars=calc_vars): + if state_name == self.get_full_name(): + defining_state = self + else: + defining_state = self._get_root_state().get_substate( + tuple(state_name.split(".")) + ) + defining_state.dirty_vars.add(cvar) dirty_vars.add(cvar) - actual_var = self.computed_vars.get(cvar) + actual_var = defining_state.computed_vars.get(cvar) if actual_var is not None: - actual_var.mark_dirty(instance=self) + actual_var.mark_dirty(instance=defining_state) + if defining_state is not self: + defining_state._mark_dirty() def _expired_computed_vars(self) -> set[str]: """Determine ComputedVars that need to be recalculated based on the expiration time. @@ -1924,7 +1931,7 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow): def _dirty_computed_vars( self, from_vars: set[str] | None = None, include_backend: bool = True - ) -> set[str]: + ) -> set[tuple[str, str]]: """Determine ComputedVars that need to be recalculated based on the given vars. Args: @@ -1935,32 +1942,59 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow): Set of computed vars to include in the delta. """ return { - cvar + (state_name, cvar) for dirty_var in from_vars or self.dirty_vars - for cvar in self._computed_var_dependencies[dirty_var] + for state_name, cvar in self._var_dependencies.get(dirty_var, set()) if include_backend or not self.computed_vars[cvar]._backend } - @classmethod - def _potentially_dirty_substates(cls) -> set[Type[BaseState]]: - """Determine substates which could be affected by dirty vars in this state. + async def _recursively_populate_dependent_substates( + self, + seen_classes: set[type[BaseState]] | None = None, + ) -> set[type[BaseState]]: + """Fetch all substates that have computed var dependencies on this state. + + Args: + seen_classes: set of classes that have already been seen to prevent infinite recursion. Returns: - Set of State classes that may need to be fetched to recalc computed vars. + The set of classes that were processed (mostly for testability). """ - # _always_dirty_substates need to be fetched to recalc computed vars. - fetch_substates = { - cls.get_class_substate((cls.get_name(), *substate_name.split("."))) - for substate_name in cls._always_dirty_substates - } - for dependent_substates in cls._substate_var_dependencies.values(): - fetch_substates.update( - { - cls.get_class_substate((cls.get_name(), *substate_name.split("."))) - for substate_name in dependent_substates - } + if seen_classes is None: + print( + f"\n\nTop-level _recursively_populate_dependent_substates from {type(self)}:" ) - return fetch_substates + seen_classes = set() + if type(self) in seen_classes: + return seen_classes + seen_classes.add(type(self)) + populated_substate_instances = {} + for substate_cls in { + self.get_class_substate((self.get_name(), *substate_name.split("."))) + for substate_name in self._always_dirty_substates + }: + # _always_dirty_substates need to be fetched to recalc computed vars. + if substate_cls not in populated_substate_instances: + print(f"fetching always dirty {substate_cls}") + populated_substate_instances[substate_cls] = await self.get_state( + substate_cls + ) + for dep_set in self._var_dependencies.values(): + for substate_name, _ in dep_set: + if substate_name == self.get_full_name(): + # Do NOT fetch our own state instance. + continue + substate_cls = self.get_root_state().get_class_substate(substate_name) + if substate_cls not in populated_substate_instances: + print(f"fetching dependent {substate_cls}") + populated_substate_instances[substate_cls] = await self.get_state( + substate_cls + ) + for substate in populated_substate_instances.values(): + await substate._recursively_populate_dependent_substates( + seen_classes=seen_classes, + ) + return seen_classes def get_delta(self) -> Delta: """Get the delta for the state. @@ -1970,21 +2004,15 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow): """ delta = {} - # Apply dirty variables down into substates - self.dirty_vars.update(self._always_dirty_computed_vars) - self._mark_dirty() - + self._mark_dirty_computed_vars() frontend_computed_vars: set[str] = { name for name, cv in self.computed_vars.items() if not cv._backend } # Return the dirty vars for this instance, any cached/dependent computed vars, # and always dirty computed vars (cache=False) - delta_vars = ( - self.dirty_vars.intersection(self.base_vars) - .union(self.dirty_vars.intersection(frontend_computed_vars)) - .union(self._dirty_computed_vars(include_backend=False)) - .union(self._always_dirty_computed_vars) + delta_vars = self.dirty_vars.intersection(self.base_vars).union( + self.dirty_vars.intersection(frontend_computed_vars) ) subdelta: Dict[str, Any] = { @@ -2014,23 +2042,9 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow): self.parent_state.dirty_substates.add(self.get_name()) self.parent_state._mark_dirty() - # Append expired computed vars to dirty_vars to trigger recalculation - self.dirty_vars.update(self._expired_computed_vars()) - # have to mark computed vars dirty to allow access to newly computed # values within the same ComputedVar function self._mark_dirty_computed_vars() - self._mark_dirty_substates() - - def _mark_dirty_substates(self): - """Propagate dirty var / computed var status into substates.""" - substates = self.substates - for var in self.dirty_vars: - for substate_name in self._substate_var_dependencies[var]: - self.dirty_substates.add(substate_name) - substate = substates[substate_name] - substate.dirty_vars.add(var) - substate._mark_dirty() def _update_was_touched(self): """Update the _was_touched flag based on dirty_vars.""" @@ -2102,11 +2116,7 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow): The object as a dictionary. """ if include_computed: - # Apply dirty variables down into substates to allow never-cached ComputedVar to - # trigger recalculation of dependent vars - self.dirty_vars.update(self._always_dirty_computed_vars) - self._mark_dirty() - + self._mark_dirty_computed_vars() base_vars = { prop_name: self.get_value(prop_name) for prop_name in self.base_vars } @@ -3339,6 +3349,79 @@ class StateManagerRedis(StateManager): ) return parent_state + async def _populate_parent_states( + self, calling_state: BaseState, target_state_cls: Type[BaseState] + ): + """Populate substates in the tree between the target_state_cls and common ancestor of calling_state. + + Args: + calling_state: The substate instance requesting subtree population. + target_state_cls: The class of the state to populate parent states for. + + Returns: + The parent state instance of target_state_cls. + """ + # Find the missing parent states up to the common ancestor. + ( + common_ancestor_name, + missing_parent_states, + ) = calling_state._determine_missing_parent_states(target_state_cls) + + # Fetch all missing parent states and link them up to the common ancestor. + parent_states_tuple = calling_state._get_parent_states() + root_state = parent_states_tuple[-1][1] + parent_states_by_name = dict(parent_states_tuple) + parent_state = parent_states_by_name[common_ancestor_name] + for parent_state_name in missing_parent_states: + try: + parent_state = root_state.get_substate(parent_state_name.split(".")) + # The requested state is already cached, do NOT fetch it again. + continue + except ValueError: + # The requested state is missing, fetch from redis. + pass + parent_state = await self.get_state( + token=_substate_key( + calling_state.router.session.client_token, parent_state_name + ), + top_level=False, + get_substates=False, + parent_state=parent_state, + ) + + # Return the direct parent of target_state_cls for subsequent linking. + return parent_state + + async def _link_arbitrary_state( + self, calling_state: BaseState, state_cls: Type[T_STATE] + ) -> T_STATE: + """Get a state instance from redis. + + Args: + calling_state: The state instance requesting the newly linked instance of state_cls. + state_cls: The class of the state to link into the tree. + + Returns: + The instance of state_cls associated with calling_state's client_token. + + Raises: + StateMismatchError: If the state instance is not of the expected type. + """ + # Fetch all missing parent states from redis. + parent_state_of_state_cls = await self._populate_parent_states( + calling_state, state_cls + ) + + # Then get the target state and all its substates. + state_in_redis = await self.get_state( + token=_substate_key(calling_state.router.session.client_token, state_cls), + top_level=False, + get_substates=True, + parent_state=parent_state_of_state_cls, + ) + + return state_in_redis + async def _populate_substates( self, token: str, @@ -3357,30 +3440,40 @@ class StateManagerRedis(StateManager): """ client_token, _ = _split_substate_key(token) + # Only _potentially_dirty_substates need to be fetched to recalc computed vars. + fetch_substates = state._get_potentially_dirty_states() if all_substates: # All substates are requested. - fetch_substates = state.get_substates() - else: - # Only _potentially_dirty_substates need to be fetched to recalc computed vars. - fetch_substates = state._potentially_dirty_substates() + fetch_substates.update(state.get_substates()) tasks = {} + link_tasks = set() # Retrieve the necessary substates from redis. for substate_cls in fetch_substates: if substate_cls.get_name() in state.substates: continue substate_name = substate_cls.get_name() - tasks[substate_name] = asyncio.create_task( - self.get_state( - token=_substate_key(client_token, substate_cls), - top_level=False, - get_substates=all_substates, - parent_state=state, + if substate_cls in state.get_substates(): + tasks[substate_name] = asyncio.create_task( + self.get_state( + token=_substate_key(client_token, substate_cls), + top_level=False, + get_substates=all_substates, + parent_state=state, + ) ) - ) + else: + try: + state._get_root_state().get_substate(substate_name.split(".")) + except ValueError: + # The requested state is missing, so fetch and link it (and its parents). + link_tasks.add( + asyncio.create_task(self._link_arbitrary_state(state, substate_cls)) + ) for substate_name, substate_task in tasks.items(): state.substates[substate_name] = await substate_task + await asyncio.gather(*link_tasks) @override async def get_state( @@ -4153,7 +4246,7 @@ def reload_state_module( if subclass.__module__ == module and module is not None: state.class_subclasses.remove(subclass) state._always_dirty_substates.discard(subclass.get_name()) - state._computed_var_dependencies = defaultdict(set) - state._substate_var_dependencies = defaultdict(set) + state._potentially_dirty_substates.discard(subclass.get_name()) + state._var_dependencies = {} state._init_var_dependency_dicts() state.get_class_substate.cache_clear() diff --git a/reflex/vars/base.py b/reflex/vars/base.py index 122545187..6bc5b25c4 100644 --- a/reflex/vars/base.py +++ b/reflex/vars/base.py @@ -1826,7 +1826,7 @@ class ComputedVar(Var[RETURN_TYPE]): _initial_value: RETURN_TYPE | types.Unset = dataclasses.field(default=types.Unset()) # Explicit var dependencies to track - _static_deps: set[str] = dataclasses.field(default_factory=set) + _static_deps: dict[str, set[str]] = dataclasses.field(default_factory=dict) # Whether var dependencies should be auto-determined _auto_deps: bool = dataclasses.field(default=True) @@ -1901,21 +1901,40 @@ class ComputedVar(Var[RETURN_TYPE]): object.__setattr__(self, "_update_interval", interval) - if deps is None: - deps = [] - else: + _static_deps = {} + if isinstance(deps, dict): + # Assume a dict is coming from _replace, so no special processing. + _static_deps = deps + elif deps is not None: for dep in deps: if isinstance(dep, Var): - continue - if isinstance(dep, str) and dep != "": - continue - raise TypeError( - "ComputedVar dependencies must be Var instances or var names (non-empty strings)." - ) + state_name = ( + all_var_data.state + if (all_var_data := dep._get_all_var_data()) + and all_var_data.state + else None + ) + var_name = ( + dep._js_expr[len(formatted_state_prefix) :] + if state_name + and ( + formatted_state_prefix := format_state_name(state_name) + + "." + ) + and dep._js_expr.startswith(formatted_state_prefix) + else dep._js_expr + ) + _static_deps.setdefault(state_name, set()).add(var_name) + elif isinstance(dep, str) and dep != "": + _static_deps.setdefault(None, set()).add(dep) + else: + raise TypeError( + "ComputedVar dependencies must be Var instances or var names (non-empty strings)." + ) object.__setattr__( self, "_static_deps", - {dep._js_expr if isinstance(dep, Var) else dep for dep in deps}, + _static_deps, ) object.__setattr__(self, "_auto_deps", auto_deps) @@ -2081,6 +2100,11 @@ class ComputedVar(Var[RETURN_TYPE]): setattr(instance, self._last_updated_attr, datetime.datetime.now()) value = getattr(instance, self._cache_attr) + self._check_deprecated_return_type(instance, value) + + return value + + def _check_deprecated_return_type(self, instance, value) -> None: if not _isinstance(value, self._var_type): console.deprecate( "mismatched-computed-var-return", @@ -2090,41 +2114,49 @@ class ComputedVar(Var[RETURN_TYPE]): "0.7.0", ) - return value - def _deps( self, - objclass: Type, + objclass: BaseState, obj: FunctionType | CodeType | None = None, - self_name: Optional[str] = None, - ) -> set[str]: + self_names: Optional[dict[str, str]] = None, + ) -> dict[str, set[str]]: """Determine var dependencies of this ComputedVar. - Save references to attributes accessed on "self". Recursively called - when the function makes a method call on "self" or define comprehensions - or nested functions that may reference "self". + Save references to attributes accessed on "self" or other fetched states. + + Recursively called when the function makes a method call on "self" or + define comprehensions or nested functions that may reference "self". Args: objclass: the class obj this ComputedVar is attached to. obj: the object to disassemble (defaults to the fget function). - self_name: if specified, look for this name in LOAD_FAST and LOAD_DEREF instructions. + self_names: if specified, look for these names in LOAD_FAST and LOAD_DEREF instructions. Returns: - A set of variable names accessed by the given obj. + A dictionary mapping state names to the set of variable names + accessed by the given obj. Raises: VarValueError: if the function references the get_state, parent_state, or substates attributes (cannot track deps in a related state, only implicitly via parent state). """ + from reflex.state import BaseState + + d = {} + if self._static_deps: + d.update(self._static_deps) + # None is a placeholder for the current state class. + if None in d: + d[objclass.get_full_name()] = d.pop(None) + if not self._auto_deps: - return self._static_deps - d = self._static_deps.copy() + return d if obj is None: fget = self._fget if fget is not None: obj = cast(FunctionType, fget) else: - return set() + return d with contextlib.suppress(AttributeError): # unbox functools.partial obj = cast(FunctionType, obj.func) # type: ignore @@ -2132,76 +2164,150 @@ class ComputedVar(Var[RETURN_TYPE]): # unbox EventHandler obj = cast(FunctionType, obj.fn) # type: ignore - if self_name is None and isinstance(obj, FunctionType): + if self_names is None and isinstance(obj, FunctionType): try: # the first argument to the function is the name of "self" arg - self_name = obj.__code__.co_varnames[0] + self_names = {obj.__code__.co_varnames[0]: objclass.get_full_name()} except (AttributeError, IndexError): - self_name = None - if self_name is None: + self_names = None + if self_names is None: # cannot reference attributes on self if method takes no args - return set() + return d - invalid_names = ["get_state", "parent_state", "substates", "get_substate"] - self_is_top_of_stack = False + invalid_names = ["parent_state", "substates", "get_substate"] + self_on_top_of_stack = None + getting_state = False + getting_var = False for instruction in dis.get_instructions(obj): + if getting_state: + if instruction.opname == "LOAD_FAST": + raise VarValueError( + f"Dependency detection cannot identify get_state class from local var {instruction.argval}." + ) + if instruction.opname == "LOAD_GLOBAL": + # Special case: referencing state class from global scope. + getting_state = obj.__globals__.get(instruction.argval) + elif instruction.opname == "LOAD_DEREF": + # Special case: referencing state class from closure. + closure = dict(zip(obj.__code__.co_freevars, obj.__closure__)) + try: + getting_state = closure[instruction.argval].cell_contents + except ValueError as ve: + raise VarValueError( + f"Cached var {self!s} cannot access arbitrary state `{instruction.argval}`, is it defined yet?." + ) from ve + elif instruction.opname == "STORE_FAST": + # Storing the result of get_state in a local variable. + if not isinstance(getting_state, type) or not issubclass( + getting_state, BaseState + ): + raise VarValueError( + f"Cached var {self!s} cannot determine dependencies in fetched state `{instruction.argval}`." + ) + self_names[instruction.argval] = getting_state.get_full_name() + getting_state = False + continue # nothing else happens until we have identified the local var + if getting_var: + if instruction.opname == "CALL": + # get the original source code and eval it + start_line = getting_var[0].positions.lineno + start_column = getting_var[0].positions.col_offset + end_line = getting_var[-1].positions.end_lineno + end_column = getting_var[-1].positions.end_col_offset + source = inspect.getsource(inspect.getmodule(obj)).splitlines(True)[start_line - 1: end_line] + if len(source) > 1: + snipped_source = "".join( + [ + source[0][start_column:], + source[1:-2] if len(source) > 2 else "", + source[-1][:end_column] + ] + ) + else: + snipped_source = source[0][start_column:end_column] + the_var = eval(f"({snipped_source})", obj.__globals__) + print(the_var) + # code = source[start_line - 1] + # bytecode = bytearray((dis.opmap["RESUME"], 0)) + # for ins in getting_var: + # bytecode.append(ins.opcode) + # bytecode.append(ins.arg or 0 & 0xFF) + # bytecode.extend((dis.opmap["RETURN_VALUE"], 0)) + # bc = dis.Bytecode(obj) + # code = bc.codeobj.replace(co_code=bytes(bytecode), co_argcount=0, co_nlocals=0, co_varnames=()) + # breakpoint() + getting_var = False + elif isinstance(getting_var, list): + getting_var.append(instruction) + else: + getting_var = [instruction] + continue if ( instruction.opname in ("LOAD_FAST", "LOAD_DEREF") - and instruction.argval == self_name + and instruction.argval in self_names ): # bytecode loaded the class instance to the top of stack, next load instruction # is referencing an attribute on self - self_is_top_of_stack = True + self_on_top_of_stack = self_names[instruction.argval] continue - if self_is_top_of_stack and instruction.opname in ( + if self_on_top_of_stack and instruction.opname in ( "LOAD_ATTR", "LOAD_METHOD", ): - try: - ref_obj = getattr(objclass, instruction.argval) - except Exception: - ref_obj = None if instruction.argval in invalid_names: raise VarValueError( f"Cached var {self!s} cannot access arbitrary state via `{instruction.argval}`." ) + if instruction.argval == "get_state": + # Special case: arbitrary state access requested. + getting_state = True + continue + if instruction.argval == "get_var_value": + # Special case: arbitrary var access requested. + getting_var = True + continue + print(f"{self_on_top_of_stack=}") + target_state = objclass.get_root_state().get_class_substate( + self_on_top_of_stack + ) + try: + ref_obj = getattr(target_state, instruction.argval) + except Exception: + ref_obj = None if callable(ref_obj): # recurse into callable attributes - d.update( - self._deps( - objclass=objclass, - obj=ref_obj, - ) - ) + for state_name, dep_name in self._deps( + objclass=target_state, + obj=ref_obj, + ).items(): + d.setdefault(state_name, set()).update(dep_name) # recurse into property fget functions elif isinstance(ref_obj, property) and not isinstance( ref_obj, ComputedVar ): - d.update( - self._deps( - objclass=objclass, - obj=ref_obj.fget, # type: ignore - ) - ) + for state_name, dep_name in self._deps( + objclass=target_state, + obj=ref_obj.fget, # type: ignore + ).items(): + d.setdefault(state_name, set()).update(dep_name) elif ( - instruction.argval in objclass.backend_vars - or instruction.argval in objclass.vars + instruction.argval in target_state.backend_vars + or instruction.argval in target_state.vars ): # var access - d.add(instruction.argval) + d.setdefault(self_on_top_of_stack, set()).add(instruction.argval) elif instruction.opname == "LOAD_CONST" and isinstance( instruction.argval, CodeType ): # recurse into nested functions / comprehensions, which can reference # instance attributes from the outer scope - d.update( - self._deps( - objclass=objclass, - obj=instruction.argval, - self_name=self_name, - ) - ) - self_is_top_of_stack = False + for state_name, dep_name in self._deps( + objclass=objclass, + obj=instruction.argval, + self_names=self_names, + ).items(): + d.setdefault(state_name, set()).update(dep_name) + self_on_top_of_stack = None return d def mark_dirty(self, instance) -> None: @@ -2249,6 +2355,60 @@ class DynamicRouteVar(ComputedVar[Union[str, List[str]]]): pass +@dataclasses.dataclass( + eq=False, + frozen=True, + init=False, + **{"slots": True} if sys.version_info >= (3, 10) else {}, +) +class AsyncComputedVar(ComputedVar[RETURN_TYPE]): + """A computed var that wraps a coroutinefunction.""" + + _fget: Callable[[BaseState], RETURN_TYPE] = dataclasses.field( + default_factory=lambda: lambda _: None + ) # type: ignore + + def __get__(self, instance: BaseState | None, owner): + """Get the ComputedVar value. + + If the value is already cached on the instance, return the cached value. + + Args: + instance: the instance of the class accessing this computed var. + owner: the class that this descriptor is attached to. + + Returns: + The value of the var for the given instance. + """ + if instance is None: + return super(AsyncComputedVar, self).__get__(instance, owner) + + if not self._cache: + + async def _awaitable_result(): + value = await self.fget(instance) + self._check_deprecated_return_type(instance, value) + + return _awaitable_result() + else: + # handle caching + async def _awaitable_result(): + if not hasattr(instance, self._cache_attr) or self.needs_update( + instance + ): + # Set cache attr on state instance. + setattr(instance, self._cache_attr, await self.fget(instance)) + # Ensure the computed var gets serialized to redis. + instance._was_touched = True + # Set the last updated timestamp on the state instance. + setattr(instance, self._last_updated_attr, datetime.datetime.now()) + value = getattr(instance, self._cache_attr) + self._check_deprecated_return_type(instance, value) + return value + + return _awaitable_result() + + if TYPE_CHECKING: BASE_STATE = TypeVar("BASE_STATE", bound=BaseState) @@ -2315,10 +2475,27 @@ def computed_var( raise VarDependencyError("Cannot track dependencies without caching.") if fget is not None: - return ComputedVar(fget, cache=cache) + if inspect.iscoroutinefunction(fget): + computed_var_cls = AsyncComputedVar + else: + computed_var_cls = ComputedVar + return computed_var_cls( + fget, + initial_value=initial_value, + cache=cache, + deps=deps, + auto_deps=auto_deps, + interval=interval, + backend=backend, + **kwargs, + ) def wrapper(fget: Callable[[BASE_STATE], Any]) -> ComputedVar: - return ComputedVar( + if inspect.iscoroutinefunction(fget): + computed_var_cls = AsyncComputedVar + else: + computed_var_cls = ComputedVar + return computed_var_cls( fget, initial_value=initial_value, cache=cache, diff --git a/tests/units/test_app.py b/tests/units/test_app.py index f805f83ec..bd9872949 100644 --- a/tests/units/test_app.py +++ b/tests/units/test_app.py @@ -277,9 +277,9 @@ def test_add_page_set_route_dynamic(index_page, windows_platform: bool): assert app.pages.keys() == {"test/[dynamic]"} assert "dynamic" in app.state.computed_vars assert app.state.computed_vars["dynamic"]._deps(objclass=EmptyState) == { - constants.ROUTER + EmptyState.get_full_name(): {constants.ROUTER}, } - assert constants.ROUTER in app.state()._computed_var_dependencies + assert constants.ROUTER in app.state()._var_dependencies def test_add_page_set_route_nested(app: App, index_page, windows_platform: bool): @@ -997,9 +997,9 @@ async def test_dynamic_route_var_route_change_completed_on_load( assert arg_name in app.state.vars assert arg_name in app.state.computed_vars assert app.state.computed_vars[arg_name]._deps(objclass=DynamicState) == { - constants.ROUTER + DynamicState.get_full_name(): {constants.ROUTER}, } - assert constants.ROUTER in app.state()._computed_var_dependencies + assert constants.ROUTER in app.state()._var_dependencies substate_token = _substate_key(token, DynamicState) sid = "mock_sid" @@ -1557,6 +1557,16 @@ def test_app_with_valid_var_dependencies(compilable_app: tuple[App, Path]): def bar(self) -> str: return "bar" + class Child1(ValidDepState): + @computed_var(deps=["base", ValidDepState.bar]) + def other(self) -> str: + return "other" + + class Child2(ValidDepState): + @computed_var(deps=["base", Child1.other]) + def other(self) -> str: + return "other" + app.state = ValidDepState app._compile() diff --git a/tests/units/test_state.py b/tests/units/test_state.py index 19f3e4239..c5e2b1287 100644 --- a/tests/units/test_state.py +++ b/tests/units/test_state.py @@ -1170,13 +1170,11 @@ def test_conditional_computed_vars(): ms = MainState() # Initially there are no dirty computed vars. - assert ms._dirty_computed_vars(from_vars={"flag"}) == {"rendered_var"} - assert ms._dirty_computed_vars(from_vars={"t2"}) == {"rendered_var"} - assert ms._dirty_computed_vars(from_vars={"t1"}) == {"rendered_var"} + assert ms._dirty_computed_vars(from_vars={"flag"}) == {(MainState.get_full_name(), "rendered_var")} + assert ms._dirty_computed_vars(from_vars={"t2"}) == {(MainState.get_full_name(), "rendered_var")} + assert ms._dirty_computed_vars(from_vars={"t1"}) == {(MainState.get_full_name(), "rendered_var")} assert ms.computed_vars["rendered_var"]._deps(objclass=MainState) == { - "flag", - "t1", - "t2", + MainState.get_full_name(): {"flag", "t1", "t2"} } @@ -1371,7 +1369,7 @@ def test_cached_var_depends_on_event_handler(use_partial: bool): assert isinstance(HandlerState.handler, EventHandler) s = HandlerState() - assert "cached_x_side_effect" in s._computed_var_dependencies["x"] + assert (HandlerState.get_full_name(), "cached_x_side_effect") in s._var_dependencies["x"] assert s.cached_x_side_effect == 1 assert s.x == 43 s.handler() @@ -1461,15 +1459,15 @@ def test_computed_var_dependencies(): return [z in self._z for z in range(5)] cs = ComputedState() - assert cs._computed_var_dependencies["v"] == { - "comp_v", - "comp_v_backend", - "comp_v_via_property", + assert cs._var_dependencies["v"] == { + (ComputedState.get_full_name(), "comp_v"), + (ComputedState.get_full_name(), "comp_v_backend"), + (ComputedState.get_full_name(), "comp_v_via_property"), } - assert cs._computed_var_dependencies["w"] == {"comp_w"} - assert cs._computed_var_dependencies["x"] == {"comp_x"} - assert cs._computed_var_dependencies["y"] == {"comp_y"} - assert cs._computed_var_dependencies["_z"] == {"comp_z"} + assert cs._var_dependencies["w"] == {(ComputedState.get_full_name(), "comp_w")} + assert cs._var_dependencies["x"] == {(ComputedState.get_full_name(), "comp_x")} + assert cs._var_dependencies["y"] == {(ComputedState.get_full_name(), "comp_y")} + assert cs._var_dependencies["_z"] == {(ComputedState.get_full_name(), "comp_z")} def test_backend_method(): @@ -3182,6 +3180,7 @@ async def test_get_state_from_sibling_not_cached(mock_app: rx.App, token: str): RxState = State +@pytest.mark.skip(reason="This test is maybe not relevant anymore.") def test_potentially_dirty_substates(): """Test that potentially_dirty_substates returns the correct substates. @@ -3203,7 +3202,8 @@ def test_potentially_dirty_substates(): assert C1._potentially_dirty_substates() == set() -def test_router_var_dep() -> None: +@pytest.mark.asyncio +async def test_router_var_dep() -> None: """Test that router var dependencies are correctly tracked.""" class RouterVarParentState(State): @@ -3221,13 +3221,9 @@ def test_router_var_dep() -> None: foo = RouterVarDepState.computed_vars["foo"] State._init_var_dependency_dicts() - assert foo._deps(objclass=RouterVarDepState) == {"router"} - assert RouterVarParentState._potentially_dirty_substates() == {RouterVarDepState} - assert RouterVarParentState._substate_var_dependencies == { - "router": {RouterVarDepState.get_name()} - } - assert RouterVarDepState._computed_var_dependencies == { - "router": {"foo"}, + assert foo._deps(objclass=RouterVarDepState) == {RouterVarDepState.get_full_name(): {"router"}} + assert State._var_dependencies == { + "router": {(RouterVarDepState.get_full_name(), "foo")} } rx_state = State() @@ -3240,11 +3236,15 @@ def test_router_var_dep() -> None: state.parent_state = parent_state parent_state.substates = {RouterVarDepState.get_name(): state} + populated_substate_classes = await rx_state._recursively_populate_dependent_substates() + assert populated_substate_classes == {State, RouterVarDepState} + assert state.dirty_vars == set() # Reassign router var state.router = state.router - assert state.dirty_vars == {"foo", "router"} + assert rx_state.dirty_vars == {"router"} + assert state.dirty_vars == {"foo"} assert parent_state.dirty_substates == {RouterVarDepState.get_name()} @@ -3803,3 +3803,74 @@ async def test_get_var_value(state_manager: StateManager, substate_token: str): # Generic Var with no state with pytest.raises(UnretrievableVarValueError): await state.get_var_value(rx.Var("undefined")) + + +@pytest.mark.asyncio +async def test_async_computed_var_get_state(mock_app: rx.App, token: str): + """A test where an async computed var depends on a var in another state. + + Args: + mock_app: An app that will be returned by `get_app()` + token: A token. + """ + + class Parent(BaseState): + """A root state like rx.State.""" + + parent_var: int = 0 + + class Child2(Parent): + """An unconnected child state.""" + + pass + + class Child3(Parent): + """A child state with a computed var causing it to be pre-fetched. + + If child3_var gets set to a value, and `get_state` erroneously + re-fetches it from redis, the value will be lost. + """ + + child3_var: int = 0 + + @rx.var(cache=True) + def v(self): + return self.child3_var + + class Child(Parent): + """A state simulating UpdateVarsInternalState.""" + + @rx.var(cache=True) + async def v(self): + p = await self.get_state(Parent) + child3 = await self.get_state(Child3) + return child3.child3_var + p.parent_var + + mock_app.state_manager.state = mock_app.state = Parent + + # Get the top level state via unconnected sibling. + root = await mock_app.state_manager.get_state(_substate_key(token, Child)) + # Set value in parent_var to assert it does not get refetched later. + root.parent_var = 1 + + if isinstance(mock_app.state_manager, StateManagerRedis): + # When redis is used, only states with uncached computed vars are pre-fetched. + assert Child2.get_name() not in root.substates + assert Child3.get_name() not in root.substates + + # Get the unconnected sibling state, which will be used to `get_state` other instances. + child = root.get_substate(Child.get_full_name().split(".")) + + # Get an uncached child state. + child2 = await child.get_state(Child2) + assert child2.parent_var == 1 + + # Set value on already-cached Child3 state (prefetched because it has a Computed Var). + child3 = await child.get_state(Child3) + child3.child3_var = 1 + + assert await child.v == 2 + assert await child.v == 2 + root.parent_var = 2 + assert await child.v == 3 + diff --git a/tests/units/test_var.py b/tests/units/test_var.py index a8e9cd88c..e34fb8871 100644 --- a/tests/units/test_var.py +++ b/tests/units/test_var.py @@ -15,6 +15,7 @@ from reflex.utils.exceptions import PrimitiveUnserializableToJSON from reflex.utils.imports import ImportVar from reflex.vars import VarData from reflex.vars.base import ( + AsyncComputedVar, ComputedVar, LiteralVar, Var, @@ -1808,9 +1809,9 @@ def cv_fget(state: BaseState) -> int: @pytest.mark.parametrize( "deps,expected", [ - (["a"], {"a"}), - (["b"], {"b"}), - ([ComputedVar(fget=cv_fget)], {"cv_fget"}), + (["a"], {None: {"a"}}), + (["b"], {None: {"b"}}), + ([ComputedVar(fget=cv_fget)], {None: {"cv_fget"}}), ], ) def test_computed_var_deps(deps: List[Union[str, Var]], expected: Set[str]): @@ -1856,3 +1857,25 @@ def test_to_string_operation(): single_var = Var.create(Email()) assert single_var._var_type == Email + + +@pytest.mark.asyncio +async def test_async_computed_var(): + side_effect_counter = 0 + + class AsyncComputedVarState(BaseState): + v: int = 1 + + @computed_var(cache=True) + async def async_computed_var(self) -> int: + nonlocal side_effect_counter + side_effect_counter += 1 + return self.v + 1 + + my_state = AsyncComputedVarState() + assert await my_state.async_computed_var == 2 + assert await my_state.async_computed_var == 2 + my_state.v = 2 + assert await my_state.async_computed_var == 3 + assert await my_state.async_computed_var == 3 + assert side_effect_counter == 2