WiP
This commit is contained in:
parent
048416163d
commit
3d50c1b623
@ -833,11 +833,17 @@ class App(MiddlewareMixin, LifespanMixin):
|
|||||||
if not var._cache:
|
if not var._cache:
|
||||||
continue
|
continue
|
||||||
deps = var._deps(objclass=state)
|
deps = var._deps(objclass=state)
|
||||||
for dep in deps:
|
for state_name, dep_set in deps.items():
|
||||||
if dep not in state.vars and dep not in state.backend_vars:
|
state_cls = (
|
||||||
raise exceptions.VarDependencyError(
|
state.get_root_state().get_class_substate(state_name)
|
||||||
f"ComputedVar {var._js_expr} on state {state.__name__} has an invalid dependency {dep}"
|
if state_name != state.get_full_name()
|
||||||
)
|
else state
|
||||||
|
)
|
||||||
|
for dep in dep_set:
|
||||||
|
if dep not in state_cls.vars and dep not in state_cls.backend_vars:
|
||||||
|
raise exceptions.VarDependencyError(
|
||||||
|
f"ComputedVar {var._js_expr} on state {state.__name__} has an invalid dependency {state_name}.{dep}"
|
||||||
|
)
|
||||||
|
|
||||||
for substate in state.class_subclasses:
|
for substate in state.class_subclasses:
|
||||||
self._validate_var_dependencies(substate)
|
self._validate_var_dependencies(substate)
|
||||||
|
@ -2,6 +2,7 @@
|
|||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import asyncio
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any, Callable, Dict, Optional, Type, Union
|
from typing import Any, Callable, Dict, Optional, Type, Union
|
||||||
from urllib.parse import urlparse
|
from urllib.parse import urlparse
|
||||||
@ -29,7 +30,7 @@ from reflex.components.base import (
|
|||||||
)
|
)
|
||||||
from reflex.components.component import Component, ComponentStyle, CustomComponent
|
from reflex.components.component import Component, ComponentStyle, CustomComponent
|
||||||
from reflex.istate.storage import Cookie, LocalStorage, SessionStorage
|
from reflex.istate.storage import Cookie, LocalStorage, SessionStorage
|
||||||
from reflex.state import BaseState
|
from reflex.state import BaseState, _resolve_delta
|
||||||
from reflex.style import Style
|
from reflex.style import Style
|
||||||
from reflex.utils import console, format, imports, path_ops
|
from reflex.utils import console, format, imports, path_ops
|
||||||
from reflex.utils.imports import ImportVar, ParsedImportDict
|
from reflex.utils.imports import ImportVar, ParsedImportDict
|
||||||
@ -169,7 +170,7 @@ def compile_state(state: Type[BaseState]) -> dict:
|
|||||||
initial_state = state(_reflex_internal_init=True).dict(
|
initial_state = state(_reflex_internal_init=True).dict(
|
||||||
initial=True, include_computed=False
|
initial=True, include_computed=False
|
||||||
)
|
)
|
||||||
return initial_state
|
return asyncio.run(_resolve_delta(initial_state))
|
||||||
|
|
||||||
|
|
||||||
def _compile_client_storage_field(
|
def _compile_client_storage_field(
|
||||||
|
@ -8,7 +8,7 @@ from typing import TYPE_CHECKING, Optional
|
|||||||
from reflex import constants
|
from reflex import constants
|
||||||
from reflex.event import Event, get_hydrate_event
|
from reflex.event import Event, get_hydrate_event
|
||||||
from reflex.middleware.middleware import Middleware
|
from reflex.middleware.middleware import Middleware
|
||||||
from reflex.state import BaseState, StateUpdate
|
from reflex.state import BaseState, StateUpdate, _resolve_delta
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from reflex.app import App
|
from reflex.app import App
|
||||||
@ -42,7 +42,7 @@ class HydrateMiddleware(Middleware):
|
|||||||
setattr(state, constants.CompileVars.IS_HYDRATED, False)
|
setattr(state, constants.CompileVars.IS_HYDRATED, False)
|
||||||
|
|
||||||
# Get the initial state.
|
# Get the initial state.
|
||||||
delta = state.dict()
|
delta = await _resolve_delta(state.dict())
|
||||||
# since a full dict was captured, clean any dirtiness
|
# since a full dict was captured, clean any dirtiness
|
||||||
state._clean()
|
state._clean()
|
||||||
|
|
||||||
|
433
reflex/state.py
433
reflex/state.py
@ -328,6 +328,25 @@ def get_var_for_field(cls: Type[BaseState], f: ModelField):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
async def _resolve_delta(delta: Delta) -> Delta:
|
||||||
|
"""Await all coroutines in the delta.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
delta: The delta to process.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The same delta dict with all coroutines resolved to their return value.
|
||||||
|
"""
|
||||||
|
tasks = {}
|
||||||
|
for state_name, state_delta in delta.items():
|
||||||
|
for var_name, value in state_delta.items():
|
||||||
|
if asyncio.iscoroutine(value):
|
||||||
|
tasks[state_name, var_name] = asyncio.create_task(value)
|
||||||
|
for (state_name, var_name), task in tasks.items():
|
||||||
|
delta[state_name][var_name] = await task
|
||||||
|
return delta
|
||||||
|
|
||||||
|
|
||||||
class BaseState(Base, ABC, extra=pydantic.Extra.allow):
|
class BaseState(Base, ABC, extra=pydantic.Extra.allow):
|
||||||
"""The state of the app."""
|
"""The state of the app."""
|
||||||
|
|
||||||
@ -355,11 +374,8 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
|
|||||||
# A set of subclassses of this class.
|
# A set of subclassses of this class.
|
||||||
class_subclasses: ClassVar[Set[Type[BaseState]]] = set()
|
class_subclasses: ClassVar[Set[Type[BaseState]]] = set()
|
||||||
|
|
||||||
# Mapping of var name to set of computed variables that depend on it
|
# Mapping of var name to set of (state_full_name, var_name) that depend on it.
|
||||||
_computed_var_dependencies: ClassVar[Dict[str, Set[str]]] = {}
|
_var_dependencies: ClassVar[Dict[str, Set[Tuple[str, str]]]] = {}
|
||||||
|
|
||||||
# Mapping of var name to set of substates that depend on it
|
|
||||||
_substate_var_dependencies: ClassVar[Dict[str, Set[str]]] = {}
|
|
||||||
|
|
||||||
# Set of vars which always need to be recomputed
|
# Set of vars which always need to be recomputed
|
||||||
_always_dirty_computed_vars: ClassVar[Set[str]] = set()
|
_always_dirty_computed_vars: ClassVar[Set[str]] = set()
|
||||||
@ -367,6 +383,9 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
|
|||||||
# Set of substates which always need to be recomputed
|
# Set of substates which always need to be recomputed
|
||||||
_always_dirty_substates: ClassVar[Set[str]] = set()
|
_always_dirty_substates: ClassVar[Set[str]] = set()
|
||||||
|
|
||||||
|
# Set of states which might need to be recomputed if vars in this state change.
|
||||||
|
_potentially_dirty_states: ClassVar[Set[str]] = set()
|
||||||
|
|
||||||
# The parent state.
|
# The parent state.
|
||||||
parent_state: Optional[BaseState] = None
|
parent_state: Optional[BaseState] = None
|
||||||
|
|
||||||
@ -518,6 +537,7 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
|
|||||||
|
|
||||||
# Reset dirty substate tracking for this class.
|
# Reset dirty substate tracking for this class.
|
||||||
cls._always_dirty_substates = set()
|
cls._always_dirty_substates = set()
|
||||||
|
cls._potentially_dirty_states = set()
|
||||||
|
|
||||||
# Get the parent vars.
|
# Get the parent vars.
|
||||||
parent_state = cls.get_parent_state()
|
parent_state = cls.get_parent_state()
|
||||||
@ -621,8 +641,7 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
|
|||||||
setattr(cls, name, handler)
|
setattr(cls, name, handler)
|
||||||
|
|
||||||
# Initialize per-class var dependency tracking.
|
# Initialize per-class var dependency tracking.
|
||||||
cls._computed_var_dependencies = defaultdict(set)
|
cls._var_dependencies = {}
|
||||||
cls._substate_var_dependencies = defaultdict(set)
|
|
||||||
cls._init_var_dependency_dicts()
|
cls._init_var_dependency_dicts()
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@ -767,26 +786,25 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
|
|||||||
Additional updates tracking dicts for vars and substates that always
|
Additional updates tracking dicts for vars and substates that always
|
||||||
need to be recomputed.
|
need to be recomputed.
|
||||||
"""
|
"""
|
||||||
inherited_vars = set(cls.inherited_vars).union(
|
|
||||||
set(cls.inherited_backend_vars),
|
|
||||||
)
|
|
||||||
for cvar_name, cvar in cls.computed_vars.items():
|
for cvar_name, cvar in cls.computed_vars.items():
|
||||||
# Add the dependencies.
|
if not cvar._cache:
|
||||||
for var in cvar._deps(objclass=cls):
|
# Do not perform dep calculation when cache=False (these are always dirty).
|
||||||
cls._computed_var_dependencies[var].add(cvar_name)
|
continue
|
||||||
if var in inherited_vars:
|
for state_name, dvar_set in cvar._deps(objclass=cls).items():
|
||||||
# track that this substate depends on its parent for this var
|
state_cls = cls.get_root_state().get_class_substate(state_name)
|
||||||
state_name = cls.get_name()
|
for dvar in dvar_set:
|
||||||
parent_state = cls.get_parent_state()
|
defining_state_cls = state_cls
|
||||||
while parent_state is not None and var in {
|
while dvar in {
|
||||||
**parent_state.vars,
|
*defining_state_cls.inherited_vars,
|
||||||
**parent_state.backend_vars,
|
*defining_state_cls.inherited_backend_vars,
|
||||||
}:
|
}:
|
||||||
parent_state._substate_var_dependencies[var].add(state_name)
|
defining_state_cls = defining_state_cls.get_parent_state()
|
||||||
state_name, parent_state = (
|
defining_state_cls._var_dependencies.setdefault(dvar, set()).add(
|
||||||
parent_state.get_name(),
|
(cls.get_full_name(), cvar_name)
|
||||||
parent_state.get_parent_state(),
|
)
|
||||||
)
|
defining_state_cls._potentially_dirty_states.add(
|
||||||
|
cls.get_full_name()
|
||||||
|
)
|
||||||
|
|
||||||
# ComputedVar with cache=False always need to be recomputed
|
# ComputedVar with cache=False always need to be recomputed
|
||||||
cls._always_dirty_computed_vars = {
|
cls._always_dirty_computed_vars = {
|
||||||
@ -901,6 +919,17 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
|
|||||||
raise ValueError(f"Only one parent state is allowed {parent_states}.")
|
raise ValueError(f"Only one parent state is allowed {parent_states}.")
|
||||||
return parent_states[0] if len(parent_states) == 1 else None # type: ignore
|
return parent_states[0] if len(parent_states) == 1 else None # type: ignore
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
@functools.lru_cache()
|
||||||
|
def get_root_state(cls) -> Type[BaseState]:
|
||||||
|
"""Get the root state.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The root state.
|
||||||
|
"""
|
||||||
|
parent_state = cls.get_parent_state()
|
||||||
|
return cls if parent_state is None else parent_state.get_root_state()
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_substates(cls) -> set[Type[BaseState]]:
|
def get_substates(cls) -> set[Type[BaseState]]:
|
||||||
"""Get the substates of the state.
|
"""Get the substates of the state.
|
||||||
@ -1353,7 +1382,7 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
|
|||||||
super().__setattr__(name, value)
|
super().__setattr__(name, value)
|
||||||
|
|
||||||
# Add the var to the dirty list.
|
# Add the var to the dirty list.
|
||||||
if name in self.vars or name in self._computed_var_dependencies:
|
if name in self.base_vars:
|
||||||
self.dirty_vars.add(name)
|
self.dirty_vars.add(name)
|
||||||
self._mark_dirty()
|
self._mark_dirty()
|
||||||
|
|
||||||
@ -1423,6 +1452,23 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
|
|||||||
raise ValueError(f"Invalid path: {path}")
|
raise ValueError(f"Invalid path: {path}")
|
||||||
return self.substates[path[0]].get_substate(path[1:])
|
return self.substates[path[0]].get_substate(path[1:])
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _get_potentially_dirty_states(cls) -> set[type[BaseState]]:
|
||||||
|
"""Get substates which may have dirty vars due to dependencies.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The set of potentially dirty substate classes.
|
||||||
|
"""
|
||||||
|
return {
|
||||||
|
cls.get_class_substate(substate_name)
|
||||||
|
for substate_name in cls._always_dirty_substates
|
||||||
|
}.union(
|
||||||
|
{
|
||||||
|
cls.get_root_state().get_class_substate(substate_name)
|
||||||
|
for substate_name in cls._potentially_dirty_states
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def _get_common_ancestor(cls, other: Type[BaseState]) -> str:
|
def _get_common_ancestor(cls, other: Type[BaseState]) -> str:
|
||||||
"""Find the name of the nearest common ancestor shared by this and the other state.
|
"""Find the name of the nearest common ancestor shared by this and the other state.
|
||||||
@ -1493,55 +1539,37 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
|
|||||||
parent_state = parent_state.parent_state
|
parent_state = parent_state.parent_state
|
||||||
return parent_state
|
return parent_state
|
||||||
|
|
||||||
async def _populate_parent_states(self, target_state_cls: Type[BaseState]):
|
async def _get_state_from_redis(self, state_cls: Type[T_STATE]) -> T_STATE:
|
||||||
"""Populate substates in the tree between the target_state_cls and common ancestor of this state.
|
"""Get a state instance from redis.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
target_state_cls: The class of the state to populate parent states for.
|
state_cls: The class of the state.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
The parent state instance of target_state_cls.
|
The instance of state_cls associated with this state's client_token.
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
RuntimeError: If redis is not used in this backend process.
|
RuntimeError: If redis is not used in this backend process.
|
||||||
|
StateMismatchError: If the state instance is not of the expected type.
|
||||||
"""
|
"""
|
||||||
|
# Then get the target state and all its substates.
|
||||||
state_manager = get_state_manager()
|
state_manager = get_state_manager()
|
||||||
if not isinstance(state_manager, StateManagerRedis):
|
if not isinstance(state_manager, StateManagerRedis):
|
||||||
raise RuntimeError(
|
raise RuntimeError(
|
||||||
f"Cannot populate parent states of {target_state_cls.get_full_name()} without redis. "
|
f"Requested state {state_cls.get_full_name()} is not cached and cannot be accessed without redis. "
|
||||||
"(All states should already be available -- this is likely a bug).",
|
"(All states should already be available -- this is likely a bug).",
|
||||||
)
|
)
|
||||||
|
state_in_redis = await state_manager._link_arbitrary_state(
|
||||||
|
self,
|
||||||
|
state_cls,
|
||||||
|
)
|
||||||
|
|
||||||
# Find the missing parent states up to the common ancestor.
|
if not isinstance(state_in_redis, state_cls):
|
||||||
(
|
raise StateMismatchError(
|
||||||
common_ancestor_name,
|
f"Searched for state {state_cls.get_full_name()} but found {state_in_redis}."
|
||||||
missing_parent_states,
|
|
||||||
) = self._determine_missing_parent_states(target_state_cls)
|
|
||||||
|
|
||||||
# Fetch all missing parent states and link them up to the common ancestor.
|
|
||||||
parent_states_tuple = self._get_parent_states()
|
|
||||||
root_state = parent_states_tuple[-1][1]
|
|
||||||
parent_states_by_name = dict(parent_states_tuple)
|
|
||||||
parent_state = parent_states_by_name[common_ancestor_name]
|
|
||||||
for parent_state_name in missing_parent_states:
|
|
||||||
try:
|
|
||||||
parent_state = root_state.get_substate(parent_state_name.split("."))
|
|
||||||
# The requested state is already cached, do NOT fetch it again.
|
|
||||||
continue
|
|
||||||
except ValueError:
|
|
||||||
# The requested state is missing, fetch from redis.
|
|
||||||
pass
|
|
||||||
parent_state = await state_manager.get_state(
|
|
||||||
token=_substate_key(
|
|
||||||
self.router.session.client_token, parent_state_name
|
|
||||||
),
|
|
||||||
top_level=False,
|
|
||||||
get_substates=False,
|
|
||||||
parent_state=parent_state,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# Return the direct parent of target_state_cls for subsequent linking.
|
return state_in_redis
|
||||||
return parent_state
|
|
||||||
|
|
||||||
def _get_state_from_cache(self, state_cls: Type[T_STATE]) -> T_STATE:
|
def _get_state_from_cache(self, state_cls: Type[T_STATE]) -> T_STATE:
|
||||||
"""Get a state instance from the cache.
|
"""Get a state instance from the cache.
|
||||||
@ -1563,44 +1591,6 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
|
|||||||
)
|
)
|
||||||
return substate
|
return substate
|
||||||
|
|
||||||
async def _get_state_from_redis(self, state_cls: Type[T_STATE]) -> T_STATE:
|
|
||||||
"""Get a state instance from redis.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
state_cls: The class of the state.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
The instance of state_cls associated with this state's client_token.
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
RuntimeError: If redis is not used in this backend process.
|
|
||||||
StateMismatchError: If the state instance is not of the expected type.
|
|
||||||
"""
|
|
||||||
# Fetch all missing parent states from redis.
|
|
||||||
parent_state_of_state_cls = await self._populate_parent_states(state_cls)
|
|
||||||
|
|
||||||
# Then get the target state and all its substates.
|
|
||||||
state_manager = get_state_manager()
|
|
||||||
if not isinstance(state_manager, StateManagerRedis):
|
|
||||||
raise RuntimeError(
|
|
||||||
f"Requested state {state_cls.get_full_name()} is not cached and cannot be accessed without redis. "
|
|
||||||
"(All states should already be available -- this is likely a bug).",
|
|
||||||
)
|
|
||||||
|
|
||||||
state_in_redis = await state_manager.get_state(
|
|
||||||
token=_substate_key(self.router.session.client_token, state_cls),
|
|
||||||
top_level=False,
|
|
||||||
get_substates=True,
|
|
||||||
parent_state=parent_state_of_state_cls,
|
|
||||||
)
|
|
||||||
|
|
||||||
if not isinstance(state_in_redis, state_cls):
|
|
||||||
raise StateMismatchError(
|
|
||||||
f"Searched for state {state_cls.get_full_name()} but found {state_in_redis}."
|
|
||||||
)
|
|
||||||
|
|
||||||
return state_in_redis
|
|
||||||
|
|
||||||
async def get_state(self, state_cls: Type[T_STATE]) -> T_STATE:
|
async def get_state(self, state_cls: Type[T_STATE]) -> T_STATE:
|
||||||
"""Get an instance of the state associated with this token.
|
"""Get an instance of the state associated with this token.
|
||||||
|
|
||||||
@ -1737,7 +1727,7 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
|
|||||||
f"Your handler {handler.fn.__qualname__} must only return/yield: None, Events or other EventHandlers referenced by their class (not using `self`)"
|
f"Your handler {handler.fn.__qualname__} must only return/yield: None, Events or other EventHandlers referenced by their class (not using `self`)"
|
||||||
)
|
)
|
||||||
|
|
||||||
def _as_state_update(
|
async def _as_state_update(
|
||||||
self,
|
self,
|
||||||
handler: EventHandler,
|
handler: EventHandler,
|
||||||
events: EventSpec | list[EventSpec] | None,
|
events: EventSpec | list[EventSpec] | None,
|
||||||
@ -1765,7 +1755,7 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
# Get the delta after processing the event.
|
# Get the delta after processing the event.
|
||||||
delta = state.get_delta()
|
delta = await _resolve_delta(state.get_delta())
|
||||||
state._clean()
|
state._clean()
|
||||||
|
|
||||||
return StateUpdate(
|
return StateUpdate(
|
||||||
@ -1865,24 +1855,28 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
|
|||||||
# Handle async generators.
|
# Handle async generators.
|
||||||
if inspect.isasyncgen(events):
|
if inspect.isasyncgen(events):
|
||||||
async for event in events:
|
async for event in events:
|
||||||
yield state._as_state_update(handler, event, final=False)
|
yield await state._as_state_update(handler, event, final=False)
|
||||||
yield state._as_state_update(handler, events=None, final=True)
|
yield await state._as_state_update(handler, events=None, final=True)
|
||||||
|
|
||||||
# Handle regular generators.
|
# Handle regular generators.
|
||||||
elif inspect.isgenerator(events):
|
elif inspect.isgenerator(events):
|
||||||
try:
|
try:
|
||||||
while True:
|
while True:
|
||||||
yield state._as_state_update(handler, next(events), final=False)
|
yield await state._as_state_update(
|
||||||
|
handler, next(events), final=False
|
||||||
|
)
|
||||||
except StopIteration as si:
|
except StopIteration as si:
|
||||||
# the "return" value of the generator is not available
|
# the "return" value of the generator is not available
|
||||||
# in the loop, we must catch StopIteration to access it
|
# in the loop, we must catch StopIteration to access it
|
||||||
if si.value is not None:
|
if si.value is not None:
|
||||||
yield state._as_state_update(handler, si.value, final=False)
|
yield await state._as_state_update(
|
||||||
yield state._as_state_update(handler, events=None, final=True)
|
handler, si.value, final=False
|
||||||
|
)
|
||||||
|
yield await state._as_state_update(handler, events=None, final=True)
|
||||||
|
|
||||||
# Handle regular event chains.
|
# Handle regular event chains.
|
||||||
else:
|
else:
|
||||||
yield state._as_state_update(handler, events, final=True)
|
yield await state._as_state_update(handler, events, final=True)
|
||||||
|
|
||||||
# If an error occurs, throw a window alert.
|
# If an error occurs, throw a window alert.
|
||||||
except Exception as ex:
|
except Exception as ex:
|
||||||
@ -1892,7 +1886,7 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
|
|||||||
prerequisites.get_and_validate_app().app.backend_exception_handler(ex)
|
prerequisites.get_and_validate_app().app.backend_exception_handler(ex)
|
||||||
)
|
)
|
||||||
|
|
||||||
yield state._as_state_update(
|
yield await state._as_state_update(
|
||||||
handler,
|
handler,
|
||||||
event_specs,
|
event_specs,
|
||||||
final=True,
|
final=True,
|
||||||
@ -1900,15 +1894,28 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
|
|||||||
|
|
||||||
def _mark_dirty_computed_vars(self) -> None:
|
def _mark_dirty_computed_vars(self) -> None:
|
||||||
"""Mark ComputedVars that need to be recalculated based on dirty_vars."""
|
"""Mark ComputedVars that need to be recalculated based on dirty_vars."""
|
||||||
|
# Append expired computed vars to dirty_vars to trigger recalculation
|
||||||
|
self.dirty_vars.update(self._expired_computed_vars())
|
||||||
|
# Append always dirty computed vars to dirty_vars to trigger recalculation
|
||||||
|
self.dirty_vars.update(self._always_dirty_computed_vars)
|
||||||
|
|
||||||
dirty_vars = self.dirty_vars
|
dirty_vars = self.dirty_vars
|
||||||
while dirty_vars:
|
while dirty_vars:
|
||||||
calc_vars, dirty_vars = dirty_vars, set()
|
calc_vars, dirty_vars = dirty_vars, set()
|
||||||
for cvar in self._dirty_computed_vars(from_vars=calc_vars):
|
for state_name, cvar in self._dirty_computed_vars(from_vars=calc_vars):
|
||||||
self.dirty_vars.add(cvar)
|
if state_name == self.get_full_name():
|
||||||
|
defining_state = self
|
||||||
|
else:
|
||||||
|
defining_state = self._get_root_state().get_substate(
|
||||||
|
tuple(state_name.split("."))
|
||||||
|
)
|
||||||
|
defining_state.dirty_vars.add(cvar)
|
||||||
dirty_vars.add(cvar)
|
dirty_vars.add(cvar)
|
||||||
actual_var = self.computed_vars.get(cvar)
|
actual_var = defining_state.computed_vars.get(cvar)
|
||||||
if actual_var is not None:
|
if actual_var is not None:
|
||||||
actual_var.mark_dirty(instance=self)
|
actual_var.mark_dirty(instance=defining_state)
|
||||||
|
if defining_state is not self:
|
||||||
|
defining_state._mark_dirty()
|
||||||
|
|
||||||
def _expired_computed_vars(self) -> set[str]:
|
def _expired_computed_vars(self) -> set[str]:
|
||||||
"""Determine ComputedVars that need to be recalculated based on the expiration time.
|
"""Determine ComputedVars that need to be recalculated based on the expiration time.
|
||||||
@ -1924,7 +1931,7 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
|
|||||||
|
|
||||||
def _dirty_computed_vars(
|
def _dirty_computed_vars(
|
||||||
self, from_vars: set[str] | None = None, include_backend: bool = True
|
self, from_vars: set[str] | None = None, include_backend: bool = True
|
||||||
) -> set[str]:
|
) -> set[tuple[str, str]]:
|
||||||
"""Determine ComputedVars that need to be recalculated based on the given vars.
|
"""Determine ComputedVars that need to be recalculated based on the given vars.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@ -1935,32 +1942,59 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
|
|||||||
Set of computed vars to include in the delta.
|
Set of computed vars to include in the delta.
|
||||||
"""
|
"""
|
||||||
return {
|
return {
|
||||||
cvar
|
(state_name, cvar)
|
||||||
for dirty_var in from_vars or self.dirty_vars
|
for dirty_var in from_vars or self.dirty_vars
|
||||||
for cvar in self._computed_var_dependencies[dirty_var]
|
for state_name, cvar in self._var_dependencies.get(dirty_var, set())
|
||||||
if include_backend or not self.computed_vars[cvar]._backend
|
if include_backend or not self.computed_vars[cvar]._backend
|
||||||
}
|
}
|
||||||
|
|
||||||
@classmethod
|
async def _recursively_populate_dependent_substates(
|
||||||
def _potentially_dirty_substates(cls) -> set[Type[BaseState]]:
|
self,
|
||||||
"""Determine substates which could be affected by dirty vars in this state.
|
seen_classes: set[type[BaseState]] | None = None,
|
||||||
|
) -> set[type[BaseState]]:
|
||||||
|
"""Fetch all substates that have computed var dependencies on this state.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
seen_classes: set of classes that have already been seen to prevent infinite recursion.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Set of State classes that may need to be fetched to recalc computed vars.
|
The set of classes that were processed (mostly for testability).
|
||||||
"""
|
"""
|
||||||
# _always_dirty_substates need to be fetched to recalc computed vars.
|
if seen_classes is None:
|
||||||
fetch_substates = {
|
print(
|
||||||
cls.get_class_substate((cls.get_name(), *substate_name.split(".")))
|
f"\n\nTop-level _recursively_populate_dependent_substates from {type(self)}:"
|
||||||
for substate_name in cls._always_dirty_substates
|
|
||||||
}
|
|
||||||
for dependent_substates in cls._substate_var_dependencies.values():
|
|
||||||
fetch_substates.update(
|
|
||||||
{
|
|
||||||
cls.get_class_substate((cls.get_name(), *substate_name.split(".")))
|
|
||||||
for substate_name in dependent_substates
|
|
||||||
}
|
|
||||||
)
|
)
|
||||||
return fetch_substates
|
seen_classes = set()
|
||||||
|
if type(self) in seen_classes:
|
||||||
|
return seen_classes
|
||||||
|
seen_classes.add(type(self))
|
||||||
|
populated_substate_instances = {}
|
||||||
|
for substate_cls in {
|
||||||
|
self.get_class_substate((self.get_name(), *substate_name.split(".")))
|
||||||
|
for substate_name in self._always_dirty_substates
|
||||||
|
}:
|
||||||
|
# _always_dirty_substates need to be fetched to recalc computed vars.
|
||||||
|
if substate_cls not in populated_substate_instances:
|
||||||
|
print(f"fetching always dirty {substate_cls}")
|
||||||
|
populated_substate_instances[substate_cls] = await self.get_state(
|
||||||
|
substate_cls
|
||||||
|
)
|
||||||
|
for dep_set in self._var_dependencies.values():
|
||||||
|
for substate_name, _ in dep_set:
|
||||||
|
if substate_name == self.get_full_name():
|
||||||
|
# Do NOT fetch our own state instance.
|
||||||
|
continue
|
||||||
|
substate_cls = self.get_root_state().get_class_substate(substate_name)
|
||||||
|
if substate_cls not in populated_substate_instances:
|
||||||
|
print(f"fetching dependent {substate_cls}")
|
||||||
|
populated_substate_instances[substate_cls] = await self.get_state(
|
||||||
|
substate_cls
|
||||||
|
)
|
||||||
|
for substate in populated_substate_instances.values():
|
||||||
|
await substate._recursively_populate_dependent_substates(
|
||||||
|
seen_classes=seen_classes,
|
||||||
|
)
|
||||||
|
return seen_classes
|
||||||
|
|
||||||
def get_delta(self) -> Delta:
|
def get_delta(self) -> Delta:
|
||||||
"""Get the delta for the state.
|
"""Get the delta for the state.
|
||||||
@ -1970,21 +2004,15 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
|
|||||||
"""
|
"""
|
||||||
delta = {}
|
delta = {}
|
||||||
|
|
||||||
# Apply dirty variables down into substates
|
self._mark_dirty_computed_vars()
|
||||||
self.dirty_vars.update(self._always_dirty_computed_vars)
|
|
||||||
self._mark_dirty()
|
|
||||||
|
|
||||||
frontend_computed_vars: set[str] = {
|
frontend_computed_vars: set[str] = {
|
||||||
name for name, cv in self.computed_vars.items() if not cv._backend
|
name for name, cv in self.computed_vars.items() if not cv._backend
|
||||||
}
|
}
|
||||||
|
|
||||||
# Return the dirty vars for this instance, any cached/dependent computed vars,
|
# Return the dirty vars for this instance, any cached/dependent computed vars,
|
||||||
# and always dirty computed vars (cache=False)
|
# and always dirty computed vars (cache=False)
|
||||||
delta_vars = (
|
delta_vars = self.dirty_vars.intersection(self.base_vars).union(
|
||||||
self.dirty_vars.intersection(self.base_vars)
|
self.dirty_vars.intersection(frontend_computed_vars)
|
||||||
.union(self.dirty_vars.intersection(frontend_computed_vars))
|
|
||||||
.union(self._dirty_computed_vars(include_backend=False))
|
|
||||||
.union(self._always_dirty_computed_vars)
|
|
||||||
)
|
)
|
||||||
|
|
||||||
subdelta: Dict[str, Any] = {
|
subdelta: Dict[str, Any] = {
|
||||||
@ -2014,23 +2042,9 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
|
|||||||
self.parent_state.dirty_substates.add(self.get_name())
|
self.parent_state.dirty_substates.add(self.get_name())
|
||||||
self.parent_state._mark_dirty()
|
self.parent_state._mark_dirty()
|
||||||
|
|
||||||
# Append expired computed vars to dirty_vars to trigger recalculation
|
|
||||||
self.dirty_vars.update(self._expired_computed_vars())
|
|
||||||
|
|
||||||
# have to mark computed vars dirty to allow access to newly computed
|
# have to mark computed vars dirty to allow access to newly computed
|
||||||
# values within the same ComputedVar function
|
# values within the same ComputedVar function
|
||||||
self._mark_dirty_computed_vars()
|
self._mark_dirty_computed_vars()
|
||||||
self._mark_dirty_substates()
|
|
||||||
|
|
||||||
def _mark_dirty_substates(self):
|
|
||||||
"""Propagate dirty var / computed var status into substates."""
|
|
||||||
substates = self.substates
|
|
||||||
for var in self.dirty_vars:
|
|
||||||
for substate_name in self._substate_var_dependencies[var]:
|
|
||||||
self.dirty_substates.add(substate_name)
|
|
||||||
substate = substates[substate_name]
|
|
||||||
substate.dirty_vars.add(var)
|
|
||||||
substate._mark_dirty()
|
|
||||||
|
|
||||||
def _update_was_touched(self):
|
def _update_was_touched(self):
|
||||||
"""Update the _was_touched flag based on dirty_vars."""
|
"""Update the _was_touched flag based on dirty_vars."""
|
||||||
@ -2102,11 +2116,7 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
|
|||||||
The object as a dictionary.
|
The object as a dictionary.
|
||||||
"""
|
"""
|
||||||
if include_computed:
|
if include_computed:
|
||||||
# Apply dirty variables down into substates to allow never-cached ComputedVar to
|
self._mark_dirty_computed_vars()
|
||||||
# trigger recalculation of dependent vars
|
|
||||||
self.dirty_vars.update(self._always_dirty_computed_vars)
|
|
||||||
self._mark_dirty()
|
|
||||||
|
|
||||||
base_vars = {
|
base_vars = {
|
||||||
prop_name: self.get_value(prop_name) for prop_name in self.base_vars
|
prop_name: self.get_value(prop_name) for prop_name in self.base_vars
|
||||||
}
|
}
|
||||||
@ -3339,6 +3349,79 @@ class StateManagerRedis(StateManager):
|
|||||||
)
|
)
|
||||||
return parent_state
|
return parent_state
|
||||||
|
|
||||||
|
async def _populate_parent_states(
|
||||||
|
self, calling_state: BaseState, target_state_cls: Type[BaseState]
|
||||||
|
):
|
||||||
|
"""Populate substates in the tree between the target_state_cls and common ancestor of calling_state.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
calling_state: The substate instance requesting subtree population.
|
||||||
|
target_state_cls: The class of the state to populate parent states for.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The parent state instance of target_state_cls.
|
||||||
|
"""
|
||||||
|
# Find the missing parent states up to the common ancestor.
|
||||||
|
(
|
||||||
|
common_ancestor_name,
|
||||||
|
missing_parent_states,
|
||||||
|
) = calling_state._determine_missing_parent_states(target_state_cls)
|
||||||
|
|
||||||
|
# Fetch all missing parent states and link them up to the common ancestor.
|
||||||
|
parent_states_tuple = calling_state._get_parent_states()
|
||||||
|
root_state = parent_states_tuple[-1][1]
|
||||||
|
parent_states_by_name = dict(parent_states_tuple)
|
||||||
|
parent_state = parent_states_by_name[common_ancestor_name]
|
||||||
|
for parent_state_name in missing_parent_states:
|
||||||
|
try:
|
||||||
|
parent_state = root_state.get_substate(parent_state_name.split("."))
|
||||||
|
# The requested state is already cached, do NOT fetch it again.
|
||||||
|
continue
|
||||||
|
except ValueError:
|
||||||
|
# The requested state is missing, fetch from redis.
|
||||||
|
pass
|
||||||
|
parent_state = await self.get_state(
|
||||||
|
token=_substate_key(
|
||||||
|
calling_state.router.session.client_token, parent_state_name
|
||||||
|
),
|
||||||
|
top_level=False,
|
||||||
|
get_substates=False,
|
||||||
|
parent_state=parent_state,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Return the direct parent of target_state_cls for subsequent linking.
|
||||||
|
return parent_state
|
||||||
|
|
||||||
|
async def _link_arbitrary_state(
|
||||||
|
self, calling_state: BaseState, state_cls: Type[T_STATE]
|
||||||
|
) -> T_STATE:
|
||||||
|
"""Get a state instance from redis.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
calling_state: The state instance requesting the newly linked instance of state_cls.
|
||||||
|
state_cls: The class of the state to link into the tree.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The instance of state_cls associated with calling_state's client_token.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
StateMismatchError: If the state instance is not of the expected type.
|
||||||
|
"""
|
||||||
|
# Fetch all missing parent states from redis.
|
||||||
|
parent_state_of_state_cls = await self._populate_parent_states(
|
||||||
|
calling_state, state_cls
|
||||||
|
)
|
||||||
|
|
||||||
|
# Then get the target state and all its substates.
|
||||||
|
state_in_redis = await self.get_state(
|
||||||
|
token=_substate_key(calling_state.router.session.client_token, state_cls),
|
||||||
|
top_level=False,
|
||||||
|
get_substates=True,
|
||||||
|
parent_state=parent_state_of_state_cls,
|
||||||
|
)
|
||||||
|
|
||||||
|
return state_in_redis
|
||||||
|
|
||||||
async def _populate_substates(
|
async def _populate_substates(
|
||||||
self,
|
self,
|
||||||
token: str,
|
token: str,
|
||||||
@ -3357,30 +3440,40 @@ class StateManagerRedis(StateManager):
|
|||||||
"""
|
"""
|
||||||
client_token, _ = _split_substate_key(token)
|
client_token, _ = _split_substate_key(token)
|
||||||
|
|
||||||
|
# Only _potentially_dirty_substates need to be fetched to recalc computed vars.
|
||||||
|
fetch_substates = state._get_potentially_dirty_states()
|
||||||
if all_substates:
|
if all_substates:
|
||||||
# All substates are requested.
|
# All substates are requested.
|
||||||
fetch_substates = state.get_substates()
|
fetch_substates.update(state.get_substates())
|
||||||
else:
|
|
||||||
# Only _potentially_dirty_substates need to be fetched to recalc computed vars.
|
|
||||||
fetch_substates = state._potentially_dirty_substates()
|
|
||||||
|
|
||||||
tasks = {}
|
tasks = {}
|
||||||
|
link_tasks = set()
|
||||||
# Retrieve the necessary substates from redis.
|
# Retrieve the necessary substates from redis.
|
||||||
for substate_cls in fetch_substates:
|
for substate_cls in fetch_substates:
|
||||||
if substate_cls.get_name() in state.substates:
|
if substate_cls.get_name() in state.substates:
|
||||||
continue
|
continue
|
||||||
substate_name = substate_cls.get_name()
|
substate_name = substate_cls.get_name()
|
||||||
tasks[substate_name] = asyncio.create_task(
|
if substate_cls in state.get_substates():
|
||||||
self.get_state(
|
tasks[substate_name] = asyncio.create_task(
|
||||||
token=_substate_key(client_token, substate_cls),
|
self.get_state(
|
||||||
top_level=False,
|
token=_substate_key(client_token, substate_cls),
|
||||||
get_substates=all_substates,
|
top_level=False,
|
||||||
parent_state=state,
|
get_substates=all_substates,
|
||||||
|
parent_state=state,
|
||||||
|
)
|
||||||
)
|
)
|
||||||
)
|
else:
|
||||||
|
try:
|
||||||
|
state._get_root_state().get_substate(substate_name.split("."))
|
||||||
|
except ValueError:
|
||||||
|
# The requested state is missing, so fetch and link it (and its parents).
|
||||||
|
link_tasks.add(
|
||||||
|
asyncio.create_task(self._link_arbitrary_state(state, substate_cls))
|
||||||
|
)
|
||||||
|
|
||||||
for substate_name, substate_task in tasks.items():
|
for substate_name, substate_task in tasks.items():
|
||||||
state.substates[substate_name] = await substate_task
|
state.substates[substate_name] = await substate_task
|
||||||
|
await asyncio.gather(*link_tasks)
|
||||||
|
|
||||||
@override
|
@override
|
||||||
async def get_state(
|
async def get_state(
|
||||||
@ -4153,7 +4246,7 @@ def reload_state_module(
|
|||||||
if subclass.__module__ == module and module is not None:
|
if subclass.__module__ == module and module is not None:
|
||||||
state.class_subclasses.remove(subclass)
|
state.class_subclasses.remove(subclass)
|
||||||
state._always_dirty_substates.discard(subclass.get_name())
|
state._always_dirty_substates.discard(subclass.get_name())
|
||||||
state._computed_var_dependencies = defaultdict(set)
|
state._potentially_dirty_substates.discard(subclass.get_name())
|
||||||
state._substate_var_dependencies = defaultdict(set)
|
state._var_dependencies = {}
|
||||||
state._init_var_dependency_dicts()
|
state._init_var_dependency_dicts()
|
||||||
state.get_class_substate.cache_clear()
|
state.get_class_substate.cache_clear()
|
||||||
|
@ -1826,7 +1826,7 @@ class ComputedVar(Var[RETURN_TYPE]):
|
|||||||
_initial_value: RETURN_TYPE | types.Unset = dataclasses.field(default=types.Unset())
|
_initial_value: RETURN_TYPE | types.Unset = dataclasses.field(default=types.Unset())
|
||||||
|
|
||||||
# Explicit var dependencies to track
|
# Explicit var dependencies to track
|
||||||
_static_deps: set[str] = dataclasses.field(default_factory=set)
|
_static_deps: dict[str, set[str]] = dataclasses.field(default_factory=dict)
|
||||||
|
|
||||||
# Whether var dependencies should be auto-determined
|
# Whether var dependencies should be auto-determined
|
||||||
_auto_deps: bool = dataclasses.field(default=True)
|
_auto_deps: bool = dataclasses.field(default=True)
|
||||||
@ -1901,21 +1901,40 @@ class ComputedVar(Var[RETURN_TYPE]):
|
|||||||
|
|
||||||
object.__setattr__(self, "_update_interval", interval)
|
object.__setattr__(self, "_update_interval", interval)
|
||||||
|
|
||||||
if deps is None:
|
_static_deps = {}
|
||||||
deps = []
|
if isinstance(deps, dict):
|
||||||
else:
|
# Assume a dict is coming from _replace, so no special processing.
|
||||||
|
_static_deps = deps
|
||||||
|
elif deps is not None:
|
||||||
for dep in deps:
|
for dep in deps:
|
||||||
if isinstance(dep, Var):
|
if isinstance(dep, Var):
|
||||||
continue
|
state_name = (
|
||||||
if isinstance(dep, str) and dep != "":
|
all_var_data.state
|
||||||
continue
|
if (all_var_data := dep._get_all_var_data())
|
||||||
raise TypeError(
|
and all_var_data.state
|
||||||
"ComputedVar dependencies must be Var instances or var names (non-empty strings)."
|
else None
|
||||||
)
|
)
|
||||||
|
var_name = (
|
||||||
|
dep._js_expr[len(formatted_state_prefix) :]
|
||||||
|
if state_name
|
||||||
|
and (
|
||||||
|
formatted_state_prefix := format_state_name(state_name)
|
||||||
|
+ "."
|
||||||
|
)
|
||||||
|
and dep._js_expr.startswith(formatted_state_prefix)
|
||||||
|
else dep._js_expr
|
||||||
|
)
|
||||||
|
_static_deps.setdefault(state_name, set()).add(var_name)
|
||||||
|
elif isinstance(dep, str) and dep != "":
|
||||||
|
_static_deps.setdefault(None, set()).add(dep)
|
||||||
|
else:
|
||||||
|
raise TypeError(
|
||||||
|
"ComputedVar dependencies must be Var instances or var names (non-empty strings)."
|
||||||
|
)
|
||||||
object.__setattr__(
|
object.__setattr__(
|
||||||
self,
|
self,
|
||||||
"_static_deps",
|
"_static_deps",
|
||||||
{dep._js_expr if isinstance(dep, Var) else dep for dep in deps},
|
_static_deps,
|
||||||
)
|
)
|
||||||
object.__setattr__(self, "_auto_deps", auto_deps)
|
object.__setattr__(self, "_auto_deps", auto_deps)
|
||||||
|
|
||||||
@ -2081,6 +2100,11 @@ class ComputedVar(Var[RETURN_TYPE]):
|
|||||||
setattr(instance, self._last_updated_attr, datetime.datetime.now())
|
setattr(instance, self._last_updated_attr, datetime.datetime.now())
|
||||||
value = getattr(instance, self._cache_attr)
|
value = getattr(instance, self._cache_attr)
|
||||||
|
|
||||||
|
self._check_deprecated_return_type(instance, value)
|
||||||
|
|
||||||
|
return value
|
||||||
|
|
||||||
|
def _check_deprecated_return_type(self, instance, value) -> None:
|
||||||
if not _isinstance(value, self._var_type):
|
if not _isinstance(value, self._var_type):
|
||||||
console.deprecate(
|
console.deprecate(
|
||||||
"mismatched-computed-var-return",
|
"mismatched-computed-var-return",
|
||||||
@ -2090,41 +2114,49 @@ class ComputedVar(Var[RETURN_TYPE]):
|
|||||||
"0.7.0",
|
"0.7.0",
|
||||||
)
|
)
|
||||||
|
|
||||||
return value
|
|
||||||
|
|
||||||
def _deps(
|
def _deps(
|
||||||
self,
|
self,
|
||||||
objclass: Type,
|
objclass: BaseState,
|
||||||
obj: FunctionType | CodeType | None = None,
|
obj: FunctionType | CodeType | None = None,
|
||||||
self_name: Optional[str] = None,
|
self_names: Optional[dict[str, str]] = None,
|
||||||
) -> set[str]:
|
) -> dict[str, set[str]]:
|
||||||
"""Determine var dependencies of this ComputedVar.
|
"""Determine var dependencies of this ComputedVar.
|
||||||
|
|
||||||
Save references to attributes accessed on "self". Recursively called
|
Save references to attributes accessed on "self" or other fetched states.
|
||||||
when the function makes a method call on "self" or define comprehensions
|
|
||||||
or nested functions that may reference "self".
|
Recursively called when the function makes a method call on "self" or
|
||||||
|
define comprehensions or nested functions that may reference "self".
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
objclass: the class obj this ComputedVar is attached to.
|
objclass: the class obj this ComputedVar is attached to.
|
||||||
obj: the object to disassemble (defaults to the fget function).
|
obj: the object to disassemble (defaults to the fget function).
|
||||||
self_name: if specified, look for this name in LOAD_FAST and LOAD_DEREF instructions.
|
self_names: if specified, look for these names in LOAD_FAST and LOAD_DEREF instructions.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
A set of variable names accessed by the given obj.
|
A dictionary mapping state names to the set of variable names
|
||||||
|
accessed by the given obj.
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
VarValueError: if the function references the get_state, parent_state, or substates attributes
|
VarValueError: if the function references the get_state, parent_state, or substates attributes
|
||||||
(cannot track deps in a related state, only implicitly via parent state).
|
(cannot track deps in a related state, only implicitly via parent state).
|
||||||
"""
|
"""
|
||||||
|
from reflex.state import BaseState
|
||||||
|
|
||||||
|
d = {}
|
||||||
|
if self._static_deps:
|
||||||
|
d.update(self._static_deps)
|
||||||
|
# None is a placeholder for the current state class.
|
||||||
|
if None in d:
|
||||||
|
d[objclass.get_full_name()] = d.pop(None)
|
||||||
|
|
||||||
if not self._auto_deps:
|
if not self._auto_deps:
|
||||||
return self._static_deps
|
return d
|
||||||
d = self._static_deps.copy()
|
|
||||||
if obj is None:
|
if obj is None:
|
||||||
fget = self._fget
|
fget = self._fget
|
||||||
if fget is not None:
|
if fget is not None:
|
||||||
obj = cast(FunctionType, fget)
|
obj = cast(FunctionType, fget)
|
||||||
else:
|
else:
|
||||||
return set()
|
return d
|
||||||
with contextlib.suppress(AttributeError):
|
with contextlib.suppress(AttributeError):
|
||||||
# unbox functools.partial
|
# unbox functools.partial
|
||||||
obj = cast(FunctionType, obj.func) # type: ignore
|
obj = cast(FunctionType, obj.func) # type: ignore
|
||||||
@ -2132,76 +2164,150 @@ class ComputedVar(Var[RETURN_TYPE]):
|
|||||||
# unbox EventHandler
|
# unbox EventHandler
|
||||||
obj = cast(FunctionType, obj.fn) # type: ignore
|
obj = cast(FunctionType, obj.fn) # type: ignore
|
||||||
|
|
||||||
if self_name is None and isinstance(obj, FunctionType):
|
if self_names is None and isinstance(obj, FunctionType):
|
||||||
try:
|
try:
|
||||||
# the first argument to the function is the name of "self" arg
|
# the first argument to the function is the name of "self" arg
|
||||||
self_name = obj.__code__.co_varnames[0]
|
self_names = {obj.__code__.co_varnames[0]: objclass.get_full_name()}
|
||||||
except (AttributeError, IndexError):
|
except (AttributeError, IndexError):
|
||||||
self_name = None
|
self_names = None
|
||||||
if self_name is None:
|
if self_names is None:
|
||||||
# cannot reference attributes on self if method takes no args
|
# cannot reference attributes on self if method takes no args
|
||||||
return set()
|
return d
|
||||||
|
|
||||||
invalid_names = ["get_state", "parent_state", "substates", "get_substate"]
|
invalid_names = ["parent_state", "substates", "get_substate"]
|
||||||
self_is_top_of_stack = False
|
self_on_top_of_stack = None
|
||||||
|
getting_state = False
|
||||||
|
getting_var = False
|
||||||
for instruction in dis.get_instructions(obj):
|
for instruction in dis.get_instructions(obj):
|
||||||
|
if getting_state:
|
||||||
|
if instruction.opname == "LOAD_FAST":
|
||||||
|
raise VarValueError(
|
||||||
|
f"Dependency detection cannot identify get_state class from local var {instruction.argval}."
|
||||||
|
)
|
||||||
|
if instruction.opname == "LOAD_GLOBAL":
|
||||||
|
# Special case: referencing state class from global scope.
|
||||||
|
getting_state = obj.__globals__.get(instruction.argval)
|
||||||
|
elif instruction.opname == "LOAD_DEREF":
|
||||||
|
# Special case: referencing state class from closure.
|
||||||
|
closure = dict(zip(obj.__code__.co_freevars, obj.__closure__))
|
||||||
|
try:
|
||||||
|
getting_state = closure[instruction.argval].cell_contents
|
||||||
|
except ValueError as ve:
|
||||||
|
raise VarValueError(
|
||||||
|
f"Cached var {self!s} cannot access arbitrary state `{instruction.argval}`, is it defined yet?."
|
||||||
|
) from ve
|
||||||
|
elif instruction.opname == "STORE_FAST":
|
||||||
|
# Storing the result of get_state in a local variable.
|
||||||
|
if not isinstance(getting_state, type) or not issubclass(
|
||||||
|
getting_state, BaseState
|
||||||
|
):
|
||||||
|
raise VarValueError(
|
||||||
|
f"Cached var {self!s} cannot determine dependencies in fetched state `{instruction.argval}`."
|
||||||
|
)
|
||||||
|
self_names[instruction.argval] = getting_state.get_full_name()
|
||||||
|
getting_state = False
|
||||||
|
continue # nothing else happens until we have identified the local var
|
||||||
|
if getting_var:
|
||||||
|
if instruction.opname == "CALL":
|
||||||
|
# get the original source code and eval it
|
||||||
|
start_line = getting_var[0].positions.lineno
|
||||||
|
start_column = getting_var[0].positions.col_offset
|
||||||
|
end_line = getting_var[-1].positions.end_lineno
|
||||||
|
end_column = getting_var[-1].positions.end_col_offset
|
||||||
|
source = inspect.getsource(inspect.getmodule(obj)).splitlines(True)[start_line - 1: end_line]
|
||||||
|
if len(source) > 1:
|
||||||
|
snipped_source = "".join(
|
||||||
|
[
|
||||||
|
source[0][start_column:],
|
||||||
|
source[1:-2] if len(source) > 2 else "",
|
||||||
|
source[-1][:end_column]
|
||||||
|
]
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
snipped_source = source[0][start_column:end_column]
|
||||||
|
the_var = eval(f"({snipped_source})", obj.__globals__)
|
||||||
|
print(the_var)
|
||||||
|
# code = source[start_line - 1]
|
||||||
|
# bytecode = bytearray((dis.opmap["RESUME"], 0))
|
||||||
|
# for ins in getting_var:
|
||||||
|
# bytecode.append(ins.opcode)
|
||||||
|
# bytecode.append(ins.arg or 0 & 0xFF)
|
||||||
|
# bytecode.extend((dis.opmap["RETURN_VALUE"], 0))
|
||||||
|
# bc = dis.Bytecode(obj)
|
||||||
|
# code = bc.codeobj.replace(co_code=bytes(bytecode), co_argcount=0, co_nlocals=0, co_varnames=())
|
||||||
|
# breakpoint()
|
||||||
|
getting_var = False
|
||||||
|
elif isinstance(getting_var, list):
|
||||||
|
getting_var.append(instruction)
|
||||||
|
else:
|
||||||
|
getting_var = [instruction]
|
||||||
|
continue
|
||||||
if (
|
if (
|
||||||
instruction.opname in ("LOAD_FAST", "LOAD_DEREF")
|
instruction.opname in ("LOAD_FAST", "LOAD_DEREF")
|
||||||
and instruction.argval == self_name
|
and instruction.argval in self_names
|
||||||
):
|
):
|
||||||
# bytecode loaded the class instance to the top of stack, next load instruction
|
# bytecode loaded the class instance to the top of stack, next load instruction
|
||||||
# is referencing an attribute on self
|
# is referencing an attribute on self
|
||||||
self_is_top_of_stack = True
|
self_on_top_of_stack = self_names[instruction.argval]
|
||||||
continue
|
continue
|
||||||
if self_is_top_of_stack and instruction.opname in (
|
if self_on_top_of_stack and instruction.opname in (
|
||||||
"LOAD_ATTR",
|
"LOAD_ATTR",
|
||||||
"LOAD_METHOD",
|
"LOAD_METHOD",
|
||||||
):
|
):
|
||||||
try:
|
|
||||||
ref_obj = getattr(objclass, instruction.argval)
|
|
||||||
except Exception:
|
|
||||||
ref_obj = None
|
|
||||||
if instruction.argval in invalid_names:
|
if instruction.argval in invalid_names:
|
||||||
raise VarValueError(
|
raise VarValueError(
|
||||||
f"Cached var {self!s} cannot access arbitrary state via `{instruction.argval}`."
|
f"Cached var {self!s} cannot access arbitrary state via `{instruction.argval}`."
|
||||||
)
|
)
|
||||||
|
if instruction.argval == "get_state":
|
||||||
|
# Special case: arbitrary state access requested.
|
||||||
|
getting_state = True
|
||||||
|
continue
|
||||||
|
if instruction.argval == "get_var_value":
|
||||||
|
# Special case: arbitrary var access requested.
|
||||||
|
getting_var = True
|
||||||
|
continue
|
||||||
|
print(f"{self_on_top_of_stack=}")
|
||||||
|
target_state = objclass.get_root_state().get_class_substate(
|
||||||
|
self_on_top_of_stack
|
||||||
|
)
|
||||||
|
try:
|
||||||
|
ref_obj = getattr(target_state, instruction.argval)
|
||||||
|
except Exception:
|
||||||
|
ref_obj = None
|
||||||
if callable(ref_obj):
|
if callable(ref_obj):
|
||||||
# recurse into callable attributes
|
# recurse into callable attributes
|
||||||
d.update(
|
for state_name, dep_name in self._deps(
|
||||||
self._deps(
|
objclass=target_state,
|
||||||
objclass=objclass,
|
obj=ref_obj,
|
||||||
obj=ref_obj,
|
).items():
|
||||||
)
|
d.setdefault(state_name, set()).update(dep_name)
|
||||||
)
|
|
||||||
# recurse into property fget functions
|
# recurse into property fget functions
|
||||||
elif isinstance(ref_obj, property) and not isinstance(
|
elif isinstance(ref_obj, property) and not isinstance(
|
||||||
ref_obj, ComputedVar
|
ref_obj, ComputedVar
|
||||||
):
|
):
|
||||||
d.update(
|
for state_name, dep_name in self._deps(
|
||||||
self._deps(
|
objclass=target_state,
|
||||||
objclass=objclass,
|
obj=ref_obj.fget, # type: ignore
|
||||||
obj=ref_obj.fget, # type: ignore
|
).items():
|
||||||
)
|
d.setdefault(state_name, set()).update(dep_name)
|
||||||
)
|
|
||||||
elif (
|
elif (
|
||||||
instruction.argval in objclass.backend_vars
|
instruction.argval in target_state.backend_vars
|
||||||
or instruction.argval in objclass.vars
|
or instruction.argval in target_state.vars
|
||||||
):
|
):
|
||||||
# var access
|
# var access
|
||||||
d.add(instruction.argval)
|
d.setdefault(self_on_top_of_stack, set()).add(instruction.argval)
|
||||||
elif instruction.opname == "LOAD_CONST" and isinstance(
|
elif instruction.opname == "LOAD_CONST" and isinstance(
|
||||||
instruction.argval, CodeType
|
instruction.argval, CodeType
|
||||||
):
|
):
|
||||||
# recurse into nested functions / comprehensions, which can reference
|
# recurse into nested functions / comprehensions, which can reference
|
||||||
# instance attributes from the outer scope
|
# instance attributes from the outer scope
|
||||||
d.update(
|
for state_name, dep_name in self._deps(
|
||||||
self._deps(
|
objclass=objclass,
|
||||||
objclass=objclass,
|
obj=instruction.argval,
|
||||||
obj=instruction.argval,
|
self_names=self_names,
|
||||||
self_name=self_name,
|
).items():
|
||||||
)
|
d.setdefault(state_name, set()).update(dep_name)
|
||||||
)
|
self_on_top_of_stack = None
|
||||||
self_is_top_of_stack = False
|
|
||||||
return d
|
return d
|
||||||
|
|
||||||
def mark_dirty(self, instance) -> None:
|
def mark_dirty(self, instance) -> None:
|
||||||
@ -2249,6 +2355,60 @@ class DynamicRouteVar(ComputedVar[Union[str, List[str]]]):
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
@dataclasses.dataclass(
|
||||||
|
eq=False,
|
||||||
|
frozen=True,
|
||||||
|
init=False,
|
||||||
|
**{"slots": True} if sys.version_info >= (3, 10) else {},
|
||||||
|
)
|
||||||
|
class AsyncComputedVar(ComputedVar[RETURN_TYPE]):
|
||||||
|
"""A computed var that wraps a coroutinefunction."""
|
||||||
|
|
||||||
|
_fget: Callable[[BaseState], RETURN_TYPE] = dataclasses.field(
|
||||||
|
default_factory=lambda: lambda _: None
|
||||||
|
) # type: ignore
|
||||||
|
|
||||||
|
def __get__(self, instance: BaseState | None, owner):
|
||||||
|
"""Get the ComputedVar value.
|
||||||
|
|
||||||
|
If the value is already cached on the instance, return the cached value.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
instance: the instance of the class accessing this computed var.
|
||||||
|
owner: the class that this descriptor is attached to.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The value of the var for the given instance.
|
||||||
|
"""
|
||||||
|
if instance is None:
|
||||||
|
return super(AsyncComputedVar, self).__get__(instance, owner)
|
||||||
|
|
||||||
|
if not self._cache:
|
||||||
|
|
||||||
|
async def _awaitable_result():
|
||||||
|
value = await self.fget(instance)
|
||||||
|
self._check_deprecated_return_type(instance, value)
|
||||||
|
|
||||||
|
return _awaitable_result()
|
||||||
|
else:
|
||||||
|
# handle caching
|
||||||
|
async def _awaitable_result():
|
||||||
|
if not hasattr(instance, self._cache_attr) or self.needs_update(
|
||||||
|
instance
|
||||||
|
):
|
||||||
|
# Set cache attr on state instance.
|
||||||
|
setattr(instance, self._cache_attr, await self.fget(instance))
|
||||||
|
# Ensure the computed var gets serialized to redis.
|
||||||
|
instance._was_touched = True
|
||||||
|
# Set the last updated timestamp on the state instance.
|
||||||
|
setattr(instance, self._last_updated_attr, datetime.datetime.now())
|
||||||
|
value = getattr(instance, self._cache_attr)
|
||||||
|
self._check_deprecated_return_type(instance, value)
|
||||||
|
return value
|
||||||
|
|
||||||
|
return _awaitable_result()
|
||||||
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
BASE_STATE = TypeVar("BASE_STATE", bound=BaseState)
|
BASE_STATE = TypeVar("BASE_STATE", bound=BaseState)
|
||||||
|
|
||||||
@ -2315,10 +2475,27 @@ def computed_var(
|
|||||||
raise VarDependencyError("Cannot track dependencies without caching.")
|
raise VarDependencyError("Cannot track dependencies without caching.")
|
||||||
|
|
||||||
if fget is not None:
|
if fget is not None:
|
||||||
return ComputedVar(fget, cache=cache)
|
if inspect.iscoroutinefunction(fget):
|
||||||
|
computed_var_cls = AsyncComputedVar
|
||||||
|
else:
|
||||||
|
computed_var_cls = ComputedVar
|
||||||
|
return computed_var_cls(
|
||||||
|
fget,
|
||||||
|
initial_value=initial_value,
|
||||||
|
cache=cache,
|
||||||
|
deps=deps,
|
||||||
|
auto_deps=auto_deps,
|
||||||
|
interval=interval,
|
||||||
|
backend=backend,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
def wrapper(fget: Callable[[BASE_STATE], Any]) -> ComputedVar:
|
def wrapper(fget: Callable[[BASE_STATE], Any]) -> ComputedVar:
|
||||||
return ComputedVar(
|
if inspect.iscoroutinefunction(fget):
|
||||||
|
computed_var_cls = AsyncComputedVar
|
||||||
|
else:
|
||||||
|
computed_var_cls = ComputedVar
|
||||||
|
return computed_var_cls(
|
||||||
fget,
|
fget,
|
||||||
initial_value=initial_value,
|
initial_value=initial_value,
|
||||||
cache=cache,
|
cache=cache,
|
||||||
|
@ -277,9 +277,9 @@ def test_add_page_set_route_dynamic(index_page, windows_platform: bool):
|
|||||||
assert app.pages.keys() == {"test/[dynamic]"}
|
assert app.pages.keys() == {"test/[dynamic]"}
|
||||||
assert "dynamic" in app.state.computed_vars
|
assert "dynamic" in app.state.computed_vars
|
||||||
assert app.state.computed_vars["dynamic"]._deps(objclass=EmptyState) == {
|
assert app.state.computed_vars["dynamic"]._deps(objclass=EmptyState) == {
|
||||||
constants.ROUTER
|
EmptyState.get_full_name(): {constants.ROUTER},
|
||||||
}
|
}
|
||||||
assert constants.ROUTER in app.state()._computed_var_dependencies
|
assert constants.ROUTER in app.state()._var_dependencies
|
||||||
|
|
||||||
|
|
||||||
def test_add_page_set_route_nested(app: App, index_page, windows_platform: bool):
|
def test_add_page_set_route_nested(app: App, index_page, windows_platform: bool):
|
||||||
@ -997,9 +997,9 @@ async def test_dynamic_route_var_route_change_completed_on_load(
|
|||||||
assert arg_name in app.state.vars
|
assert arg_name in app.state.vars
|
||||||
assert arg_name in app.state.computed_vars
|
assert arg_name in app.state.computed_vars
|
||||||
assert app.state.computed_vars[arg_name]._deps(objclass=DynamicState) == {
|
assert app.state.computed_vars[arg_name]._deps(objclass=DynamicState) == {
|
||||||
constants.ROUTER
|
DynamicState.get_full_name(): {constants.ROUTER},
|
||||||
}
|
}
|
||||||
assert constants.ROUTER in app.state()._computed_var_dependencies
|
assert constants.ROUTER in app.state()._var_dependencies
|
||||||
|
|
||||||
substate_token = _substate_key(token, DynamicState)
|
substate_token = _substate_key(token, DynamicState)
|
||||||
sid = "mock_sid"
|
sid = "mock_sid"
|
||||||
@ -1557,6 +1557,16 @@ def test_app_with_valid_var_dependencies(compilable_app: tuple[App, Path]):
|
|||||||
def bar(self) -> str:
|
def bar(self) -> str:
|
||||||
return "bar"
|
return "bar"
|
||||||
|
|
||||||
|
class Child1(ValidDepState):
|
||||||
|
@computed_var(deps=["base", ValidDepState.bar])
|
||||||
|
def other(self) -> str:
|
||||||
|
return "other"
|
||||||
|
|
||||||
|
class Child2(ValidDepState):
|
||||||
|
@computed_var(deps=["base", Child1.other])
|
||||||
|
def other(self) -> str:
|
||||||
|
return "other"
|
||||||
|
|
||||||
app.state = ValidDepState
|
app.state = ValidDepState
|
||||||
app._compile()
|
app._compile()
|
||||||
|
|
||||||
|
@ -1170,13 +1170,11 @@ def test_conditional_computed_vars():
|
|||||||
|
|
||||||
ms = MainState()
|
ms = MainState()
|
||||||
# Initially there are no dirty computed vars.
|
# Initially there are no dirty computed vars.
|
||||||
assert ms._dirty_computed_vars(from_vars={"flag"}) == {"rendered_var"}
|
assert ms._dirty_computed_vars(from_vars={"flag"}) == {(MainState.get_full_name(), "rendered_var")}
|
||||||
assert ms._dirty_computed_vars(from_vars={"t2"}) == {"rendered_var"}
|
assert ms._dirty_computed_vars(from_vars={"t2"}) == {(MainState.get_full_name(), "rendered_var")}
|
||||||
assert ms._dirty_computed_vars(from_vars={"t1"}) == {"rendered_var"}
|
assert ms._dirty_computed_vars(from_vars={"t1"}) == {(MainState.get_full_name(), "rendered_var")}
|
||||||
assert ms.computed_vars["rendered_var"]._deps(objclass=MainState) == {
|
assert ms.computed_vars["rendered_var"]._deps(objclass=MainState) == {
|
||||||
"flag",
|
MainState.get_full_name(): {"flag", "t1", "t2"}
|
||||||
"t1",
|
|
||||||
"t2",
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@ -1371,7 +1369,7 @@ def test_cached_var_depends_on_event_handler(use_partial: bool):
|
|||||||
assert isinstance(HandlerState.handler, EventHandler)
|
assert isinstance(HandlerState.handler, EventHandler)
|
||||||
|
|
||||||
s = HandlerState()
|
s = HandlerState()
|
||||||
assert "cached_x_side_effect" in s._computed_var_dependencies["x"]
|
assert (HandlerState.get_full_name(), "cached_x_side_effect") in s._var_dependencies["x"]
|
||||||
assert s.cached_x_side_effect == 1
|
assert s.cached_x_side_effect == 1
|
||||||
assert s.x == 43
|
assert s.x == 43
|
||||||
s.handler()
|
s.handler()
|
||||||
@ -1461,15 +1459,15 @@ def test_computed_var_dependencies():
|
|||||||
return [z in self._z for z in range(5)]
|
return [z in self._z for z in range(5)]
|
||||||
|
|
||||||
cs = ComputedState()
|
cs = ComputedState()
|
||||||
assert cs._computed_var_dependencies["v"] == {
|
assert cs._var_dependencies["v"] == {
|
||||||
"comp_v",
|
(ComputedState.get_full_name(), "comp_v"),
|
||||||
"comp_v_backend",
|
(ComputedState.get_full_name(), "comp_v_backend"),
|
||||||
"comp_v_via_property",
|
(ComputedState.get_full_name(), "comp_v_via_property"),
|
||||||
}
|
}
|
||||||
assert cs._computed_var_dependencies["w"] == {"comp_w"}
|
assert cs._var_dependencies["w"] == {(ComputedState.get_full_name(), "comp_w")}
|
||||||
assert cs._computed_var_dependencies["x"] == {"comp_x"}
|
assert cs._var_dependencies["x"] == {(ComputedState.get_full_name(), "comp_x")}
|
||||||
assert cs._computed_var_dependencies["y"] == {"comp_y"}
|
assert cs._var_dependencies["y"] == {(ComputedState.get_full_name(), "comp_y")}
|
||||||
assert cs._computed_var_dependencies["_z"] == {"comp_z"}
|
assert cs._var_dependencies["_z"] == {(ComputedState.get_full_name(), "comp_z")}
|
||||||
|
|
||||||
|
|
||||||
def test_backend_method():
|
def test_backend_method():
|
||||||
@ -3182,6 +3180,7 @@ async def test_get_state_from_sibling_not_cached(mock_app: rx.App, token: str):
|
|||||||
RxState = State
|
RxState = State
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skip(reason="This test is maybe not relevant anymore.")
|
||||||
def test_potentially_dirty_substates():
|
def test_potentially_dirty_substates():
|
||||||
"""Test that potentially_dirty_substates returns the correct substates.
|
"""Test that potentially_dirty_substates returns the correct substates.
|
||||||
|
|
||||||
@ -3203,7 +3202,8 @@ def test_potentially_dirty_substates():
|
|||||||
assert C1._potentially_dirty_substates() == set()
|
assert C1._potentially_dirty_substates() == set()
|
||||||
|
|
||||||
|
|
||||||
def test_router_var_dep() -> None:
|
@pytest.mark.asyncio
|
||||||
|
async def test_router_var_dep() -> None:
|
||||||
"""Test that router var dependencies are correctly tracked."""
|
"""Test that router var dependencies are correctly tracked."""
|
||||||
|
|
||||||
class RouterVarParentState(State):
|
class RouterVarParentState(State):
|
||||||
@ -3221,13 +3221,9 @@ def test_router_var_dep() -> None:
|
|||||||
foo = RouterVarDepState.computed_vars["foo"]
|
foo = RouterVarDepState.computed_vars["foo"]
|
||||||
State._init_var_dependency_dicts()
|
State._init_var_dependency_dicts()
|
||||||
|
|
||||||
assert foo._deps(objclass=RouterVarDepState) == {"router"}
|
assert foo._deps(objclass=RouterVarDepState) == {RouterVarDepState.get_full_name(): {"router"}}
|
||||||
assert RouterVarParentState._potentially_dirty_substates() == {RouterVarDepState}
|
assert State._var_dependencies == {
|
||||||
assert RouterVarParentState._substate_var_dependencies == {
|
"router": {(RouterVarDepState.get_full_name(), "foo")}
|
||||||
"router": {RouterVarDepState.get_name()}
|
|
||||||
}
|
|
||||||
assert RouterVarDepState._computed_var_dependencies == {
|
|
||||||
"router": {"foo"},
|
|
||||||
}
|
}
|
||||||
|
|
||||||
rx_state = State()
|
rx_state = State()
|
||||||
@ -3240,11 +3236,15 @@ def test_router_var_dep() -> None:
|
|||||||
state.parent_state = parent_state
|
state.parent_state = parent_state
|
||||||
parent_state.substates = {RouterVarDepState.get_name(): state}
|
parent_state.substates = {RouterVarDepState.get_name(): state}
|
||||||
|
|
||||||
|
populated_substate_classes = await rx_state._recursively_populate_dependent_substates()
|
||||||
|
assert populated_substate_classes == {State, RouterVarDepState}
|
||||||
|
|
||||||
assert state.dirty_vars == set()
|
assert state.dirty_vars == set()
|
||||||
|
|
||||||
# Reassign router var
|
# Reassign router var
|
||||||
state.router = state.router
|
state.router = state.router
|
||||||
assert state.dirty_vars == {"foo", "router"}
|
assert rx_state.dirty_vars == {"router"}
|
||||||
|
assert state.dirty_vars == {"foo"}
|
||||||
assert parent_state.dirty_substates == {RouterVarDepState.get_name()}
|
assert parent_state.dirty_substates == {RouterVarDepState.get_name()}
|
||||||
|
|
||||||
|
|
||||||
@ -3803,3 +3803,74 @@ async def test_get_var_value(state_manager: StateManager, substate_token: str):
|
|||||||
# Generic Var with no state
|
# Generic Var with no state
|
||||||
with pytest.raises(UnretrievableVarValueError):
|
with pytest.raises(UnretrievableVarValueError):
|
||||||
await state.get_var_value(rx.Var("undefined"))
|
await state.get_var_value(rx.Var("undefined"))
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_async_computed_var_get_state(mock_app: rx.App, token: str):
|
||||||
|
"""A test where an async computed var depends on a var in another state.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
mock_app: An app that will be returned by `get_app()`
|
||||||
|
token: A token.
|
||||||
|
"""
|
||||||
|
|
||||||
|
class Parent(BaseState):
|
||||||
|
"""A root state like rx.State."""
|
||||||
|
|
||||||
|
parent_var: int = 0
|
||||||
|
|
||||||
|
class Child2(Parent):
|
||||||
|
"""An unconnected child state."""
|
||||||
|
|
||||||
|
pass
|
||||||
|
|
||||||
|
class Child3(Parent):
|
||||||
|
"""A child state with a computed var causing it to be pre-fetched.
|
||||||
|
|
||||||
|
If child3_var gets set to a value, and `get_state` erroneously
|
||||||
|
re-fetches it from redis, the value will be lost.
|
||||||
|
"""
|
||||||
|
|
||||||
|
child3_var: int = 0
|
||||||
|
|
||||||
|
@rx.var(cache=True)
|
||||||
|
def v(self):
|
||||||
|
return self.child3_var
|
||||||
|
|
||||||
|
class Child(Parent):
|
||||||
|
"""A state simulating UpdateVarsInternalState."""
|
||||||
|
|
||||||
|
@rx.var(cache=True)
|
||||||
|
async def v(self):
|
||||||
|
p = await self.get_state(Parent)
|
||||||
|
child3 = await self.get_state(Child3)
|
||||||
|
return child3.child3_var + p.parent_var
|
||||||
|
|
||||||
|
mock_app.state_manager.state = mock_app.state = Parent
|
||||||
|
|
||||||
|
# Get the top level state via unconnected sibling.
|
||||||
|
root = await mock_app.state_manager.get_state(_substate_key(token, Child))
|
||||||
|
# Set value in parent_var to assert it does not get refetched later.
|
||||||
|
root.parent_var = 1
|
||||||
|
|
||||||
|
if isinstance(mock_app.state_manager, StateManagerRedis):
|
||||||
|
# When redis is used, only states with uncached computed vars are pre-fetched.
|
||||||
|
assert Child2.get_name() not in root.substates
|
||||||
|
assert Child3.get_name() not in root.substates
|
||||||
|
|
||||||
|
# Get the unconnected sibling state, which will be used to `get_state` other instances.
|
||||||
|
child = root.get_substate(Child.get_full_name().split("."))
|
||||||
|
|
||||||
|
# Get an uncached child state.
|
||||||
|
child2 = await child.get_state(Child2)
|
||||||
|
assert child2.parent_var == 1
|
||||||
|
|
||||||
|
# Set value on already-cached Child3 state (prefetched because it has a Computed Var).
|
||||||
|
child3 = await child.get_state(Child3)
|
||||||
|
child3.child3_var = 1
|
||||||
|
|
||||||
|
assert await child.v == 2
|
||||||
|
assert await child.v == 2
|
||||||
|
root.parent_var = 2
|
||||||
|
assert await child.v == 3
|
||||||
|
|
||||||
|
@ -15,6 +15,7 @@ from reflex.utils.exceptions import PrimitiveUnserializableToJSON
|
|||||||
from reflex.utils.imports import ImportVar
|
from reflex.utils.imports import ImportVar
|
||||||
from reflex.vars import VarData
|
from reflex.vars import VarData
|
||||||
from reflex.vars.base import (
|
from reflex.vars.base import (
|
||||||
|
AsyncComputedVar,
|
||||||
ComputedVar,
|
ComputedVar,
|
||||||
LiteralVar,
|
LiteralVar,
|
||||||
Var,
|
Var,
|
||||||
@ -1808,9 +1809,9 @@ def cv_fget(state: BaseState) -> int:
|
|||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"deps,expected",
|
"deps,expected",
|
||||||
[
|
[
|
||||||
(["a"], {"a"}),
|
(["a"], {None: {"a"}}),
|
||||||
(["b"], {"b"}),
|
(["b"], {None: {"b"}}),
|
||||||
([ComputedVar(fget=cv_fget)], {"cv_fget"}),
|
([ComputedVar(fget=cv_fget)], {None: {"cv_fget"}}),
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
def test_computed_var_deps(deps: List[Union[str, Var]], expected: Set[str]):
|
def test_computed_var_deps(deps: List[Union[str, Var]], expected: Set[str]):
|
||||||
@ -1856,3 +1857,25 @@ def test_to_string_operation():
|
|||||||
|
|
||||||
single_var = Var.create(Email())
|
single_var = Var.create(Email())
|
||||||
assert single_var._var_type == Email
|
assert single_var._var_type == Email
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_async_computed_var():
|
||||||
|
side_effect_counter = 0
|
||||||
|
|
||||||
|
class AsyncComputedVarState(BaseState):
|
||||||
|
v: int = 1
|
||||||
|
|
||||||
|
@computed_var(cache=True)
|
||||||
|
async def async_computed_var(self) -> int:
|
||||||
|
nonlocal side_effect_counter
|
||||||
|
side_effect_counter += 1
|
||||||
|
return self.v + 1
|
||||||
|
|
||||||
|
my_state = AsyncComputedVarState()
|
||||||
|
assert await my_state.async_computed_var == 2
|
||||||
|
assert await my_state.async_computed_var == 2
|
||||||
|
my_state.v = 2
|
||||||
|
assert await my_state.async_computed_var == 3
|
||||||
|
assert await my_state.async_computed_var == 3
|
||||||
|
assert side_effect_counter == 2
|
||||||
|
Loading…
Reference in New Issue
Block a user