[ENG-4326] Async ComputedVar (#4711)
* WiP * Save the var from get_var_name * flatten StateManagerRedis.get_state algorithm simplify fetching of states and avoid repeatedly fetching the same state * Get all the states in a single redis round-trip * update docstrings in StateManagerRedis * Move computed var dep tracking to separate module * Fix pre-commit issues * ComputedVar.add_dependency: explicitly dependency declaration Allow var dependencies to be added at runtime, for example, when defining a ComponentState that depends on vars that cannot be known statically. Fix more pyright issues. * Fix/ignore more pyright issues from recent merge * handle cleaning out _potentially_dirty_states on reload * ignore accessed attributes missing on state class these might be added dynamically later in which case we recompute the dependency tracking dicts... if not, they'll blow up anyway at runtime. * fix playwright tests, which insist on running an asyncio loop --------- Co-authored-by: Khaleel Al-Adhami <khaleel.aladhami@gmail.com>
This commit is contained in:
parent
7da5fa0e5c
commit
a2243190ff
@ -908,11 +908,17 @@ class App(MiddlewareMixin, LifespanMixin):
|
|||||||
if not var._cache:
|
if not var._cache:
|
||||||
continue
|
continue
|
||||||
deps = var._deps(objclass=state)
|
deps = var._deps(objclass=state)
|
||||||
for dep in deps:
|
for state_name, dep_set in deps.items():
|
||||||
if dep not in state.vars and dep not in state.backend_vars:
|
state_cls = (
|
||||||
raise exceptions.VarDependencyError(
|
state.get_root_state().get_class_substate(state_name)
|
||||||
f"ComputedVar {var._js_expr} on state {state.__name__} has an invalid dependency {dep}"
|
if state_name != state.get_full_name()
|
||||||
)
|
else state
|
||||||
|
)
|
||||||
|
for dep in dep_set:
|
||||||
|
if dep not in state_cls.vars and dep not in state_cls.backend_vars:
|
||||||
|
raise exceptions.VarDependencyError(
|
||||||
|
f"ComputedVar {var._js_expr} on state {state.__name__} has an invalid dependency {state_name}.{dep}"
|
||||||
|
)
|
||||||
|
|
||||||
for substate in state.class_subclasses:
|
for substate in state.class_subclasses:
|
||||||
self._validate_var_dependencies(substate)
|
self._validate_var_dependencies(substate)
|
||||||
|
@ -2,12 +2,15 @@
|
|||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import concurrent.futures
|
||||||
import traceback
|
import traceback
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any, Callable, Dict, Optional, Type, Union
|
from typing import Any, Callable, Dict, Optional, Type, Union
|
||||||
from urllib.parse import urlparse
|
from urllib.parse import urlparse
|
||||||
|
|
||||||
|
from reflex.utils.exec import is_in_app_harness
|
||||||
from reflex.utils.prerequisites import get_web_dir
|
from reflex.utils.prerequisites import get_web_dir
|
||||||
from reflex.vars.base import Var
|
from reflex.vars.base import Var
|
||||||
|
|
||||||
@ -33,7 +36,7 @@ from reflex.components.base import (
|
|||||||
)
|
)
|
||||||
from reflex.components.component import Component, ComponentStyle, CustomComponent
|
from reflex.components.component import Component, ComponentStyle, CustomComponent
|
||||||
from reflex.istate.storage import Cookie, LocalStorage, SessionStorage
|
from reflex.istate.storage import Cookie, LocalStorage, SessionStorage
|
||||||
from reflex.state import BaseState
|
from reflex.state import BaseState, _resolve_delta
|
||||||
from reflex.style import Style
|
from reflex.style import Style
|
||||||
from reflex.utils import console, format, imports, path_ops
|
from reflex.utils import console, format, imports, path_ops
|
||||||
from reflex.utils.imports import ImportVar, ParsedImportDict
|
from reflex.utils.imports import ImportVar, ParsedImportDict
|
||||||
@ -177,7 +180,24 @@ def compile_state(state: Type[BaseState]) -> dict:
|
|||||||
initial_state = state(_reflex_internal_init=True).dict(
|
initial_state = state(_reflex_internal_init=True).dict(
|
||||||
initial=True, include_computed=False
|
initial=True, include_computed=False
|
||||||
)
|
)
|
||||||
return initial_state
|
try:
|
||||||
|
_ = asyncio.get_running_loop()
|
||||||
|
except RuntimeError:
|
||||||
|
pass
|
||||||
|
else:
|
||||||
|
if is_in_app_harness():
|
||||||
|
# Playwright tests already have an event loop running, so we can't use asyncio.run.
|
||||||
|
with concurrent.futures.ThreadPoolExecutor() as pool:
|
||||||
|
resolved_initial_state = pool.submit(
|
||||||
|
asyncio.run, _resolve_delta(initial_state)
|
||||||
|
).result()
|
||||||
|
console.warn(
|
||||||
|
f"Had to get initial state in a thread 🤮 {resolved_initial_state}",
|
||||||
|
)
|
||||||
|
return resolved_initial_state
|
||||||
|
|
||||||
|
# Normally the compile runs before any event loop starts, we asyncio.run is available for calling.
|
||||||
|
return asyncio.run(_resolve_delta(initial_state))
|
||||||
|
|
||||||
|
|
||||||
def _compile_client_storage_field(
|
def _compile_client_storage_field(
|
||||||
|
@ -8,7 +8,7 @@ from typing import TYPE_CHECKING, Optional
|
|||||||
from reflex import constants
|
from reflex import constants
|
||||||
from reflex.event import Event, get_hydrate_event
|
from reflex.event import Event, get_hydrate_event
|
||||||
from reflex.middleware.middleware import Middleware
|
from reflex.middleware.middleware import Middleware
|
||||||
from reflex.state import BaseState, StateUpdate
|
from reflex.state import BaseState, StateUpdate, _resolve_delta
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from reflex.app import App
|
from reflex.app import App
|
||||||
@ -42,7 +42,7 @@ class HydrateMiddleware(Middleware):
|
|||||||
setattr(state, constants.CompileVars.IS_HYDRATED, False)
|
setattr(state, constants.CompileVars.IS_HYDRATED, False)
|
||||||
|
|
||||||
# Get the initial state.
|
# Get the initial state.
|
||||||
delta = state.dict()
|
delta = await _resolve_delta(state.dict())
|
||||||
# since a full dict was captured, clean any dirtiness
|
# since a full dict was captured, clean any dirtiness
|
||||||
state._clean()
|
state._clean()
|
||||||
|
|
||||||
|
570
reflex/state.py
570
reflex/state.py
@ -15,7 +15,6 @@ import time
|
|||||||
import typing
|
import typing
|
||||||
import uuid
|
import uuid
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from collections import defaultdict
|
|
||||||
from hashlib import md5
|
from hashlib import md5
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from types import FunctionType, MethodType
|
from types import FunctionType, MethodType
|
||||||
@ -329,6 +328,25 @@ def get_var_for_field(cls: Type[BaseState], f: ModelField):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
async def _resolve_delta(delta: Delta) -> Delta:
|
||||||
|
"""Await all coroutines in the delta.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
delta: The delta to process.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The same delta dict with all coroutines resolved to their return value.
|
||||||
|
"""
|
||||||
|
tasks = {}
|
||||||
|
for state_name, state_delta in delta.items():
|
||||||
|
for var_name, value in state_delta.items():
|
||||||
|
if asyncio.iscoroutine(value):
|
||||||
|
tasks[state_name, var_name] = asyncio.create_task(value)
|
||||||
|
for (state_name, var_name), task in tasks.items():
|
||||||
|
delta[state_name][var_name] = await task
|
||||||
|
return delta
|
||||||
|
|
||||||
|
|
||||||
class BaseState(Base, ABC, extra=pydantic.Extra.allow):
|
class BaseState(Base, ABC, extra=pydantic.Extra.allow):
|
||||||
"""The state of the app."""
|
"""The state of the app."""
|
||||||
|
|
||||||
@ -356,11 +374,8 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
|
|||||||
# A set of subclassses of this class.
|
# A set of subclassses of this class.
|
||||||
class_subclasses: ClassVar[Set[Type[BaseState]]] = set()
|
class_subclasses: ClassVar[Set[Type[BaseState]]] = set()
|
||||||
|
|
||||||
# Mapping of var name to set of computed variables that depend on it
|
# Mapping of var name to set of (state_full_name, var_name) that depend on it.
|
||||||
_computed_var_dependencies: ClassVar[Dict[str, Set[str]]] = {}
|
_var_dependencies: ClassVar[Dict[str, Set[Tuple[str, str]]]] = {}
|
||||||
|
|
||||||
# Mapping of var name to set of substates that depend on it
|
|
||||||
_substate_var_dependencies: ClassVar[Dict[str, Set[str]]] = {}
|
|
||||||
|
|
||||||
# Set of vars which always need to be recomputed
|
# Set of vars which always need to be recomputed
|
||||||
_always_dirty_computed_vars: ClassVar[Set[str]] = set()
|
_always_dirty_computed_vars: ClassVar[Set[str]] = set()
|
||||||
@ -368,6 +383,9 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
|
|||||||
# Set of substates which always need to be recomputed
|
# Set of substates which always need to be recomputed
|
||||||
_always_dirty_substates: ClassVar[Set[str]] = set()
|
_always_dirty_substates: ClassVar[Set[str]] = set()
|
||||||
|
|
||||||
|
# Set of states which might need to be recomputed if vars in this state change.
|
||||||
|
_potentially_dirty_states: ClassVar[Set[str]] = set()
|
||||||
|
|
||||||
# The parent state.
|
# The parent state.
|
||||||
parent_state: Optional[BaseState] = None
|
parent_state: Optional[BaseState] = None
|
||||||
|
|
||||||
@ -519,6 +537,7 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
|
|||||||
|
|
||||||
# Reset dirty substate tracking for this class.
|
# Reset dirty substate tracking for this class.
|
||||||
cls._always_dirty_substates = set()
|
cls._always_dirty_substates = set()
|
||||||
|
cls._potentially_dirty_states = set()
|
||||||
|
|
||||||
# Get the parent vars.
|
# Get the parent vars.
|
||||||
parent_state = cls.get_parent_state()
|
parent_state = cls.get_parent_state()
|
||||||
@ -622,8 +641,7 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
|
|||||||
setattr(cls, name, handler)
|
setattr(cls, name, handler)
|
||||||
|
|
||||||
# Initialize per-class var dependency tracking.
|
# Initialize per-class var dependency tracking.
|
||||||
cls._computed_var_dependencies = defaultdict(set)
|
cls._var_dependencies = {}
|
||||||
cls._substate_var_dependencies = defaultdict(set)
|
|
||||||
cls._init_var_dependency_dicts()
|
cls._init_var_dependency_dicts()
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@ -768,26 +786,27 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
|
|||||||
Additional updates tracking dicts for vars and substates that always
|
Additional updates tracking dicts for vars and substates that always
|
||||||
need to be recomputed.
|
need to be recomputed.
|
||||||
"""
|
"""
|
||||||
inherited_vars = set(cls.inherited_vars).union(
|
|
||||||
set(cls.inherited_backend_vars),
|
|
||||||
)
|
|
||||||
for cvar_name, cvar in cls.computed_vars.items():
|
for cvar_name, cvar in cls.computed_vars.items():
|
||||||
# Add the dependencies.
|
if not cvar._cache:
|
||||||
for var in cvar._deps(objclass=cls):
|
# Do not perform dep calculation when cache=False (these are always dirty).
|
||||||
cls._computed_var_dependencies[var].add(cvar_name)
|
continue
|
||||||
if var in inherited_vars:
|
for state_name, dvar_set in cvar._deps(objclass=cls).items():
|
||||||
# track that this substate depends on its parent for this var
|
state_cls = cls.get_root_state().get_class_substate(state_name)
|
||||||
state_name = cls.get_name()
|
for dvar in dvar_set:
|
||||||
parent_state = cls.get_parent_state()
|
defining_state_cls = state_cls
|
||||||
while parent_state is not None and var in {
|
while dvar in {
|
||||||
**parent_state.vars,
|
*defining_state_cls.inherited_vars,
|
||||||
**parent_state.backend_vars,
|
*defining_state_cls.inherited_backend_vars,
|
||||||
}:
|
}:
|
||||||
parent_state._substate_var_dependencies[var].add(state_name)
|
parent_state = defining_state_cls.get_parent_state()
|
||||||
state_name, parent_state = (
|
if parent_state is not None:
|
||||||
parent_state.get_name(),
|
defining_state_cls = parent_state
|
||||||
parent_state.get_parent_state(),
|
defining_state_cls._var_dependencies.setdefault(dvar, set()).add(
|
||||||
)
|
(cls.get_full_name(), cvar_name)
|
||||||
|
)
|
||||||
|
defining_state_cls._potentially_dirty_states.add(
|
||||||
|
cls.get_full_name()
|
||||||
|
)
|
||||||
|
|
||||||
# ComputedVar with cache=False always need to be recomputed
|
# ComputedVar with cache=False always need to be recomputed
|
||||||
cls._always_dirty_computed_vars = {
|
cls._always_dirty_computed_vars = {
|
||||||
@ -902,6 +921,17 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
|
|||||||
raise ValueError(f"Only one parent state is allowed {parent_states}.")
|
raise ValueError(f"Only one parent state is allowed {parent_states}.")
|
||||||
return parent_states[0] if len(parent_states) == 1 else None
|
return parent_states[0] if len(parent_states) == 1 else None
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
@functools.lru_cache()
|
||||||
|
def get_root_state(cls) -> Type[BaseState]:
|
||||||
|
"""Get the root state.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The root state.
|
||||||
|
"""
|
||||||
|
parent_state = cls.get_parent_state()
|
||||||
|
return cls if parent_state is None else parent_state.get_root_state()
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_substates(cls) -> set[Type[BaseState]]:
|
def get_substates(cls) -> set[Type[BaseState]]:
|
||||||
"""Get the substates of the state.
|
"""Get the substates of the state.
|
||||||
@ -1351,7 +1381,7 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
|
|||||||
super().__setattr__(name, value)
|
super().__setattr__(name, value)
|
||||||
|
|
||||||
# Add the var to the dirty list.
|
# Add the var to the dirty list.
|
||||||
if name in self.vars or name in self._computed_var_dependencies:
|
if name in self.base_vars:
|
||||||
self.dirty_vars.add(name)
|
self.dirty_vars.add(name)
|
||||||
self._mark_dirty()
|
self._mark_dirty()
|
||||||
|
|
||||||
@ -1422,64 +1452,21 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
|
|||||||
return self.substates[path[0]].get_substate(path[1:])
|
return self.substates[path[0]].get_substate(path[1:])
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def _get_common_ancestor(cls, other: Type[BaseState]) -> str:
|
def _get_potentially_dirty_states(cls) -> set[type[BaseState]]:
|
||||||
"""Find the name of the nearest common ancestor shared by this and the other state.
|
"""Get substates which may have dirty vars due to dependencies.
|
||||||
|
|
||||||
Args:
|
|
||||||
other: The other state.
|
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Full name of the nearest common ancestor.
|
The set of potentially dirty substate classes.
|
||||||
"""
|
"""
|
||||||
common_ancestor_parts = []
|
return {
|
||||||
for part1, part2 in zip(
|
cls.get_class_substate(substate_name)
|
||||||
cls.get_full_name().split("."),
|
for substate_name in cls._always_dirty_substates
|
||||||
other.get_full_name().split("."),
|
}.union(
|
||||||
strict=True,
|
{
|
||||||
):
|
cls.get_root_state().get_class_substate(substate_name)
|
||||||
if part1 != part2:
|
for substate_name in cls._potentially_dirty_states
|
||||||
break
|
}
|
||||||
common_ancestor_parts.append(part1)
|
)
|
||||||
return ".".join(common_ancestor_parts)
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def _determine_missing_parent_states(
|
|
||||||
cls, target_state_cls: Type[BaseState]
|
|
||||||
) -> tuple[str, list[str]]:
|
|
||||||
"""Determine the missing parent states between the target_state_cls and common ancestor of this state.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
target_state_cls: The class of the state to find missing parent states for.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
The name of the common ancestor and the list of missing parent states.
|
|
||||||
"""
|
|
||||||
common_ancestor_name = cls._get_common_ancestor(target_state_cls)
|
|
||||||
common_ancestor_parts = common_ancestor_name.split(".")
|
|
||||||
target_state_parts = tuple(target_state_cls.get_full_name().split("."))
|
|
||||||
relative_target_state_parts = target_state_parts[len(common_ancestor_parts) :]
|
|
||||||
|
|
||||||
# Determine which parent states to fetch from the common ancestor down to the target_state_cls.
|
|
||||||
fetch_parent_states = [common_ancestor_name]
|
|
||||||
for relative_parent_state_name in relative_target_state_parts:
|
|
||||||
fetch_parent_states.append(
|
|
||||||
".".join((fetch_parent_states[-1], relative_parent_state_name))
|
|
||||||
)
|
|
||||||
|
|
||||||
return common_ancestor_name, fetch_parent_states[1:-1]
|
|
||||||
|
|
||||||
def _get_parent_states(self) -> list[tuple[str, BaseState]]:
|
|
||||||
"""Get all parent state instances up to the root of the state tree.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
A list of tuples containing the name and the instance of each parent state.
|
|
||||||
"""
|
|
||||||
parent_states_with_name = []
|
|
||||||
parent_state = self
|
|
||||||
while parent_state.parent_state is not None:
|
|
||||||
parent_state = parent_state.parent_state
|
|
||||||
parent_states_with_name.append((parent_state.get_full_name(), parent_state))
|
|
||||||
return parent_states_with_name
|
|
||||||
|
|
||||||
def _get_root_state(self) -> BaseState:
|
def _get_root_state(self) -> BaseState:
|
||||||
"""Get the root state of the state tree.
|
"""Get the root state of the state tree.
|
||||||
@ -1492,55 +1479,38 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
|
|||||||
parent_state = parent_state.parent_state
|
parent_state = parent_state.parent_state
|
||||||
return parent_state
|
return parent_state
|
||||||
|
|
||||||
async def _populate_parent_states(self, target_state_cls: Type[BaseState]):
|
async def _get_state_from_redis(self, state_cls: Type[T_STATE]) -> T_STATE:
|
||||||
"""Populate substates in the tree between the target_state_cls and common ancestor of this state.
|
"""Get a state instance from redis.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
target_state_cls: The class of the state to populate parent states for.
|
state_cls: The class of the state.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
The parent state instance of target_state_cls.
|
The instance of state_cls associated with this state's client_token.
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
RuntimeError: If redis is not used in this backend process.
|
RuntimeError: If redis is not used in this backend process.
|
||||||
|
StateMismatchError: If the state instance is not of the expected type.
|
||||||
"""
|
"""
|
||||||
|
# Then get the target state and all its substates.
|
||||||
state_manager = get_state_manager()
|
state_manager = get_state_manager()
|
||||||
if not isinstance(state_manager, StateManagerRedis):
|
if not isinstance(state_manager, StateManagerRedis):
|
||||||
raise RuntimeError(
|
raise RuntimeError(
|
||||||
f"Cannot populate parent states of {target_state_cls.get_full_name()} without redis. "
|
f"Requested state {state_cls.get_full_name()} is not cached and cannot be accessed without redis. "
|
||||||
"(All states should already be available -- this is likely a bug).",
|
"(All states should already be available -- this is likely a bug).",
|
||||||
)
|
)
|
||||||
|
state_in_redis = await state_manager.get_state(
|
||||||
|
token=_substate_key(self.router.session.client_token, state_cls),
|
||||||
|
top_level=False,
|
||||||
|
for_state_instance=self,
|
||||||
|
)
|
||||||
|
|
||||||
# Find the missing parent states up to the common ancestor.
|
if not isinstance(state_in_redis, state_cls):
|
||||||
(
|
raise StateMismatchError(
|
||||||
common_ancestor_name,
|
f"Searched for state {state_cls.get_full_name()} but found {state_in_redis}."
|
||||||
missing_parent_states,
|
|
||||||
) = self._determine_missing_parent_states(target_state_cls)
|
|
||||||
|
|
||||||
# Fetch all missing parent states and link them up to the common ancestor.
|
|
||||||
parent_states_tuple = self._get_parent_states()
|
|
||||||
root_state = parent_states_tuple[-1][1]
|
|
||||||
parent_states_by_name = dict(parent_states_tuple)
|
|
||||||
parent_state = parent_states_by_name[common_ancestor_name]
|
|
||||||
for parent_state_name in missing_parent_states:
|
|
||||||
try:
|
|
||||||
parent_state = root_state.get_substate(parent_state_name.split("."))
|
|
||||||
# The requested state is already cached, do NOT fetch it again.
|
|
||||||
continue
|
|
||||||
except ValueError:
|
|
||||||
# The requested state is missing, fetch from redis.
|
|
||||||
pass
|
|
||||||
parent_state = await state_manager.get_state(
|
|
||||||
token=_substate_key(
|
|
||||||
self.router.session.client_token, parent_state_name
|
|
||||||
),
|
|
||||||
top_level=False,
|
|
||||||
get_substates=False,
|
|
||||||
parent_state=parent_state,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# Return the direct parent of target_state_cls for subsequent linking.
|
return state_in_redis
|
||||||
return parent_state
|
|
||||||
|
|
||||||
def _get_state_from_cache(self, state_cls: Type[T_STATE]) -> T_STATE:
|
def _get_state_from_cache(self, state_cls: Type[T_STATE]) -> T_STATE:
|
||||||
"""Get a state instance from the cache.
|
"""Get a state instance from the cache.
|
||||||
@ -1562,44 +1532,6 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
|
|||||||
)
|
)
|
||||||
return substate
|
return substate
|
||||||
|
|
||||||
async def _get_state_from_redis(self, state_cls: Type[T_STATE]) -> T_STATE:
|
|
||||||
"""Get a state instance from redis.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
state_cls: The class of the state.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
The instance of state_cls associated with this state's client_token.
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
RuntimeError: If redis is not used in this backend process.
|
|
||||||
StateMismatchError: If the state instance is not of the expected type.
|
|
||||||
"""
|
|
||||||
# Fetch all missing parent states from redis.
|
|
||||||
parent_state_of_state_cls = await self._populate_parent_states(state_cls)
|
|
||||||
|
|
||||||
# Then get the target state and all its substates.
|
|
||||||
state_manager = get_state_manager()
|
|
||||||
if not isinstance(state_manager, StateManagerRedis):
|
|
||||||
raise RuntimeError(
|
|
||||||
f"Requested state {state_cls.get_full_name()} is not cached and cannot be accessed without redis. "
|
|
||||||
"(All states should already be available -- this is likely a bug).",
|
|
||||||
)
|
|
||||||
|
|
||||||
state_in_redis = await state_manager.get_state(
|
|
||||||
token=_substate_key(self.router.session.client_token, state_cls),
|
|
||||||
top_level=False,
|
|
||||||
get_substates=True,
|
|
||||||
parent_state=parent_state_of_state_cls,
|
|
||||||
)
|
|
||||||
|
|
||||||
if not isinstance(state_in_redis, state_cls):
|
|
||||||
raise StateMismatchError(
|
|
||||||
f"Searched for state {state_cls.get_full_name()} but found {state_in_redis}."
|
|
||||||
)
|
|
||||||
|
|
||||||
return state_in_redis
|
|
||||||
|
|
||||||
async def get_state(self, state_cls: Type[T_STATE]) -> T_STATE:
|
async def get_state(self, state_cls: Type[T_STATE]) -> T_STATE:
|
||||||
"""Get an instance of the state associated with this token.
|
"""Get an instance of the state associated with this token.
|
||||||
|
|
||||||
@ -1738,7 +1670,7 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
|
|||||||
f"Your handler {handler.fn.__qualname__} must only return/yield: None, Events or other EventHandlers referenced by their class (not using `self`)"
|
f"Your handler {handler.fn.__qualname__} must only return/yield: None, Events or other EventHandlers referenced by their class (not using `self`)"
|
||||||
)
|
)
|
||||||
|
|
||||||
def _as_state_update(
|
async def _as_state_update(
|
||||||
self,
|
self,
|
||||||
handler: EventHandler,
|
handler: EventHandler,
|
||||||
events: EventSpec | list[EventSpec] | None,
|
events: EventSpec | list[EventSpec] | None,
|
||||||
@ -1766,7 +1698,7 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
# Get the delta after processing the event.
|
# Get the delta after processing the event.
|
||||||
delta = state.get_delta()
|
delta = await _resolve_delta(state.get_delta())
|
||||||
state._clean()
|
state._clean()
|
||||||
|
|
||||||
return StateUpdate(
|
return StateUpdate(
|
||||||
@ -1866,24 +1798,28 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
|
|||||||
# Handle async generators.
|
# Handle async generators.
|
||||||
if inspect.isasyncgen(events):
|
if inspect.isasyncgen(events):
|
||||||
async for event in events:
|
async for event in events:
|
||||||
yield state._as_state_update(handler, event, final=False)
|
yield await state._as_state_update(handler, event, final=False)
|
||||||
yield state._as_state_update(handler, events=None, final=True)
|
yield await state._as_state_update(handler, events=None, final=True)
|
||||||
|
|
||||||
# Handle regular generators.
|
# Handle regular generators.
|
||||||
elif inspect.isgenerator(events):
|
elif inspect.isgenerator(events):
|
||||||
try:
|
try:
|
||||||
while True:
|
while True:
|
||||||
yield state._as_state_update(handler, next(events), final=False)
|
yield await state._as_state_update(
|
||||||
|
handler, next(events), final=False
|
||||||
|
)
|
||||||
except StopIteration as si:
|
except StopIteration as si:
|
||||||
# the "return" value of the generator is not available
|
# the "return" value of the generator is not available
|
||||||
# in the loop, we must catch StopIteration to access it
|
# in the loop, we must catch StopIteration to access it
|
||||||
if si.value is not None:
|
if si.value is not None:
|
||||||
yield state._as_state_update(handler, si.value, final=False)
|
yield await state._as_state_update(
|
||||||
yield state._as_state_update(handler, events=None, final=True)
|
handler, si.value, final=False
|
||||||
|
)
|
||||||
|
yield await state._as_state_update(handler, events=None, final=True)
|
||||||
|
|
||||||
# Handle regular event chains.
|
# Handle regular event chains.
|
||||||
else:
|
else:
|
||||||
yield state._as_state_update(handler, events, final=True)
|
yield await state._as_state_update(handler, events, final=True)
|
||||||
|
|
||||||
# If an error occurs, throw a window alert.
|
# If an error occurs, throw a window alert.
|
||||||
except Exception as ex:
|
except Exception as ex:
|
||||||
@ -1893,7 +1829,7 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
|
|||||||
prerequisites.get_and_validate_app().app.backend_exception_handler(ex)
|
prerequisites.get_and_validate_app().app.backend_exception_handler(ex)
|
||||||
)
|
)
|
||||||
|
|
||||||
yield state._as_state_update(
|
yield await state._as_state_update(
|
||||||
handler,
|
handler,
|
||||||
event_specs,
|
event_specs,
|
||||||
final=True,
|
final=True,
|
||||||
@ -1901,15 +1837,28 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
|
|||||||
|
|
||||||
def _mark_dirty_computed_vars(self) -> None:
|
def _mark_dirty_computed_vars(self) -> None:
|
||||||
"""Mark ComputedVars that need to be recalculated based on dirty_vars."""
|
"""Mark ComputedVars that need to be recalculated based on dirty_vars."""
|
||||||
|
# Append expired computed vars to dirty_vars to trigger recalculation
|
||||||
|
self.dirty_vars.update(self._expired_computed_vars())
|
||||||
|
# Append always dirty computed vars to dirty_vars to trigger recalculation
|
||||||
|
self.dirty_vars.update(self._always_dirty_computed_vars)
|
||||||
|
|
||||||
dirty_vars = self.dirty_vars
|
dirty_vars = self.dirty_vars
|
||||||
while dirty_vars:
|
while dirty_vars:
|
||||||
calc_vars, dirty_vars = dirty_vars, set()
|
calc_vars, dirty_vars = dirty_vars, set()
|
||||||
for cvar in self._dirty_computed_vars(from_vars=calc_vars):
|
for state_name, cvar in self._dirty_computed_vars(from_vars=calc_vars):
|
||||||
self.dirty_vars.add(cvar)
|
if state_name == self.get_full_name():
|
||||||
|
defining_state = self
|
||||||
|
else:
|
||||||
|
defining_state = self._get_root_state().get_substate(
|
||||||
|
tuple(state_name.split("."))
|
||||||
|
)
|
||||||
|
defining_state.dirty_vars.add(cvar)
|
||||||
dirty_vars.add(cvar)
|
dirty_vars.add(cvar)
|
||||||
actual_var = self.computed_vars.get(cvar)
|
actual_var = defining_state.computed_vars.get(cvar)
|
||||||
if actual_var is not None:
|
if actual_var is not None:
|
||||||
actual_var.mark_dirty(instance=self)
|
actual_var.mark_dirty(instance=defining_state)
|
||||||
|
if defining_state is not self:
|
||||||
|
defining_state._mark_dirty()
|
||||||
|
|
||||||
def _expired_computed_vars(self) -> set[str]:
|
def _expired_computed_vars(self) -> set[str]:
|
||||||
"""Determine ComputedVars that need to be recalculated based on the expiration time.
|
"""Determine ComputedVars that need to be recalculated based on the expiration time.
|
||||||
@ -1925,7 +1874,7 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
|
|||||||
|
|
||||||
def _dirty_computed_vars(
|
def _dirty_computed_vars(
|
||||||
self, from_vars: set[str] | None = None, include_backend: bool = True
|
self, from_vars: set[str] | None = None, include_backend: bool = True
|
||||||
) -> set[str]:
|
) -> set[tuple[str, str]]:
|
||||||
"""Determine ComputedVars that need to be recalculated based on the given vars.
|
"""Determine ComputedVars that need to be recalculated based on the given vars.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@ -1936,33 +1885,12 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
|
|||||||
Set of computed vars to include in the delta.
|
Set of computed vars to include in the delta.
|
||||||
"""
|
"""
|
||||||
return {
|
return {
|
||||||
cvar
|
(state_name, cvar)
|
||||||
for dirty_var in from_vars or self.dirty_vars
|
for dirty_var in from_vars or self.dirty_vars
|
||||||
for cvar in self._computed_var_dependencies[dirty_var]
|
for state_name, cvar in self._var_dependencies.get(dirty_var, set())
|
||||||
if include_backend or not self.computed_vars[cvar]._backend
|
if include_backend or not self.computed_vars[cvar]._backend
|
||||||
}
|
}
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def _potentially_dirty_substates(cls) -> set[Type[BaseState]]:
|
|
||||||
"""Determine substates which could be affected by dirty vars in this state.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Set of State classes that may need to be fetched to recalc computed vars.
|
|
||||||
"""
|
|
||||||
# _always_dirty_substates need to be fetched to recalc computed vars.
|
|
||||||
fetch_substates = {
|
|
||||||
cls.get_class_substate((cls.get_name(), *substate_name.split(".")))
|
|
||||||
for substate_name in cls._always_dirty_substates
|
|
||||||
}
|
|
||||||
for dependent_substates in cls._substate_var_dependencies.values():
|
|
||||||
fetch_substates.update(
|
|
||||||
{
|
|
||||||
cls.get_class_substate((cls.get_name(), *substate_name.split(".")))
|
|
||||||
for substate_name in dependent_substates
|
|
||||||
}
|
|
||||||
)
|
|
||||||
return fetch_substates
|
|
||||||
|
|
||||||
def get_delta(self) -> Delta:
|
def get_delta(self) -> Delta:
|
||||||
"""Get the delta for the state.
|
"""Get the delta for the state.
|
||||||
|
|
||||||
@ -1971,21 +1899,15 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
|
|||||||
"""
|
"""
|
||||||
delta = {}
|
delta = {}
|
||||||
|
|
||||||
# Apply dirty variables down into substates
|
self._mark_dirty_computed_vars()
|
||||||
self.dirty_vars.update(self._always_dirty_computed_vars)
|
|
||||||
self._mark_dirty()
|
|
||||||
|
|
||||||
frontend_computed_vars: set[str] = {
|
frontend_computed_vars: set[str] = {
|
||||||
name for name, cv in self.computed_vars.items() if not cv._backend
|
name for name, cv in self.computed_vars.items() if not cv._backend
|
||||||
}
|
}
|
||||||
|
|
||||||
# Return the dirty vars for this instance, any cached/dependent computed vars,
|
# Return the dirty vars for this instance, any cached/dependent computed vars,
|
||||||
# and always dirty computed vars (cache=False)
|
# and always dirty computed vars (cache=False)
|
||||||
delta_vars = (
|
delta_vars = self.dirty_vars.intersection(self.base_vars).union(
|
||||||
self.dirty_vars.intersection(self.base_vars)
|
self.dirty_vars.intersection(frontend_computed_vars)
|
||||||
.union(self.dirty_vars.intersection(frontend_computed_vars))
|
|
||||||
.union(self._dirty_computed_vars(include_backend=False))
|
|
||||||
.union(self._always_dirty_computed_vars)
|
|
||||||
)
|
)
|
||||||
|
|
||||||
subdelta: Dict[str, Any] = {
|
subdelta: Dict[str, Any] = {
|
||||||
@ -2015,23 +1937,9 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
|
|||||||
self.parent_state.dirty_substates.add(self.get_name())
|
self.parent_state.dirty_substates.add(self.get_name())
|
||||||
self.parent_state._mark_dirty()
|
self.parent_state._mark_dirty()
|
||||||
|
|
||||||
# Append expired computed vars to dirty_vars to trigger recalculation
|
|
||||||
self.dirty_vars.update(self._expired_computed_vars())
|
|
||||||
|
|
||||||
# have to mark computed vars dirty to allow access to newly computed
|
# have to mark computed vars dirty to allow access to newly computed
|
||||||
# values within the same ComputedVar function
|
# values within the same ComputedVar function
|
||||||
self._mark_dirty_computed_vars()
|
self._mark_dirty_computed_vars()
|
||||||
self._mark_dirty_substates()
|
|
||||||
|
|
||||||
def _mark_dirty_substates(self):
|
|
||||||
"""Propagate dirty var / computed var status into substates."""
|
|
||||||
substates = self.substates
|
|
||||||
for var in self.dirty_vars:
|
|
||||||
for substate_name in self._substate_var_dependencies[var]:
|
|
||||||
self.dirty_substates.add(substate_name)
|
|
||||||
substate = substates[substate_name]
|
|
||||||
substate.dirty_vars.add(var)
|
|
||||||
substate._mark_dirty()
|
|
||||||
|
|
||||||
def _update_was_touched(self):
|
def _update_was_touched(self):
|
||||||
"""Update the _was_touched flag based on dirty_vars."""
|
"""Update the _was_touched flag based on dirty_vars."""
|
||||||
@ -2103,11 +2011,7 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
|
|||||||
The object as a dictionary.
|
The object as a dictionary.
|
||||||
"""
|
"""
|
||||||
if include_computed:
|
if include_computed:
|
||||||
# Apply dirty variables down into substates to allow never-cached ComputedVar to
|
self._mark_dirty_computed_vars()
|
||||||
# trigger recalculation of dependent vars
|
|
||||||
self.dirty_vars.update(self._always_dirty_computed_vars)
|
|
||||||
self._mark_dirty()
|
|
||||||
|
|
||||||
base_vars = {
|
base_vars = {
|
||||||
prop_name: self.get_value(prop_name) for prop_name in self.base_vars
|
prop_name: self.get_value(prop_name) for prop_name in self.base_vars
|
||||||
}
|
}
|
||||||
@ -2824,7 +2728,7 @@ class StateProxy(wrapt.ObjectProxy):
|
|||||||
await self.__wrapped__.get_state(state_cls), parent_state_proxy=self
|
await self.__wrapped__.get_state(state_cls), parent_state_proxy=self
|
||||||
)
|
)
|
||||||
|
|
||||||
def _as_state_update(self, *args, **kwargs) -> StateUpdate:
|
async def _as_state_update(self, *args, **kwargs) -> StateUpdate:
|
||||||
"""Temporarily allow mutability to access parent_state.
|
"""Temporarily allow mutability to access parent_state.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@ -2837,7 +2741,7 @@ class StateProxy(wrapt.ObjectProxy):
|
|||||||
original_mutable = self._self_mutable
|
original_mutable = self._self_mutable
|
||||||
self._self_mutable = True
|
self._self_mutable = True
|
||||||
try:
|
try:
|
||||||
return self.__wrapped__._as_state_update(*args, **kwargs)
|
return await self.__wrapped__._as_state_update(*args, **kwargs)
|
||||||
finally:
|
finally:
|
||||||
self._self_mutable = original_mutable
|
self._self_mutable = original_mutable
|
||||||
|
|
||||||
@ -3313,103 +3217,106 @@ class StateManagerRedis(StateManager):
|
|||||||
b"evicted",
|
b"evicted",
|
||||||
}
|
}
|
||||||
|
|
||||||
async def _get_parent_state(
|
def _get_required_state_classes(
|
||||||
self, token: str, state: BaseState | None = None
|
self,
|
||||||
) -> BaseState | None:
|
target_state_cls: Type[BaseState],
|
||||||
"""Get the parent state for the state requested in the token.
|
subclasses: bool = False,
|
||||||
|
required_state_classes: set[Type[BaseState]] | None = None,
|
||||||
|
) -> set[Type[BaseState]]:
|
||||||
|
"""Recursively determine which states are required to fetch the target state.
|
||||||
|
|
||||||
|
This will always include potentially dirty substates that depend on vars
|
||||||
|
in the target_state_cls.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
token: The token to get the state for (_substate_key).
|
target_state_cls: The target state class being fetched.
|
||||||
state: The state instance to get parent state for.
|
subclasses: Whether to include subclasses of the target state.
|
||||||
|
required_state_classes: Recursive argument tracking state classes that have already been seen.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
The parent state for the state requested by the token or None if there is no such parent.
|
The set of state classes required to fetch the target state.
|
||||||
"""
|
"""
|
||||||
parent_state = None
|
if required_state_classes is None:
|
||||||
client_token, state_path = _split_substate_key(token)
|
required_state_classes = set()
|
||||||
parent_state_name = state_path.rpartition(".")[0]
|
# Get the substates if requested.
|
||||||
if parent_state_name:
|
if subclasses:
|
||||||
cached_substates = None
|
for substate in target_state_cls.get_substates():
|
||||||
if state is not None:
|
self._get_required_state_classes(
|
||||||
cached_substates = [state]
|
substate,
|
||||||
# Retrieve the parent state to populate event handlers onto this substate.
|
subclasses=True,
|
||||||
parent_state = await self.get_state(
|
required_state_classes=required_state_classes,
|
||||||
token=_substate_key(client_token, parent_state_name),
|
)
|
||||||
top_level=False,
|
if target_state_cls in required_state_classes:
|
||||||
get_substates=False,
|
return required_state_classes
|
||||||
cached_substates=cached_substates,
|
required_state_classes.add(target_state_cls)
|
||||||
|
|
||||||
|
# Get dependent substates.
|
||||||
|
for pd_substates in target_state_cls._get_potentially_dirty_states():
|
||||||
|
self._get_required_state_classes(
|
||||||
|
pd_substates,
|
||||||
|
subclasses=False,
|
||||||
|
required_state_classes=required_state_classes,
|
||||||
)
|
)
|
||||||
return parent_state
|
|
||||||
|
|
||||||
async def _populate_substates(
|
# Get the parent state if it exists.
|
||||||
|
if parent_state := target_state_cls.get_parent_state():
|
||||||
|
self._get_required_state_classes(
|
||||||
|
parent_state,
|
||||||
|
subclasses=False,
|
||||||
|
required_state_classes=required_state_classes,
|
||||||
|
)
|
||||||
|
return required_state_classes
|
||||||
|
|
||||||
|
def _get_populated_states(
|
||||||
self,
|
self,
|
||||||
token: str,
|
target_state: BaseState,
|
||||||
state: BaseState,
|
populated_states: dict[str, BaseState] | None = None,
|
||||||
all_substates: bool = False,
|
) -> dict[str, BaseState]:
|
||||||
):
|
"""Recursively determine which states from target_state are already fetched.
|
||||||
"""Fetch and link substates for the given state instance.
|
|
||||||
|
|
||||||
There is no return value; the side-effect is that `state` will have `substates` populated,
|
|
||||||
and each substate will have its `parent_state` set to `state`.
|
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
token: The token to get the state for.
|
target_state: The state to check for populated states.
|
||||||
state: The state instance to populate substates for.
|
populated_states: Recursive argument tracking states seen in previous calls.
|
||||||
all_substates: Whether to fetch all substates or just required substates.
|
|
||||||
|
Returns:
|
||||||
|
A dictionary of state full name to state instance.
|
||||||
"""
|
"""
|
||||||
client_token, _ = _split_substate_key(token)
|
if populated_states is None:
|
||||||
|
populated_states = {}
|
||||||
if all_substates:
|
if target_state.get_full_name() in populated_states:
|
||||||
# All substates are requested.
|
return populated_states
|
||||||
fetch_substates = state.get_substates()
|
populated_states[target_state.get_full_name()] = target_state
|
||||||
else:
|
for substate in target_state.substates.values():
|
||||||
# Only _potentially_dirty_substates need to be fetched to recalc computed vars.
|
self._get_populated_states(substate, populated_states=populated_states)
|
||||||
fetch_substates = state._potentially_dirty_substates()
|
if target_state.parent_state is not None:
|
||||||
|
self._get_populated_states(
|
||||||
tasks = {}
|
target_state.parent_state, populated_states=populated_states
|
||||||
# Retrieve the necessary substates from redis.
|
|
||||||
for substate_cls in fetch_substates:
|
|
||||||
if substate_cls.get_name() in state.substates:
|
|
||||||
continue
|
|
||||||
substate_name = substate_cls.get_name()
|
|
||||||
tasks[substate_name] = asyncio.create_task(
|
|
||||||
self.get_state(
|
|
||||||
token=_substate_key(client_token, substate_cls),
|
|
||||||
top_level=False,
|
|
||||||
get_substates=all_substates,
|
|
||||||
parent_state=state,
|
|
||||||
)
|
|
||||||
)
|
)
|
||||||
|
return populated_states
|
||||||
for substate_name, substate_task in tasks.items():
|
|
||||||
state.substates[substate_name] = await substate_task
|
|
||||||
|
|
||||||
@override
|
@override
|
||||||
async def get_state(
|
async def get_state(
|
||||||
self,
|
self,
|
||||||
token: str,
|
token: str,
|
||||||
top_level: bool = True,
|
top_level: bool = True,
|
||||||
get_substates: bool = True,
|
for_state_instance: BaseState | None = None,
|
||||||
parent_state: BaseState | None = None,
|
|
||||||
cached_substates: list[BaseState] | None = None,
|
|
||||||
) -> BaseState:
|
) -> BaseState:
|
||||||
"""Get the state for a token.
|
"""Get the state for a token.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
token: The token to get the state for.
|
token: The token to get the state for.
|
||||||
top_level: If true, return an instance of the top-level state (self.state).
|
top_level: If true, return an instance of the top-level state (self.state).
|
||||||
get_substates: If true, also retrieve substates.
|
for_state_instance: If provided, attach the requested states to this existing state tree.
|
||||||
parent_state: If provided, use this parent_state instead of getting it from redis.
|
|
||||||
cached_substates: If provided, attach these substates to the state.
|
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
The state for the token.
|
The state for the token.
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
RuntimeError: when the state_cls is not specified in the token
|
RuntimeError: when the state_cls is not specified in the token, or when the parent state for a
|
||||||
|
requested state was not fetched.
|
||||||
"""
|
"""
|
||||||
# Split the actual token from the fully qualified substate name.
|
# Split the actual token from the fully qualified substate name.
|
||||||
_, state_path = _split_substate_key(token)
|
token, state_path = _split_substate_key(token)
|
||||||
if state_path:
|
if state_path:
|
||||||
# Get the State class associated with the given path.
|
# Get the State class associated with the given path.
|
||||||
state_cls = self.state.get_class_substate(state_path)
|
state_cls = self.state.get_class_substate(state_path)
|
||||||
@ -3418,43 +3325,59 @@ class StateManagerRedis(StateManager):
|
|||||||
f"StateManagerRedis requires token to be specified in the form of {{token}}_{{state_full_name}}, but got {token}"
|
f"StateManagerRedis requires token to be specified in the form of {{token}}_{{state_full_name}}, but got {token}"
|
||||||
)
|
)
|
||||||
|
|
||||||
# The deserialized or newly created (sub)state instance.
|
# Determine which states we already have.
|
||||||
state = None
|
flat_state_tree: dict[str, BaseState] = (
|
||||||
|
self._get_populated_states(for_state_instance) if for_state_instance else {}
|
||||||
|
)
|
||||||
|
|
||||||
# Fetch the serialized substate from redis.
|
# Determine which states from the tree need to be fetched.
|
||||||
redis_state = await self.redis.get(token)
|
required_state_classes = sorted(
|
||||||
|
self._get_required_state_classes(state_cls, subclasses=True)
|
||||||
|
- {type(s) for s in flat_state_tree.values()},
|
||||||
|
key=lambda x: x.get_full_name(),
|
||||||
|
)
|
||||||
|
|
||||||
if redis_state is not None:
|
redis_pipeline = self.redis.pipeline()
|
||||||
# Deserialize the substate.
|
for state_cls in required_state_classes:
|
||||||
with contextlib.suppress(StateSchemaMismatchError):
|
redis_pipeline.get(_substate_key(token, state_cls))
|
||||||
state = BaseState._deserialize(data=redis_state)
|
|
||||||
if state is None:
|
for state_cls, redis_state in zip(
|
||||||
# Key didn't exist or schema mismatch so create a new instance for this token.
|
required_state_classes,
|
||||||
state = state_cls(
|
await redis_pipeline.execute(),
|
||||||
init_substates=False,
|
strict=False,
|
||||||
_reflex_internal_init=True,
|
):
|
||||||
)
|
state = None
|
||||||
# Populate parent state if missing and requested.
|
|
||||||
if parent_state is None:
|
if redis_state is not None:
|
||||||
parent_state = await self._get_parent_state(token, state)
|
# Deserialize the substate.
|
||||||
# Set up Bidirectional linkage between this state and its parent.
|
with contextlib.suppress(StateSchemaMismatchError):
|
||||||
if parent_state is not None:
|
state = BaseState._deserialize(data=redis_state)
|
||||||
parent_state.substates[state.get_name()] = state
|
if state is None:
|
||||||
state.parent_state = parent_state
|
# Key didn't exist or schema mismatch so create a new instance for this token.
|
||||||
# Avoid fetching substates multiple times.
|
state = state_cls(
|
||||||
if cached_substates:
|
init_substates=False,
|
||||||
for substate in cached_substates:
|
_reflex_internal_init=True,
|
||||||
state.substates[substate.get_name()] = substate
|
)
|
||||||
if substate.parent_state is None:
|
flat_state_tree[state.get_full_name()] = state
|
||||||
substate.parent_state = state
|
if state.get_parent_state() is not None:
|
||||||
# Populate substates if requested.
|
parent_state_name, _dot, state_name = state.get_full_name().rpartition(
|
||||||
await self._populate_substates(token, state, all_substates=get_substates)
|
"."
|
||||||
|
)
|
||||||
|
parent_state = flat_state_tree.get(parent_state_name)
|
||||||
|
if parent_state is None:
|
||||||
|
raise RuntimeError(
|
||||||
|
f"Parent state for {state.get_full_name()} was not found "
|
||||||
|
"in the state tree, but should have already been fetched. "
|
||||||
|
"This is a bug",
|
||||||
|
)
|
||||||
|
parent_state.substates[state_name] = state
|
||||||
|
state.parent_state = parent_state
|
||||||
|
|
||||||
# To retain compatibility with previous implementation, by default, we return
|
# To retain compatibility with previous implementation, by default, we return
|
||||||
# the top-level state by chasing `parent_state` pointers up the tree.
|
# the top-level state which should always be fetched or already cached.
|
||||||
if top_level:
|
if top_level:
|
||||||
return state._get_root_state()
|
return flat_state_tree[self.state.get_full_name()]
|
||||||
return state
|
return flat_state_tree[state_cls.get_full_name()]
|
||||||
|
|
||||||
@override
|
@override
|
||||||
async def set_state(
|
async def set_state(
|
||||||
@ -4154,12 +4077,19 @@ def reload_state_module(
|
|||||||
state: Recursive argument for the state class to reload.
|
state: Recursive argument for the state class to reload.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
# Clean out all potentially dirty states of reloaded modules.
|
||||||
|
for pd_state in tuple(state._potentially_dirty_states):
|
||||||
|
with contextlib.suppress(ValueError):
|
||||||
|
if (
|
||||||
|
state.get_root_state().get_class_substate(pd_state).__module__ == module
|
||||||
|
and module is not None
|
||||||
|
):
|
||||||
|
state._potentially_dirty_states.remove(pd_state)
|
||||||
for subclass in tuple(state.class_subclasses):
|
for subclass in tuple(state.class_subclasses):
|
||||||
reload_state_module(module=module, state=subclass)
|
reload_state_module(module=module, state=subclass)
|
||||||
if subclass.__module__ == module and module is not None:
|
if subclass.__module__ == module and module is not None:
|
||||||
state.class_subclasses.remove(subclass)
|
state.class_subclasses.remove(subclass)
|
||||||
state._always_dirty_substates.discard(subclass.get_name())
|
state._always_dirty_substates.discard(subclass.get_name())
|
||||||
state._computed_var_dependencies = defaultdict(set)
|
state._var_dependencies = {}
|
||||||
state._substate_var_dependencies = defaultdict(set)
|
|
||||||
state._init_var_dependency_dicts()
|
state._init_var_dependency_dicts()
|
||||||
state.get_class_substate.cache_clear()
|
state.get_class_substate.cache_clear()
|
||||||
|
@ -488,7 +488,7 @@ def output_system_info():
|
|||||||
dependencies.append(fnm_info)
|
dependencies.append(fnm_info)
|
||||||
|
|
||||||
if system == "Linux":
|
if system == "Linux":
|
||||||
import distro
|
import distro # pyright: ignore[reportMissingImports]
|
||||||
|
|
||||||
os_version = distro.name(pretty=True)
|
os_version = distro.name(pretty=True)
|
||||||
else:
|
else:
|
||||||
|
@ -5,7 +5,6 @@ from __future__ import annotations
|
|||||||
import contextlib
|
import contextlib
|
||||||
import dataclasses
|
import dataclasses
|
||||||
import datetime
|
import datetime
|
||||||
import dis
|
|
||||||
import functools
|
import functools
|
||||||
import inspect
|
import inspect
|
||||||
import json
|
import json
|
||||||
@ -20,6 +19,7 @@ from typing import (
|
|||||||
Any,
|
Any,
|
||||||
Callable,
|
Callable,
|
||||||
ClassVar,
|
ClassVar,
|
||||||
|
Coroutine,
|
||||||
Dict,
|
Dict,
|
||||||
FrozenSet,
|
FrozenSet,
|
||||||
Generic,
|
Generic,
|
||||||
@ -51,7 +51,6 @@ from reflex.utils.exceptions import (
|
|||||||
VarAttributeError,
|
VarAttributeError,
|
||||||
VarDependencyError,
|
VarDependencyError,
|
||||||
VarTypeError,
|
VarTypeError,
|
||||||
VarValueError,
|
|
||||||
)
|
)
|
||||||
from reflex.utils.format import format_state_name
|
from reflex.utils.format import format_state_name
|
||||||
from reflex.utils.imports import (
|
from reflex.utils.imports import (
|
||||||
@ -1983,7 +1982,7 @@ class ComputedVar(Var[RETURN_TYPE]):
|
|||||||
_initial_value: RETURN_TYPE | types.Unset = dataclasses.field(default=types.Unset())
|
_initial_value: RETURN_TYPE | types.Unset = dataclasses.field(default=types.Unset())
|
||||||
|
|
||||||
# Explicit var dependencies to track
|
# Explicit var dependencies to track
|
||||||
_static_deps: set[str] = dataclasses.field(default_factory=set)
|
_static_deps: dict[str, set[str]] = dataclasses.field(default_factory=dict)
|
||||||
|
|
||||||
# Whether var dependencies should be auto-determined
|
# Whether var dependencies should be auto-determined
|
||||||
_auto_deps: bool = dataclasses.field(default=True)
|
_auto_deps: bool = dataclasses.field(default=True)
|
||||||
@ -2053,21 +2052,34 @@ class ComputedVar(Var[RETURN_TYPE]):
|
|||||||
|
|
||||||
object.__setattr__(self, "_update_interval", interval)
|
object.__setattr__(self, "_update_interval", interval)
|
||||||
|
|
||||||
if deps is None:
|
_static_deps = {}
|
||||||
deps = []
|
if isinstance(deps, dict):
|
||||||
else:
|
# Assume a dict is coming from _replace, so no special processing.
|
||||||
|
_static_deps = deps
|
||||||
|
elif deps is not None:
|
||||||
for dep in deps:
|
for dep in deps:
|
||||||
if isinstance(dep, Var):
|
if isinstance(dep, Var):
|
||||||
continue
|
state_name = (
|
||||||
if isinstance(dep, str) and dep != "":
|
all_var_data.state
|
||||||
continue
|
if (all_var_data := dep._get_all_var_data())
|
||||||
raise TypeError(
|
and all_var_data.state
|
||||||
"ComputedVar dependencies must be Var instances or var names (non-empty strings)."
|
else None
|
||||||
)
|
)
|
||||||
|
if all_var_data is not None:
|
||||||
|
var_name = all_var_data.field_name
|
||||||
|
else:
|
||||||
|
var_name = dep._js_expr
|
||||||
|
_static_deps.setdefault(state_name, set()).add(var_name)
|
||||||
|
elif isinstance(dep, str) and dep != "":
|
||||||
|
_static_deps.setdefault(None, set()).add(dep)
|
||||||
|
else:
|
||||||
|
raise TypeError(
|
||||||
|
"ComputedVar dependencies must be Var instances or var names (non-empty strings)."
|
||||||
|
)
|
||||||
object.__setattr__(
|
object.__setattr__(
|
||||||
self,
|
self,
|
||||||
"_static_deps",
|
"_static_deps",
|
||||||
{dep._js_expr if isinstance(dep, Var) else dep for dep in deps},
|
_static_deps,
|
||||||
)
|
)
|
||||||
object.__setattr__(self, "_auto_deps", auto_deps)
|
object.__setattr__(self, "_auto_deps", auto_deps)
|
||||||
|
|
||||||
@ -2149,6 +2161,13 @@ class ComputedVar(Var[RETURN_TYPE]):
|
|||||||
return True
|
return True
|
||||||
return datetime.datetime.now() - last_updated > self._update_interval
|
return datetime.datetime.now() - last_updated > self._update_interval
|
||||||
|
|
||||||
|
@overload
|
||||||
|
def __get__(
|
||||||
|
self: ComputedVar[bool],
|
||||||
|
instance: None,
|
||||||
|
owner: Type,
|
||||||
|
) -> BooleanVar: ...
|
||||||
|
|
||||||
@overload
|
@overload
|
||||||
def __get__(
|
def __get__(
|
||||||
self: ComputedVar[int] | ComputedVar[float],
|
self: ComputedVar[int] | ComputedVar[float],
|
||||||
@ -2233,125 +2252,67 @@ class ComputedVar(Var[RETURN_TYPE]):
|
|||||||
setattr(instance, self._last_updated_attr, datetime.datetime.now())
|
setattr(instance, self._last_updated_attr, datetime.datetime.now())
|
||||||
value = getattr(instance, self._cache_attr)
|
value = getattr(instance, self._cache_attr)
|
||||||
|
|
||||||
|
self._check_deprecated_return_type(instance, value)
|
||||||
|
|
||||||
|
return value
|
||||||
|
|
||||||
|
def _check_deprecated_return_type(self, instance: BaseState, value: Any) -> None:
|
||||||
if not _isinstance(value, self._var_type):
|
if not _isinstance(value, self._var_type):
|
||||||
console.error(
|
console.error(
|
||||||
f"Computed var '{type(instance).__name__}.{self._js_expr}' must return"
|
f"Computed var '{type(instance).__name__}.{self._js_expr}' must return"
|
||||||
f" type '{self._var_type}', got '{type(value)}'."
|
f" type '{self._var_type}', got '{type(value)}'."
|
||||||
)
|
)
|
||||||
|
|
||||||
return value
|
|
||||||
|
|
||||||
def _deps(
|
def _deps(
|
||||||
self,
|
self,
|
||||||
objclass: Type,
|
objclass: Type[BaseState],
|
||||||
obj: FunctionType | CodeType | None = None,
|
obj: FunctionType | CodeType | None = None,
|
||||||
self_name: Optional[str] = None,
|
) -> dict[str, set[str]]:
|
||||||
) -> set[str]:
|
|
||||||
"""Determine var dependencies of this ComputedVar.
|
"""Determine var dependencies of this ComputedVar.
|
||||||
|
|
||||||
Save references to attributes accessed on "self". Recursively called
|
Save references to attributes accessed on "self" or other fetched states.
|
||||||
when the function makes a method call on "self" or define comprehensions
|
|
||||||
or nested functions that may reference "self".
|
Recursively called when the function makes a method call on "self" or
|
||||||
|
define comprehensions or nested functions that may reference "self".
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
objclass: the class obj this ComputedVar is attached to.
|
objclass: the class obj this ComputedVar is attached to.
|
||||||
obj: the object to disassemble (defaults to the fget function).
|
obj: the object to disassemble (defaults to the fget function).
|
||||||
self_name: if specified, look for this name in LOAD_FAST and LOAD_DEREF instructions.
|
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
A set of variable names accessed by the given obj.
|
A dictionary mapping state names to the set of variable names
|
||||||
|
accessed by the given obj.
|
||||||
Raises:
|
|
||||||
VarValueError: if the function references the get_state, parent_state, or substates attributes
|
|
||||||
(cannot track deps in a related state, only implicitly via parent state).
|
|
||||||
"""
|
"""
|
||||||
|
from .dep_tracking import DependencyTracker
|
||||||
|
|
||||||
|
d = {}
|
||||||
|
if self._static_deps:
|
||||||
|
d.update(self._static_deps)
|
||||||
|
# None is a placeholder for the current state class.
|
||||||
|
if None in d:
|
||||||
|
d[objclass.get_full_name()] = d.pop(None)
|
||||||
|
|
||||||
if not self._auto_deps:
|
if not self._auto_deps:
|
||||||
return self._static_deps
|
return d
|
||||||
d = self._static_deps.copy()
|
|
||||||
if obj is None:
|
if obj is None:
|
||||||
fget = self._fget
|
fget = self._fget
|
||||||
if fget is not None:
|
if fget is not None:
|
||||||
obj = cast(FunctionType, fget)
|
obj = cast(FunctionType, fget)
|
||||||
else:
|
else:
|
||||||
return set()
|
return d
|
||||||
with contextlib.suppress(AttributeError):
|
|
||||||
# unbox functools.partial
|
|
||||||
obj = cast(FunctionType, obj.func) # pyright: ignore [reportAttributeAccessIssue]
|
|
||||||
with contextlib.suppress(AttributeError):
|
|
||||||
# unbox EventHandler
|
|
||||||
obj = cast(FunctionType, obj.fn) # pyright: ignore [reportAttributeAccessIssue]
|
|
||||||
|
|
||||||
if self_name is None and isinstance(obj, FunctionType):
|
try:
|
||||||
try:
|
return DependencyTracker(
|
||||||
# the first argument to the function is the name of "self" arg
|
func=obj, state_cls=objclass, dependencies=d
|
||||||
self_name = obj.__code__.co_varnames[0]
|
).dependencies
|
||||||
except (AttributeError, IndexError):
|
except Exception as e:
|
||||||
self_name = None
|
console.warn(
|
||||||
if self_name is None:
|
"Failed to automatically determine dependencies for computed var "
|
||||||
# cannot reference attributes on self if method takes no args
|
f"{objclass.__name__}.{self._js_expr}: {e}. "
|
||||||
return set()
|
"Provide static_deps and set auto_deps=False to suppress this warning."
|
||||||
|
)
|
||||||
invalid_names = ["get_state", "parent_state", "substates", "get_substate"]
|
return d
|
||||||
self_is_top_of_stack = False
|
|
||||||
for instruction in dis.get_instructions(obj):
|
|
||||||
if (
|
|
||||||
instruction.opname in ("LOAD_FAST", "LOAD_DEREF")
|
|
||||||
and instruction.argval == self_name
|
|
||||||
):
|
|
||||||
# bytecode loaded the class instance to the top of stack, next load instruction
|
|
||||||
# is referencing an attribute on self
|
|
||||||
self_is_top_of_stack = True
|
|
||||||
continue
|
|
||||||
if self_is_top_of_stack and instruction.opname in (
|
|
||||||
"LOAD_ATTR",
|
|
||||||
"LOAD_METHOD",
|
|
||||||
):
|
|
||||||
try:
|
|
||||||
ref_obj = getattr(objclass, instruction.argval)
|
|
||||||
except Exception:
|
|
||||||
ref_obj = None
|
|
||||||
if instruction.argval in invalid_names:
|
|
||||||
raise VarValueError(
|
|
||||||
f"Cached var {self!s} cannot access arbitrary state via `{instruction.argval}`."
|
|
||||||
)
|
|
||||||
if callable(ref_obj):
|
|
||||||
# recurse into callable attributes
|
|
||||||
d.update(
|
|
||||||
self._deps(
|
|
||||||
objclass=objclass,
|
|
||||||
obj=ref_obj, # pyright: ignore [reportArgumentType]
|
|
||||||
)
|
|
||||||
)
|
|
||||||
# recurse into property fget functions
|
|
||||||
elif isinstance(ref_obj, property) and not isinstance(
|
|
||||||
ref_obj, ComputedVar
|
|
||||||
):
|
|
||||||
d.update(
|
|
||||||
self._deps(
|
|
||||||
objclass=objclass,
|
|
||||||
obj=ref_obj.fget, # pyright: ignore [reportArgumentType]
|
|
||||||
)
|
|
||||||
)
|
|
||||||
elif (
|
|
||||||
instruction.argval in objclass.backend_vars
|
|
||||||
or instruction.argval in objclass.vars
|
|
||||||
):
|
|
||||||
# var access
|
|
||||||
d.add(instruction.argval)
|
|
||||||
elif instruction.opname == "LOAD_CONST" and isinstance(
|
|
||||||
instruction.argval, CodeType
|
|
||||||
):
|
|
||||||
# recurse into nested functions / comprehensions, which can reference
|
|
||||||
# instance attributes from the outer scope
|
|
||||||
d.update(
|
|
||||||
self._deps(
|
|
||||||
objclass=objclass,
|
|
||||||
obj=instruction.argval,
|
|
||||||
self_name=self_name,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
self_is_top_of_stack = False
|
|
||||||
return d
|
|
||||||
|
|
||||||
def mark_dirty(self, instance: BaseState) -> None:
|
def mark_dirty(self, instance: BaseState) -> None:
|
||||||
"""Mark this ComputedVar as dirty.
|
"""Mark this ComputedVar as dirty.
|
||||||
@ -2362,6 +2323,37 @@ class ComputedVar(Var[RETURN_TYPE]):
|
|||||||
with contextlib.suppress(AttributeError):
|
with contextlib.suppress(AttributeError):
|
||||||
delattr(instance, self._cache_attr)
|
delattr(instance, self._cache_attr)
|
||||||
|
|
||||||
|
def add_dependency(self, objclass: Type[BaseState], dep: Var):
|
||||||
|
"""Explicitly add a dependency to the ComputedVar.
|
||||||
|
|
||||||
|
After adding the dependency, when the `dep` changes, this computed var
|
||||||
|
will be marked dirty.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
objclass: The class obj this ComputedVar is attached to.
|
||||||
|
dep: The dependency to add.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
VarDependencyError: If the dependency is not a Var instance with a
|
||||||
|
state and field name
|
||||||
|
"""
|
||||||
|
if all_var_data := dep._get_all_var_data():
|
||||||
|
state_name = all_var_data.state
|
||||||
|
if state_name:
|
||||||
|
var_name = all_var_data.field_name
|
||||||
|
if var_name:
|
||||||
|
self._static_deps.setdefault(state_name, set()).add(var_name)
|
||||||
|
objclass.get_root_state().get_class_substate(
|
||||||
|
state_name
|
||||||
|
)._var_dependencies.setdefault(var_name, set()).add(
|
||||||
|
(objclass.get_full_name(), self._js_expr)
|
||||||
|
)
|
||||||
|
return
|
||||||
|
raise VarDependencyError(
|
||||||
|
"ComputedVar dependencies must be Var instances with a state and "
|
||||||
|
f"field name, got {dep!r}."
|
||||||
|
)
|
||||||
|
|
||||||
def _determine_var_type(self) -> Type:
|
def _determine_var_type(self) -> Type:
|
||||||
"""Get the type of the var.
|
"""Get the type of the var.
|
||||||
|
|
||||||
@ -2398,6 +2390,126 @@ class DynamicRouteVar(ComputedVar[Union[str, List[str]]]):
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
async def _default_async_computed_var(_self: BaseState) -> Any:
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
@dataclasses.dataclass(
|
||||||
|
eq=False,
|
||||||
|
frozen=True,
|
||||||
|
init=False,
|
||||||
|
slots=True,
|
||||||
|
)
|
||||||
|
class AsyncComputedVar(ComputedVar[RETURN_TYPE]):
|
||||||
|
"""A computed var that wraps a coroutinefunction."""
|
||||||
|
|
||||||
|
_fget: Callable[[BaseState], Coroutine[None, None, RETURN_TYPE]] = (
|
||||||
|
dataclasses.field(default=_default_async_computed_var)
|
||||||
|
)
|
||||||
|
|
||||||
|
@overload
|
||||||
|
def __get__(
|
||||||
|
self: AsyncComputedVar[bool],
|
||||||
|
instance: None,
|
||||||
|
owner: Type,
|
||||||
|
) -> BooleanVar: ...
|
||||||
|
|
||||||
|
@overload
|
||||||
|
def __get__(
|
||||||
|
self: AsyncComputedVar[int] | ComputedVar[float],
|
||||||
|
instance: None,
|
||||||
|
owner: Type,
|
||||||
|
) -> NumberVar: ...
|
||||||
|
|
||||||
|
@overload
|
||||||
|
def __get__(
|
||||||
|
self: AsyncComputedVar[str],
|
||||||
|
instance: None,
|
||||||
|
owner: Type,
|
||||||
|
) -> StringVar: ...
|
||||||
|
|
||||||
|
@overload
|
||||||
|
def __get__(
|
||||||
|
self: AsyncComputedVar[Mapping[DICT_KEY, DICT_VAL]],
|
||||||
|
instance: None,
|
||||||
|
owner: Type,
|
||||||
|
) -> ObjectVar[Mapping[DICT_KEY, DICT_VAL]]: ...
|
||||||
|
|
||||||
|
@overload
|
||||||
|
def __get__(
|
||||||
|
self: AsyncComputedVar[list[LIST_INSIDE]],
|
||||||
|
instance: None,
|
||||||
|
owner: Type,
|
||||||
|
) -> ArrayVar[list[LIST_INSIDE]]: ...
|
||||||
|
|
||||||
|
@overload
|
||||||
|
def __get__(
|
||||||
|
self: AsyncComputedVar[tuple[LIST_INSIDE, ...]],
|
||||||
|
instance: None,
|
||||||
|
owner: Type,
|
||||||
|
) -> ArrayVar[tuple[LIST_INSIDE, ...]]: ...
|
||||||
|
|
||||||
|
@overload
|
||||||
|
def __get__(self, instance: None, owner: Type) -> AsyncComputedVar[RETURN_TYPE]: ...
|
||||||
|
|
||||||
|
@overload
|
||||||
|
def __get__(
|
||||||
|
self, instance: BaseState, owner: Type
|
||||||
|
) -> Coroutine[None, None, RETURN_TYPE]: ...
|
||||||
|
|
||||||
|
def __get__(
|
||||||
|
self, instance: BaseState | None, owner
|
||||||
|
) -> Var | Coroutine[None, None, RETURN_TYPE]:
|
||||||
|
"""Get the ComputedVar value.
|
||||||
|
|
||||||
|
If the value is already cached on the instance, return the cached value.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
instance: the instance of the class accessing this computed var.
|
||||||
|
owner: the class that this descriptor is attached to.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The value of the var for the given instance.
|
||||||
|
"""
|
||||||
|
if instance is None:
|
||||||
|
return super(AsyncComputedVar, self).__get__(instance, owner)
|
||||||
|
|
||||||
|
if not self._cache:
|
||||||
|
|
||||||
|
async def _awaitable_result(instance: BaseState = instance) -> RETURN_TYPE:
|
||||||
|
value = await self.fget(instance)
|
||||||
|
self._check_deprecated_return_type(instance, value)
|
||||||
|
return value
|
||||||
|
|
||||||
|
return _awaitable_result()
|
||||||
|
else:
|
||||||
|
# handle caching
|
||||||
|
async def _awaitable_result(instance: BaseState = instance) -> RETURN_TYPE:
|
||||||
|
if not hasattr(instance, self._cache_attr) or self.needs_update(
|
||||||
|
instance
|
||||||
|
):
|
||||||
|
# Set cache attr on state instance.
|
||||||
|
setattr(instance, self._cache_attr, await self.fget(instance))
|
||||||
|
# Ensure the computed var gets serialized to redis.
|
||||||
|
instance._was_touched = True
|
||||||
|
# Set the last updated timestamp on the state instance.
|
||||||
|
setattr(instance, self._last_updated_attr, datetime.datetime.now())
|
||||||
|
value = getattr(instance, self._cache_attr)
|
||||||
|
self._check_deprecated_return_type(instance, value)
|
||||||
|
return value
|
||||||
|
|
||||||
|
return _awaitable_result()
|
||||||
|
|
||||||
|
@property
|
||||||
|
def fget(self) -> Callable[[BaseState], Coroutine[None, None, RETURN_TYPE]]:
|
||||||
|
"""Get the getter function.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The getter function.
|
||||||
|
"""
|
||||||
|
return self._fget
|
||||||
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
BASE_STATE = TypeVar("BASE_STATE", bound=BaseState)
|
BASE_STATE = TypeVar("BASE_STATE", bound=BaseState)
|
||||||
|
|
||||||
@ -2464,10 +2576,27 @@ def computed_var(
|
|||||||
raise VarDependencyError("Cannot track dependencies without caching.")
|
raise VarDependencyError("Cannot track dependencies without caching.")
|
||||||
|
|
||||||
if fget is not None:
|
if fget is not None:
|
||||||
return ComputedVar(fget, cache=cache)
|
if inspect.iscoroutinefunction(fget):
|
||||||
|
computed_var_cls = AsyncComputedVar
|
||||||
|
else:
|
||||||
|
computed_var_cls = ComputedVar
|
||||||
|
return computed_var_cls(
|
||||||
|
fget,
|
||||||
|
initial_value=initial_value,
|
||||||
|
cache=cache,
|
||||||
|
deps=deps,
|
||||||
|
auto_deps=auto_deps,
|
||||||
|
interval=interval,
|
||||||
|
backend=backend,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
def wrapper(fget: Callable[[BASE_STATE], Any]) -> ComputedVar:
|
def wrapper(fget: Callable[[BASE_STATE], Any]) -> ComputedVar:
|
||||||
return ComputedVar(
|
if inspect.iscoroutinefunction(fget):
|
||||||
|
computed_var_cls = AsyncComputedVar
|
||||||
|
else:
|
||||||
|
computed_var_cls = ComputedVar
|
||||||
|
return computed_var_cls(
|
||||||
fget,
|
fget,
|
||||||
initial_value=initial_value,
|
initial_value=initial_value,
|
||||||
cache=cache,
|
cache=cache,
|
||||||
|
344
reflex/vars/dep_tracking.py
Normal file
344
reflex/vars/dep_tracking.py
Normal file
@ -0,0 +1,344 @@
|
|||||||
|
"""Collection of base classes."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import contextlib
|
||||||
|
import dataclasses
|
||||||
|
import dis
|
||||||
|
import enum
|
||||||
|
import inspect
|
||||||
|
from types import CellType, CodeType, FunctionType
|
||||||
|
from typing import TYPE_CHECKING, Any, ClassVar, Type, cast
|
||||||
|
|
||||||
|
from reflex.utils.exceptions import VarValueError
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from reflex.state import BaseState
|
||||||
|
|
||||||
|
from .base import Var
|
||||||
|
|
||||||
|
|
||||||
|
CellEmpty = object()
|
||||||
|
|
||||||
|
|
||||||
|
def get_cell_value(cell: CellType) -> Any:
|
||||||
|
"""Get the value of a cell object.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
cell: The cell object to get the value from. (func.__closure__ objects)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The value from the cell or CellEmpty if a ValueError is raised.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
return cell.cell_contents
|
||||||
|
except ValueError:
|
||||||
|
return CellEmpty
|
||||||
|
|
||||||
|
|
||||||
|
class ScanStatus(enum.Enum):
|
||||||
|
"""State of the dis instruction scanning loop."""
|
||||||
|
|
||||||
|
SCANNING = enum.auto()
|
||||||
|
GETTING_ATTR = enum.auto()
|
||||||
|
GETTING_STATE = enum.auto()
|
||||||
|
GETTING_VAR = enum.auto()
|
||||||
|
|
||||||
|
|
||||||
|
@dataclasses.dataclass
|
||||||
|
class DependencyTracker:
|
||||||
|
"""State machine for identifying state attributes that are accessed by a function."""
|
||||||
|
|
||||||
|
func: FunctionType | CodeType = dataclasses.field()
|
||||||
|
state_cls: Type[BaseState] = dataclasses.field()
|
||||||
|
|
||||||
|
dependencies: dict[str, set[str]] = dataclasses.field(default_factory=dict)
|
||||||
|
|
||||||
|
scan_status: ScanStatus = dataclasses.field(default=ScanStatus.SCANNING)
|
||||||
|
top_of_stack: str | None = dataclasses.field(default=None)
|
||||||
|
|
||||||
|
tracked_locals: dict[str, Type[BaseState]] = dataclasses.field(default_factory=dict)
|
||||||
|
|
||||||
|
_getting_state_class: Type[BaseState] | None = dataclasses.field(default=None)
|
||||||
|
_getting_var_instructions: list[dis.Instruction] = dataclasses.field(
|
||||||
|
default_factory=list
|
||||||
|
)
|
||||||
|
|
||||||
|
INVALID_NAMES: ClassVar[list[str]] = ["parent_state", "substates", "get_substate"]
|
||||||
|
|
||||||
|
def __post_init__(self):
|
||||||
|
"""After initializing, populate the dependencies dict."""
|
||||||
|
with contextlib.suppress(AttributeError):
|
||||||
|
# unbox functools.partial
|
||||||
|
self.func = cast(FunctionType, self.func.func) # pyright: ignore[reportAttributeAccessIssue]
|
||||||
|
with contextlib.suppress(AttributeError):
|
||||||
|
# unbox EventHandler
|
||||||
|
self.func = cast(FunctionType, self.func.fn) # pyright: ignore[reportAttributeAccessIssue]
|
||||||
|
|
||||||
|
if isinstance(self.func, FunctionType):
|
||||||
|
with contextlib.suppress(AttributeError, IndexError):
|
||||||
|
# the first argument to the function is the name of "self" arg
|
||||||
|
self.tracked_locals[self.func.__code__.co_varnames[0]] = self.state_cls
|
||||||
|
|
||||||
|
self._populate_dependencies()
|
||||||
|
|
||||||
|
def _merge_deps(self, tracker: DependencyTracker) -> None:
|
||||||
|
"""Merge dependencies from another DependencyTracker.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
tracker: The DependencyTracker to merge dependencies from.
|
||||||
|
"""
|
||||||
|
for state_name, dep_name in tracker.dependencies.items():
|
||||||
|
self.dependencies.setdefault(state_name, set()).update(dep_name)
|
||||||
|
|
||||||
|
def load_attr_or_method(self, instruction: dis.Instruction) -> None:
|
||||||
|
"""Handle loading an attribute or method from the object on top of the stack.
|
||||||
|
|
||||||
|
This method directly tracks attributes and recursively merges
|
||||||
|
dependencies from analyzing the dependencies of any methods called.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
instruction: The dis instruction to process.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
VarValueError: if the attribute is an disallowed name.
|
||||||
|
"""
|
||||||
|
from .base import ComputedVar
|
||||||
|
|
||||||
|
if instruction.argval in self.INVALID_NAMES:
|
||||||
|
raise VarValueError(
|
||||||
|
f"Cached var {self!s} cannot access arbitrary state via `{instruction.argval}`."
|
||||||
|
)
|
||||||
|
if instruction.argval == "get_state":
|
||||||
|
# Special case: arbitrary state access requested.
|
||||||
|
self.scan_status = ScanStatus.GETTING_STATE
|
||||||
|
return
|
||||||
|
if instruction.argval == "get_var_value":
|
||||||
|
# Special case: arbitrary var access requested.
|
||||||
|
self.scan_status = ScanStatus.GETTING_VAR
|
||||||
|
return
|
||||||
|
|
||||||
|
# Reset status back to SCANNING after attribute is accessed.
|
||||||
|
self.scan_status = ScanStatus.SCANNING
|
||||||
|
if not self.top_of_stack:
|
||||||
|
return
|
||||||
|
target_state = self.tracked_locals[self.top_of_stack]
|
||||||
|
try:
|
||||||
|
ref_obj = getattr(target_state, instruction.argval)
|
||||||
|
except AttributeError:
|
||||||
|
# Not found on this state class, maybe it is a dynamic attribute that will be picked up later.
|
||||||
|
ref_obj = None
|
||||||
|
|
||||||
|
if isinstance(ref_obj, property) and not isinstance(ref_obj, ComputedVar):
|
||||||
|
# recurse into property fget functions
|
||||||
|
ref_obj = ref_obj.fget
|
||||||
|
if callable(ref_obj):
|
||||||
|
# recurse into callable attributes
|
||||||
|
self._merge_deps(
|
||||||
|
type(self)(func=cast(FunctionType, ref_obj), state_cls=target_state)
|
||||||
|
)
|
||||||
|
elif (
|
||||||
|
instruction.argval in target_state.backend_vars
|
||||||
|
or instruction.argval in target_state.vars
|
||||||
|
):
|
||||||
|
# var access
|
||||||
|
self.dependencies.setdefault(target_state.get_full_name(), set()).add(
|
||||||
|
instruction.argval
|
||||||
|
)
|
||||||
|
|
||||||
|
def _get_globals(self) -> dict[str, Any]:
|
||||||
|
"""Get the globals of the function.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The var names and values in the globals of the function.
|
||||||
|
"""
|
||||||
|
if isinstance(self.func, CodeType):
|
||||||
|
return {}
|
||||||
|
return self.func.__globals__ # pyright: ignore[reportAttributeAccessIssue]
|
||||||
|
|
||||||
|
def _get_closure(self) -> dict[str, Any]:
|
||||||
|
"""Get the closure of the function, with unbound values omitted.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The var names and values in the closure of the function.
|
||||||
|
"""
|
||||||
|
if isinstance(self.func, CodeType):
|
||||||
|
return {}
|
||||||
|
return {
|
||||||
|
var_name: get_cell_value(cell)
|
||||||
|
for var_name, cell in zip(
|
||||||
|
self.func.__code__.co_freevars, # pyright: ignore[reportAttributeAccessIssue]
|
||||||
|
self.func.__closure__ or (),
|
||||||
|
strict=False,
|
||||||
|
)
|
||||||
|
if get_cell_value(cell) is not CellEmpty
|
||||||
|
}
|
||||||
|
|
||||||
|
def handle_getting_state(self, instruction: dis.Instruction) -> None:
|
||||||
|
"""Handle bytecode analysis when `get_state` was called in the function.
|
||||||
|
|
||||||
|
If the wrapped function is getting an arbitrary state and saving it to a
|
||||||
|
local variable, this method associates the local variable name with the
|
||||||
|
state class in self.tracked_locals.
|
||||||
|
|
||||||
|
When an attribute/method is accessed on a tracked local, it will be
|
||||||
|
associated with this state.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
instruction: The dis instruction to process.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
VarValueError: if the state class cannot be determined from the instruction.
|
||||||
|
"""
|
||||||
|
from reflex.state import BaseState
|
||||||
|
|
||||||
|
if instruction.opname == "LOAD_FAST":
|
||||||
|
raise VarValueError(
|
||||||
|
f"Dependency detection cannot identify get_state class from local var {instruction.argval}."
|
||||||
|
)
|
||||||
|
if isinstance(self.func, CodeType):
|
||||||
|
raise VarValueError(
|
||||||
|
"Dependency detection cannot identify get_state class from a code object."
|
||||||
|
)
|
||||||
|
if instruction.opname == "LOAD_GLOBAL":
|
||||||
|
# Special case: referencing state class from global scope.
|
||||||
|
try:
|
||||||
|
self._getting_state_class = self._get_globals()[instruction.argval]
|
||||||
|
except (ValueError, KeyError) as ve:
|
||||||
|
raise VarValueError(
|
||||||
|
f"Cached var {self!s} cannot access arbitrary state `{instruction.argval}`, not found in globals."
|
||||||
|
) from ve
|
||||||
|
elif instruction.opname == "LOAD_DEREF":
|
||||||
|
# Special case: referencing state class from closure.
|
||||||
|
try:
|
||||||
|
self._getting_state_class = self._get_closure()[instruction.argval]
|
||||||
|
except (ValueError, KeyError) as ve:
|
||||||
|
raise VarValueError(
|
||||||
|
f"Cached var {self!s} cannot access arbitrary state `{instruction.argval}`, is it defined yet?"
|
||||||
|
) from ve
|
||||||
|
elif instruction.opname == "STORE_FAST":
|
||||||
|
# Storing the result of get_state in a local variable.
|
||||||
|
if not isinstance(self._getting_state_class, type) or not issubclass(
|
||||||
|
self._getting_state_class, BaseState
|
||||||
|
):
|
||||||
|
raise VarValueError(
|
||||||
|
f"Cached var {self!s} cannot determine dependencies in fetched state `{instruction.argval}`."
|
||||||
|
)
|
||||||
|
self.tracked_locals[instruction.argval] = self._getting_state_class
|
||||||
|
self.scan_status = ScanStatus.SCANNING
|
||||||
|
self._getting_state_class = None
|
||||||
|
|
||||||
|
def _eval_var(self) -> Var:
|
||||||
|
"""Evaluate instructions from the wrapped function to get the Var object.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The Var object.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
VarValueError: if the source code for the var cannot be determined.
|
||||||
|
"""
|
||||||
|
# Get the original source code and eval it to get the Var.
|
||||||
|
module = inspect.getmodule(self.func)
|
||||||
|
positions0 = self._getting_var_instructions[0].positions
|
||||||
|
positions1 = self._getting_var_instructions[-1].positions
|
||||||
|
if module is None or positions0 is None or positions1 is None:
|
||||||
|
raise VarValueError(
|
||||||
|
f"Cannot determine the source code for the var in {self.func!r}."
|
||||||
|
)
|
||||||
|
start_line = positions0.lineno
|
||||||
|
start_column = positions0.col_offset
|
||||||
|
end_line = positions1.end_lineno
|
||||||
|
end_column = positions1.end_col_offset
|
||||||
|
if (
|
||||||
|
start_line is None
|
||||||
|
or start_column is None
|
||||||
|
or end_line is None
|
||||||
|
or end_column is None
|
||||||
|
):
|
||||||
|
raise VarValueError(
|
||||||
|
f"Cannot determine the source code for the var in {self.func!r}."
|
||||||
|
)
|
||||||
|
source = inspect.getsource(module).splitlines(True)[start_line - 1 : end_line]
|
||||||
|
# Create a python source string snippet.
|
||||||
|
if len(source) > 1:
|
||||||
|
snipped_source = "".join(
|
||||||
|
[
|
||||||
|
*source[0][start_column:],
|
||||||
|
*(source[1:-2] if len(source) > 2 else []),
|
||||||
|
*source[-1][: end_column - 1],
|
||||||
|
]
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
snipped_source = source[0][start_column : end_column - 1]
|
||||||
|
# Evaluate the string in the context of the function's globals and closure.
|
||||||
|
return eval(f"({snipped_source})", self._get_globals(), self._get_closure())
|
||||||
|
|
||||||
|
def handle_getting_var(self, instruction: dis.Instruction) -> None:
|
||||||
|
"""Handle bytecode analysis when `get_var_value` was called in the function.
|
||||||
|
|
||||||
|
This only really works if the expression passed to `get_var_value` is
|
||||||
|
evaluable in the function's global scope or closure, so getting the var
|
||||||
|
value from a var saved in a local variable or in the class instance is
|
||||||
|
not possible.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
instruction: The dis instruction to process.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
VarValueError: if the source code for the var cannot be determined.
|
||||||
|
"""
|
||||||
|
if instruction.opname == "CALL" and self._getting_var_instructions:
|
||||||
|
if self._getting_var_instructions:
|
||||||
|
the_var = self._eval_var()
|
||||||
|
the_var_data = the_var._get_all_var_data()
|
||||||
|
if the_var_data is None:
|
||||||
|
raise VarValueError(
|
||||||
|
f"Cannot determine the source code for the var in {self.func!r}."
|
||||||
|
)
|
||||||
|
self.dependencies.setdefault(the_var_data.state, set()).add(
|
||||||
|
the_var_data.field_name
|
||||||
|
)
|
||||||
|
self._getting_var_instructions.clear()
|
||||||
|
self.scan_status = ScanStatus.SCANNING
|
||||||
|
else:
|
||||||
|
self._getting_var_instructions.append(instruction)
|
||||||
|
|
||||||
|
def _populate_dependencies(self) -> None:
|
||||||
|
"""Update self.dependencies based on the disassembly of self.func.
|
||||||
|
|
||||||
|
Save references to attributes accessed on "self" or other fetched states.
|
||||||
|
|
||||||
|
Recursively called when the function makes a method call on "self" or
|
||||||
|
define comprehensions or nested functions that may reference "self".
|
||||||
|
"""
|
||||||
|
for instruction in dis.get_instructions(self.func):
|
||||||
|
if self.scan_status == ScanStatus.GETTING_STATE:
|
||||||
|
self.handle_getting_state(instruction)
|
||||||
|
elif self.scan_status == ScanStatus.GETTING_VAR:
|
||||||
|
self.handle_getting_var(instruction)
|
||||||
|
elif (
|
||||||
|
instruction.opname in ("LOAD_FAST", "LOAD_DEREF")
|
||||||
|
and instruction.argval in self.tracked_locals
|
||||||
|
):
|
||||||
|
# bytecode loaded the class instance to the top of stack, next load instruction
|
||||||
|
# is referencing an attribute on self
|
||||||
|
self.top_of_stack = instruction.argval
|
||||||
|
self.scan_status = ScanStatus.GETTING_ATTR
|
||||||
|
elif self.scan_status == ScanStatus.GETTING_ATTR and instruction.opname in (
|
||||||
|
"LOAD_ATTR",
|
||||||
|
"LOAD_METHOD",
|
||||||
|
):
|
||||||
|
self.load_attr_or_method(instruction)
|
||||||
|
self.top_of_stack = None
|
||||||
|
elif instruction.opname == "LOAD_CONST" and isinstance(
|
||||||
|
instruction.argval, CodeType
|
||||||
|
):
|
||||||
|
# recurse into nested functions / comprehensions, which can reference
|
||||||
|
# instance attributes from the outer scope
|
||||||
|
self._merge_deps(
|
||||||
|
type(self)(
|
||||||
|
func=instruction.argval,
|
||||||
|
state_cls=self.state_cls,
|
||||||
|
tracked_locals=self.tracked_locals,
|
||||||
|
)
|
||||||
|
)
|
@ -3,7 +3,7 @@
|
|||||||
from typing import Generator
|
from typing import Generator
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
from playwright.sync_api import Page
|
from playwright.sync_api import Page, expect
|
||||||
|
|
||||||
from reflex.testing import AppHarness
|
from reflex.testing import AppHarness
|
||||||
|
|
||||||
@ -87,12 +87,14 @@ def test_table(page: Page, table_app: AppHarness):
|
|||||||
table = page.get_by_role("table")
|
table = page.get_by_role("table")
|
||||||
|
|
||||||
# Check column headers
|
# Check column headers
|
||||||
headers = table.get_by_role("columnheader").all_inner_texts()
|
headers = table.get_by_role("columnheader")
|
||||||
assert headers == expected_col_headers
|
for header, exp_value in zip(headers.all(), expected_col_headers, strict=True):
|
||||||
|
expect(header).to_have_text(exp_value)
|
||||||
|
|
||||||
# Check rows headers
|
# Check rows headers
|
||||||
rows = table.get_by_role("rowheader").all_inner_texts()
|
rows = table.get_by_role("rowheader")
|
||||||
assert rows == expected_row_headers
|
for row, expected_row in zip(rows.all(), expected_row_headers, strict=True):
|
||||||
|
expect(row).to_have_text(expected_row)
|
||||||
|
|
||||||
# Check cells
|
# Check cells
|
||||||
rows = table.get_by_role("cell").all_inner_texts()
|
rows = table.get_by_role("cell").all_inner_texts()
|
||||||
|
@ -277,9 +277,9 @@ def test_add_page_set_route_dynamic(index_page, windows_platform: bool):
|
|||||||
assert app._pages.keys() == {"test/[dynamic]"}
|
assert app._pages.keys() == {"test/[dynamic]"}
|
||||||
assert "dynamic" in app._state.computed_vars
|
assert "dynamic" in app._state.computed_vars
|
||||||
assert app._state.computed_vars["dynamic"]._deps(objclass=EmptyState) == {
|
assert app._state.computed_vars["dynamic"]._deps(objclass=EmptyState) == {
|
||||||
constants.ROUTER
|
EmptyState.get_full_name(): {constants.ROUTER},
|
||||||
}
|
}
|
||||||
assert constants.ROUTER in app._state()._computed_var_dependencies
|
assert constants.ROUTER in app._state()._var_dependencies
|
||||||
|
|
||||||
|
|
||||||
def test_add_page_set_route_nested(app: App, index_page, windows_platform: bool):
|
def test_add_page_set_route_nested(app: App, index_page, windows_platform: bool):
|
||||||
@ -995,9 +995,9 @@ async def test_dynamic_route_var_route_change_completed_on_load(
|
|||||||
assert arg_name in app._state.vars
|
assert arg_name in app._state.vars
|
||||||
assert arg_name in app._state.computed_vars
|
assert arg_name in app._state.computed_vars
|
||||||
assert app._state.computed_vars[arg_name]._deps(objclass=DynamicState) == {
|
assert app._state.computed_vars[arg_name]._deps(objclass=DynamicState) == {
|
||||||
constants.ROUTER
|
DynamicState.get_full_name(): {constants.ROUTER},
|
||||||
}
|
}
|
||||||
assert constants.ROUTER in app._state()._computed_var_dependencies
|
assert constants.ROUTER in app._state()._var_dependencies
|
||||||
|
|
||||||
substate_token = _substate_key(token, DynamicState)
|
substate_token = _substate_key(token, DynamicState)
|
||||||
sid = "mock_sid"
|
sid = "mock_sid"
|
||||||
@ -1555,6 +1555,16 @@ def test_app_with_valid_var_dependencies(compilable_app: tuple[App, Path]):
|
|||||||
def bar(self) -> str:
|
def bar(self) -> str:
|
||||||
return "bar"
|
return "bar"
|
||||||
|
|
||||||
|
class Child1(ValidDepState):
|
||||||
|
@computed_var(deps=["base", ValidDepState.bar])
|
||||||
|
def other(self) -> str:
|
||||||
|
return "other"
|
||||||
|
|
||||||
|
class Child2(ValidDepState):
|
||||||
|
@computed_var(deps=["base", Child1.other])
|
||||||
|
def other(self) -> str:
|
||||||
|
return "other"
|
||||||
|
|
||||||
app._state = ValidDepState
|
app._state = ValidDepState
|
||||||
app._compile()
|
app._compile()
|
||||||
|
|
||||||
|
@ -14,6 +14,7 @@ from typing import (
|
|||||||
Any,
|
Any,
|
||||||
AsyncGenerator,
|
AsyncGenerator,
|
||||||
Callable,
|
Callable,
|
||||||
|
ClassVar,
|
||||||
Dict,
|
Dict,
|
||||||
List,
|
List,
|
||||||
Optional,
|
Optional,
|
||||||
@ -1169,13 +1170,17 @@ def test_conditional_computed_vars():
|
|||||||
|
|
||||||
ms = MainState()
|
ms = MainState()
|
||||||
# Initially there are no dirty computed vars.
|
# Initially there are no dirty computed vars.
|
||||||
assert ms._dirty_computed_vars(from_vars={"flag"}) == {"rendered_var"}
|
assert ms._dirty_computed_vars(from_vars={"flag"}) == {
|
||||||
assert ms._dirty_computed_vars(from_vars={"t2"}) == {"rendered_var"}
|
(MainState.get_full_name(), "rendered_var")
|
||||||
assert ms._dirty_computed_vars(from_vars={"t1"}) == {"rendered_var"}
|
}
|
||||||
|
assert ms._dirty_computed_vars(from_vars={"t2"}) == {
|
||||||
|
(MainState.get_full_name(), "rendered_var")
|
||||||
|
}
|
||||||
|
assert ms._dirty_computed_vars(from_vars={"t1"}) == {
|
||||||
|
(MainState.get_full_name(), "rendered_var")
|
||||||
|
}
|
||||||
assert ms.computed_vars["rendered_var"]._deps(objclass=MainState) == {
|
assert ms.computed_vars["rendered_var"]._deps(objclass=MainState) == {
|
||||||
"flag",
|
MainState.get_full_name(): {"flag", "t1", "t2"}
|
||||||
"t1",
|
|
||||||
"t2",
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@ -1370,7 +1375,10 @@ def test_cached_var_depends_on_event_handler(use_partial: bool):
|
|||||||
assert isinstance(HandlerState.handler, EventHandler)
|
assert isinstance(HandlerState.handler, EventHandler)
|
||||||
|
|
||||||
s = HandlerState()
|
s = HandlerState()
|
||||||
assert "cached_x_side_effect" in s._computed_var_dependencies["x"]
|
assert (
|
||||||
|
HandlerState.get_full_name(),
|
||||||
|
"cached_x_side_effect",
|
||||||
|
) in s._var_dependencies["x"]
|
||||||
assert s.cached_x_side_effect == 1
|
assert s.cached_x_side_effect == 1
|
||||||
assert s.x == 43
|
assert s.x == 43
|
||||||
s.handler()
|
s.handler()
|
||||||
@ -1460,15 +1468,15 @@ def test_computed_var_dependencies():
|
|||||||
return [z in self._z for z in range(5)]
|
return [z in self._z for z in range(5)]
|
||||||
|
|
||||||
cs = ComputedState()
|
cs = ComputedState()
|
||||||
assert cs._computed_var_dependencies["v"] == {
|
assert cs._var_dependencies["v"] == {
|
||||||
"comp_v",
|
(ComputedState.get_full_name(), "comp_v"),
|
||||||
"comp_v_backend",
|
(ComputedState.get_full_name(), "comp_v_backend"),
|
||||||
"comp_v_via_property",
|
(ComputedState.get_full_name(), "comp_v_via_property"),
|
||||||
}
|
}
|
||||||
assert cs._computed_var_dependencies["w"] == {"comp_w"}
|
assert cs._var_dependencies["w"] == {(ComputedState.get_full_name(), "comp_w")}
|
||||||
assert cs._computed_var_dependencies["x"] == {"comp_x"}
|
assert cs._var_dependencies["x"] == {(ComputedState.get_full_name(), "comp_x")}
|
||||||
assert cs._computed_var_dependencies["y"] == {"comp_y"}
|
assert cs._var_dependencies["y"] == {(ComputedState.get_full_name(), "comp_y")}
|
||||||
assert cs._computed_var_dependencies["_z"] == {"comp_z"}
|
assert cs._var_dependencies["_z"] == {(ComputedState.get_full_name(), "comp_z")}
|
||||||
|
|
||||||
|
|
||||||
def test_backend_method():
|
def test_backend_method():
|
||||||
@ -3180,7 +3188,7 @@ async def test_get_state_from_sibling_not_cached(mock_app: rx.App, token: str):
|
|||||||
RxState = State
|
RxState = State
|
||||||
|
|
||||||
|
|
||||||
def test_potentially_dirty_substates():
|
def test_potentially_dirty_states():
|
||||||
"""Test that potentially_dirty_substates returns the correct substates.
|
"""Test that potentially_dirty_substates returns the correct substates.
|
||||||
|
|
||||||
Even if the name "State" is shadowed, it should still work correctly.
|
Even if the name "State" is shadowed, it should still work correctly.
|
||||||
@ -3196,13 +3204,19 @@ def test_potentially_dirty_substates():
|
|||||||
def bar(self) -> str:
|
def bar(self) -> str:
|
||||||
return ""
|
return ""
|
||||||
|
|
||||||
assert RxState._potentially_dirty_substates() == set()
|
assert RxState._get_potentially_dirty_states() == set()
|
||||||
assert State._potentially_dirty_substates() == set()
|
assert State._get_potentially_dirty_states() == set()
|
||||||
assert C1._potentially_dirty_substates() == set()
|
assert C1._get_potentially_dirty_states() == set()
|
||||||
|
|
||||||
|
|
||||||
def test_router_var_dep() -> None:
|
@pytest.mark.asyncio
|
||||||
"""Test that router var dependencies are correctly tracked."""
|
async def test_router_var_dep(state_manager: StateManager, token: str) -> None:
|
||||||
|
"""Test that router var dependencies are correctly tracked.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
state_manager: A state manager.
|
||||||
|
token: A token.
|
||||||
|
"""
|
||||||
|
|
||||||
class RouterVarParentState(State):
|
class RouterVarParentState(State):
|
||||||
"""A parent state for testing router var dependency."""
|
"""A parent state for testing router var dependency."""
|
||||||
@ -3219,30 +3233,27 @@ def test_router_var_dep() -> None:
|
|||||||
foo = RouterVarDepState.computed_vars["foo"]
|
foo = RouterVarDepState.computed_vars["foo"]
|
||||||
State._init_var_dependency_dicts()
|
State._init_var_dependency_dicts()
|
||||||
|
|
||||||
assert foo._deps(objclass=RouterVarDepState) == {"router"}
|
assert foo._deps(objclass=RouterVarDepState) == {
|
||||||
assert RouterVarParentState._potentially_dirty_substates() == {RouterVarDepState}
|
RouterVarDepState.get_full_name(): {"router"}
|
||||||
assert RouterVarParentState._substate_var_dependencies == {
|
|
||||||
"router": {RouterVarDepState.get_name()}
|
|
||||||
}
|
|
||||||
assert RouterVarDepState._computed_var_dependencies == {
|
|
||||||
"router": {"foo"},
|
|
||||||
}
|
}
|
||||||
|
assert (RouterVarDepState.get_full_name(), "foo") in State._var_dependencies[
|
||||||
|
"router"
|
||||||
|
]
|
||||||
|
|
||||||
rx_state = State()
|
# Get state from state manager.
|
||||||
parent_state = RouterVarParentState()
|
state_manager.state = State
|
||||||
state = RouterVarDepState()
|
rx_state = await state_manager.get_state(_substate_key(token, State))
|
||||||
|
assert RouterVarParentState.get_name() in rx_state.substates
|
||||||
# link states
|
parent_state = rx_state.substates[RouterVarParentState.get_name()]
|
||||||
rx_state.substates = {RouterVarParentState.get_name(): parent_state}
|
assert RouterVarDepState.get_name() in parent_state.substates
|
||||||
parent_state.parent_state = rx_state
|
state = parent_state.substates[RouterVarDepState.get_name()]
|
||||||
state.parent_state = parent_state
|
|
||||||
parent_state.substates = {RouterVarDepState.get_name(): state}
|
|
||||||
|
|
||||||
assert state.dirty_vars == set()
|
assert state.dirty_vars == set()
|
||||||
|
|
||||||
# Reassign router var
|
# Reassign router var
|
||||||
state.router = state.router
|
state.router = state.router
|
||||||
assert state.dirty_vars == {"foo", "router"}
|
assert rx_state.dirty_vars == {"router"}
|
||||||
|
assert state.dirty_vars == {"foo"}
|
||||||
assert parent_state.dirty_substates == {RouterVarDepState.get_name()}
|
assert parent_state.dirty_substates == {RouterVarDepState.get_name()}
|
||||||
|
|
||||||
|
|
||||||
@ -3801,3 +3812,128 @@ async def test_get_var_value(state_manager: StateManager, substate_token: str):
|
|||||||
# Generic Var with no state
|
# Generic Var with no state
|
||||||
with pytest.raises(UnretrievableVarValueError):
|
with pytest.raises(UnretrievableVarValueError):
|
||||||
await state.get_var_value(rx.Var("undefined"))
|
await state.get_var_value(rx.Var("undefined"))
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_async_computed_var_get_state(mock_app: rx.App, token: str):
|
||||||
|
"""A test where an async computed var depends on a var in another state.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
mock_app: An app that will be returned by `get_app()`
|
||||||
|
token: A token.
|
||||||
|
"""
|
||||||
|
|
||||||
|
class Parent(BaseState):
|
||||||
|
"""A root state like rx.State."""
|
||||||
|
|
||||||
|
parent_var: int = 0
|
||||||
|
|
||||||
|
class Child2(Parent):
|
||||||
|
"""An unconnected child state."""
|
||||||
|
|
||||||
|
pass
|
||||||
|
|
||||||
|
class Child3(Parent):
|
||||||
|
"""A child state with a computed var causing it to be pre-fetched.
|
||||||
|
|
||||||
|
If child3_var gets set to a value, and `get_state` erroneously
|
||||||
|
re-fetches it from redis, the value will be lost.
|
||||||
|
"""
|
||||||
|
|
||||||
|
child3_var: int = 0
|
||||||
|
|
||||||
|
@rx.var(cache=True)
|
||||||
|
def v(self) -> int:
|
||||||
|
return self.child3_var
|
||||||
|
|
||||||
|
class Child(Parent):
|
||||||
|
"""A state simulating UpdateVarsInternalState."""
|
||||||
|
|
||||||
|
@rx.var(cache=True)
|
||||||
|
async def v(self) -> int:
|
||||||
|
p = await self.get_state(Parent)
|
||||||
|
child3 = await self.get_state(Child3)
|
||||||
|
return child3.child3_var + p.parent_var
|
||||||
|
|
||||||
|
mock_app.state_manager.state = mock_app._state = Parent
|
||||||
|
|
||||||
|
# Get the top level state via unconnected sibling.
|
||||||
|
root = await mock_app.state_manager.get_state(_substate_key(token, Child))
|
||||||
|
# Set value in parent_var to assert it does not get refetched later.
|
||||||
|
root.parent_var = 1
|
||||||
|
|
||||||
|
if isinstance(mock_app.state_manager, StateManagerRedis):
|
||||||
|
# When redis is used, only states with uncached computed vars are pre-fetched.
|
||||||
|
assert Child2.get_name() not in root.substates
|
||||||
|
assert Child3.get_name() not in root.substates
|
||||||
|
|
||||||
|
# Get the unconnected sibling state, which will be used to `get_state` other instances.
|
||||||
|
child = root.get_substate(Child.get_full_name().split("."))
|
||||||
|
|
||||||
|
# Get an uncached child state.
|
||||||
|
child2 = await child.get_state(Child2)
|
||||||
|
assert child2.parent_var == 1
|
||||||
|
|
||||||
|
# Set value on already-cached Child3 state (prefetched because it has a Computed Var).
|
||||||
|
child3 = await child.get_state(Child3)
|
||||||
|
child3.child3_var = 1
|
||||||
|
|
||||||
|
assert await child.v == 2
|
||||||
|
assert await child.v == 2
|
||||||
|
root.parent_var = 2
|
||||||
|
assert await child.v == 3
|
||||||
|
|
||||||
|
|
||||||
|
class Table(rx.ComponentState):
|
||||||
|
"""A table state."""
|
||||||
|
|
||||||
|
data: ClassVar[Var]
|
||||||
|
|
||||||
|
@rx.var(cache=True, auto_deps=False)
|
||||||
|
async def rows(self) -> List[Dict[str, Any]]:
|
||||||
|
"""Computed var over the given rows.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The data rows.
|
||||||
|
"""
|
||||||
|
return await self.get_var_value(self.data)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_component(cls, data: Var) -> rx.Component:
|
||||||
|
"""Get the component for the table.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
data: The data var.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The component.
|
||||||
|
"""
|
||||||
|
cls.data = data
|
||||||
|
cls.computed_vars["rows"].add_dependency(cls, data)
|
||||||
|
return rx.foreach(data, lambda d: rx.text(d.to_string()))
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_async_computed_var_get_var_value(mock_app: rx.App, token: str):
|
||||||
|
"""A test where an async computed var depends on a var in another state.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
mock_app: An app that will be returned by `get_app()`
|
||||||
|
token: A token.
|
||||||
|
"""
|
||||||
|
|
||||||
|
class OtherState(rx.State):
|
||||||
|
"""A state with a var."""
|
||||||
|
|
||||||
|
data: List[Dict[str, Any]] = [{"foo": "bar"}]
|
||||||
|
|
||||||
|
mock_app.state_manager.state = mock_app._state = rx.State
|
||||||
|
comp = Table.create(data=OtherState.data)
|
||||||
|
state = await mock_app.state_manager.get_state(_substate_key(token, OtherState))
|
||||||
|
other_state = await state.get_state(OtherState)
|
||||||
|
assert comp.State is not None
|
||||||
|
comp_state = await state.get_state(comp.State)
|
||||||
|
assert comp_state.dirty_vars == set()
|
||||||
|
|
||||||
|
other_state.data.append({"foo": "baz"})
|
||||||
|
assert "rows" in comp_state.dirty_vars
|
||||||
|
@ -1807,9 +1807,9 @@ def cv_fget(state: BaseState) -> int:
|
|||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"deps,expected",
|
"deps,expected",
|
||||||
[
|
[
|
||||||
(["a"], {"a"}),
|
(["a"], {None: {"a"}}),
|
||||||
(["b"], {"b"}),
|
(["b"], {None: {"b"}}),
|
||||||
([ComputedVar(fget=cv_fget)], {"cv_fget"}),
|
([ComputedVar(fget=cv_fget)], {None: {"cv_fget"}}),
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
def test_computed_var_deps(deps: List[Union[str, Var]], expected: Set[str]):
|
def test_computed_var_deps(deps: List[Union[str, Var]], expected: Set[str]):
|
||||||
@ -1857,6 +1857,28 @@ def test_to_string_operation():
|
|||||||
assert single_var._var_type == Email
|
assert single_var._var_type == Email
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_async_computed_var():
|
||||||
|
side_effect_counter = 0
|
||||||
|
|
||||||
|
class AsyncComputedVarState(BaseState):
|
||||||
|
v: int = 1
|
||||||
|
|
||||||
|
@computed_var(cache=True)
|
||||||
|
async def async_computed_var(self) -> int:
|
||||||
|
nonlocal side_effect_counter
|
||||||
|
side_effect_counter += 1
|
||||||
|
return self.v + 1
|
||||||
|
|
||||||
|
my_state = AsyncComputedVarState()
|
||||||
|
assert await my_state.async_computed_var == 2
|
||||||
|
assert await my_state.async_computed_var == 2
|
||||||
|
my_state.v = 2
|
||||||
|
assert await my_state.async_computed_var == 3
|
||||||
|
assert await my_state.async_computed_var == 3
|
||||||
|
assert side_effect_counter == 2
|
||||||
|
|
||||||
|
|
||||||
def test_var_data_hooks():
|
def test_var_data_hooks():
|
||||||
var_data_str = VarData(hooks="what")
|
var_data_str = VarData(hooks="what")
|
||||||
var_data_list = VarData(hooks=["what"])
|
var_data_list = VarData(hooks=["what"])
|
||||||
|
Loading…
Reference in New Issue
Block a user