add computed backend vars (#3573)

* add computed backend vars

* finish computed backend vars, add tests

* fix token for AppHarness with redis state manager

* fix timing issues

* add unit tests for computed backend vars

* automagically mark cvs with _ prefix as backend var

* fully migrate backend computed vars

* rename is_backend_variable to is_backend_base_variable

* add integration test for implicit backend cv, adjust comments

* replace expensive backend var check at runtime

* keep stuff together

* simplify backend var check method, consistent naming, improve test typing

* fix: do not convert properties to cvs

* add test for property

* fix cached_properties with _ prefix in state cls
This commit is contained in:
benedikt-bartscher 2024-06-29 02:01:07 +02:00 committed by GitHub
parent bcc7a61452
commit b7651e214b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 134 additions and 50 deletions

View File

@ -26,6 +26,16 @@ def ComputedVars():
def count1(self) -> int:
return self.count
# cached backend var with dep on count
@rx.var(cache=True, interval=15, backend=True)
def count1_backend(self) -> int:
return self.count
# same as above but implicit backend with `_` prefix
@rx.var(cache=True, interval=15)
def _count1_backend(self) -> int:
return self.count
# explicit disabled auto_deps
@rx.var(interval=15, cache=True, auto_deps=False)
def count3(self) -> int:
@ -70,6 +80,10 @@ def ComputedVars():
rx.text(State.count, id="count"),
rx.text("count1:"),
rx.text(State.count1, id="count1"),
rx.text("count1_backend:"),
rx.text(State.count1_backend, id="count1_backend"),
rx.text("_count1_backend:"),
rx.text(State._count1_backend, id="_count1_backend"),
rx.text("count3:"),
rx.text(State.count3, id="count3"),
rx.text("depends_on_count:"),
@ -154,7 +168,8 @@ def token(computed_vars: AppHarness, driver: WebDriver) -> str:
return token
def test_computed_vars(
@pytest.mark.asyncio
async def test_computed_vars(
computed_vars: AppHarness,
driver: WebDriver,
token: str,
@ -168,6 +183,20 @@ def test_computed_vars(
"""
assert computed_vars.app_instance is not None
token = f"{token}_state.state"
state = (await computed_vars.get_state(token)).substates["state"]
assert state is not None
assert state.count1_backend == 0
assert state._count1_backend == 0
# test that backend var is not rendered
count1_backend = driver.find_element(By.ID, "count1_backend")
assert count1_backend
assert count1_backend.text == ""
_count1_backend = driver.find_element(By.ID, "_count1_backend")
assert _count1_backend
assert _count1_backend.text == ""
count = driver.find_element(By.ID, "count")
assert count
assert count.text == "0"
@ -207,6 +236,12 @@ def test_computed_vars(
computed_vars.poll_for_content(depends_on_count, timeout=2, exp_not_equal="0")
== "1"
)
state = (await computed_vars.get_state(token)).substates["state"]
assert state is not None
assert state.count1_backend == 1
assert count1_backend.text == ""
assert state._count1_backend == 1
assert _count1_backend.text == ""
mark_dirty.click()
with pytest.raises(TimeoutError):

View File

@ -305,10 +305,10 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
# Vars inherited by the parent state.
inherited_vars: ClassVar[Dict[str, Var]] = {}
# Backend vars that are never sent to the client.
# Backend base vars that are never sent to the client.
backend_vars: ClassVar[Dict[str, Any]] = {}
# Backend vars inherited
# Backend base vars inherited
inherited_backend_vars: ClassVar[Dict[str, Any]] = {}
# The event handlers.
@ -344,7 +344,7 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
# The routing path that triggered the state
router_data: Dict[str, Any] = {}
# Per-instance copy of backend variable values
# Per-instance copy of backend base variable values
_backend_vars: Dict[str, Any] = {}
# The router data for the current page
@ -492,21 +492,12 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
new_backend_vars = {
name: value
for name, value in cls.__dict__.items()
if types.is_backend_variable(name, cls)
}
# Get backend computed vars
backend_computed_vars = {
v._var_name: v._var_set_state(cls)
for v in computed_vars
if types.is_backend_variable(v._var_name, cls)
and v._var_name not in cls.inherited_backend_vars
if types.is_backend_base_variable(name, cls)
}
cls.backend_vars = {
**cls.inherited_backend_vars,
**new_backend_vars,
**backend_computed_vars,
}
# Set the base and computed vars.
@ -548,7 +539,7 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
cls.computed_vars[newcv._var_name] = newcv
cls.vars[newcv._var_name] = newcv
continue
if types.is_backend_variable(name, mixin):
if types.is_backend_base_variable(name, mixin):
cls.backend_vars[name] = copy.deepcopy(value)
continue
if events.get(name) is not None:
@ -1087,7 +1078,7 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
setattr(self.parent_state, name, value)
return
if types.is_backend_variable(name, type(self)):
if name in self.backend_vars:
self._backend_vars.__setitem__(name, value)
self.dirty_vars.add(name)
self._mark_dirty()
@ -1538,11 +1529,14 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
if self.computed_vars[cvar].needs_update(instance=self)
)
def _dirty_computed_vars(self, from_vars: set[str] | None = None) -> set[str]:
def _dirty_computed_vars(
self, from_vars: set[str] | None = None, include_backend: bool = True
) -> set[str]:
"""Determine ComputedVars that need to be recalculated based on the given vars.
Args:
from_vars: find ComputedVar that depend on this set of vars. If unspecified, will use the dirty_vars.
include_backend: whether to include backend vars in the calculation.
Returns:
Set of computed vars to include in the delta.
@ -1551,6 +1545,7 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
cvar
for dirty_var in from_vars or self.dirty_vars
for cvar in self._computed_var_dependencies[dirty_var]
if include_backend or not self.computed_vars[cvar]._backend
)
@classmethod
@ -1586,19 +1581,23 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
self.dirty_vars.update(self._always_dirty_computed_vars)
self._mark_dirty()
frontend_computed_vars: set[str] = {
name for name, cv in self.computed_vars.items() if not cv._backend
}
# Return the dirty vars for this instance, any cached/dependent computed vars,
# and always dirty computed vars (cache=False)
delta_vars = (
self.dirty_vars.intersection(self.base_vars)
.union(self.dirty_vars.intersection(self.computed_vars))
.union(self._dirty_computed_vars())
.union(self.dirty_vars.intersection(frontend_computed_vars))
.union(self._dirty_computed_vars(include_backend=False))
.union(self._always_dirty_computed_vars)
)
subdelta = {
prop: getattr(self, prop)
for prop in delta_vars
if not types.is_backend_variable(prop, type(self))
if not types.is_backend_base_variable(prop, type(self))
}
if len(subdelta) > 0:
delta[self.get_full_name()] = subdelta
@ -1727,12 +1726,14 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
else self.get_value(getattr(self, prop_name))
)
for prop_name, cv in self.computed_vars.items()
if not cv._backend
}
elif include_computed:
computed_vars = {
# Include the computed vars.
prop_name: self.get_value(getattr(self, prop_name))
for prop_name in self.computed_vars
for prop_name, cv in self.computed_vars.items()
if not cv._backend
}
else:
computed_vars = {}
@ -1745,6 +1746,7 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
for v in self.substates.values()
]:
d.update(substate_d)
return d
async def __aenter__(self) -> BaseState:

View File

@ -6,7 +6,7 @@ import contextlib
import inspect
import sys
import types
from functools import wraps
from functools import cached_property, wraps
from typing import (
Any,
Callable,
@ -410,7 +410,7 @@ def is_valid_var_type(type_: Type) -> bool:
return _issubclass(type_, StateVar) or serializers.has_serializer(type_)
def is_backend_variable(name: str, cls: Type | None = None) -> bool:
def is_backend_base_variable(name: str, cls: Type) -> bool:
"""Check if this variable name correspond to a backend variable.
Args:
@ -429,31 +429,30 @@ def is_backend_variable(name: str, cls: Type | None = None) -> bool:
if name.startswith("__"):
return False
if cls is not None:
if name.startswith(f"_{cls.__name__}__"):
return False
hints = get_type_hints(cls)
if name in hints:
hint = get_origin(hints[name])
if hint == ClassVar:
return False
if name.startswith(f"_{cls.__name__}__"):
return False
if name in cls.inherited_backend_vars:
hints = get_type_hints(cls)
if name in hints:
hint = get_origin(hints[name])
if hint == ClassVar:
return False
if name in cls.__dict__:
value = cls.__dict__[name]
if type(value) == classmethod:
return False
if callable(value):
return False
if isinstance(value, types.FunctionType):
return False
# enable after #3573 is merged
# from reflex.vars import ComputedVar
#
# if isinstance(value, ComputedVar):
# return False
if name in cls.inherited_backend_vars:
return False
if name in cls.__dict__:
value = cls.__dict__[name]
if type(value) == classmethod:
return False
if callable(value):
return False
from reflex.vars import ComputedVar
if isinstance(
value, (types.FunctionType, property, cached_property, ComputedVar)
):
return False
return True

View File

@ -1944,6 +1944,9 @@ class ComputedVar(Var, property):
# Whether to track dependencies and cache computed values
_cache: bool = dataclasses.field(default=False)
# Whether the computed var is a backend var
_backend: bool = dataclasses.field(default=False)
# The initial value of the computed var
_initial_value: Any | types.Unset = dataclasses.field(default=types.Unset())
@ -1964,6 +1967,7 @@ class ComputedVar(Var, property):
deps: Optional[List[Union[str, Var]]] = None,
auto_deps: bool = True,
interval: Optional[Union[int, datetime.timedelta]] = None,
backend: bool | None = None,
**kwargs,
):
"""Initialize a ComputedVar.
@ -1975,11 +1979,16 @@ class ComputedVar(Var, property):
deps: Explicit var dependencies to track.
auto_deps: Whether var dependencies should be auto-determined.
interval: Interval at which the computed var should be updated.
backend: Whether the computed var is a backend var.
**kwargs: additional attributes to set on the instance
Raises:
TypeError: If the computed var dependencies are not Var instances or var names.
"""
if backend is None:
backend = fget.__name__.startswith("_")
self._backend = backend
self._initial_value = initial_value
self._cache = cache
if isinstance(interval, int):
@ -2023,6 +2032,7 @@ class ComputedVar(Var, property):
deps=kwargs.get("deps", self._static_deps),
auto_deps=kwargs.get("auto_deps", self._auto_deps),
interval=kwargs.get("interval", self._update_interval),
backend=kwargs.get("backend", self._backend),
_var_name=kwargs.get("_var_name", self._var_name),
_var_type=kwargs.get("_var_type", self._var_type),
_var_is_local=kwargs.get("_var_is_local", self._var_is_local),
@ -2233,6 +2243,7 @@ def computed_var(
deps: Optional[List[Union[str, Var]]] = None,
auto_deps: bool = True,
interval: Optional[Union[datetime.timedelta, int]] = None,
backend: bool | None = None,
_deprecated_cached_var: bool = False,
**kwargs,
) -> ComputedVar | Callable[[Callable[[BaseState], Any]], ComputedVar]:
@ -2245,6 +2256,7 @@ def computed_var(
deps: Explicit var dependencies to track.
auto_deps: Whether var dependencies should be auto-determined.
interval: Interval at which the computed var should be updated.
backend: Whether the computed var is a backend var.
_deprecated_cached_var: Indicate usage of deprecated cached_var partial function.
**kwargs: additional attributes to set on the instance
@ -2280,6 +2292,7 @@ def computed_var(
deps=deps,
auto_deps=auto_deps,
interval=interval,
backend=backend,
**kwargs,
)

View File

@ -925,6 +925,15 @@ class InterdependentState(BaseState):
"""
return self._v2 * 2
@rx.var(cache=True, backend=True)
def v2x2_backend(self) -> int:
"""Depends on backend var _v2.
Returns:
backend var _v2 multiplied by 2
"""
return self._v2 * 2
@rx.var(cache=True)
def v1x2x2(self) -> int:
"""Depends on ComputedVar v1x2.
@ -1002,11 +1011,11 @@ def test_dirty_computed_var_from_backend_var(
Args:
interdependent_state: A state with varying Var dependencies.
"""
assert InterdependentState._v3._backend is True
interdependent_state._v2 = 2
assert interdependent_state.get_delta() == {
interdependent_state.get_full_name(): {"v2x2": 4, "v3x2": 4},
}
assert "_v3" in InterdependentState.backend_vars
def test_per_state_backend_var(interdependent_state: InterdependentState) -> None:
@ -1295,6 +1304,15 @@ def test_computed_var_dependencies():
"""
return self.v
@rx.var(cache=True, backend=True)
def comp_v_backend(self) -> int:
"""Direct access backend var.
Returns:
The value of self.v.
"""
return self.v
@rx.var(cache=True)
def comp_v_via_property(self) -> int:
"""Access v via property.
@ -1345,7 +1363,11 @@ def test_computed_var_dependencies():
return [z in self._z for z in range(5)]
cs = ComputedState()
assert cs._computed_var_dependencies["v"] == {"comp_v", "comp_v_via_property"}
assert cs._computed_var_dependencies["v"] == {
"comp_v",
"comp_v_backend",
"comp_v_via_property",
}
assert cs._computed_var_dependencies["w"] == {"comp_w"}
assert cs._computed_var_dependencies["x"] == {"comp_x"}
assert cs._computed_var_dependencies["y"] == {"comp_y"}

View File

@ -1,7 +1,8 @@
import os
import typing
from functools import cached_property
from pathlib import Path
from typing import Any, ClassVar, List, Literal, Union
from typing import Any, ClassVar, List, Literal, Type, Union
import pytest
import typer
@ -161,6 +162,14 @@ def test_backend_variable_cls():
def _hidden_method(self):
pass
@property
def _hidden_property(self):
pass
@cached_property
def _cached_hidden_property(self):
pass
return TestBackendVariable
@ -173,10 +182,14 @@ def test_backend_variable_cls():
("_hidden", True),
("not_hidden", False),
("__dundermethod__", False),
("_hidden_property", False),
("_cached_hidden_property", False),
],
)
def test_is_backend_variable(test_backend_variable_cls, input, output):
assert types.is_backend_variable(input, test_backend_variable_cls) == output
def test_is_backend_base_variable(
test_backend_variable_cls: Type[BaseState], input: str, output: bool
):
assert types.is_backend_base_variable(input, test_backend_variable_cls) == output
@pytest.mark.parametrize(