add computed backend vars (#3573)
* add computed backend vars * finish computed backend vars, add tests * fix token for AppHarness with redis state manager * fix timing issues * add unit tests for computed backend vars * automagically mark cvs with _ prefix as backend var * fully migrate backend computed vars * rename is_backend_variable to is_backend_base_variable * add integration test for implicit backend cv, adjust comments * replace expensive backend var check at runtime * keep stuff together * simplify backend var check method, consistent naming, improve test typing * fix: do not convert properties to cvs * add test for property * fix cached_properties with _ prefix in state cls
This commit is contained in:
parent
bcc7a61452
commit
b7651e214b
@ -26,6 +26,16 @@ def ComputedVars():
|
||||
def count1(self) -> int:
|
||||
return self.count
|
||||
|
||||
# cached backend var with dep on count
|
||||
@rx.var(cache=True, interval=15, backend=True)
|
||||
def count1_backend(self) -> int:
|
||||
return self.count
|
||||
|
||||
# same as above but implicit backend with `_` prefix
|
||||
@rx.var(cache=True, interval=15)
|
||||
def _count1_backend(self) -> int:
|
||||
return self.count
|
||||
|
||||
# explicit disabled auto_deps
|
||||
@rx.var(interval=15, cache=True, auto_deps=False)
|
||||
def count3(self) -> int:
|
||||
@ -70,6 +80,10 @@ def ComputedVars():
|
||||
rx.text(State.count, id="count"),
|
||||
rx.text("count1:"),
|
||||
rx.text(State.count1, id="count1"),
|
||||
rx.text("count1_backend:"),
|
||||
rx.text(State.count1_backend, id="count1_backend"),
|
||||
rx.text("_count1_backend:"),
|
||||
rx.text(State._count1_backend, id="_count1_backend"),
|
||||
rx.text("count3:"),
|
||||
rx.text(State.count3, id="count3"),
|
||||
rx.text("depends_on_count:"),
|
||||
@ -154,7 +168,8 @@ def token(computed_vars: AppHarness, driver: WebDriver) -> str:
|
||||
return token
|
||||
|
||||
|
||||
def test_computed_vars(
|
||||
@pytest.mark.asyncio
|
||||
async def test_computed_vars(
|
||||
computed_vars: AppHarness,
|
||||
driver: WebDriver,
|
||||
token: str,
|
||||
@ -168,6 +183,20 @@ def test_computed_vars(
|
||||
"""
|
||||
assert computed_vars.app_instance is not None
|
||||
|
||||
token = f"{token}_state.state"
|
||||
state = (await computed_vars.get_state(token)).substates["state"]
|
||||
assert state is not None
|
||||
assert state.count1_backend == 0
|
||||
assert state._count1_backend == 0
|
||||
|
||||
# test that backend var is not rendered
|
||||
count1_backend = driver.find_element(By.ID, "count1_backend")
|
||||
assert count1_backend
|
||||
assert count1_backend.text == ""
|
||||
_count1_backend = driver.find_element(By.ID, "_count1_backend")
|
||||
assert _count1_backend
|
||||
assert _count1_backend.text == ""
|
||||
|
||||
count = driver.find_element(By.ID, "count")
|
||||
assert count
|
||||
assert count.text == "0"
|
||||
@ -207,6 +236,12 @@ def test_computed_vars(
|
||||
computed_vars.poll_for_content(depends_on_count, timeout=2, exp_not_equal="0")
|
||||
== "1"
|
||||
)
|
||||
state = (await computed_vars.get_state(token)).substates["state"]
|
||||
assert state is not None
|
||||
assert state.count1_backend == 1
|
||||
assert count1_backend.text == ""
|
||||
assert state._count1_backend == 1
|
||||
assert _count1_backend.text == ""
|
||||
|
||||
mark_dirty.click()
|
||||
with pytest.raises(TimeoutError):
|
||||
|
@ -305,10 +305,10 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
|
||||
# Vars inherited by the parent state.
|
||||
inherited_vars: ClassVar[Dict[str, Var]] = {}
|
||||
|
||||
# Backend vars that are never sent to the client.
|
||||
# Backend base vars that are never sent to the client.
|
||||
backend_vars: ClassVar[Dict[str, Any]] = {}
|
||||
|
||||
# Backend vars inherited
|
||||
# Backend base vars inherited
|
||||
inherited_backend_vars: ClassVar[Dict[str, Any]] = {}
|
||||
|
||||
# The event handlers.
|
||||
@ -344,7 +344,7 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
|
||||
# The routing path that triggered the state
|
||||
router_data: Dict[str, Any] = {}
|
||||
|
||||
# Per-instance copy of backend variable values
|
||||
# Per-instance copy of backend base variable values
|
||||
_backend_vars: Dict[str, Any] = {}
|
||||
|
||||
# The router data for the current page
|
||||
@ -492,21 +492,12 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
|
||||
new_backend_vars = {
|
||||
name: value
|
||||
for name, value in cls.__dict__.items()
|
||||
if types.is_backend_variable(name, cls)
|
||||
}
|
||||
|
||||
# Get backend computed vars
|
||||
backend_computed_vars = {
|
||||
v._var_name: v._var_set_state(cls)
|
||||
for v in computed_vars
|
||||
if types.is_backend_variable(v._var_name, cls)
|
||||
and v._var_name not in cls.inherited_backend_vars
|
||||
if types.is_backend_base_variable(name, cls)
|
||||
}
|
||||
|
||||
cls.backend_vars = {
|
||||
**cls.inherited_backend_vars,
|
||||
**new_backend_vars,
|
||||
**backend_computed_vars,
|
||||
}
|
||||
|
||||
# Set the base and computed vars.
|
||||
@ -548,7 +539,7 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
|
||||
cls.computed_vars[newcv._var_name] = newcv
|
||||
cls.vars[newcv._var_name] = newcv
|
||||
continue
|
||||
if types.is_backend_variable(name, mixin):
|
||||
if types.is_backend_base_variable(name, mixin):
|
||||
cls.backend_vars[name] = copy.deepcopy(value)
|
||||
continue
|
||||
if events.get(name) is not None:
|
||||
@ -1087,7 +1078,7 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
|
||||
setattr(self.parent_state, name, value)
|
||||
return
|
||||
|
||||
if types.is_backend_variable(name, type(self)):
|
||||
if name in self.backend_vars:
|
||||
self._backend_vars.__setitem__(name, value)
|
||||
self.dirty_vars.add(name)
|
||||
self._mark_dirty()
|
||||
@ -1538,11 +1529,14 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
|
||||
if self.computed_vars[cvar].needs_update(instance=self)
|
||||
)
|
||||
|
||||
def _dirty_computed_vars(self, from_vars: set[str] | None = None) -> set[str]:
|
||||
def _dirty_computed_vars(
|
||||
self, from_vars: set[str] | None = None, include_backend: bool = True
|
||||
) -> set[str]:
|
||||
"""Determine ComputedVars that need to be recalculated based on the given vars.
|
||||
|
||||
Args:
|
||||
from_vars: find ComputedVar that depend on this set of vars. If unspecified, will use the dirty_vars.
|
||||
include_backend: whether to include backend vars in the calculation.
|
||||
|
||||
Returns:
|
||||
Set of computed vars to include in the delta.
|
||||
@ -1551,6 +1545,7 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
|
||||
cvar
|
||||
for dirty_var in from_vars or self.dirty_vars
|
||||
for cvar in self._computed_var_dependencies[dirty_var]
|
||||
if include_backend or not self.computed_vars[cvar]._backend
|
||||
)
|
||||
|
||||
@classmethod
|
||||
@ -1586,19 +1581,23 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
|
||||
self.dirty_vars.update(self._always_dirty_computed_vars)
|
||||
self._mark_dirty()
|
||||
|
||||
frontend_computed_vars: set[str] = {
|
||||
name for name, cv in self.computed_vars.items() if not cv._backend
|
||||
}
|
||||
|
||||
# Return the dirty vars for this instance, any cached/dependent computed vars,
|
||||
# and always dirty computed vars (cache=False)
|
||||
delta_vars = (
|
||||
self.dirty_vars.intersection(self.base_vars)
|
||||
.union(self.dirty_vars.intersection(self.computed_vars))
|
||||
.union(self._dirty_computed_vars())
|
||||
.union(self.dirty_vars.intersection(frontend_computed_vars))
|
||||
.union(self._dirty_computed_vars(include_backend=False))
|
||||
.union(self._always_dirty_computed_vars)
|
||||
)
|
||||
|
||||
subdelta = {
|
||||
prop: getattr(self, prop)
|
||||
for prop in delta_vars
|
||||
if not types.is_backend_variable(prop, type(self))
|
||||
if not types.is_backend_base_variable(prop, type(self))
|
||||
}
|
||||
if len(subdelta) > 0:
|
||||
delta[self.get_full_name()] = subdelta
|
||||
@ -1727,12 +1726,14 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
|
||||
else self.get_value(getattr(self, prop_name))
|
||||
)
|
||||
for prop_name, cv in self.computed_vars.items()
|
||||
if not cv._backend
|
||||
}
|
||||
elif include_computed:
|
||||
computed_vars = {
|
||||
# Include the computed vars.
|
||||
prop_name: self.get_value(getattr(self, prop_name))
|
||||
for prop_name in self.computed_vars
|
||||
for prop_name, cv in self.computed_vars.items()
|
||||
if not cv._backend
|
||||
}
|
||||
else:
|
||||
computed_vars = {}
|
||||
@ -1745,6 +1746,7 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
|
||||
for v in self.substates.values()
|
||||
]:
|
||||
d.update(substate_d)
|
||||
|
||||
return d
|
||||
|
||||
async def __aenter__(self) -> BaseState:
|
||||
|
@ -6,7 +6,7 @@ import contextlib
|
||||
import inspect
|
||||
import sys
|
||||
import types
|
||||
from functools import wraps
|
||||
from functools import cached_property, wraps
|
||||
from typing import (
|
||||
Any,
|
||||
Callable,
|
||||
@ -410,7 +410,7 @@ def is_valid_var_type(type_: Type) -> bool:
|
||||
return _issubclass(type_, StateVar) or serializers.has_serializer(type_)
|
||||
|
||||
|
||||
def is_backend_variable(name: str, cls: Type | None = None) -> bool:
|
||||
def is_backend_base_variable(name: str, cls: Type) -> bool:
|
||||
"""Check if this variable name correspond to a backend variable.
|
||||
|
||||
Args:
|
||||
@ -429,31 +429,30 @@ def is_backend_variable(name: str, cls: Type | None = None) -> bool:
|
||||
if name.startswith("__"):
|
||||
return False
|
||||
|
||||
if cls is not None:
|
||||
if name.startswith(f"_{cls.__name__}__"):
|
||||
return False
|
||||
hints = get_type_hints(cls)
|
||||
if name in hints:
|
||||
hint = get_origin(hints[name])
|
||||
if hint == ClassVar:
|
||||
return False
|
||||
if name.startswith(f"_{cls.__name__}__"):
|
||||
return False
|
||||
|
||||
if name in cls.inherited_backend_vars:
|
||||
hints = get_type_hints(cls)
|
||||
if name in hints:
|
||||
hint = get_origin(hints[name])
|
||||
if hint == ClassVar:
|
||||
return False
|
||||
|
||||
if name in cls.__dict__:
|
||||
value = cls.__dict__[name]
|
||||
if type(value) == classmethod:
|
||||
return False
|
||||
if callable(value):
|
||||
return False
|
||||
if isinstance(value, types.FunctionType):
|
||||
return False
|
||||
# enable after #3573 is merged
|
||||
# from reflex.vars import ComputedVar
|
||||
#
|
||||
# if isinstance(value, ComputedVar):
|
||||
# return False
|
||||
if name in cls.inherited_backend_vars:
|
||||
return False
|
||||
|
||||
if name in cls.__dict__:
|
||||
value = cls.__dict__[name]
|
||||
if type(value) == classmethod:
|
||||
return False
|
||||
if callable(value):
|
||||
return False
|
||||
from reflex.vars import ComputedVar
|
||||
|
||||
if isinstance(
|
||||
value, (types.FunctionType, property, cached_property, ComputedVar)
|
||||
):
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
|
@ -1944,6 +1944,9 @@ class ComputedVar(Var, property):
|
||||
# Whether to track dependencies and cache computed values
|
||||
_cache: bool = dataclasses.field(default=False)
|
||||
|
||||
# Whether the computed var is a backend var
|
||||
_backend: bool = dataclasses.field(default=False)
|
||||
|
||||
# The initial value of the computed var
|
||||
_initial_value: Any | types.Unset = dataclasses.field(default=types.Unset())
|
||||
|
||||
@ -1964,6 +1967,7 @@ class ComputedVar(Var, property):
|
||||
deps: Optional[List[Union[str, Var]]] = None,
|
||||
auto_deps: bool = True,
|
||||
interval: Optional[Union[int, datetime.timedelta]] = None,
|
||||
backend: bool | None = None,
|
||||
**kwargs,
|
||||
):
|
||||
"""Initialize a ComputedVar.
|
||||
@ -1975,11 +1979,16 @@ class ComputedVar(Var, property):
|
||||
deps: Explicit var dependencies to track.
|
||||
auto_deps: Whether var dependencies should be auto-determined.
|
||||
interval: Interval at which the computed var should be updated.
|
||||
backend: Whether the computed var is a backend var.
|
||||
**kwargs: additional attributes to set on the instance
|
||||
|
||||
Raises:
|
||||
TypeError: If the computed var dependencies are not Var instances or var names.
|
||||
"""
|
||||
if backend is None:
|
||||
backend = fget.__name__.startswith("_")
|
||||
self._backend = backend
|
||||
|
||||
self._initial_value = initial_value
|
||||
self._cache = cache
|
||||
if isinstance(interval, int):
|
||||
@ -2023,6 +2032,7 @@ class ComputedVar(Var, property):
|
||||
deps=kwargs.get("deps", self._static_deps),
|
||||
auto_deps=kwargs.get("auto_deps", self._auto_deps),
|
||||
interval=kwargs.get("interval", self._update_interval),
|
||||
backend=kwargs.get("backend", self._backend),
|
||||
_var_name=kwargs.get("_var_name", self._var_name),
|
||||
_var_type=kwargs.get("_var_type", self._var_type),
|
||||
_var_is_local=kwargs.get("_var_is_local", self._var_is_local),
|
||||
@ -2233,6 +2243,7 @@ def computed_var(
|
||||
deps: Optional[List[Union[str, Var]]] = None,
|
||||
auto_deps: bool = True,
|
||||
interval: Optional[Union[datetime.timedelta, int]] = None,
|
||||
backend: bool | None = None,
|
||||
_deprecated_cached_var: bool = False,
|
||||
**kwargs,
|
||||
) -> ComputedVar | Callable[[Callable[[BaseState], Any]], ComputedVar]:
|
||||
@ -2245,6 +2256,7 @@ def computed_var(
|
||||
deps: Explicit var dependencies to track.
|
||||
auto_deps: Whether var dependencies should be auto-determined.
|
||||
interval: Interval at which the computed var should be updated.
|
||||
backend: Whether the computed var is a backend var.
|
||||
_deprecated_cached_var: Indicate usage of deprecated cached_var partial function.
|
||||
**kwargs: additional attributes to set on the instance
|
||||
|
||||
@ -2280,6 +2292,7 @@ def computed_var(
|
||||
deps=deps,
|
||||
auto_deps=auto_deps,
|
||||
interval=interval,
|
||||
backend=backend,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
@ -925,6 +925,15 @@ class InterdependentState(BaseState):
|
||||
"""
|
||||
return self._v2 * 2
|
||||
|
||||
@rx.var(cache=True, backend=True)
|
||||
def v2x2_backend(self) -> int:
|
||||
"""Depends on backend var _v2.
|
||||
|
||||
Returns:
|
||||
backend var _v2 multiplied by 2
|
||||
"""
|
||||
return self._v2 * 2
|
||||
|
||||
@rx.var(cache=True)
|
||||
def v1x2x2(self) -> int:
|
||||
"""Depends on ComputedVar v1x2.
|
||||
@ -1002,11 +1011,11 @@ def test_dirty_computed_var_from_backend_var(
|
||||
Args:
|
||||
interdependent_state: A state with varying Var dependencies.
|
||||
"""
|
||||
assert InterdependentState._v3._backend is True
|
||||
interdependent_state._v2 = 2
|
||||
assert interdependent_state.get_delta() == {
|
||||
interdependent_state.get_full_name(): {"v2x2": 4, "v3x2": 4},
|
||||
}
|
||||
assert "_v3" in InterdependentState.backend_vars
|
||||
|
||||
|
||||
def test_per_state_backend_var(interdependent_state: InterdependentState) -> None:
|
||||
@ -1295,6 +1304,15 @@ def test_computed_var_dependencies():
|
||||
"""
|
||||
return self.v
|
||||
|
||||
@rx.var(cache=True, backend=True)
|
||||
def comp_v_backend(self) -> int:
|
||||
"""Direct access backend var.
|
||||
|
||||
Returns:
|
||||
The value of self.v.
|
||||
"""
|
||||
return self.v
|
||||
|
||||
@rx.var(cache=True)
|
||||
def comp_v_via_property(self) -> int:
|
||||
"""Access v via property.
|
||||
@ -1345,7 +1363,11 @@ def test_computed_var_dependencies():
|
||||
return [z in self._z for z in range(5)]
|
||||
|
||||
cs = ComputedState()
|
||||
assert cs._computed_var_dependencies["v"] == {"comp_v", "comp_v_via_property"}
|
||||
assert cs._computed_var_dependencies["v"] == {
|
||||
"comp_v",
|
||||
"comp_v_backend",
|
||||
"comp_v_via_property",
|
||||
}
|
||||
assert cs._computed_var_dependencies["w"] == {"comp_w"}
|
||||
assert cs._computed_var_dependencies["x"] == {"comp_x"}
|
||||
assert cs._computed_var_dependencies["y"] == {"comp_y"}
|
||||
|
@ -1,7 +1,8 @@
|
||||
import os
|
||||
import typing
|
||||
from functools import cached_property
|
||||
from pathlib import Path
|
||||
from typing import Any, ClassVar, List, Literal, Union
|
||||
from typing import Any, ClassVar, List, Literal, Type, Union
|
||||
|
||||
import pytest
|
||||
import typer
|
||||
@ -161,6 +162,14 @@ def test_backend_variable_cls():
|
||||
def _hidden_method(self):
|
||||
pass
|
||||
|
||||
@property
|
||||
def _hidden_property(self):
|
||||
pass
|
||||
|
||||
@cached_property
|
||||
def _cached_hidden_property(self):
|
||||
pass
|
||||
|
||||
return TestBackendVariable
|
||||
|
||||
|
||||
@ -173,10 +182,14 @@ def test_backend_variable_cls():
|
||||
("_hidden", True),
|
||||
("not_hidden", False),
|
||||
("__dundermethod__", False),
|
||||
("_hidden_property", False),
|
||||
("_cached_hidden_property", False),
|
||||
],
|
||||
)
|
||||
def test_is_backend_variable(test_backend_variable_cls, input, output):
|
||||
assert types.is_backend_variable(input, test_backend_variable_cls) == output
|
||||
def test_is_backend_base_variable(
|
||||
test_backend_variable_cls: Type[BaseState], input: str, output: bool
|
||||
):
|
||||
assert types.is_backend_base_variable(input, test_backend_variable_cls) == output
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
|
Loading…
Reference in New Issue
Block a user