WiP
This commit is contained in:
parent
048416163d
commit
3d50c1b623
@ -833,11 +833,17 @@ class App(MiddlewareMixin, LifespanMixin):
|
||||
if not var._cache:
|
||||
continue
|
||||
deps = var._deps(objclass=state)
|
||||
for dep in deps:
|
||||
if dep not in state.vars and dep not in state.backend_vars:
|
||||
raise exceptions.VarDependencyError(
|
||||
f"ComputedVar {var._js_expr} on state {state.__name__} has an invalid dependency {dep}"
|
||||
)
|
||||
for state_name, dep_set in deps.items():
|
||||
state_cls = (
|
||||
state.get_root_state().get_class_substate(state_name)
|
||||
if state_name != state.get_full_name()
|
||||
else state
|
||||
)
|
||||
for dep in dep_set:
|
||||
if dep not in state_cls.vars and dep not in state_cls.backend_vars:
|
||||
raise exceptions.VarDependencyError(
|
||||
f"ComputedVar {var._js_expr} on state {state.__name__} has an invalid dependency {state_name}.{dep}"
|
||||
)
|
||||
|
||||
for substate in state.class_subclasses:
|
||||
self._validate_var_dependencies(substate)
|
||||
|
@ -2,6 +2,7 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from pathlib import Path
|
||||
from typing import Any, Callable, Dict, Optional, Type, Union
|
||||
from urllib.parse import urlparse
|
||||
@ -29,7 +30,7 @@ from reflex.components.base import (
|
||||
)
|
||||
from reflex.components.component import Component, ComponentStyle, CustomComponent
|
||||
from reflex.istate.storage import Cookie, LocalStorage, SessionStorage
|
||||
from reflex.state import BaseState
|
||||
from reflex.state import BaseState, _resolve_delta
|
||||
from reflex.style import Style
|
||||
from reflex.utils import console, format, imports, path_ops
|
||||
from reflex.utils.imports import ImportVar, ParsedImportDict
|
||||
@ -169,7 +170,7 @@ def compile_state(state: Type[BaseState]) -> dict:
|
||||
initial_state = state(_reflex_internal_init=True).dict(
|
||||
initial=True, include_computed=False
|
||||
)
|
||||
return initial_state
|
||||
return asyncio.run(_resolve_delta(initial_state))
|
||||
|
||||
|
||||
def _compile_client_storage_field(
|
||||
|
@ -8,7 +8,7 @@ from typing import TYPE_CHECKING, Optional
|
||||
from reflex import constants
|
||||
from reflex.event import Event, get_hydrate_event
|
||||
from reflex.middleware.middleware import Middleware
|
||||
from reflex.state import BaseState, StateUpdate
|
||||
from reflex.state import BaseState, StateUpdate, _resolve_delta
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from reflex.app import App
|
||||
@ -42,7 +42,7 @@ class HydrateMiddleware(Middleware):
|
||||
setattr(state, constants.CompileVars.IS_HYDRATED, False)
|
||||
|
||||
# Get the initial state.
|
||||
delta = state.dict()
|
||||
delta = await _resolve_delta(state.dict())
|
||||
# since a full dict was captured, clean any dirtiness
|
||||
state._clean()
|
||||
|
||||
|
433
reflex/state.py
433
reflex/state.py
@ -328,6 +328,25 @@ def get_var_for_field(cls: Type[BaseState], f: ModelField):
|
||||
)
|
||||
|
||||
|
||||
async def _resolve_delta(delta: Delta) -> Delta:
|
||||
"""Await all coroutines in the delta.
|
||||
|
||||
Args:
|
||||
delta: The delta to process.
|
||||
|
||||
Returns:
|
||||
The same delta dict with all coroutines resolved to their return value.
|
||||
"""
|
||||
tasks = {}
|
||||
for state_name, state_delta in delta.items():
|
||||
for var_name, value in state_delta.items():
|
||||
if asyncio.iscoroutine(value):
|
||||
tasks[state_name, var_name] = asyncio.create_task(value)
|
||||
for (state_name, var_name), task in tasks.items():
|
||||
delta[state_name][var_name] = await task
|
||||
return delta
|
||||
|
||||
|
||||
class BaseState(Base, ABC, extra=pydantic.Extra.allow):
|
||||
"""The state of the app."""
|
||||
|
||||
@ -355,11 +374,8 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
|
||||
# A set of subclassses of this class.
|
||||
class_subclasses: ClassVar[Set[Type[BaseState]]] = set()
|
||||
|
||||
# Mapping of var name to set of computed variables that depend on it
|
||||
_computed_var_dependencies: ClassVar[Dict[str, Set[str]]] = {}
|
||||
|
||||
# Mapping of var name to set of substates that depend on it
|
||||
_substate_var_dependencies: ClassVar[Dict[str, Set[str]]] = {}
|
||||
# Mapping of var name to set of (state_full_name, var_name) that depend on it.
|
||||
_var_dependencies: ClassVar[Dict[str, Set[Tuple[str, str]]]] = {}
|
||||
|
||||
# Set of vars which always need to be recomputed
|
||||
_always_dirty_computed_vars: ClassVar[Set[str]] = set()
|
||||
@ -367,6 +383,9 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
|
||||
# Set of substates which always need to be recomputed
|
||||
_always_dirty_substates: ClassVar[Set[str]] = set()
|
||||
|
||||
# Set of states which might need to be recomputed if vars in this state change.
|
||||
_potentially_dirty_states: ClassVar[Set[str]] = set()
|
||||
|
||||
# The parent state.
|
||||
parent_state: Optional[BaseState] = None
|
||||
|
||||
@ -518,6 +537,7 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
|
||||
|
||||
# Reset dirty substate tracking for this class.
|
||||
cls._always_dirty_substates = set()
|
||||
cls._potentially_dirty_states = set()
|
||||
|
||||
# Get the parent vars.
|
||||
parent_state = cls.get_parent_state()
|
||||
@ -621,8 +641,7 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
|
||||
setattr(cls, name, handler)
|
||||
|
||||
# Initialize per-class var dependency tracking.
|
||||
cls._computed_var_dependencies = defaultdict(set)
|
||||
cls._substate_var_dependencies = defaultdict(set)
|
||||
cls._var_dependencies = {}
|
||||
cls._init_var_dependency_dicts()
|
||||
|
||||
@staticmethod
|
||||
@ -767,26 +786,25 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
|
||||
Additional updates tracking dicts for vars and substates that always
|
||||
need to be recomputed.
|
||||
"""
|
||||
inherited_vars = set(cls.inherited_vars).union(
|
||||
set(cls.inherited_backend_vars),
|
||||
)
|
||||
for cvar_name, cvar in cls.computed_vars.items():
|
||||
# Add the dependencies.
|
||||
for var in cvar._deps(objclass=cls):
|
||||
cls._computed_var_dependencies[var].add(cvar_name)
|
||||
if var in inherited_vars:
|
||||
# track that this substate depends on its parent for this var
|
||||
state_name = cls.get_name()
|
||||
parent_state = cls.get_parent_state()
|
||||
while parent_state is not None and var in {
|
||||
**parent_state.vars,
|
||||
**parent_state.backend_vars,
|
||||
if not cvar._cache:
|
||||
# Do not perform dep calculation when cache=False (these are always dirty).
|
||||
continue
|
||||
for state_name, dvar_set in cvar._deps(objclass=cls).items():
|
||||
state_cls = cls.get_root_state().get_class_substate(state_name)
|
||||
for dvar in dvar_set:
|
||||
defining_state_cls = state_cls
|
||||
while dvar in {
|
||||
*defining_state_cls.inherited_vars,
|
||||
*defining_state_cls.inherited_backend_vars,
|
||||
}:
|
||||
parent_state._substate_var_dependencies[var].add(state_name)
|
||||
state_name, parent_state = (
|
||||
parent_state.get_name(),
|
||||
parent_state.get_parent_state(),
|
||||
)
|
||||
defining_state_cls = defining_state_cls.get_parent_state()
|
||||
defining_state_cls._var_dependencies.setdefault(dvar, set()).add(
|
||||
(cls.get_full_name(), cvar_name)
|
||||
)
|
||||
defining_state_cls._potentially_dirty_states.add(
|
||||
cls.get_full_name()
|
||||
)
|
||||
|
||||
# ComputedVar with cache=False always need to be recomputed
|
||||
cls._always_dirty_computed_vars = {
|
||||
@ -901,6 +919,17 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
|
||||
raise ValueError(f"Only one parent state is allowed {parent_states}.")
|
||||
return parent_states[0] if len(parent_states) == 1 else None # type: ignore
|
||||
|
||||
@classmethod
|
||||
@functools.lru_cache()
|
||||
def get_root_state(cls) -> Type[BaseState]:
|
||||
"""Get the root state.
|
||||
|
||||
Returns:
|
||||
The root state.
|
||||
"""
|
||||
parent_state = cls.get_parent_state()
|
||||
return cls if parent_state is None else parent_state.get_root_state()
|
||||
|
||||
@classmethod
|
||||
def get_substates(cls) -> set[Type[BaseState]]:
|
||||
"""Get the substates of the state.
|
||||
@ -1353,7 +1382,7 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
|
||||
super().__setattr__(name, value)
|
||||
|
||||
# Add the var to the dirty list.
|
||||
if name in self.vars or name in self._computed_var_dependencies:
|
||||
if name in self.base_vars:
|
||||
self.dirty_vars.add(name)
|
||||
self._mark_dirty()
|
||||
|
||||
@ -1423,6 +1452,23 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
|
||||
raise ValueError(f"Invalid path: {path}")
|
||||
return self.substates[path[0]].get_substate(path[1:])
|
||||
|
||||
@classmethod
|
||||
def _get_potentially_dirty_states(cls) -> set[type[BaseState]]:
|
||||
"""Get substates which may have dirty vars due to dependencies.
|
||||
|
||||
Returns:
|
||||
The set of potentially dirty substate classes.
|
||||
"""
|
||||
return {
|
||||
cls.get_class_substate(substate_name)
|
||||
for substate_name in cls._always_dirty_substates
|
||||
}.union(
|
||||
{
|
||||
cls.get_root_state().get_class_substate(substate_name)
|
||||
for substate_name in cls._potentially_dirty_states
|
||||
}
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def _get_common_ancestor(cls, other: Type[BaseState]) -> str:
|
||||
"""Find the name of the nearest common ancestor shared by this and the other state.
|
||||
@ -1493,55 +1539,37 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
|
||||
parent_state = parent_state.parent_state
|
||||
return parent_state
|
||||
|
||||
async def _populate_parent_states(self, target_state_cls: Type[BaseState]):
|
||||
"""Populate substates in the tree between the target_state_cls and common ancestor of this state.
|
||||
async def _get_state_from_redis(self, state_cls: Type[T_STATE]) -> T_STATE:
|
||||
"""Get a state instance from redis.
|
||||
|
||||
Args:
|
||||
target_state_cls: The class of the state to populate parent states for.
|
||||
state_cls: The class of the state.
|
||||
|
||||
Returns:
|
||||
The parent state instance of target_state_cls.
|
||||
The instance of state_cls associated with this state's client_token.
|
||||
|
||||
Raises:
|
||||
RuntimeError: If redis is not used in this backend process.
|
||||
StateMismatchError: If the state instance is not of the expected type.
|
||||
"""
|
||||
# Then get the target state and all its substates.
|
||||
state_manager = get_state_manager()
|
||||
if not isinstance(state_manager, StateManagerRedis):
|
||||
raise RuntimeError(
|
||||
f"Cannot populate parent states of {target_state_cls.get_full_name()} without redis. "
|
||||
f"Requested state {state_cls.get_full_name()} is not cached and cannot be accessed without redis. "
|
||||
"(All states should already be available -- this is likely a bug).",
|
||||
)
|
||||
state_in_redis = await state_manager._link_arbitrary_state(
|
||||
self,
|
||||
state_cls,
|
||||
)
|
||||
|
||||
# Find the missing parent states up to the common ancestor.
|
||||
(
|
||||
common_ancestor_name,
|
||||
missing_parent_states,
|
||||
) = self._determine_missing_parent_states(target_state_cls)
|
||||
|
||||
# Fetch all missing parent states and link them up to the common ancestor.
|
||||
parent_states_tuple = self._get_parent_states()
|
||||
root_state = parent_states_tuple[-1][1]
|
||||
parent_states_by_name = dict(parent_states_tuple)
|
||||
parent_state = parent_states_by_name[common_ancestor_name]
|
||||
for parent_state_name in missing_parent_states:
|
||||
try:
|
||||
parent_state = root_state.get_substate(parent_state_name.split("."))
|
||||
# The requested state is already cached, do NOT fetch it again.
|
||||
continue
|
||||
except ValueError:
|
||||
# The requested state is missing, fetch from redis.
|
||||
pass
|
||||
parent_state = await state_manager.get_state(
|
||||
token=_substate_key(
|
||||
self.router.session.client_token, parent_state_name
|
||||
),
|
||||
top_level=False,
|
||||
get_substates=False,
|
||||
parent_state=parent_state,
|
||||
if not isinstance(state_in_redis, state_cls):
|
||||
raise StateMismatchError(
|
||||
f"Searched for state {state_cls.get_full_name()} but found {state_in_redis}."
|
||||
)
|
||||
|
||||
# Return the direct parent of target_state_cls for subsequent linking.
|
||||
return parent_state
|
||||
return state_in_redis
|
||||
|
||||
def _get_state_from_cache(self, state_cls: Type[T_STATE]) -> T_STATE:
|
||||
"""Get a state instance from the cache.
|
||||
@ -1563,44 +1591,6 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
|
||||
)
|
||||
return substate
|
||||
|
||||
async def _get_state_from_redis(self, state_cls: Type[T_STATE]) -> T_STATE:
|
||||
"""Get a state instance from redis.
|
||||
|
||||
Args:
|
||||
state_cls: The class of the state.
|
||||
|
||||
Returns:
|
||||
The instance of state_cls associated with this state's client_token.
|
||||
|
||||
Raises:
|
||||
RuntimeError: If redis is not used in this backend process.
|
||||
StateMismatchError: If the state instance is not of the expected type.
|
||||
"""
|
||||
# Fetch all missing parent states from redis.
|
||||
parent_state_of_state_cls = await self._populate_parent_states(state_cls)
|
||||
|
||||
# Then get the target state and all its substates.
|
||||
state_manager = get_state_manager()
|
||||
if not isinstance(state_manager, StateManagerRedis):
|
||||
raise RuntimeError(
|
||||
f"Requested state {state_cls.get_full_name()} is not cached and cannot be accessed without redis. "
|
||||
"(All states should already be available -- this is likely a bug).",
|
||||
)
|
||||
|
||||
state_in_redis = await state_manager.get_state(
|
||||
token=_substate_key(self.router.session.client_token, state_cls),
|
||||
top_level=False,
|
||||
get_substates=True,
|
||||
parent_state=parent_state_of_state_cls,
|
||||
)
|
||||
|
||||
if not isinstance(state_in_redis, state_cls):
|
||||
raise StateMismatchError(
|
||||
f"Searched for state {state_cls.get_full_name()} but found {state_in_redis}."
|
||||
)
|
||||
|
||||
return state_in_redis
|
||||
|
||||
async def get_state(self, state_cls: Type[T_STATE]) -> T_STATE:
|
||||
"""Get an instance of the state associated with this token.
|
||||
|
||||
@ -1737,7 +1727,7 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
|
||||
f"Your handler {handler.fn.__qualname__} must only return/yield: None, Events or other EventHandlers referenced by their class (not using `self`)"
|
||||
)
|
||||
|
||||
def _as_state_update(
|
||||
async def _as_state_update(
|
||||
self,
|
||||
handler: EventHandler,
|
||||
events: EventSpec | list[EventSpec] | None,
|
||||
@ -1765,7 +1755,7 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
|
||||
|
||||
try:
|
||||
# Get the delta after processing the event.
|
||||
delta = state.get_delta()
|
||||
delta = await _resolve_delta(state.get_delta())
|
||||
state._clean()
|
||||
|
||||
return StateUpdate(
|
||||
@ -1865,24 +1855,28 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
|
||||
# Handle async generators.
|
||||
if inspect.isasyncgen(events):
|
||||
async for event in events:
|
||||
yield state._as_state_update(handler, event, final=False)
|
||||
yield state._as_state_update(handler, events=None, final=True)
|
||||
yield await state._as_state_update(handler, event, final=False)
|
||||
yield await state._as_state_update(handler, events=None, final=True)
|
||||
|
||||
# Handle regular generators.
|
||||
elif inspect.isgenerator(events):
|
||||
try:
|
||||
while True:
|
||||
yield state._as_state_update(handler, next(events), final=False)
|
||||
yield await state._as_state_update(
|
||||
handler, next(events), final=False
|
||||
)
|
||||
except StopIteration as si:
|
||||
# the "return" value of the generator is not available
|
||||
# in the loop, we must catch StopIteration to access it
|
||||
if si.value is not None:
|
||||
yield state._as_state_update(handler, si.value, final=False)
|
||||
yield state._as_state_update(handler, events=None, final=True)
|
||||
yield await state._as_state_update(
|
||||
handler, si.value, final=False
|
||||
)
|
||||
yield await state._as_state_update(handler, events=None, final=True)
|
||||
|
||||
# Handle regular event chains.
|
||||
else:
|
||||
yield state._as_state_update(handler, events, final=True)
|
||||
yield await state._as_state_update(handler, events, final=True)
|
||||
|
||||
# If an error occurs, throw a window alert.
|
||||
except Exception as ex:
|
||||
@ -1892,7 +1886,7 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
|
||||
prerequisites.get_and_validate_app().app.backend_exception_handler(ex)
|
||||
)
|
||||
|
||||
yield state._as_state_update(
|
||||
yield await state._as_state_update(
|
||||
handler,
|
||||
event_specs,
|
||||
final=True,
|
||||
@ -1900,15 +1894,28 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
|
||||
|
||||
def _mark_dirty_computed_vars(self) -> None:
|
||||
"""Mark ComputedVars that need to be recalculated based on dirty_vars."""
|
||||
# Append expired computed vars to dirty_vars to trigger recalculation
|
||||
self.dirty_vars.update(self._expired_computed_vars())
|
||||
# Append always dirty computed vars to dirty_vars to trigger recalculation
|
||||
self.dirty_vars.update(self._always_dirty_computed_vars)
|
||||
|
||||
dirty_vars = self.dirty_vars
|
||||
while dirty_vars:
|
||||
calc_vars, dirty_vars = dirty_vars, set()
|
||||
for cvar in self._dirty_computed_vars(from_vars=calc_vars):
|
||||
self.dirty_vars.add(cvar)
|
||||
for state_name, cvar in self._dirty_computed_vars(from_vars=calc_vars):
|
||||
if state_name == self.get_full_name():
|
||||
defining_state = self
|
||||
else:
|
||||
defining_state = self._get_root_state().get_substate(
|
||||
tuple(state_name.split("."))
|
||||
)
|
||||
defining_state.dirty_vars.add(cvar)
|
||||
dirty_vars.add(cvar)
|
||||
actual_var = self.computed_vars.get(cvar)
|
||||
actual_var = defining_state.computed_vars.get(cvar)
|
||||
if actual_var is not None:
|
||||
actual_var.mark_dirty(instance=self)
|
||||
actual_var.mark_dirty(instance=defining_state)
|
||||
if defining_state is not self:
|
||||
defining_state._mark_dirty()
|
||||
|
||||
def _expired_computed_vars(self) -> set[str]:
|
||||
"""Determine ComputedVars that need to be recalculated based on the expiration time.
|
||||
@ -1924,7 +1931,7 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
|
||||
|
||||
def _dirty_computed_vars(
|
||||
self, from_vars: set[str] | None = None, include_backend: bool = True
|
||||
) -> set[str]:
|
||||
) -> set[tuple[str, str]]:
|
||||
"""Determine ComputedVars that need to be recalculated based on the given vars.
|
||||
|
||||
Args:
|
||||
@ -1935,32 +1942,59 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
|
||||
Set of computed vars to include in the delta.
|
||||
"""
|
||||
return {
|
||||
cvar
|
||||
(state_name, cvar)
|
||||
for dirty_var in from_vars or self.dirty_vars
|
||||
for cvar in self._computed_var_dependencies[dirty_var]
|
||||
for state_name, cvar in self._var_dependencies.get(dirty_var, set())
|
||||
if include_backend or not self.computed_vars[cvar]._backend
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def _potentially_dirty_substates(cls) -> set[Type[BaseState]]:
|
||||
"""Determine substates which could be affected by dirty vars in this state.
|
||||
async def _recursively_populate_dependent_substates(
|
||||
self,
|
||||
seen_classes: set[type[BaseState]] | None = None,
|
||||
) -> set[type[BaseState]]:
|
||||
"""Fetch all substates that have computed var dependencies on this state.
|
||||
|
||||
Args:
|
||||
seen_classes: set of classes that have already been seen to prevent infinite recursion.
|
||||
|
||||
Returns:
|
||||
Set of State classes that may need to be fetched to recalc computed vars.
|
||||
The set of classes that were processed (mostly for testability).
|
||||
"""
|
||||
# _always_dirty_substates need to be fetched to recalc computed vars.
|
||||
fetch_substates = {
|
||||
cls.get_class_substate((cls.get_name(), *substate_name.split(".")))
|
||||
for substate_name in cls._always_dirty_substates
|
||||
}
|
||||
for dependent_substates in cls._substate_var_dependencies.values():
|
||||
fetch_substates.update(
|
||||
{
|
||||
cls.get_class_substate((cls.get_name(), *substate_name.split(".")))
|
||||
for substate_name in dependent_substates
|
||||
}
|
||||
if seen_classes is None:
|
||||
print(
|
||||
f"\n\nTop-level _recursively_populate_dependent_substates from {type(self)}:"
|
||||
)
|
||||
return fetch_substates
|
||||
seen_classes = set()
|
||||
if type(self) in seen_classes:
|
||||
return seen_classes
|
||||
seen_classes.add(type(self))
|
||||
populated_substate_instances = {}
|
||||
for substate_cls in {
|
||||
self.get_class_substate((self.get_name(), *substate_name.split(".")))
|
||||
for substate_name in self._always_dirty_substates
|
||||
}:
|
||||
# _always_dirty_substates need to be fetched to recalc computed vars.
|
||||
if substate_cls not in populated_substate_instances:
|
||||
print(f"fetching always dirty {substate_cls}")
|
||||
populated_substate_instances[substate_cls] = await self.get_state(
|
||||
substate_cls
|
||||
)
|
||||
for dep_set in self._var_dependencies.values():
|
||||
for substate_name, _ in dep_set:
|
||||
if substate_name == self.get_full_name():
|
||||
# Do NOT fetch our own state instance.
|
||||
continue
|
||||
substate_cls = self.get_root_state().get_class_substate(substate_name)
|
||||
if substate_cls not in populated_substate_instances:
|
||||
print(f"fetching dependent {substate_cls}")
|
||||
populated_substate_instances[substate_cls] = await self.get_state(
|
||||
substate_cls
|
||||
)
|
||||
for substate in populated_substate_instances.values():
|
||||
await substate._recursively_populate_dependent_substates(
|
||||
seen_classes=seen_classes,
|
||||
)
|
||||
return seen_classes
|
||||
|
||||
def get_delta(self) -> Delta:
|
||||
"""Get the delta for the state.
|
||||
@ -1970,21 +2004,15 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
|
||||
"""
|
||||
delta = {}
|
||||
|
||||
# Apply dirty variables down into substates
|
||||
self.dirty_vars.update(self._always_dirty_computed_vars)
|
||||
self._mark_dirty()
|
||||
|
||||
self._mark_dirty_computed_vars()
|
||||
frontend_computed_vars: set[str] = {
|
||||
name for name, cv in self.computed_vars.items() if not cv._backend
|
||||
}
|
||||
|
||||
# Return the dirty vars for this instance, any cached/dependent computed vars,
|
||||
# and always dirty computed vars (cache=False)
|
||||
delta_vars = (
|
||||
self.dirty_vars.intersection(self.base_vars)
|
||||
.union(self.dirty_vars.intersection(frontend_computed_vars))
|
||||
.union(self._dirty_computed_vars(include_backend=False))
|
||||
.union(self._always_dirty_computed_vars)
|
||||
delta_vars = self.dirty_vars.intersection(self.base_vars).union(
|
||||
self.dirty_vars.intersection(frontend_computed_vars)
|
||||
)
|
||||
|
||||
subdelta: Dict[str, Any] = {
|
||||
@ -2014,23 +2042,9 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
|
||||
self.parent_state.dirty_substates.add(self.get_name())
|
||||
self.parent_state._mark_dirty()
|
||||
|
||||
# Append expired computed vars to dirty_vars to trigger recalculation
|
||||
self.dirty_vars.update(self._expired_computed_vars())
|
||||
|
||||
# have to mark computed vars dirty to allow access to newly computed
|
||||
# values within the same ComputedVar function
|
||||
self._mark_dirty_computed_vars()
|
||||
self._mark_dirty_substates()
|
||||
|
||||
def _mark_dirty_substates(self):
|
||||
"""Propagate dirty var / computed var status into substates."""
|
||||
substates = self.substates
|
||||
for var in self.dirty_vars:
|
||||
for substate_name in self._substate_var_dependencies[var]:
|
||||
self.dirty_substates.add(substate_name)
|
||||
substate = substates[substate_name]
|
||||
substate.dirty_vars.add(var)
|
||||
substate._mark_dirty()
|
||||
|
||||
def _update_was_touched(self):
|
||||
"""Update the _was_touched flag based on dirty_vars."""
|
||||
@ -2102,11 +2116,7 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
|
||||
The object as a dictionary.
|
||||
"""
|
||||
if include_computed:
|
||||
# Apply dirty variables down into substates to allow never-cached ComputedVar to
|
||||
# trigger recalculation of dependent vars
|
||||
self.dirty_vars.update(self._always_dirty_computed_vars)
|
||||
self._mark_dirty()
|
||||
|
||||
self._mark_dirty_computed_vars()
|
||||
base_vars = {
|
||||
prop_name: self.get_value(prop_name) for prop_name in self.base_vars
|
||||
}
|
||||
@ -3339,6 +3349,79 @@ class StateManagerRedis(StateManager):
|
||||
)
|
||||
return parent_state
|
||||
|
||||
async def _populate_parent_states(
|
||||
self, calling_state: BaseState, target_state_cls: Type[BaseState]
|
||||
):
|
||||
"""Populate substates in the tree between the target_state_cls and common ancestor of calling_state.
|
||||
|
||||
Args:
|
||||
calling_state: The substate instance requesting subtree population.
|
||||
target_state_cls: The class of the state to populate parent states for.
|
||||
|
||||
Returns:
|
||||
The parent state instance of target_state_cls.
|
||||
"""
|
||||
# Find the missing parent states up to the common ancestor.
|
||||
(
|
||||
common_ancestor_name,
|
||||
missing_parent_states,
|
||||
) = calling_state._determine_missing_parent_states(target_state_cls)
|
||||
|
||||
# Fetch all missing parent states and link them up to the common ancestor.
|
||||
parent_states_tuple = calling_state._get_parent_states()
|
||||
root_state = parent_states_tuple[-1][1]
|
||||
parent_states_by_name = dict(parent_states_tuple)
|
||||
parent_state = parent_states_by_name[common_ancestor_name]
|
||||
for parent_state_name in missing_parent_states:
|
||||
try:
|
||||
parent_state = root_state.get_substate(parent_state_name.split("."))
|
||||
# The requested state is already cached, do NOT fetch it again.
|
||||
continue
|
||||
except ValueError:
|
||||
# The requested state is missing, fetch from redis.
|
||||
pass
|
||||
parent_state = await self.get_state(
|
||||
token=_substate_key(
|
||||
calling_state.router.session.client_token, parent_state_name
|
||||
),
|
||||
top_level=False,
|
||||
get_substates=False,
|
||||
parent_state=parent_state,
|
||||
)
|
||||
|
||||
# Return the direct parent of target_state_cls for subsequent linking.
|
||||
return parent_state
|
||||
|
||||
async def _link_arbitrary_state(
|
||||
self, calling_state: BaseState, state_cls: Type[T_STATE]
|
||||
) -> T_STATE:
|
||||
"""Get a state instance from redis.
|
||||
|
||||
Args:
|
||||
calling_state: The state instance requesting the newly linked instance of state_cls.
|
||||
state_cls: The class of the state to link into the tree.
|
||||
|
||||
Returns:
|
||||
The instance of state_cls associated with calling_state's client_token.
|
||||
|
||||
Raises:
|
||||
StateMismatchError: If the state instance is not of the expected type.
|
||||
"""
|
||||
# Fetch all missing parent states from redis.
|
||||
parent_state_of_state_cls = await self._populate_parent_states(
|
||||
calling_state, state_cls
|
||||
)
|
||||
|
||||
# Then get the target state and all its substates.
|
||||
state_in_redis = await self.get_state(
|
||||
token=_substate_key(calling_state.router.session.client_token, state_cls),
|
||||
top_level=False,
|
||||
get_substates=True,
|
||||
parent_state=parent_state_of_state_cls,
|
||||
)
|
||||
|
||||
return state_in_redis
|
||||
|
||||
async def _populate_substates(
|
||||
self,
|
||||
token: str,
|
||||
@ -3357,30 +3440,40 @@ class StateManagerRedis(StateManager):
|
||||
"""
|
||||
client_token, _ = _split_substate_key(token)
|
||||
|
||||
# Only _potentially_dirty_substates need to be fetched to recalc computed vars.
|
||||
fetch_substates = state._get_potentially_dirty_states()
|
||||
if all_substates:
|
||||
# All substates are requested.
|
||||
fetch_substates = state.get_substates()
|
||||
else:
|
||||
# Only _potentially_dirty_substates need to be fetched to recalc computed vars.
|
||||
fetch_substates = state._potentially_dirty_substates()
|
||||
fetch_substates.update(state.get_substates())
|
||||
|
||||
tasks = {}
|
||||
link_tasks = set()
|
||||
# Retrieve the necessary substates from redis.
|
||||
for substate_cls in fetch_substates:
|
||||
if substate_cls.get_name() in state.substates:
|
||||
continue
|
||||
substate_name = substate_cls.get_name()
|
||||
tasks[substate_name] = asyncio.create_task(
|
||||
self.get_state(
|
||||
token=_substate_key(client_token, substate_cls),
|
||||
top_level=False,
|
||||
get_substates=all_substates,
|
||||
parent_state=state,
|
||||
if substate_cls in state.get_substates():
|
||||
tasks[substate_name] = asyncio.create_task(
|
||||
self.get_state(
|
||||
token=_substate_key(client_token, substate_cls),
|
||||
top_level=False,
|
||||
get_substates=all_substates,
|
||||
parent_state=state,
|
||||
)
|
||||
)
|
||||
)
|
||||
else:
|
||||
try:
|
||||
state._get_root_state().get_substate(substate_name.split("."))
|
||||
except ValueError:
|
||||
# The requested state is missing, so fetch and link it (and its parents).
|
||||
link_tasks.add(
|
||||
asyncio.create_task(self._link_arbitrary_state(state, substate_cls))
|
||||
)
|
||||
|
||||
for substate_name, substate_task in tasks.items():
|
||||
state.substates[substate_name] = await substate_task
|
||||
await asyncio.gather(*link_tasks)
|
||||
|
||||
@override
|
||||
async def get_state(
|
||||
@ -4153,7 +4246,7 @@ def reload_state_module(
|
||||
if subclass.__module__ == module and module is not None:
|
||||
state.class_subclasses.remove(subclass)
|
||||
state._always_dirty_substates.discard(subclass.get_name())
|
||||
state._computed_var_dependencies = defaultdict(set)
|
||||
state._substate_var_dependencies = defaultdict(set)
|
||||
state._potentially_dirty_substates.discard(subclass.get_name())
|
||||
state._var_dependencies = {}
|
||||
state._init_var_dependency_dicts()
|
||||
state.get_class_substate.cache_clear()
|
||||
|
@ -1826,7 +1826,7 @@ class ComputedVar(Var[RETURN_TYPE]):
|
||||
_initial_value: RETURN_TYPE | types.Unset = dataclasses.field(default=types.Unset())
|
||||
|
||||
# Explicit var dependencies to track
|
||||
_static_deps: set[str] = dataclasses.field(default_factory=set)
|
||||
_static_deps: dict[str, set[str]] = dataclasses.field(default_factory=dict)
|
||||
|
||||
# Whether var dependencies should be auto-determined
|
||||
_auto_deps: bool = dataclasses.field(default=True)
|
||||
@ -1901,21 +1901,40 @@ class ComputedVar(Var[RETURN_TYPE]):
|
||||
|
||||
object.__setattr__(self, "_update_interval", interval)
|
||||
|
||||
if deps is None:
|
||||
deps = []
|
||||
else:
|
||||
_static_deps = {}
|
||||
if isinstance(deps, dict):
|
||||
# Assume a dict is coming from _replace, so no special processing.
|
||||
_static_deps = deps
|
||||
elif deps is not None:
|
||||
for dep in deps:
|
||||
if isinstance(dep, Var):
|
||||
continue
|
||||
if isinstance(dep, str) and dep != "":
|
||||
continue
|
||||
raise TypeError(
|
||||
"ComputedVar dependencies must be Var instances or var names (non-empty strings)."
|
||||
)
|
||||
state_name = (
|
||||
all_var_data.state
|
||||
if (all_var_data := dep._get_all_var_data())
|
||||
and all_var_data.state
|
||||
else None
|
||||
)
|
||||
var_name = (
|
||||
dep._js_expr[len(formatted_state_prefix) :]
|
||||
if state_name
|
||||
and (
|
||||
formatted_state_prefix := format_state_name(state_name)
|
||||
+ "."
|
||||
)
|
||||
and dep._js_expr.startswith(formatted_state_prefix)
|
||||
else dep._js_expr
|
||||
)
|
||||
_static_deps.setdefault(state_name, set()).add(var_name)
|
||||
elif isinstance(dep, str) and dep != "":
|
||||
_static_deps.setdefault(None, set()).add(dep)
|
||||
else:
|
||||
raise TypeError(
|
||||
"ComputedVar dependencies must be Var instances or var names (non-empty strings)."
|
||||
)
|
||||
object.__setattr__(
|
||||
self,
|
||||
"_static_deps",
|
||||
{dep._js_expr if isinstance(dep, Var) else dep for dep in deps},
|
||||
_static_deps,
|
||||
)
|
||||
object.__setattr__(self, "_auto_deps", auto_deps)
|
||||
|
||||
@ -2081,6 +2100,11 @@ class ComputedVar(Var[RETURN_TYPE]):
|
||||
setattr(instance, self._last_updated_attr, datetime.datetime.now())
|
||||
value = getattr(instance, self._cache_attr)
|
||||
|
||||
self._check_deprecated_return_type(instance, value)
|
||||
|
||||
return value
|
||||
|
||||
def _check_deprecated_return_type(self, instance, value) -> None:
|
||||
if not _isinstance(value, self._var_type):
|
||||
console.deprecate(
|
||||
"mismatched-computed-var-return",
|
||||
@ -2090,41 +2114,49 @@ class ComputedVar(Var[RETURN_TYPE]):
|
||||
"0.7.0",
|
||||
)
|
||||
|
||||
return value
|
||||
|
||||
def _deps(
|
||||
self,
|
||||
objclass: Type,
|
||||
objclass: BaseState,
|
||||
obj: FunctionType | CodeType | None = None,
|
||||
self_name: Optional[str] = None,
|
||||
) -> set[str]:
|
||||
self_names: Optional[dict[str, str]] = None,
|
||||
) -> dict[str, set[str]]:
|
||||
"""Determine var dependencies of this ComputedVar.
|
||||
|
||||
Save references to attributes accessed on "self". Recursively called
|
||||
when the function makes a method call on "self" or define comprehensions
|
||||
or nested functions that may reference "self".
|
||||
Save references to attributes accessed on "self" or other fetched states.
|
||||
|
||||
Recursively called when the function makes a method call on "self" or
|
||||
define comprehensions or nested functions that may reference "self".
|
||||
|
||||
Args:
|
||||
objclass: the class obj this ComputedVar is attached to.
|
||||
obj: the object to disassemble (defaults to the fget function).
|
||||
self_name: if specified, look for this name in LOAD_FAST and LOAD_DEREF instructions.
|
||||
self_names: if specified, look for these names in LOAD_FAST and LOAD_DEREF instructions.
|
||||
|
||||
Returns:
|
||||
A set of variable names accessed by the given obj.
|
||||
A dictionary mapping state names to the set of variable names
|
||||
accessed by the given obj.
|
||||
|
||||
Raises:
|
||||
VarValueError: if the function references the get_state, parent_state, or substates attributes
|
||||
(cannot track deps in a related state, only implicitly via parent state).
|
||||
"""
|
||||
from reflex.state import BaseState
|
||||
|
||||
d = {}
|
||||
if self._static_deps:
|
||||
d.update(self._static_deps)
|
||||
# None is a placeholder for the current state class.
|
||||
if None in d:
|
||||
d[objclass.get_full_name()] = d.pop(None)
|
||||
|
||||
if not self._auto_deps:
|
||||
return self._static_deps
|
||||
d = self._static_deps.copy()
|
||||
return d
|
||||
if obj is None:
|
||||
fget = self._fget
|
||||
if fget is not None:
|
||||
obj = cast(FunctionType, fget)
|
||||
else:
|
||||
return set()
|
||||
return d
|
||||
with contextlib.suppress(AttributeError):
|
||||
# unbox functools.partial
|
||||
obj = cast(FunctionType, obj.func) # type: ignore
|
||||
@ -2132,76 +2164,150 @@ class ComputedVar(Var[RETURN_TYPE]):
|
||||
# unbox EventHandler
|
||||
obj = cast(FunctionType, obj.fn) # type: ignore
|
||||
|
||||
if self_name is None and isinstance(obj, FunctionType):
|
||||
if self_names is None and isinstance(obj, FunctionType):
|
||||
try:
|
||||
# the first argument to the function is the name of "self" arg
|
||||
self_name = obj.__code__.co_varnames[0]
|
||||
self_names = {obj.__code__.co_varnames[0]: objclass.get_full_name()}
|
||||
except (AttributeError, IndexError):
|
||||
self_name = None
|
||||
if self_name is None:
|
||||
self_names = None
|
||||
if self_names is None:
|
||||
# cannot reference attributes on self if method takes no args
|
||||
return set()
|
||||
return d
|
||||
|
||||
invalid_names = ["get_state", "parent_state", "substates", "get_substate"]
|
||||
self_is_top_of_stack = False
|
||||
invalid_names = ["parent_state", "substates", "get_substate"]
|
||||
self_on_top_of_stack = None
|
||||
getting_state = False
|
||||
getting_var = False
|
||||
for instruction in dis.get_instructions(obj):
|
||||
if getting_state:
|
||||
if instruction.opname == "LOAD_FAST":
|
||||
raise VarValueError(
|
||||
f"Dependency detection cannot identify get_state class from local var {instruction.argval}."
|
||||
)
|
||||
if instruction.opname == "LOAD_GLOBAL":
|
||||
# Special case: referencing state class from global scope.
|
||||
getting_state = obj.__globals__.get(instruction.argval)
|
||||
elif instruction.opname == "LOAD_DEREF":
|
||||
# Special case: referencing state class from closure.
|
||||
closure = dict(zip(obj.__code__.co_freevars, obj.__closure__))
|
||||
try:
|
||||
getting_state = closure[instruction.argval].cell_contents
|
||||
except ValueError as ve:
|
||||
raise VarValueError(
|
||||
f"Cached var {self!s} cannot access arbitrary state `{instruction.argval}`, is it defined yet?."
|
||||
) from ve
|
||||
elif instruction.opname == "STORE_FAST":
|
||||
# Storing the result of get_state in a local variable.
|
||||
if not isinstance(getting_state, type) or not issubclass(
|
||||
getting_state, BaseState
|
||||
):
|
||||
raise VarValueError(
|
||||
f"Cached var {self!s} cannot determine dependencies in fetched state `{instruction.argval}`."
|
||||
)
|
||||
self_names[instruction.argval] = getting_state.get_full_name()
|
||||
getting_state = False
|
||||
continue # nothing else happens until we have identified the local var
|
||||
if getting_var:
|
||||
if instruction.opname == "CALL":
|
||||
# get the original source code and eval it
|
||||
start_line = getting_var[0].positions.lineno
|
||||
start_column = getting_var[0].positions.col_offset
|
||||
end_line = getting_var[-1].positions.end_lineno
|
||||
end_column = getting_var[-1].positions.end_col_offset
|
||||
source = inspect.getsource(inspect.getmodule(obj)).splitlines(True)[start_line - 1: end_line]
|
||||
if len(source) > 1:
|
||||
snipped_source = "".join(
|
||||
[
|
||||
source[0][start_column:],
|
||||
source[1:-2] if len(source) > 2 else "",
|
||||
source[-1][:end_column]
|
||||
]
|
||||
)
|
||||
else:
|
||||
snipped_source = source[0][start_column:end_column]
|
||||
the_var = eval(f"({snipped_source})", obj.__globals__)
|
||||
print(the_var)
|
||||
# code = source[start_line - 1]
|
||||
# bytecode = bytearray((dis.opmap["RESUME"], 0))
|
||||
# for ins in getting_var:
|
||||
# bytecode.append(ins.opcode)
|
||||
# bytecode.append(ins.arg or 0 & 0xFF)
|
||||
# bytecode.extend((dis.opmap["RETURN_VALUE"], 0))
|
||||
# bc = dis.Bytecode(obj)
|
||||
# code = bc.codeobj.replace(co_code=bytes(bytecode), co_argcount=0, co_nlocals=0, co_varnames=())
|
||||
# breakpoint()
|
||||
getting_var = False
|
||||
elif isinstance(getting_var, list):
|
||||
getting_var.append(instruction)
|
||||
else:
|
||||
getting_var = [instruction]
|
||||
continue
|
||||
if (
|
||||
instruction.opname in ("LOAD_FAST", "LOAD_DEREF")
|
||||
and instruction.argval == self_name
|
||||
and instruction.argval in self_names
|
||||
):
|
||||
# bytecode loaded the class instance to the top of stack, next load instruction
|
||||
# is referencing an attribute on self
|
||||
self_is_top_of_stack = True
|
||||
self_on_top_of_stack = self_names[instruction.argval]
|
||||
continue
|
||||
if self_is_top_of_stack and instruction.opname in (
|
||||
if self_on_top_of_stack and instruction.opname in (
|
||||
"LOAD_ATTR",
|
||||
"LOAD_METHOD",
|
||||
):
|
||||
try:
|
||||
ref_obj = getattr(objclass, instruction.argval)
|
||||
except Exception:
|
||||
ref_obj = None
|
||||
if instruction.argval in invalid_names:
|
||||
raise VarValueError(
|
||||
f"Cached var {self!s} cannot access arbitrary state via `{instruction.argval}`."
|
||||
)
|
||||
if instruction.argval == "get_state":
|
||||
# Special case: arbitrary state access requested.
|
||||
getting_state = True
|
||||
continue
|
||||
if instruction.argval == "get_var_value":
|
||||
# Special case: arbitrary var access requested.
|
||||
getting_var = True
|
||||
continue
|
||||
print(f"{self_on_top_of_stack=}")
|
||||
target_state = objclass.get_root_state().get_class_substate(
|
||||
self_on_top_of_stack
|
||||
)
|
||||
try:
|
||||
ref_obj = getattr(target_state, instruction.argval)
|
||||
except Exception:
|
||||
ref_obj = None
|
||||
if callable(ref_obj):
|
||||
# recurse into callable attributes
|
||||
d.update(
|
||||
self._deps(
|
||||
objclass=objclass,
|
||||
obj=ref_obj,
|
||||
)
|
||||
)
|
||||
for state_name, dep_name in self._deps(
|
||||
objclass=target_state,
|
||||
obj=ref_obj,
|
||||
).items():
|
||||
d.setdefault(state_name, set()).update(dep_name)
|
||||
# recurse into property fget functions
|
||||
elif isinstance(ref_obj, property) and not isinstance(
|
||||
ref_obj, ComputedVar
|
||||
):
|
||||
d.update(
|
||||
self._deps(
|
||||
objclass=objclass,
|
||||
obj=ref_obj.fget, # type: ignore
|
||||
)
|
||||
)
|
||||
for state_name, dep_name in self._deps(
|
||||
objclass=target_state,
|
||||
obj=ref_obj.fget, # type: ignore
|
||||
).items():
|
||||
d.setdefault(state_name, set()).update(dep_name)
|
||||
elif (
|
||||
instruction.argval in objclass.backend_vars
|
||||
or instruction.argval in objclass.vars
|
||||
instruction.argval in target_state.backend_vars
|
||||
or instruction.argval in target_state.vars
|
||||
):
|
||||
# var access
|
||||
d.add(instruction.argval)
|
||||
d.setdefault(self_on_top_of_stack, set()).add(instruction.argval)
|
||||
elif instruction.opname == "LOAD_CONST" and isinstance(
|
||||
instruction.argval, CodeType
|
||||
):
|
||||
# recurse into nested functions / comprehensions, which can reference
|
||||
# instance attributes from the outer scope
|
||||
d.update(
|
||||
self._deps(
|
||||
objclass=objclass,
|
||||
obj=instruction.argval,
|
||||
self_name=self_name,
|
||||
)
|
||||
)
|
||||
self_is_top_of_stack = False
|
||||
for state_name, dep_name in self._deps(
|
||||
objclass=objclass,
|
||||
obj=instruction.argval,
|
||||
self_names=self_names,
|
||||
).items():
|
||||
d.setdefault(state_name, set()).update(dep_name)
|
||||
self_on_top_of_stack = None
|
||||
return d
|
||||
|
||||
def mark_dirty(self, instance) -> None:
|
||||
@ -2249,6 +2355,60 @@ class DynamicRouteVar(ComputedVar[Union[str, List[str]]]):
|
||||
pass
|
||||
|
||||
|
||||
@dataclasses.dataclass(
|
||||
eq=False,
|
||||
frozen=True,
|
||||
init=False,
|
||||
**{"slots": True} if sys.version_info >= (3, 10) else {},
|
||||
)
|
||||
class AsyncComputedVar(ComputedVar[RETURN_TYPE]):
|
||||
"""A computed var that wraps a coroutinefunction."""
|
||||
|
||||
_fget: Callable[[BaseState], RETURN_TYPE] = dataclasses.field(
|
||||
default_factory=lambda: lambda _: None
|
||||
) # type: ignore
|
||||
|
||||
def __get__(self, instance: BaseState | None, owner):
|
||||
"""Get the ComputedVar value.
|
||||
|
||||
If the value is already cached on the instance, return the cached value.
|
||||
|
||||
Args:
|
||||
instance: the instance of the class accessing this computed var.
|
||||
owner: the class that this descriptor is attached to.
|
||||
|
||||
Returns:
|
||||
The value of the var for the given instance.
|
||||
"""
|
||||
if instance is None:
|
||||
return super(AsyncComputedVar, self).__get__(instance, owner)
|
||||
|
||||
if not self._cache:
|
||||
|
||||
async def _awaitable_result():
|
||||
value = await self.fget(instance)
|
||||
self._check_deprecated_return_type(instance, value)
|
||||
|
||||
return _awaitable_result()
|
||||
else:
|
||||
# handle caching
|
||||
async def _awaitable_result():
|
||||
if not hasattr(instance, self._cache_attr) or self.needs_update(
|
||||
instance
|
||||
):
|
||||
# Set cache attr on state instance.
|
||||
setattr(instance, self._cache_attr, await self.fget(instance))
|
||||
# Ensure the computed var gets serialized to redis.
|
||||
instance._was_touched = True
|
||||
# Set the last updated timestamp on the state instance.
|
||||
setattr(instance, self._last_updated_attr, datetime.datetime.now())
|
||||
value = getattr(instance, self._cache_attr)
|
||||
self._check_deprecated_return_type(instance, value)
|
||||
return value
|
||||
|
||||
return _awaitable_result()
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
BASE_STATE = TypeVar("BASE_STATE", bound=BaseState)
|
||||
|
||||
@ -2315,10 +2475,27 @@ def computed_var(
|
||||
raise VarDependencyError("Cannot track dependencies without caching.")
|
||||
|
||||
if fget is not None:
|
||||
return ComputedVar(fget, cache=cache)
|
||||
if inspect.iscoroutinefunction(fget):
|
||||
computed_var_cls = AsyncComputedVar
|
||||
else:
|
||||
computed_var_cls = ComputedVar
|
||||
return computed_var_cls(
|
||||
fget,
|
||||
initial_value=initial_value,
|
||||
cache=cache,
|
||||
deps=deps,
|
||||
auto_deps=auto_deps,
|
||||
interval=interval,
|
||||
backend=backend,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def wrapper(fget: Callable[[BASE_STATE], Any]) -> ComputedVar:
|
||||
return ComputedVar(
|
||||
if inspect.iscoroutinefunction(fget):
|
||||
computed_var_cls = AsyncComputedVar
|
||||
else:
|
||||
computed_var_cls = ComputedVar
|
||||
return computed_var_cls(
|
||||
fget,
|
||||
initial_value=initial_value,
|
||||
cache=cache,
|
||||
|
@ -277,9 +277,9 @@ def test_add_page_set_route_dynamic(index_page, windows_platform: bool):
|
||||
assert app.pages.keys() == {"test/[dynamic]"}
|
||||
assert "dynamic" in app.state.computed_vars
|
||||
assert app.state.computed_vars["dynamic"]._deps(objclass=EmptyState) == {
|
||||
constants.ROUTER
|
||||
EmptyState.get_full_name(): {constants.ROUTER},
|
||||
}
|
||||
assert constants.ROUTER in app.state()._computed_var_dependencies
|
||||
assert constants.ROUTER in app.state()._var_dependencies
|
||||
|
||||
|
||||
def test_add_page_set_route_nested(app: App, index_page, windows_platform: bool):
|
||||
@ -997,9 +997,9 @@ async def test_dynamic_route_var_route_change_completed_on_load(
|
||||
assert arg_name in app.state.vars
|
||||
assert arg_name in app.state.computed_vars
|
||||
assert app.state.computed_vars[arg_name]._deps(objclass=DynamicState) == {
|
||||
constants.ROUTER
|
||||
DynamicState.get_full_name(): {constants.ROUTER},
|
||||
}
|
||||
assert constants.ROUTER in app.state()._computed_var_dependencies
|
||||
assert constants.ROUTER in app.state()._var_dependencies
|
||||
|
||||
substate_token = _substate_key(token, DynamicState)
|
||||
sid = "mock_sid"
|
||||
@ -1557,6 +1557,16 @@ def test_app_with_valid_var_dependencies(compilable_app: tuple[App, Path]):
|
||||
def bar(self) -> str:
|
||||
return "bar"
|
||||
|
||||
class Child1(ValidDepState):
|
||||
@computed_var(deps=["base", ValidDepState.bar])
|
||||
def other(self) -> str:
|
||||
return "other"
|
||||
|
||||
class Child2(ValidDepState):
|
||||
@computed_var(deps=["base", Child1.other])
|
||||
def other(self) -> str:
|
||||
return "other"
|
||||
|
||||
app.state = ValidDepState
|
||||
app._compile()
|
||||
|
||||
|
@ -1170,13 +1170,11 @@ def test_conditional_computed_vars():
|
||||
|
||||
ms = MainState()
|
||||
# Initially there are no dirty computed vars.
|
||||
assert ms._dirty_computed_vars(from_vars={"flag"}) == {"rendered_var"}
|
||||
assert ms._dirty_computed_vars(from_vars={"t2"}) == {"rendered_var"}
|
||||
assert ms._dirty_computed_vars(from_vars={"t1"}) == {"rendered_var"}
|
||||
assert ms._dirty_computed_vars(from_vars={"flag"}) == {(MainState.get_full_name(), "rendered_var")}
|
||||
assert ms._dirty_computed_vars(from_vars={"t2"}) == {(MainState.get_full_name(), "rendered_var")}
|
||||
assert ms._dirty_computed_vars(from_vars={"t1"}) == {(MainState.get_full_name(), "rendered_var")}
|
||||
assert ms.computed_vars["rendered_var"]._deps(objclass=MainState) == {
|
||||
"flag",
|
||||
"t1",
|
||||
"t2",
|
||||
MainState.get_full_name(): {"flag", "t1", "t2"}
|
||||
}
|
||||
|
||||
|
||||
@ -1371,7 +1369,7 @@ def test_cached_var_depends_on_event_handler(use_partial: bool):
|
||||
assert isinstance(HandlerState.handler, EventHandler)
|
||||
|
||||
s = HandlerState()
|
||||
assert "cached_x_side_effect" in s._computed_var_dependencies["x"]
|
||||
assert (HandlerState.get_full_name(), "cached_x_side_effect") in s._var_dependencies["x"]
|
||||
assert s.cached_x_side_effect == 1
|
||||
assert s.x == 43
|
||||
s.handler()
|
||||
@ -1461,15 +1459,15 @@ def test_computed_var_dependencies():
|
||||
return [z in self._z for z in range(5)]
|
||||
|
||||
cs = ComputedState()
|
||||
assert cs._computed_var_dependencies["v"] == {
|
||||
"comp_v",
|
||||
"comp_v_backend",
|
||||
"comp_v_via_property",
|
||||
assert cs._var_dependencies["v"] == {
|
||||
(ComputedState.get_full_name(), "comp_v"),
|
||||
(ComputedState.get_full_name(), "comp_v_backend"),
|
||||
(ComputedState.get_full_name(), "comp_v_via_property"),
|
||||
}
|
||||
assert cs._computed_var_dependencies["w"] == {"comp_w"}
|
||||
assert cs._computed_var_dependencies["x"] == {"comp_x"}
|
||||
assert cs._computed_var_dependencies["y"] == {"comp_y"}
|
||||
assert cs._computed_var_dependencies["_z"] == {"comp_z"}
|
||||
assert cs._var_dependencies["w"] == {(ComputedState.get_full_name(), "comp_w")}
|
||||
assert cs._var_dependencies["x"] == {(ComputedState.get_full_name(), "comp_x")}
|
||||
assert cs._var_dependencies["y"] == {(ComputedState.get_full_name(), "comp_y")}
|
||||
assert cs._var_dependencies["_z"] == {(ComputedState.get_full_name(), "comp_z")}
|
||||
|
||||
|
||||
def test_backend_method():
|
||||
@ -3182,6 +3180,7 @@ async def test_get_state_from_sibling_not_cached(mock_app: rx.App, token: str):
|
||||
RxState = State
|
||||
|
||||
|
||||
@pytest.mark.skip(reason="This test is maybe not relevant anymore.")
|
||||
def test_potentially_dirty_substates():
|
||||
"""Test that potentially_dirty_substates returns the correct substates.
|
||||
|
||||
@ -3203,7 +3202,8 @@ def test_potentially_dirty_substates():
|
||||
assert C1._potentially_dirty_substates() == set()
|
||||
|
||||
|
||||
def test_router_var_dep() -> None:
|
||||
@pytest.mark.asyncio
|
||||
async def test_router_var_dep() -> None:
|
||||
"""Test that router var dependencies are correctly tracked."""
|
||||
|
||||
class RouterVarParentState(State):
|
||||
@ -3221,13 +3221,9 @@ def test_router_var_dep() -> None:
|
||||
foo = RouterVarDepState.computed_vars["foo"]
|
||||
State._init_var_dependency_dicts()
|
||||
|
||||
assert foo._deps(objclass=RouterVarDepState) == {"router"}
|
||||
assert RouterVarParentState._potentially_dirty_substates() == {RouterVarDepState}
|
||||
assert RouterVarParentState._substate_var_dependencies == {
|
||||
"router": {RouterVarDepState.get_name()}
|
||||
}
|
||||
assert RouterVarDepState._computed_var_dependencies == {
|
||||
"router": {"foo"},
|
||||
assert foo._deps(objclass=RouterVarDepState) == {RouterVarDepState.get_full_name(): {"router"}}
|
||||
assert State._var_dependencies == {
|
||||
"router": {(RouterVarDepState.get_full_name(), "foo")}
|
||||
}
|
||||
|
||||
rx_state = State()
|
||||
@ -3240,11 +3236,15 @@ def test_router_var_dep() -> None:
|
||||
state.parent_state = parent_state
|
||||
parent_state.substates = {RouterVarDepState.get_name(): state}
|
||||
|
||||
populated_substate_classes = await rx_state._recursively_populate_dependent_substates()
|
||||
assert populated_substate_classes == {State, RouterVarDepState}
|
||||
|
||||
assert state.dirty_vars == set()
|
||||
|
||||
# Reassign router var
|
||||
state.router = state.router
|
||||
assert state.dirty_vars == {"foo", "router"}
|
||||
assert rx_state.dirty_vars == {"router"}
|
||||
assert state.dirty_vars == {"foo"}
|
||||
assert parent_state.dirty_substates == {RouterVarDepState.get_name()}
|
||||
|
||||
|
||||
@ -3803,3 +3803,74 @@ async def test_get_var_value(state_manager: StateManager, substate_token: str):
|
||||
# Generic Var with no state
|
||||
with pytest.raises(UnretrievableVarValueError):
|
||||
await state.get_var_value(rx.Var("undefined"))
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_computed_var_get_state(mock_app: rx.App, token: str):
|
||||
"""A test where an async computed var depends on a var in another state.
|
||||
|
||||
Args:
|
||||
mock_app: An app that will be returned by `get_app()`
|
||||
token: A token.
|
||||
"""
|
||||
|
||||
class Parent(BaseState):
|
||||
"""A root state like rx.State."""
|
||||
|
||||
parent_var: int = 0
|
||||
|
||||
class Child2(Parent):
|
||||
"""An unconnected child state."""
|
||||
|
||||
pass
|
||||
|
||||
class Child3(Parent):
|
||||
"""A child state with a computed var causing it to be pre-fetched.
|
||||
|
||||
If child3_var gets set to a value, and `get_state` erroneously
|
||||
re-fetches it from redis, the value will be lost.
|
||||
"""
|
||||
|
||||
child3_var: int = 0
|
||||
|
||||
@rx.var(cache=True)
|
||||
def v(self):
|
||||
return self.child3_var
|
||||
|
||||
class Child(Parent):
|
||||
"""A state simulating UpdateVarsInternalState."""
|
||||
|
||||
@rx.var(cache=True)
|
||||
async def v(self):
|
||||
p = await self.get_state(Parent)
|
||||
child3 = await self.get_state(Child3)
|
||||
return child3.child3_var + p.parent_var
|
||||
|
||||
mock_app.state_manager.state = mock_app.state = Parent
|
||||
|
||||
# Get the top level state via unconnected sibling.
|
||||
root = await mock_app.state_manager.get_state(_substate_key(token, Child))
|
||||
# Set value in parent_var to assert it does not get refetched later.
|
||||
root.parent_var = 1
|
||||
|
||||
if isinstance(mock_app.state_manager, StateManagerRedis):
|
||||
# When redis is used, only states with uncached computed vars are pre-fetched.
|
||||
assert Child2.get_name() not in root.substates
|
||||
assert Child3.get_name() not in root.substates
|
||||
|
||||
# Get the unconnected sibling state, which will be used to `get_state` other instances.
|
||||
child = root.get_substate(Child.get_full_name().split("."))
|
||||
|
||||
# Get an uncached child state.
|
||||
child2 = await child.get_state(Child2)
|
||||
assert child2.parent_var == 1
|
||||
|
||||
# Set value on already-cached Child3 state (prefetched because it has a Computed Var).
|
||||
child3 = await child.get_state(Child3)
|
||||
child3.child3_var = 1
|
||||
|
||||
assert await child.v == 2
|
||||
assert await child.v == 2
|
||||
root.parent_var = 2
|
||||
assert await child.v == 3
|
||||
|
||||
|
@ -15,6 +15,7 @@ from reflex.utils.exceptions import PrimitiveUnserializableToJSON
|
||||
from reflex.utils.imports import ImportVar
|
||||
from reflex.vars import VarData
|
||||
from reflex.vars.base import (
|
||||
AsyncComputedVar,
|
||||
ComputedVar,
|
||||
LiteralVar,
|
||||
Var,
|
||||
@ -1808,9 +1809,9 @@ def cv_fget(state: BaseState) -> int:
|
||||
@pytest.mark.parametrize(
|
||||
"deps,expected",
|
||||
[
|
||||
(["a"], {"a"}),
|
||||
(["b"], {"b"}),
|
||||
([ComputedVar(fget=cv_fget)], {"cv_fget"}),
|
||||
(["a"], {None: {"a"}}),
|
||||
(["b"], {None: {"b"}}),
|
||||
([ComputedVar(fget=cv_fget)], {None: {"cv_fget"}}),
|
||||
],
|
||||
)
|
||||
def test_computed_var_deps(deps: List[Union[str, Var]], expected: Set[str]):
|
||||
@ -1856,3 +1857,25 @@ def test_to_string_operation():
|
||||
|
||||
single_var = Var.create(Email())
|
||||
assert single_var._var_type == Email
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_computed_var():
|
||||
side_effect_counter = 0
|
||||
|
||||
class AsyncComputedVarState(BaseState):
|
||||
v: int = 1
|
||||
|
||||
@computed_var(cache=True)
|
||||
async def async_computed_var(self) -> int:
|
||||
nonlocal side_effect_counter
|
||||
side_effect_counter += 1
|
||||
return self.v + 1
|
||||
|
||||
my_state = AsyncComputedVarState()
|
||||
assert await my_state.async_computed_var == 2
|
||||
assert await my_state.async_computed_var == 2
|
||||
my_state.v = 2
|
||||
assert await my_state.async_computed_var == 3
|
||||
assert await my_state.async_computed_var == 3
|
||||
assert side_effect_counter == 2
|
||||
|
Loading…
Reference in New Issue
Block a user