add computed backend vars (#3573)

* add computed backend vars

* finish computed backend vars, add tests

* fix token for AppHarness with redis state manager

* fix timing issues

* add unit tests for computed backend vars

* automagically mark cvs with _ prefix as backend var

* fully migrate backend computed vars

* rename is_backend_variable to is_backend_base_variable

* add integration test for implicit backend cv, adjust comments

* replace expensive backend var check at runtime

* keep stuff together

* simplify backend var check method, consistent naming, improve test typing

* fix: do not convert properties to cvs

* add test for property

* fix cached_properties with _ prefix in state cls
This commit is contained in:
benedikt-bartscher 2024-06-29 02:01:07 +02:00 committed by GitHub
parent bcc7a61452
commit b7651e214b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 134 additions and 50 deletions

View File

@ -26,6 +26,16 @@ def ComputedVars():
def count1(self) -> int: def count1(self) -> int:
return self.count return self.count
# cached backend var with dep on count
@rx.var(cache=True, interval=15, backend=True)
def count1_backend(self) -> int:
return self.count
# same as above but implicit backend with `_` prefix
@rx.var(cache=True, interval=15)
def _count1_backend(self) -> int:
return self.count
# explicit disabled auto_deps # explicit disabled auto_deps
@rx.var(interval=15, cache=True, auto_deps=False) @rx.var(interval=15, cache=True, auto_deps=False)
def count3(self) -> int: def count3(self) -> int:
@ -70,6 +80,10 @@ def ComputedVars():
rx.text(State.count, id="count"), rx.text(State.count, id="count"),
rx.text("count1:"), rx.text("count1:"),
rx.text(State.count1, id="count1"), rx.text(State.count1, id="count1"),
rx.text("count1_backend:"),
rx.text(State.count1_backend, id="count1_backend"),
rx.text("_count1_backend:"),
rx.text(State._count1_backend, id="_count1_backend"),
rx.text("count3:"), rx.text("count3:"),
rx.text(State.count3, id="count3"), rx.text(State.count3, id="count3"),
rx.text("depends_on_count:"), rx.text("depends_on_count:"),
@ -154,7 +168,8 @@ def token(computed_vars: AppHarness, driver: WebDriver) -> str:
return token return token
def test_computed_vars( @pytest.mark.asyncio
async def test_computed_vars(
computed_vars: AppHarness, computed_vars: AppHarness,
driver: WebDriver, driver: WebDriver,
token: str, token: str,
@ -168,6 +183,20 @@ def test_computed_vars(
""" """
assert computed_vars.app_instance is not None assert computed_vars.app_instance is not None
token = f"{token}_state.state"
state = (await computed_vars.get_state(token)).substates["state"]
assert state is not None
assert state.count1_backend == 0
assert state._count1_backend == 0
# test that backend var is not rendered
count1_backend = driver.find_element(By.ID, "count1_backend")
assert count1_backend
assert count1_backend.text == ""
_count1_backend = driver.find_element(By.ID, "_count1_backend")
assert _count1_backend
assert _count1_backend.text == ""
count = driver.find_element(By.ID, "count") count = driver.find_element(By.ID, "count")
assert count assert count
assert count.text == "0" assert count.text == "0"
@ -207,6 +236,12 @@ def test_computed_vars(
computed_vars.poll_for_content(depends_on_count, timeout=2, exp_not_equal="0") computed_vars.poll_for_content(depends_on_count, timeout=2, exp_not_equal="0")
== "1" == "1"
) )
state = (await computed_vars.get_state(token)).substates["state"]
assert state is not None
assert state.count1_backend == 1
assert count1_backend.text == ""
assert state._count1_backend == 1
assert _count1_backend.text == ""
mark_dirty.click() mark_dirty.click()
with pytest.raises(TimeoutError): with pytest.raises(TimeoutError):

View File

@ -305,10 +305,10 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
# Vars inherited by the parent state. # Vars inherited by the parent state.
inherited_vars: ClassVar[Dict[str, Var]] = {} inherited_vars: ClassVar[Dict[str, Var]] = {}
# Backend vars that are never sent to the client. # Backend base vars that are never sent to the client.
backend_vars: ClassVar[Dict[str, Any]] = {} backend_vars: ClassVar[Dict[str, Any]] = {}
# Backend vars inherited # Backend base vars inherited
inherited_backend_vars: ClassVar[Dict[str, Any]] = {} inherited_backend_vars: ClassVar[Dict[str, Any]] = {}
# The event handlers. # The event handlers.
@ -344,7 +344,7 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
# The routing path that triggered the state # The routing path that triggered the state
router_data: Dict[str, Any] = {} router_data: Dict[str, Any] = {}
# Per-instance copy of backend variable values # Per-instance copy of backend base variable values
_backend_vars: Dict[str, Any] = {} _backend_vars: Dict[str, Any] = {}
# The router data for the current page # The router data for the current page
@ -492,21 +492,12 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
new_backend_vars = { new_backend_vars = {
name: value name: value
for name, value in cls.__dict__.items() for name, value in cls.__dict__.items()
if types.is_backend_variable(name, cls) if types.is_backend_base_variable(name, cls)
}
# Get backend computed vars
backend_computed_vars = {
v._var_name: v._var_set_state(cls)
for v in computed_vars
if types.is_backend_variable(v._var_name, cls)
and v._var_name not in cls.inherited_backend_vars
} }
cls.backend_vars = { cls.backend_vars = {
**cls.inherited_backend_vars, **cls.inherited_backend_vars,
**new_backend_vars, **new_backend_vars,
**backend_computed_vars,
} }
# Set the base and computed vars. # Set the base and computed vars.
@ -548,7 +539,7 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
cls.computed_vars[newcv._var_name] = newcv cls.computed_vars[newcv._var_name] = newcv
cls.vars[newcv._var_name] = newcv cls.vars[newcv._var_name] = newcv
continue continue
if types.is_backend_variable(name, mixin): if types.is_backend_base_variable(name, mixin):
cls.backend_vars[name] = copy.deepcopy(value) cls.backend_vars[name] = copy.deepcopy(value)
continue continue
if events.get(name) is not None: if events.get(name) is not None:
@ -1087,7 +1078,7 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
setattr(self.parent_state, name, value) setattr(self.parent_state, name, value)
return return
if types.is_backend_variable(name, type(self)): if name in self.backend_vars:
self._backend_vars.__setitem__(name, value) self._backend_vars.__setitem__(name, value)
self.dirty_vars.add(name) self.dirty_vars.add(name)
self._mark_dirty() self._mark_dirty()
@ -1538,11 +1529,14 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
if self.computed_vars[cvar].needs_update(instance=self) if self.computed_vars[cvar].needs_update(instance=self)
) )
def _dirty_computed_vars(self, from_vars: set[str] | None = None) -> set[str]: def _dirty_computed_vars(
self, from_vars: set[str] | None = None, include_backend: bool = True
) -> set[str]:
"""Determine ComputedVars that need to be recalculated based on the given vars. """Determine ComputedVars that need to be recalculated based on the given vars.
Args: Args:
from_vars: find ComputedVar that depend on this set of vars. If unspecified, will use the dirty_vars. from_vars: find ComputedVar that depend on this set of vars. If unspecified, will use the dirty_vars.
include_backend: whether to include backend vars in the calculation.
Returns: Returns:
Set of computed vars to include in the delta. Set of computed vars to include in the delta.
@ -1551,6 +1545,7 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
cvar cvar
for dirty_var in from_vars or self.dirty_vars for dirty_var in from_vars or self.dirty_vars
for cvar in self._computed_var_dependencies[dirty_var] for cvar in self._computed_var_dependencies[dirty_var]
if include_backend or not self.computed_vars[cvar]._backend
) )
@classmethod @classmethod
@ -1586,19 +1581,23 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
self.dirty_vars.update(self._always_dirty_computed_vars) self.dirty_vars.update(self._always_dirty_computed_vars)
self._mark_dirty() self._mark_dirty()
frontend_computed_vars: set[str] = {
name for name, cv in self.computed_vars.items() if not cv._backend
}
# Return the dirty vars for this instance, any cached/dependent computed vars, # Return the dirty vars for this instance, any cached/dependent computed vars,
# and always dirty computed vars (cache=False) # and always dirty computed vars (cache=False)
delta_vars = ( delta_vars = (
self.dirty_vars.intersection(self.base_vars) self.dirty_vars.intersection(self.base_vars)
.union(self.dirty_vars.intersection(self.computed_vars)) .union(self.dirty_vars.intersection(frontend_computed_vars))
.union(self._dirty_computed_vars()) .union(self._dirty_computed_vars(include_backend=False))
.union(self._always_dirty_computed_vars) .union(self._always_dirty_computed_vars)
) )
subdelta = { subdelta = {
prop: getattr(self, prop) prop: getattr(self, prop)
for prop in delta_vars for prop in delta_vars
if not types.is_backend_variable(prop, type(self)) if not types.is_backend_base_variable(prop, type(self))
} }
if len(subdelta) > 0: if len(subdelta) > 0:
delta[self.get_full_name()] = subdelta delta[self.get_full_name()] = subdelta
@ -1727,12 +1726,14 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
else self.get_value(getattr(self, prop_name)) else self.get_value(getattr(self, prop_name))
) )
for prop_name, cv in self.computed_vars.items() for prop_name, cv in self.computed_vars.items()
if not cv._backend
} }
elif include_computed: elif include_computed:
computed_vars = { computed_vars = {
# Include the computed vars. # Include the computed vars.
prop_name: self.get_value(getattr(self, prop_name)) prop_name: self.get_value(getattr(self, prop_name))
for prop_name in self.computed_vars for prop_name, cv in self.computed_vars.items()
if not cv._backend
} }
else: else:
computed_vars = {} computed_vars = {}
@ -1745,6 +1746,7 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
for v in self.substates.values() for v in self.substates.values()
]: ]:
d.update(substate_d) d.update(substate_d)
return d return d
async def __aenter__(self) -> BaseState: async def __aenter__(self) -> BaseState:

View File

@ -6,7 +6,7 @@ import contextlib
import inspect import inspect
import sys import sys
import types import types
from functools import wraps from functools import cached_property, wraps
from typing import ( from typing import (
Any, Any,
Callable, Callable,
@ -410,7 +410,7 @@ def is_valid_var_type(type_: Type) -> bool:
return _issubclass(type_, StateVar) or serializers.has_serializer(type_) return _issubclass(type_, StateVar) or serializers.has_serializer(type_)
def is_backend_variable(name: str, cls: Type | None = None) -> bool: def is_backend_base_variable(name: str, cls: Type) -> bool:
"""Check if this variable name correspond to a backend variable. """Check if this variable name correspond to a backend variable.
Args: Args:
@ -429,31 +429,30 @@ def is_backend_variable(name: str, cls: Type | None = None) -> bool:
if name.startswith("__"): if name.startswith("__"):
return False return False
if cls is not None: if name.startswith(f"_{cls.__name__}__"):
if name.startswith(f"_{cls.__name__}__"): return False
return False
hints = get_type_hints(cls)
if name in hints:
hint = get_origin(hints[name])
if hint == ClassVar:
return False
if name in cls.inherited_backend_vars: hints = get_type_hints(cls)
if name in hints:
hint = get_origin(hints[name])
if hint == ClassVar:
return False return False
if name in cls.__dict__: if name in cls.inherited_backend_vars:
value = cls.__dict__[name] return False
if type(value) == classmethod:
return False if name in cls.__dict__:
if callable(value): value = cls.__dict__[name]
return False if type(value) == classmethod:
if isinstance(value, types.FunctionType): return False
return False if callable(value):
# enable after #3573 is merged return False
# from reflex.vars import ComputedVar from reflex.vars import ComputedVar
#
# if isinstance(value, ComputedVar): if isinstance(
# return False value, (types.FunctionType, property, cached_property, ComputedVar)
):
return False
return True return True

View File

@ -1944,6 +1944,9 @@ class ComputedVar(Var, property):
# Whether to track dependencies and cache computed values # Whether to track dependencies and cache computed values
_cache: bool = dataclasses.field(default=False) _cache: bool = dataclasses.field(default=False)
# Whether the computed var is a backend var
_backend: bool = dataclasses.field(default=False)
# The initial value of the computed var # The initial value of the computed var
_initial_value: Any | types.Unset = dataclasses.field(default=types.Unset()) _initial_value: Any | types.Unset = dataclasses.field(default=types.Unset())
@ -1964,6 +1967,7 @@ class ComputedVar(Var, property):
deps: Optional[List[Union[str, Var]]] = None, deps: Optional[List[Union[str, Var]]] = None,
auto_deps: bool = True, auto_deps: bool = True,
interval: Optional[Union[int, datetime.timedelta]] = None, interval: Optional[Union[int, datetime.timedelta]] = None,
backend: bool | None = None,
**kwargs, **kwargs,
): ):
"""Initialize a ComputedVar. """Initialize a ComputedVar.
@ -1975,11 +1979,16 @@ class ComputedVar(Var, property):
deps: Explicit var dependencies to track. deps: Explicit var dependencies to track.
auto_deps: Whether var dependencies should be auto-determined. auto_deps: Whether var dependencies should be auto-determined.
interval: Interval at which the computed var should be updated. interval: Interval at which the computed var should be updated.
backend: Whether the computed var is a backend var.
**kwargs: additional attributes to set on the instance **kwargs: additional attributes to set on the instance
Raises: Raises:
TypeError: If the computed var dependencies are not Var instances or var names. TypeError: If the computed var dependencies are not Var instances or var names.
""" """
if backend is None:
backend = fget.__name__.startswith("_")
self._backend = backend
self._initial_value = initial_value self._initial_value = initial_value
self._cache = cache self._cache = cache
if isinstance(interval, int): if isinstance(interval, int):
@ -2023,6 +2032,7 @@ class ComputedVar(Var, property):
deps=kwargs.get("deps", self._static_deps), deps=kwargs.get("deps", self._static_deps),
auto_deps=kwargs.get("auto_deps", self._auto_deps), auto_deps=kwargs.get("auto_deps", self._auto_deps),
interval=kwargs.get("interval", self._update_interval), interval=kwargs.get("interval", self._update_interval),
backend=kwargs.get("backend", self._backend),
_var_name=kwargs.get("_var_name", self._var_name), _var_name=kwargs.get("_var_name", self._var_name),
_var_type=kwargs.get("_var_type", self._var_type), _var_type=kwargs.get("_var_type", self._var_type),
_var_is_local=kwargs.get("_var_is_local", self._var_is_local), _var_is_local=kwargs.get("_var_is_local", self._var_is_local),
@ -2233,6 +2243,7 @@ def computed_var(
deps: Optional[List[Union[str, Var]]] = None, deps: Optional[List[Union[str, Var]]] = None,
auto_deps: bool = True, auto_deps: bool = True,
interval: Optional[Union[datetime.timedelta, int]] = None, interval: Optional[Union[datetime.timedelta, int]] = None,
backend: bool | None = None,
_deprecated_cached_var: bool = False, _deprecated_cached_var: bool = False,
**kwargs, **kwargs,
) -> ComputedVar | Callable[[Callable[[BaseState], Any]], ComputedVar]: ) -> ComputedVar | Callable[[Callable[[BaseState], Any]], ComputedVar]:
@ -2245,6 +2256,7 @@ def computed_var(
deps: Explicit var dependencies to track. deps: Explicit var dependencies to track.
auto_deps: Whether var dependencies should be auto-determined. auto_deps: Whether var dependencies should be auto-determined.
interval: Interval at which the computed var should be updated. interval: Interval at which the computed var should be updated.
backend: Whether the computed var is a backend var.
_deprecated_cached_var: Indicate usage of deprecated cached_var partial function. _deprecated_cached_var: Indicate usage of deprecated cached_var partial function.
**kwargs: additional attributes to set on the instance **kwargs: additional attributes to set on the instance
@ -2280,6 +2292,7 @@ def computed_var(
deps=deps, deps=deps,
auto_deps=auto_deps, auto_deps=auto_deps,
interval=interval, interval=interval,
backend=backend,
**kwargs, **kwargs,
) )

View File

@ -925,6 +925,15 @@ class InterdependentState(BaseState):
""" """
return self._v2 * 2 return self._v2 * 2
@rx.var(cache=True, backend=True)
def v2x2_backend(self) -> int:
"""Depends on backend var _v2.
Returns:
backend var _v2 multiplied by 2
"""
return self._v2 * 2
@rx.var(cache=True) @rx.var(cache=True)
def v1x2x2(self) -> int: def v1x2x2(self) -> int:
"""Depends on ComputedVar v1x2. """Depends on ComputedVar v1x2.
@ -1002,11 +1011,11 @@ def test_dirty_computed_var_from_backend_var(
Args: Args:
interdependent_state: A state with varying Var dependencies. interdependent_state: A state with varying Var dependencies.
""" """
assert InterdependentState._v3._backend is True
interdependent_state._v2 = 2 interdependent_state._v2 = 2
assert interdependent_state.get_delta() == { assert interdependent_state.get_delta() == {
interdependent_state.get_full_name(): {"v2x2": 4, "v3x2": 4}, interdependent_state.get_full_name(): {"v2x2": 4, "v3x2": 4},
} }
assert "_v3" in InterdependentState.backend_vars
def test_per_state_backend_var(interdependent_state: InterdependentState) -> None: def test_per_state_backend_var(interdependent_state: InterdependentState) -> None:
@ -1295,6 +1304,15 @@ def test_computed_var_dependencies():
""" """
return self.v return self.v
@rx.var(cache=True, backend=True)
def comp_v_backend(self) -> int:
"""Direct access backend var.
Returns:
The value of self.v.
"""
return self.v
@rx.var(cache=True) @rx.var(cache=True)
def comp_v_via_property(self) -> int: def comp_v_via_property(self) -> int:
"""Access v via property. """Access v via property.
@ -1345,7 +1363,11 @@ def test_computed_var_dependencies():
return [z in self._z for z in range(5)] return [z in self._z for z in range(5)]
cs = ComputedState() cs = ComputedState()
assert cs._computed_var_dependencies["v"] == {"comp_v", "comp_v_via_property"} assert cs._computed_var_dependencies["v"] == {
"comp_v",
"comp_v_backend",
"comp_v_via_property",
}
assert cs._computed_var_dependencies["w"] == {"comp_w"} assert cs._computed_var_dependencies["w"] == {"comp_w"}
assert cs._computed_var_dependencies["x"] == {"comp_x"} assert cs._computed_var_dependencies["x"] == {"comp_x"}
assert cs._computed_var_dependencies["y"] == {"comp_y"} assert cs._computed_var_dependencies["y"] == {"comp_y"}

View File

@ -1,7 +1,8 @@
import os import os
import typing import typing
from functools import cached_property
from pathlib import Path from pathlib import Path
from typing import Any, ClassVar, List, Literal, Union from typing import Any, ClassVar, List, Literal, Type, Union
import pytest import pytest
import typer import typer
@ -161,6 +162,14 @@ def test_backend_variable_cls():
def _hidden_method(self): def _hidden_method(self):
pass pass
@property
def _hidden_property(self):
pass
@cached_property
def _cached_hidden_property(self):
pass
return TestBackendVariable return TestBackendVariable
@ -173,10 +182,14 @@ def test_backend_variable_cls():
("_hidden", True), ("_hidden", True),
("not_hidden", False), ("not_hidden", False),
("__dundermethod__", False), ("__dundermethod__", False),
("_hidden_property", False),
("_cached_hidden_property", False),
], ],
) )
def test_is_backend_variable(test_backend_variable_cls, input, output): def test_is_backend_base_variable(
assert types.is_backend_variable(input, test_backend_variable_cls) == output test_backend_variable_cls: Type[BaseState], input: str, output: bool
):
assert types.is_backend_base_variable(input, test_backend_variable_cls) == output
@pytest.mark.parametrize( @pytest.mark.parametrize(