This commit is contained in:
Masen Furer 2025-01-22 05:00:49 -08:00
parent 048416163d
commit 3d50c1b623
No known key found for this signature in database
GPG Key ID: B0008AD22B3B3A95
8 changed files with 654 additions and 273 deletions

View File

@ -833,11 +833,17 @@ class App(MiddlewareMixin, LifespanMixin):
if not var._cache:
continue
deps = var._deps(objclass=state)
for dep in deps:
if dep not in state.vars and dep not in state.backend_vars:
raise exceptions.VarDependencyError(
f"ComputedVar {var._js_expr} on state {state.__name__} has an invalid dependency {dep}"
)
for state_name, dep_set in deps.items():
state_cls = (
state.get_root_state().get_class_substate(state_name)
if state_name != state.get_full_name()
else state
)
for dep in dep_set:
if dep not in state_cls.vars and dep not in state_cls.backend_vars:
raise exceptions.VarDependencyError(
f"ComputedVar {var._js_expr} on state {state.__name__} has an invalid dependency {state_name}.{dep}"
)
for substate in state.class_subclasses:
self._validate_var_dependencies(substate)

View File

@ -2,6 +2,7 @@
from __future__ import annotations
import asyncio
from pathlib import Path
from typing import Any, Callable, Dict, Optional, Type, Union
from urllib.parse import urlparse
@ -29,7 +30,7 @@ from reflex.components.base import (
)
from reflex.components.component import Component, ComponentStyle, CustomComponent
from reflex.istate.storage import Cookie, LocalStorage, SessionStorage
from reflex.state import BaseState
from reflex.state import BaseState, _resolve_delta
from reflex.style import Style
from reflex.utils import console, format, imports, path_ops
from reflex.utils.imports import ImportVar, ParsedImportDict
@ -169,7 +170,7 @@ def compile_state(state: Type[BaseState]) -> dict:
initial_state = state(_reflex_internal_init=True).dict(
initial=True, include_computed=False
)
return initial_state
return asyncio.run(_resolve_delta(initial_state))
def _compile_client_storage_field(

View File

@ -8,7 +8,7 @@ from typing import TYPE_CHECKING, Optional
from reflex import constants
from reflex.event import Event, get_hydrate_event
from reflex.middleware.middleware import Middleware
from reflex.state import BaseState, StateUpdate
from reflex.state import BaseState, StateUpdate, _resolve_delta
if TYPE_CHECKING:
from reflex.app import App
@ -42,7 +42,7 @@ class HydrateMiddleware(Middleware):
setattr(state, constants.CompileVars.IS_HYDRATED, False)
# Get the initial state.
delta = state.dict()
delta = await _resolve_delta(state.dict())
# since a full dict was captured, clean any dirtiness
state._clean()

View File

@ -328,6 +328,25 @@ def get_var_for_field(cls: Type[BaseState], f: ModelField):
)
async def _resolve_delta(delta: Delta) -> Delta:
"""Await all coroutines in the delta.
Args:
delta: The delta to process.
Returns:
The same delta dict with all coroutines resolved to their return value.
"""
tasks = {}
for state_name, state_delta in delta.items():
for var_name, value in state_delta.items():
if asyncio.iscoroutine(value):
tasks[state_name, var_name] = asyncio.create_task(value)
for (state_name, var_name), task in tasks.items():
delta[state_name][var_name] = await task
return delta
class BaseState(Base, ABC, extra=pydantic.Extra.allow):
"""The state of the app."""
@ -355,11 +374,8 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
# A set of subclassses of this class.
class_subclasses: ClassVar[Set[Type[BaseState]]] = set()
# Mapping of var name to set of computed variables that depend on it
_computed_var_dependencies: ClassVar[Dict[str, Set[str]]] = {}
# Mapping of var name to set of substates that depend on it
_substate_var_dependencies: ClassVar[Dict[str, Set[str]]] = {}
# Mapping of var name to set of (state_full_name, var_name) that depend on it.
_var_dependencies: ClassVar[Dict[str, Set[Tuple[str, str]]]] = {}
# Set of vars which always need to be recomputed
_always_dirty_computed_vars: ClassVar[Set[str]] = set()
@ -367,6 +383,9 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
# Set of substates which always need to be recomputed
_always_dirty_substates: ClassVar[Set[str]] = set()
# Set of states which might need to be recomputed if vars in this state change.
_potentially_dirty_states: ClassVar[Set[str]] = set()
# The parent state.
parent_state: Optional[BaseState] = None
@ -518,6 +537,7 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
# Reset dirty substate tracking for this class.
cls._always_dirty_substates = set()
cls._potentially_dirty_states = set()
# Get the parent vars.
parent_state = cls.get_parent_state()
@ -621,8 +641,7 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
setattr(cls, name, handler)
# Initialize per-class var dependency tracking.
cls._computed_var_dependencies = defaultdict(set)
cls._substate_var_dependencies = defaultdict(set)
cls._var_dependencies = {}
cls._init_var_dependency_dicts()
@staticmethod
@ -767,26 +786,25 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
Additional updates tracking dicts for vars and substates that always
need to be recomputed.
"""
inherited_vars = set(cls.inherited_vars).union(
set(cls.inherited_backend_vars),
)
for cvar_name, cvar in cls.computed_vars.items():
# Add the dependencies.
for var in cvar._deps(objclass=cls):
cls._computed_var_dependencies[var].add(cvar_name)
if var in inherited_vars:
# track that this substate depends on its parent for this var
state_name = cls.get_name()
parent_state = cls.get_parent_state()
while parent_state is not None and var in {
**parent_state.vars,
**parent_state.backend_vars,
if not cvar._cache:
# Do not perform dep calculation when cache=False (these are always dirty).
continue
for state_name, dvar_set in cvar._deps(objclass=cls).items():
state_cls = cls.get_root_state().get_class_substate(state_name)
for dvar in dvar_set:
defining_state_cls = state_cls
while dvar in {
*defining_state_cls.inherited_vars,
*defining_state_cls.inherited_backend_vars,
}:
parent_state._substate_var_dependencies[var].add(state_name)
state_name, parent_state = (
parent_state.get_name(),
parent_state.get_parent_state(),
)
defining_state_cls = defining_state_cls.get_parent_state()
defining_state_cls._var_dependencies.setdefault(dvar, set()).add(
(cls.get_full_name(), cvar_name)
)
defining_state_cls._potentially_dirty_states.add(
cls.get_full_name()
)
# ComputedVar with cache=False always need to be recomputed
cls._always_dirty_computed_vars = {
@ -901,6 +919,17 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
raise ValueError(f"Only one parent state is allowed {parent_states}.")
return parent_states[0] if len(parent_states) == 1 else None # type: ignore
@classmethod
@functools.lru_cache()
def get_root_state(cls) -> Type[BaseState]:
"""Get the root state.
Returns:
The root state.
"""
parent_state = cls.get_parent_state()
return cls if parent_state is None else parent_state.get_root_state()
@classmethod
def get_substates(cls) -> set[Type[BaseState]]:
"""Get the substates of the state.
@ -1353,7 +1382,7 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
super().__setattr__(name, value)
# Add the var to the dirty list.
if name in self.vars or name in self._computed_var_dependencies:
if name in self.base_vars:
self.dirty_vars.add(name)
self._mark_dirty()
@ -1423,6 +1452,23 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
raise ValueError(f"Invalid path: {path}")
return self.substates[path[0]].get_substate(path[1:])
@classmethod
def _get_potentially_dirty_states(cls) -> set[type[BaseState]]:
"""Get substates which may have dirty vars due to dependencies.
Returns:
The set of potentially dirty substate classes.
"""
return {
cls.get_class_substate(substate_name)
for substate_name in cls._always_dirty_substates
}.union(
{
cls.get_root_state().get_class_substate(substate_name)
for substate_name in cls._potentially_dirty_states
}
)
@classmethod
def _get_common_ancestor(cls, other: Type[BaseState]) -> str:
"""Find the name of the nearest common ancestor shared by this and the other state.
@ -1493,55 +1539,37 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
parent_state = parent_state.parent_state
return parent_state
async def _populate_parent_states(self, target_state_cls: Type[BaseState]):
"""Populate substates in the tree between the target_state_cls and common ancestor of this state.
async def _get_state_from_redis(self, state_cls: Type[T_STATE]) -> T_STATE:
"""Get a state instance from redis.
Args:
target_state_cls: The class of the state to populate parent states for.
state_cls: The class of the state.
Returns:
The parent state instance of target_state_cls.
The instance of state_cls associated with this state's client_token.
Raises:
RuntimeError: If redis is not used in this backend process.
StateMismatchError: If the state instance is not of the expected type.
"""
# Then get the target state and all its substates.
state_manager = get_state_manager()
if not isinstance(state_manager, StateManagerRedis):
raise RuntimeError(
f"Cannot populate parent states of {target_state_cls.get_full_name()} without redis. "
f"Requested state {state_cls.get_full_name()} is not cached and cannot be accessed without redis. "
"(All states should already be available -- this is likely a bug).",
)
state_in_redis = await state_manager._link_arbitrary_state(
self,
state_cls,
)
# Find the missing parent states up to the common ancestor.
(
common_ancestor_name,
missing_parent_states,
) = self._determine_missing_parent_states(target_state_cls)
# Fetch all missing parent states and link them up to the common ancestor.
parent_states_tuple = self._get_parent_states()
root_state = parent_states_tuple[-1][1]
parent_states_by_name = dict(parent_states_tuple)
parent_state = parent_states_by_name[common_ancestor_name]
for parent_state_name in missing_parent_states:
try:
parent_state = root_state.get_substate(parent_state_name.split("."))
# The requested state is already cached, do NOT fetch it again.
continue
except ValueError:
# The requested state is missing, fetch from redis.
pass
parent_state = await state_manager.get_state(
token=_substate_key(
self.router.session.client_token, parent_state_name
),
top_level=False,
get_substates=False,
parent_state=parent_state,
if not isinstance(state_in_redis, state_cls):
raise StateMismatchError(
f"Searched for state {state_cls.get_full_name()} but found {state_in_redis}."
)
# Return the direct parent of target_state_cls for subsequent linking.
return parent_state
return state_in_redis
def _get_state_from_cache(self, state_cls: Type[T_STATE]) -> T_STATE:
"""Get a state instance from the cache.
@ -1563,44 +1591,6 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
)
return substate
async def _get_state_from_redis(self, state_cls: Type[T_STATE]) -> T_STATE:
"""Get a state instance from redis.
Args:
state_cls: The class of the state.
Returns:
The instance of state_cls associated with this state's client_token.
Raises:
RuntimeError: If redis is not used in this backend process.
StateMismatchError: If the state instance is not of the expected type.
"""
# Fetch all missing parent states from redis.
parent_state_of_state_cls = await self._populate_parent_states(state_cls)
# Then get the target state and all its substates.
state_manager = get_state_manager()
if not isinstance(state_manager, StateManagerRedis):
raise RuntimeError(
f"Requested state {state_cls.get_full_name()} is not cached and cannot be accessed without redis. "
"(All states should already be available -- this is likely a bug).",
)
state_in_redis = await state_manager.get_state(
token=_substate_key(self.router.session.client_token, state_cls),
top_level=False,
get_substates=True,
parent_state=parent_state_of_state_cls,
)
if not isinstance(state_in_redis, state_cls):
raise StateMismatchError(
f"Searched for state {state_cls.get_full_name()} but found {state_in_redis}."
)
return state_in_redis
async def get_state(self, state_cls: Type[T_STATE]) -> T_STATE:
"""Get an instance of the state associated with this token.
@ -1737,7 +1727,7 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
f"Your handler {handler.fn.__qualname__} must only return/yield: None, Events or other EventHandlers referenced by their class (not using `self`)"
)
def _as_state_update(
async def _as_state_update(
self,
handler: EventHandler,
events: EventSpec | list[EventSpec] | None,
@ -1765,7 +1755,7 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
try:
# Get the delta after processing the event.
delta = state.get_delta()
delta = await _resolve_delta(state.get_delta())
state._clean()
return StateUpdate(
@ -1865,24 +1855,28 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
# Handle async generators.
if inspect.isasyncgen(events):
async for event in events:
yield state._as_state_update(handler, event, final=False)
yield state._as_state_update(handler, events=None, final=True)
yield await state._as_state_update(handler, event, final=False)
yield await state._as_state_update(handler, events=None, final=True)
# Handle regular generators.
elif inspect.isgenerator(events):
try:
while True:
yield state._as_state_update(handler, next(events), final=False)
yield await state._as_state_update(
handler, next(events), final=False
)
except StopIteration as si:
# the "return" value of the generator is not available
# in the loop, we must catch StopIteration to access it
if si.value is not None:
yield state._as_state_update(handler, si.value, final=False)
yield state._as_state_update(handler, events=None, final=True)
yield await state._as_state_update(
handler, si.value, final=False
)
yield await state._as_state_update(handler, events=None, final=True)
# Handle regular event chains.
else:
yield state._as_state_update(handler, events, final=True)
yield await state._as_state_update(handler, events, final=True)
# If an error occurs, throw a window alert.
except Exception as ex:
@ -1892,7 +1886,7 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
prerequisites.get_and_validate_app().app.backend_exception_handler(ex)
)
yield state._as_state_update(
yield await state._as_state_update(
handler,
event_specs,
final=True,
@ -1900,15 +1894,28 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
def _mark_dirty_computed_vars(self) -> None:
"""Mark ComputedVars that need to be recalculated based on dirty_vars."""
# Append expired computed vars to dirty_vars to trigger recalculation
self.dirty_vars.update(self._expired_computed_vars())
# Append always dirty computed vars to dirty_vars to trigger recalculation
self.dirty_vars.update(self._always_dirty_computed_vars)
dirty_vars = self.dirty_vars
while dirty_vars:
calc_vars, dirty_vars = dirty_vars, set()
for cvar in self._dirty_computed_vars(from_vars=calc_vars):
self.dirty_vars.add(cvar)
for state_name, cvar in self._dirty_computed_vars(from_vars=calc_vars):
if state_name == self.get_full_name():
defining_state = self
else:
defining_state = self._get_root_state().get_substate(
tuple(state_name.split("."))
)
defining_state.dirty_vars.add(cvar)
dirty_vars.add(cvar)
actual_var = self.computed_vars.get(cvar)
actual_var = defining_state.computed_vars.get(cvar)
if actual_var is not None:
actual_var.mark_dirty(instance=self)
actual_var.mark_dirty(instance=defining_state)
if defining_state is not self:
defining_state._mark_dirty()
def _expired_computed_vars(self) -> set[str]:
"""Determine ComputedVars that need to be recalculated based on the expiration time.
@ -1924,7 +1931,7 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
def _dirty_computed_vars(
self, from_vars: set[str] | None = None, include_backend: bool = True
) -> set[str]:
) -> set[tuple[str, str]]:
"""Determine ComputedVars that need to be recalculated based on the given vars.
Args:
@ -1935,32 +1942,59 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
Set of computed vars to include in the delta.
"""
return {
cvar
(state_name, cvar)
for dirty_var in from_vars or self.dirty_vars
for cvar in self._computed_var_dependencies[dirty_var]
for state_name, cvar in self._var_dependencies.get(dirty_var, set())
if include_backend or not self.computed_vars[cvar]._backend
}
@classmethod
def _potentially_dirty_substates(cls) -> set[Type[BaseState]]:
"""Determine substates which could be affected by dirty vars in this state.
async def _recursively_populate_dependent_substates(
self,
seen_classes: set[type[BaseState]] | None = None,
) -> set[type[BaseState]]:
"""Fetch all substates that have computed var dependencies on this state.
Args:
seen_classes: set of classes that have already been seen to prevent infinite recursion.
Returns:
Set of State classes that may need to be fetched to recalc computed vars.
The set of classes that were processed (mostly for testability).
"""
# _always_dirty_substates need to be fetched to recalc computed vars.
fetch_substates = {
cls.get_class_substate((cls.get_name(), *substate_name.split(".")))
for substate_name in cls._always_dirty_substates
}
for dependent_substates in cls._substate_var_dependencies.values():
fetch_substates.update(
{
cls.get_class_substate((cls.get_name(), *substate_name.split(".")))
for substate_name in dependent_substates
}
if seen_classes is None:
print(
f"\n\nTop-level _recursively_populate_dependent_substates from {type(self)}:"
)
return fetch_substates
seen_classes = set()
if type(self) in seen_classes:
return seen_classes
seen_classes.add(type(self))
populated_substate_instances = {}
for substate_cls in {
self.get_class_substate((self.get_name(), *substate_name.split(".")))
for substate_name in self._always_dirty_substates
}:
# _always_dirty_substates need to be fetched to recalc computed vars.
if substate_cls not in populated_substate_instances:
print(f"fetching always dirty {substate_cls}")
populated_substate_instances[substate_cls] = await self.get_state(
substate_cls
)
for dep_set in self._var_dependencies.values():
for substate_name, _ in dep_set:
if substate_name == self.get_full_name():
# Do NOT fetch our own state instance.
continue
substate_cls = self.get_root_state().get_class_substate(substate_name)
if substate_cls not in populated_substate_instances:
print(f"fetching dependent {substate_cls}")
populated_substate_instances[substate_cls] = await self.get_state(
substate_cls
)
for substate in populated_substate_instances.values():
await substate._recursively_populate_dependent_substates(
seen_classes=seen_classes,
)
return seen_classes
def get_delta(self) -> Delta:
"""Get the delta for the state.
@ -1970,21 +2004,15 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
"""
delta = {}
# Apply dirty variables down into substates
self.dirty_vars.update(self._always_dirty_computed_vars)
self._mark_dirty()
self._mark_dirty_computed_vars()
frontend_computed_vars: set[str] = {
name for name, cv in self.computed_vars.items() if not cv._backend
}
# Return the dirty vars for this instance, any cached/dependent computed vars,
# and always dirty computed vars (cache=False)
delta_vars = (
self.dirty_vars.intersection(self.base_vars)
.union(self.dirty_vars.intersection(frontend_computed_vars))
.union(self._dirty_computed_vars(include_backend=False))
.union(self._always_dirty_computed_vars)
delta_vars = self.dirty_vars.intersection(self.base_vars).union(
self.dirty_vars.intersection(frontend_computed_vars)
)
subdelta: Dict[str, Any] = {
@ -2014,23 +2042,9 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
self.parent_state.dirty_substates.add(self.get_name())
self.parent_state._mark_dirty()
# Append expired computed vars to dirty_vars to trigger recalculation
self.dirty_vars.update(self._expired_computed_vars())
# have to mark computed vars dirty to allow access to newly computed
# values within the same ComputedVar function
self._mark_dirty_computed_vars()
self._mark_dirty_substates()
def _mark_dirty_substates(self):
"""Propagate dirty var / computed var status into substates."""
substates = self.substates
for var in self.dirty_vars:
for substate_name in self._substate_var_dependencies[var]:
self.dirty_substates.add(substate_name)
substate = substates[substate_name]
substate.dirty_vars.add(var)
substate._mark_dirty()
def _update_was_touched(self):
"""Update the _was_touched flag based on dirty_vars."""
@ -2102,11 +2116,7 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
The object as a dictionary.
"""
if include_computed:
# Apply dirty variables down into substates to allow never-cached ComputedVar to
# trigger recalculation of dependent vars
self.dirty_vars.update(self._always_dirty_computed_vars)
self._mark_dirty()
self._mark_dirty_computed_vars()
base_vars = {
prop_name: self.get_value(prop_name) for prop_name in self.base_vars
}
@ -3339,6 +3349,79 @@ class StateManagerRedis(StateManager):
)
return parent_state
async def _populate_parent_states(
self, calling_state: BaseState, target_state_cls: Type[BaseState]
):
"""Populate substates in the tree between the target_state_cls and common ancestor of calling_state.
Args:
calling_state: The substate instance requesting subtree population.
target_state_cls: The class of the state to populate parent states for.
Returns:
The parent state instance of target_state_cls.
"""
# Find the missing parent states up to the common ancestor.
(
common_ancestor_name,
missing_parent_states,
) = calling_state._determine_missing_parent_states(target_state_cls)
# Fetch all missing parent states and link them up to the common ancestor.
parent_states_tuple = calling_state._get_parent_states()
root_state = parent_states_tuple[-1][1]
parent_states_by_name = dict(parent_states_tuple)
parent_state = parent_states_by_name[common_ancestor_name]
for parent_state_name in missing_parent_states:
try:
parent_state = root_state.get_substate(parent_state_name.split("."))
# The requested state is already cached, do NOT fetch it again.
continue
except ValueError:
# The requested state is missing, fetch from redis.
pass
parent_state = await self.get_state(
token=_substate_key(
calling_state.router.session.client_token, parent_state_name
),
top_level=False,
get_substates=False,
parent_state=parent_state,
)
# Return the direct parent of target_state_cls for subsequent linking.
return parent_state
async def _link_arbitrary_state(
self, calling_state: BaseState, state_cls: Type[T_STATE]
) -> T_STATE:
"""Get a state instance from redis.
Args:
calling_state: The state instance requesting the newly linked instance of state_cls.
state_cls: The class of the state to link into the tree.
Returns:
The instance of state_cls associated with calling_state's client_token.
Raises:
StateMismatchError: If the state instance is not of the expected type.
"""
# Fetch all missing parent states from redis.
parent_state_of_state_cls = await self._populate_parent_states(
calling_state, state_cls
)
# Then get the target state and all its substates.
state_in_redis = await self.get_state(
token=_substate_key(calling_state.router.session.client_token, state_cls),
top_level=False,
get_substates=True,
parent_state=parent_state_of_state_cls,
)
return state_in_redis
async def _populate_substates(
self,
token: str,
@ -3357,30 +3440,40 @@ class StateManagerRedis(StateManager):
"""
client_token, _ = _split_substate_key(token)
# Only _potentially_dirty_substates need to be fetched to recalc computed vars.
fetch_substates = state._get_potentially_dirty_states()
if all_substates:
# All substates are requested.
fetch_substates = state.get_substates()
else:
# Only _potentially_dirty_substates need to be fetched to recalc computed vars.
fetch_substates = state._potentially_dirty_substates()
fetch_substates.update(state.get_substates())
tasks = {}
link_tasks = set()
# Retrieve the necessary substates from redis.
for substate_cls in fetch_substates:
if substate_cls.get_name() in state.substates:
continue
substate_name = substate_cls.get_name()
tasks[substate_name] = asyncio.create_task(
self.get_state(
token=_substate_key(client_token, substate_cls),
top_level=False,
get_substates=all_substates,
parent_state=state,
if substate_cls in state.get_substates():
tasks[substate_name] = asyncio.create_task(
self.get_state(
token=_substate_key(client_token, substate_cls),
top_level=False,
get_substates=all_substates,
parent_state=state,
)
)
)
else:
try:
state._get_root_state().get_substate(substate_name.split("."))
except ValueError:
# The requested state is missing, so fetch and link it (and its parents).
link_tasks.add(
asyncio.create_task(self._link_arbitrary_state(state, substate_cls))
)
for substate_name, substate_task in tasks.items():
state.substates[substate_name] = await substate_task
await asyncio.gather(*link_tasks)
@override
async def get_state(
@ -4153,7 +4246,7 @@ def reload_state_module(
if subclass.__module__ == module and module is not None:
state.class_subclasses.remove(subclass)
state._always_dirty_substates.discard(subclass.get_name())
state._computed_var_dependencies = defaultdict(set)
state._substate_var_dependencies = defaultdict(set)
state._potentially_dirty_substates.discard(subclass.get_name())
state._var_dependencies = {}
state._init_var_dependency_dicts()
state.get_class_substate.cache_clear()

View File

@ -1826,7 +1826,7 @@ class ComputedVar(Var[RETURN_TYPE]):
_initial_value: RETURN_TYPE | types.Unset = dataclasses.field(default=types.Unset())
# Explicit var dependencies to track
_static_deps: set[str] = dataclasses.field(default_factory=set)
_static_deps: dict[str, set[str]] = dataclasses.field(default_factory=dict)
# Whether var dependencies should be auto-determined
_auto_deps: bool = dataclasses.field(default=True)
@ -1901,21 +1901,40 @@ class ComputedVar(Var[RETURN_TYPE]):
object.__setattr__(self, "_update_interval", interval)
if deps is None:
deps = []
else:
_static_deps = {}
if isinstance(deps, dict):
# Assume a dict is coming from _replace, so no special processing.
_static_deps = deps
elif deps is not None:
for dep in deps:
if isinstance(dep, Var):
continue
if isinstance(dep, str) and dep != "":
continue
raise TypeError(
"ComputedVar dependencies must be Var instances or var names (non-empty strings)."
)
state_name = (
all_var_data.state
if (all_var_data := dep._get_all_var_data())
and all_var_data.state
else None
)
var_name = (
dep._js_expr[len(formatted_state_prefix) :]
if state_name
and (
formatted_state_prefix := format_state_name(state_name)
+ "."
)
and dep._js_expr.startswith(formatted_state_prefix)
else dep._js_expr
)
_static_deps.setdefault(state_name, set()).add(var_name)
elif isinstance(dep, str) and dep != "":
_static_deps.setdefault(None, set()).add(dep)
else:
raise TypeError(
"ComputedVar dependencies must be Var instances or var names (non-empty strings)."
)
object.__setattr__(
self,
"_static_deps",
{dep._js_expr if isinstance(dep, Var) else dep for dep in deps},
_static_deps,
)
object.__setattr__(self, "_auto_deps", auto_deps)
@ -2081,6 +2100,11 @@ class ComputedVar(Var[RETURN_TYPE]):
setattr(instance, self._last_updated_attr, datetime.datetime.now())
value = getattr(instance, self._cache_attr)
self._check_deprecated_return_type(instance, value)
return value
def _check_deprecated_return_type(self, instance, value) -> None:
if not _isinstance(value, self._var_type):
console.deprecate(
"mismatched-computed-var-return",
@ -2090,41 +2114,49 @@ class ComputedVar(Var[RETURN_TYPE]):
"0.7.0",
)
return value
def _deps(
self,
objclass: Type,
objclass: BaseState,
obj: FunctionType | CodeType | None = None,
self_name: Optional[str] = None,
) -> set[str]:
self_names: Optional[dict[str, str]] = None,
) -> dict[str, set[str]]:
"""Determine var dependencies of this ComputedVar.
Save references to attributes accessed on "self". Recursively called
when the function makes a method call on "self" or define comprehensions
or nested functions that may reference "self".
Save references to attributes accessed on "self" or other fetched states.
Recursively called when the function makes a method call on "self" or
define comprehensions or nested functions that may reference "self".
Args:
objclass: the class obj this ComputedVar is attached to.
obj: the object to disassemble (defaults to the fget function).
self_name: if specified, look for this name in LOAD_FAST and LOAD_DEREF instructions.
self_names: if specified, look for these names in LOAD_FAST and LOAD_DEREF instructions.
Returns:
A set of variable names accessed by the given obj.
A dictionary mapping state names to the set of variable names
accessed by the given obj.
Raises:
VarValueError: if the function references the get_state, parent_state, or substates attributes
(cannot track deps in a related state, only implicitly via parent state).
"""
from reflex.state import BaseState
d = {}
if self._static_deps:
d.update(self._static_deps)
# None is a placeholder for the current state class.
if None in d:
d[objclass.get_full_name()] = d.pop(None)
if not self._auto_deps:
return self._static_deps
d = self._static_deps.copy()
return d
if obj is None:
fget = self._fget
if fget is not None:
obj = cast(FunctionType, fget)
else:
return set()
return d
with contextlib.suppress(AttributeError):
# unbox functools.partial
obj = cast(FunctionType, obj.func) # type: ignore
@ -2132,76 +2164,150 @@ class ComputedVar(Var[RETURN_TYPE]):
# unbox EventHandler
obj = cast(FunctionType, obj.fn) # type: ignore
if self_name is None and isinstance(obj, FunctionType):
if self_names is None and isinstance(obj, FunctionType):
try:
# the first argument to the function is the name of "self" arg
self_name = obj.__code__.co_varnames[0]
self_names = {obj.__code__.co_varnames[0]: objclass.get_full_name()}
except (AttributeError, IndexError):
self_name = None
if self_name is None:
self_names = None
if self_names is None:
# cannot reference attributes on self if method takes no args
return set()
return d
invalid_names = ["get_state", "parent_state", "substates", "get_substate"]
self_is_top_of_stack = False
invalid_names = ["parent_state", "substates", "get_substate"]
self_on_top_of_stack = None
getting_state = False
getting_var = False
for instruction in dis.get_instructions(obj):
if getting_state:
if instruction.opname == "LOAD_FAST":
raise VarValueError(
f"Dependency detection cannot identify get_state class from local var {instruction.argval}."
)
if instruction.opname == "LOAD_GLOBAL":
# Special case: referencing state class from global scope.
getting_state = obj.__globals__.get(instruction.argval)
elif instruction.opname == "LOAD_DEREF":
# Special case: referencing state class from closure.
closure = dict(zip(obj.__code__.co_freevars, obj.__closure__))
try:
getting_state = closure[instruction.argval].cell_contents
except ValueError as ve:
raise VarValueError(
f"Cached var {self!s} cannot access arbitrary state `{instruction.argval}`, is it defined yet?."
) from ve
elif instruction.opname == "STORE_FAST":
# Storing the result of get_state in a local variable.
if not isinstance(getting_state, type) or not issubclass(
getting_state, BaseState
):
raise VarValueError(
f"Cached var {self!s} cannot determine dependencies in fetched state `{instruction.argval}`."
)
self_names[instruction.argval] = getting_state.get_full_name()
getting_state = False
continue # nothing else happens until we have identified the local var
if getting_var:
if instruction.opname == "CALL":
# get the original source code and eval it
start_line = getting_var[0].positions.lineno
start_column = getting_var[0].positions.col_offset
end_line = getting_var[-1].positions.end_lineno
end_column = getting_var[-1].positions.end_col_offset
source = inspect.getsource(inspect.getmodule(obj)).splitlines(True)[start_line - 1: end_line]
if len(source) > 1:
snipped_source = "".join(
[
source[0][start_column:],
source[1:-2] if len(source) > 2 else "",
source[-1][:end_column]
]
)
else:
snipped_source = source[0][start_column:end_column]
the_var = eval(f"({snipped_source})", obj.__globals__)
print(the_var)
# code = source[start_line - 1]
# bytecode = bytearray((dis.opmap["RESUME"], 0))
# for ins in getting_var:
# bytecode.append(ins.opcode)
# bytecode.append(ins.arg or 0 & 0xFF)
# bytecode.extend((dis.opmap["RETURN_VALUE"], 0))
# bc = dis.Bytecode(obj)
# code = bc.codeobj.replace(co_code=bytes(bytecode), co_argcount=0, co_nlocals=0, co_varnames=())
# breakpoint()
getting_var = False
elif isinstance(getting_var, list):
getting_var.append(instruction)
else:
getting_var = [instruction]
continue
if (
instruction.opname in ("LOAD_FAST", "LOAD_DEREF")
and instruction.argval == self_name
and instruction.argval in self_names
):
# bytecode loaded the class instance to the top of stack, next load instruction
# is referencing an attribute on self
self_is_top_of_stack = True
self_on_top_of_stack = self_names[instruction.argval]
continue
if self_is_top_of_stack and instruction.opname in (
if self_on_top_of_stack and instruction.opname in (
"LOAD_ATTR",
"LOAD_METHOD",
):
try:
ref_obj = getattr(objclass, instruction.argval)
except Exception:
ref_obj = None
if instruction.argval in invalid_names:
raise VarValueError(
f"Cached var {self!s} cannot access arbitrary state via `{instruction.argval}`."
)
if instruction.argval == "get_state":
# Special case: arbitrary state access requested.
getting_state = True
continue
if instruction.argval == "get_var_value":
# Special case: arbitrary var access requested.
getting_var = True
continue
print(f"{self_on_top_of_stack=}")
target_state = objclass.get_root_state().get_class_substate(
self_on_top_of_stack
)
try:
ref_obj = getattr(target_state, instruction.argval)
except Exception:
ref_obj = None
if callable(ref_obj):
# recurse into callable attributes
d.update(
self._deps(
objclass=objclass,
obj=ref_obj,
)
)
for state_name, dep_name in self._deps(
objclass=target_state,
obj=ref_obj,
).items():
d.setdefault(state_name, set()).update(dep_name)
# recurse into property fget functions
elif isinstance(ref_obj, property) and not isinstance(
ref_obj, ComputedVar
):
d.update(
self._deps(
objclass=objclass,
obj=ref_obj.fget, # type: ignore
)
)
for state_name, dep_name in self._deps(
objclass=target_state,
obj=ref_obj.fget, # type: ignore
).items():
d.setdefault(state_name, set()).update(dep_name)
elif (
instruction.argval in objclass.backend_vars
or instruction.argval in objclass.vars
instruction.argval in target_state.backend_vars
or instruction.argval in target_state.vars
):
# var access
d.add(instruction.argval)
d.setdefault(self_on_top_of_stack, set()).add(instruction.argval)
elif instruction.opname == "LOAD_CONST" and isinstance(
instruction.argval, CodeType
):
# recurse into nested functions / comprehensions, which can reference
# instance attributes from the outer scope
d.update(
self._deps(
objclass=objclass,
obj=instruction.argval,
self_name=self_name,
)
)
self_is_top_of_stack = False
for state_name, dep_name in self._deps(
objclass=objclass,
obj=instruction.argval,
self_names=self_names,
).items():
d.setdefault(state_name, set()).update(dep_name)
self_on_top_of_stack = None
return d
def mark_dirty(self, instance) -> None:
@ -2249,6 +2355,60 @@ class DynamicRouteVar(ComputedVar[Union[str, List[str]]]):
pass
@dataclasses.dataclass(
eq=False,
frozen=True,
init=False,
**{"slots": True} if sys.version_info >= (3, 10) else {},
)
class AsyncComputedVar(ComputedVar[RETURN_TYPE]):
"""A computed var that wraps a coroutinefunction."""
_fget: Callable[[BaseState], RETURN_TYPE] = dataclasses.field(
default_factory=lambda: lambda _: None
) # type: ignore
def __get__(self, instance: BaseState | None, owner):
"""Get the ComputedVar value.
If the value is already cached on the instance, return the cached value.
Args:
instance: the instance of the class accessing this computed var.
owner: the class that this descriptor is attached to.
Returns:
The value of the var for the given instance.
"""
if instance is None:
return super(AsyncComputedVar, self).__get__(instance, owner)
if not self._cache:
async def _awaitable_result():
value = await self.fget(instance)
self._check_deprecated_return_type(instance, value)
return _awaitable_result()
else:
# handle caching
async def _awaitable_result():
if not hasattr(instance, self._cache_attr) or self.needs_update(
instance
):
# Set cache attr on state instance.
setattr(instance, self._cache_attr, await self.fget(instance))
# Ensure the computed var gets serialized to redis.
instance._was_touched = True
# Set the last updated timestamp on the state instance.
setattr(instance, self._last_updated_attr, datetime.datetime.now())
value = getattr(instance, self._cache_attr)
self._check_deprecated_return_type(instance, value)
return value
return _awaitable_result()
if TYPE_CHECKING:
BASE_STATE = TypeVar("BASE_STATE", bound=BaseState)
@ -2315,10 +2475,27 @@ def computed_var(
raise VarDependencyError("Cannot track dependencies without caching.")
if fget is not None:
return ComputedVar(fget, cache=cache)
if inspect.iscoroutinefunction(fget):
computed_var_cls = AsyncComputedVar
else:
computed_var_cls = ComputedVar
return computed_var_cls(
fget,
initial_value=initial_value,
cache=cache,
deps=deps,
auto_deps=auto_deps,
interval=interval,
backend=backend,
**kwargs,
)
def wrapper(fget: Callable[[BASE_STATE], Any]) -> ComputedVar:
return ComputedVar(
if inspect.iscoroutinefunction(fget):
computed_var_cls = AsyncComputedVar
else:
computed_var_cls = ComputedVar
return computed_var_cls(
fget,
initial_value=initial_value,
cache=cache,

View File

@ -277,9 +277,9 @@ def test_add_page_set_route_dynamic(index_page, windows_platform: bool):
assert app.pages.keys() == {"test/[dynamic]"}
assert "dynamic" in app.state.computed_vars
assert app.state.computed_vars["dynamic"]._deps(objclass=EmptyState) == {
constants.ROUTER
EmptyState.get_full_name(): {constants.ROUTER},
}
assert constants.ROUTER in app.state()._computed_var_dependencies
assert constants.ROUTER in app.state()._var_dependencies
def test_add_page_set_route_nested(app: App, index_page, windows_platform: bool):
@ -997,9 +997,9 @@ async def test_dynamic_route_var_route_change_completed_on_load(
assert arg_name in app.state.vars
assert arg_name in app.state.computed_vars
assert app.state.computed_vars[arg_name]._deps(objclass=DynamicState) == {
constants.ROUTER
DynamicState.get_full_name(): {constants.ROUTER},
}
assert constants.ROUTER in app.state()._computed_var_dependencies
assert constants.ROUTER in app.state()._var_dependencies
substate_token = _substate_key(token, DynamicState)
sid = "mock_sid"
@ -1557,6 +1557,16 @@ def test_app_with_valid_var_dependencies(compilable_app: tuple[App, Path]):
def bar(self) -> str:
return "bar"
class Child1(ValidDepState):
@computed_var(deps=["base", ValidDepState.bar])
def other(self) -> str:
return "other"
class Child2(ValidDepState):
@computed_var(deps=["base", Child1.other])
def other(self) -> str:
return "other"
app.state = ValidDepState
app._compile()

View File

@ -1170,13 +1170,11 @@ def test_conditional_computed_vars():
ms = MainState()
# Initially there are no dirty computed vars.
assert ms._dirty_computed_vars(from_vars={"flag"}) == {"rendered_var"}
assert ms._dirty_computed_vars(from_vars={"t2"}) == {"rendered_var"}
assert ms._dirty_computed_vars(from_vars={"t1"}) == {"rendered_var"}
assert ms._dirty_computed_vars(from_vars={"flag"}) == {(MainState.get_full_name(), "rendered_var")}
assert ms._dirty_computed_vars(from_vars={"t2"}) == {(MainState.get_full_name(), "rendered_var")}
assert ms._dirty_computed_vars(from_vars={"t1"}) == {(MainState.get_full_name(), "rendered_var")}
assert ms.computed_vars["rendered_var"]._deps(objclass=MainState) == {
"flag",
"t1",
"t2",
MainState.get_full_name(): {"flag", "t1", "t2"}
}
@ -1371,7 +1369,7 @@ def test_cached_var_depends_on_event_handler(use_partial: bool):
assert isinstance(HandlerState.handler, EventHandler)
s = HandlerState()
assert "cached_x_side_effect" in s._computed_var_dependencies["x"]
assert (HandlerState.get_full_name(), "cached_x_side_effect") in s._var_dependencies["x"]
assert s.cached_x_side_effect == 1
assert s.x == 43
s.handler()
@ -1461,15 +1459,15 @@ def test_computed_var_dependencies():
return [z in self._z for z in range(5)]
cs = ComputedState()
assert cs._computed_var_dependencies["v"] == {
"comp_v",
"comp_v_backend",
"comp_v_via_property",
assert cs._var_dependencies["v"] == {
(ComputedState.get_full_name(), "comp_v"),
(ComputedState.get_full_name(), "comp_v_backend"),
(ComputedState.get_full_name(), "comp_v_via_property"),
}
assert cs._computed_var_dependencies["w"] == {"comp_w"}
assert cs._computed_var_dependencies["x"] == {"comp_x"}
assert cs._computed_var_dependencies["y"] == {"comp_y"}
assert cs._computed_var_dependencies["_z"] == {"comp_z"}
assert cs._var_dependencies["w"] == {(ComputedState.get_full_name(), "comp_w")}
assert cs._var_dependencies["x"] == {(ComputedState.get_full_name(), "comp_x")}
assert cs._var_dependencies["y"] == {(ComputedState.get_full_name(), "comp_y")}
assert cs._var_dependencies["_z"] == {(ComputedState.get_full_name(), "comp_z")}
def test_backend_method():
@ -3182,6 +3180,7 @@ async def test_get_state_from_sibling_not_cached(mock_app: rx.App, token: str):
RxState = State
@pytest.mark.skip(reason="This test is maybe not relevant anymore.")
def test_potentially_dirty_substates():
"""Test that potentially_dirty_substates returns the correct substates.
@ -3203,7 +3202,8 @@ def test_potentially_dirty_substates():
assert C1._potentially_dirty_substates() == set()
def test_router_var_dep() -> None:
@pytest.mark.asyncio
async def test_router_var_dep() -> None:
"""Test that router var dependencies are correctly tracked."""
class RouterVarParentState(State):
@ -3221,13 +3221,9 @@ def test_router_var_dep() -> None:
foo = RouterVarDepState.computed_vars["foo"]
State._init_var_dependency_dicts()
assert foo._deps(objclass=RouterVarDepState) == {"router"}
assert RouterVarParentState._potentially_dirty_substates() == {RouterVarDepState}
assert RouterVarParentState._substate_var_dependencies == {
"router": {RouterVarDepState.get_name()}
}
assert RouterVarDepState._computed_var_dependencies == {
"router": {"foo"},
assert foo._deps(objclass=RouterVarDepState) == {RouterVarDepState.get_full_name(): {"router"}}
assert State._var_dependencies == {
"router": {(RouterVarDepState.get_full_name(), "foo")}
}
rx_state = State()
@ -3240,11 +3236,15 @@ def test_router_var_dep() -> None:
state.parent_state = parent_state
parent_state.substates = {RouterVarDepState.get_name(): state}
populated_substate_classes = await rx_state._recursively_populate_dependent_substates()
assert populated_substate_classes == {State, RouterVarDepState}
assert state.dirty_vars == set()
# Reassign router var
state.router = state.router
assert state.dirty_vars == {"foo", "router"}
assert rx_state.dirty_vars == {"router"}
assert state.dirty_vars == {"foo"}
assert parent_state.dirty_substates == {RouterVarDepState.get_name()}
@ -3803,3 +3803,74 @@ async def test_get_var_value(state_manager: StateManager, substate_token: str):
# Generic Var with no state
with pytest.raises(UnretrievableVarValueError):
await state.get_var_value(rx.Var("undefined"))
@pytest.mark.asyncio
async def test_async_computed_var_get_state(mock_app: rx.App, token: str):
"""A test where an async computed var depends on a var in another state.
Args:
mock_app: An app that will be returned by `get_app()`
token: A token.
"""
class Parent(BaseState):
"""A root state like rx.State."""
parent_var: int = 0
class Child2(Parent):
"""An unconnected child state."""
pass
class Child3(Parent):
"""A child state with a computed var causing it to be pre-fetched.
If child3_var gets set to a value, and `get_state` erroneously
re-fetches it from redis, the value will be lost.
"""
child3_var: int = 0
@rx.var(cache=True)
def v(self):
return self.child3_var
class Child(Parent):
"""A state simulating UpdateVarsInternalState."""
@rx.var(cache=True)
async def v(self):
p = await self.get_state(Parent)
child3 = await self.get_state(Child3)
return child3.child3_var + p.parent_var
mock_app.state_manager.state = mock_app.state = Parent
# Get the top level state via unconnected sibling.
root = await mock_app.state_manager.get_state(_substate_key(token, Child))
# Set value in parent_var to assert it does not get refetched later.
root.parent_var = 1
if isinstance(mock_app.state_manager, StateManagerRedis):
# When redis is used, only states with uncached computed vars are pre-fetched.
assert Child2.get_name() not in root.substates
assert Child3.get_name() not in root.substates
# Get the unconnected sibling state, which will be used to `get_state` other instances.
child = root.get_substate(Child.get_full_name().split("."))
# Get an uncached child state.
child2 = await child.get_state(Child2)
assert child2.parent_var == 1
# Set value on already-cached Child3 state (prefetched because it has a Computed Var).
child3 = await child.get_state(Child3)
child3.child3_var = 1
assert await child.v == 2
assert await child.v == 2
root.parent_var = 2
assert await child.v == 3

View File

@ -15,6 +15,7 @@ from reflex.utils.exceptions import PrimitiveUnserializableToJSON
from reflex.utils.imports import ImportVar
from reflex.vars import VarData
from reflex.vars.base import (
AsyncComputedVar,
ComputedVar,
LiteralVar,
Var,
@ -1808,9 +1809,9 @@ def cv_fget(state: BaseState) -> int:
@pytest.mark.parametrize(
"deps,expected",
[
(["a"], {"a"}),
(["b"], {"b"}),
([ComputedVar(fget=cv_fget)], {"cv_fget"}),
(["a"], {None: {"a"}}),
(["b"], {None: {"b"}}),
([ComputedVar(fget=cv_fget)], {None: {"cv_fget"}}),
],
)
def test_computed_var_deps(deps: List[Union[str, Var]], expected: Set[str]):
@ -1856,3 +1857,25 @@ def test_to_string_operation():
single_var = Var.create(Email())
assert single_var._var_type == Email
@pytest.mark.asyncio
async def test_async_computed_var():
side_effect_counter = 0
class AsyncComputedVarState(BaseState):
v: int = 1
@computed_var(cache=True)
async def async_computed_var(self) -> int:
nonlocal side_effect_counter
side_effect_counter += 1
return self.v + 1
my_state = AsyncComputedVarState()
assert await my_state.async_computed_var == 2
assert await my_state.async_computed_var == 2
my_state.v = 2
assert await my_state.async_computed_var == 3
assert await my_state.async_computed_var == 3
assert side_effect_counter == 2