From b7651e214bf09e607f589d8c5f0d96bdb4f7dc63 Mon Sep 17 00:00:00 2001 From: benedikt-bartscher <31854409+benedikt-bartscher@users.noreply.github.com> Date: Sat, 29 Jun 2024 02:01:07 +0200 Subject: [PATCH] add computed backend vars (#3573) * add computed backend vars * finish computed backend vars, add tests * fix token for AppHarness with redis state manager * fix timing issues * add unit tests for computed backend vars * automagically mark cvs with _ prefix as backend var * fully migrate backend computed vars * rename is_backend_variable to is_backend_base_variable * add integration test for implicit backend cv, adjust comments * replace expensive backend var check at runtime * keep stuff together * simplify backend var check method, consistent naming, improve test typing * fix: do not convert properties to cvs * add test for property * fix cached_properties with _ prefix in state cls --- integration/test_computed_vars.py | 37 +++++++++++++++++++++++- reflex/state.py | 42 ++++++++++++++------------- reflex/utils/types.py | 47 +++++++++++++++---------------- reflex/vars.py | 13 +++++++++ tests/test_state.py | 26 +++++++++++++++-- tests/utils/test_utils.py | 19 +++++++++++-- 6 files changed, 134 insertions(+), 50 deletions(-) diff --git a/integration/test_computed_vars.py b/integration/test_computed_vars.py index 7ccf01813..1b86f0ac7 100644 --- a/integration/test_computed_vars.py +++ b/integration/test_computed_vars.py @@ -26,6 +26,16 @@ def ComputedVars(): def count1(self) -> int: return self.count + # cached backend var with dep on count + @rx.var(cache=True, interval=15, backend=True) + def count1_backend(self) -> int: + return self.count + + # same as above but implicit backend with `_` prefix + @rx.var(cache=True, interval=15) + def _count1_backend(self) -> int: + return self.count + # explicit disabled auto_deps @rx.var(interval=15, cache=True, auto_deps=False) def count3(self) -> int: @@ -70,6 +80,10 @@ def ComputedVars(): rx.text(State.count, id="count"), rx.text("count1:"), rx.text(State.count1, id="count1"), + rx.text("count1_backend:"), + rx.text(State.count1_backend, id="count1_backend"), + rx.text("_count1_backend:"), + rx.text(State._count1_backend, id="_count1_backend"), rx.text("count3:"), rx.text(State.count3, id="count3"), rx.text("depends_on_count:"), @@ -154,7 +168,8 @@ def token(computed_vars: AppHarness, driver: WebDriver) -> str: return token -def test_computed_vars( +@pytest.mark.asyncio +async def test_computed_vars( computed_vars: AppHarness, driver: WebDriver, token: str, @@ -168,6 +183,20 @@ def test_computed_vars( """ assert computed_vars.app_instance is not None + token = f"{token}_state.state" + state = (await computed_vars.get_state(token)).substates["state"] + assert state is not None + assert state.count1_backend == 0 + assert state._count1_backend == 0 + + # test that backend var is not rendered + count1_backend = driver.find_element(By.ID, "count1_backend") + assert count1_backend + assert count1_backend.text == "" + _count1_backend = driver.find_element(By.ID, "_count1_backend") + assert _count1_backend + assert _count1_backend.text == "" + count = driver.find_element(By.ID, "count") assert count assert count.text == "0" @@ -207,6 +236,12 @@ def test_computed_vars( computed_vars.poll_for_content(depends_on_count, timeout=2, exp_not_equal="0") == "1" ) + state = (await computed_vars.get_state(token)).substates["state"] + assert state is not None + assert state.count1_backend == 1 + assert count1_backend.text == "" + assert state._count1_backend == 1 + assert _count1_backend.text == "" mark_dirty.click() with pytest.raises(TimeoutError): diff --git a/reflex/state.py b/reflex/state.py index fd5c8d3ca..fba7355c6 100644 --- a/reflex/state.py +++ b/reflex/state.py @@ -305,10 +305,10 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow): # Vars inherited by the parent state. inherited_vars: ClassVar[Dict[str, Var]] = {} - # Backend vars that are never sent to the client. + # Backend base vars that are never sent to the client. backend_vars: ClassVar[Dict[str, Any]] = {} - # Backend vars inherited + # Backend base vars inherited inherited_backend_vars: ClassVar[Dict[str, Any]] = {} # The event handlers. @@ -344,7 +344,7 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow): # The routing path that triggered the state router_data: Dict[str, Any] = {} - # Per-instance copy of backend variable values + # Per-instance copy of backend base variable values _backend_vars: Dict[str, Any] = {} # The router data for the current page @@ -492,21 +492,12 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow): new_backend_vars = { name: value for name, value in cls.__dict__.items() - if types.is_backend_variable(name, cls) - } - - # Get backend computed vars - backend_computed_vars = { - v._var_name: v._var_set_state(cls) - for v in computed_vars - if types.is_backend_variable(v._var_name, cls) - and v._var_name not in cls.inherited_backend_vars + if types.is_backend_base_variable(name, cls) } cls.backend_vars = { **cls.inherited_backend_vars, **new_backend_vars, - **backend_computed_vars, } # Set the base and computed vars. @@ -548,7 +539,7 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow): cls.computed_vars[newcv._var_name] = newcv cls.vars[newcv._var_name] = newcv continue - if types.is_backend_variable(name, mixin): + if types.is_backend_base_variable(name, mixin): cls.backend_vars[name] = copy.deepcopy(value) continue if events.get(name) is not None: @@ -1087,7 +1078,7 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow): setattr(self.parent_state, name, value) return - if types.is_backend_variable(name, type(self)): + if name in self.backend_vars: self._backend_vars.__setitem__(name, value) self.dirty_vars.add(name) self._mark_dirty() @@ -1538,11 +1529,14 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow): if self.computed_vars[cvar].needs_update(instance=self) ) - def _dirty_computed_vars(self, from_vars: set[str] | None = None) -> set[str]: + def _dirty_computed_vars( + self, from_vars: set[str] | None = None, include_backend: bool = True + ) -> set[str]: """Determine ComputedVars that need to be recalculated based on the given vars. Args: from_vars: find ComputedVar that depend on this set of vars. If unspecified, will use the dirty_vars. + include_backend: whether to include backend vars in the calculation. Returns: Set of computed vars to include in the delta. @@ -1551,6 +1545,7 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow): cvar for dirty_var in from_vars or self.dirty_vars for cvar in self._computed_var_dependencies[dirty_var] + if include_backend or not self.computed_vars[cvar]._backend ) @classmethod @@ -1586,19 +1581,23 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow): self.dirty_vars.update(self._always_dirty_computed_vars) self._mark_dirty() + frontend_computed_vars: set[str] = { + name for name, cv in self.computed_vars.items() if not cv._backend + } + # Return the dirty vars for this instance, any cached/dependent computed vars, # and always dirty computed vars (cache=False) delta_vars = ( self.dirty_vars.intersection(self.base_vars) - .union(self.dirty_vars.intersection(self.computed_vars)) - .union(self._dirty_computed_vars()) + .union(self.dirty_vars.intersection(frontend_computed_vars)) + .union(self._dirty_computed_vars(include_backend=False)) .union(self._always_dirty_computed_vars) ) subdelta = { prop: getattr(self, prop) for prop in delta_vars - if not types.is_backend_variable(prop, type(self)) + if not types.is_backend_base_variable(prop, type(self)) } if len(subdelta) > 0: delta[self.get_full_name()] = subdelta @@ -1727,12 +1726,14 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow): else self.get_value(getattr(self, prop_name)) ) for prop_name, cv in self.computed_vars.items() + if not cv._backend } elif include_computed: computed_vars = { # Include the computed vars. prop_name: self.get_value(getattr(self, prop_name)) - for prop_name in self.computed_vars + for prop_name, cv in self.computed_vars.items() + if not cv._backend } else: computed_vars = {} @@ -1745,6 +1746,7 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow): for v in self.substates.values() ]: d.update(substate_d) + return d async def __aenter__(self) -> BaseState: diff --git a/reflex/utils/types.py b/reflex/utils/types.py index ecd842e3d..50e18acbf 100644 --- a/reflex/utils/types.py +++ b/reflex/utils/types.py @@ -6,7 +6,7 @@ import contextlib import inspect import sys import types -from functools import wraps +from functools import cached_property, wraps from typing import ( Any, Callable, @@ -410,7 +410,7 @@ def is_valid_var_type(type_: Type) -> bool: return _issubclass(type_, StateVar) or serializers.has_serializer(type_) -def is_backend_variable(name: str, cls: Type | None = None) -> bool: +def is_backend_base_variable(name: str, cls: Type) -> bool: """Check if this variable name correspond to a backend variable. Args: @@ -429,31 +429,30 @@ def is_backend_variable(name: str, cls: Type | None = None) -> bool: if name.startswith("__"): return False - if cls is not None: - if name.startswith(f"_{cls.__name__}__"): - return False - hints = get_type_hints(cls) - if name in hints: - hint = get_origin(hints[name]) - if hint == ClassVar: - return False + if name.startswith(f"_{cls.__name__}__"): + return False - if name in cls.inherited_backend_vars: + hints = get_type_hints(cls) + if name in hints: + hint = get_origin(hints[name]) + if hint == ClassVar: return False - if name in cls.__dict__: - value = cls.__dict__[name] - if type(value) == classmethod: - return False - if callable(value): - return False - if isinstance(value, types.FunctionType): - return False - # enable after #3573 is merged - # from reflex.vars import ComputedVar - # - # if isinstance(value, ComputedVar): - # return False + if name in cls.inherited_backend_vars: + return False + + if name in cls.__dict__: + value = cls.__dict__[name] + if type(value) == classmethod: + return False + if callable(value): + return False + from reflex.vars import ComputedVar + + if isinstance( + value, (types.FunctionType, property, cached_property, ComputedVar) + ): + return False return True diff --git a/reflex/vars.py b/reflex/vars.py index 7f0bbaa8e..afe3e60c3 100644 --- a/reflex/vars.py +++ b/reflex/vars.py @@ -1944,6 +1944,9 @@ class ComputedVar(Var, property): # Whether to track dependencies and cache computed values _cache: bool = dataclasses.field(default=False) + # Whether the computed var is a backend var + _backend: bool = dataclasses.field(default=False) + # The initial value of the computed var _initial_value: Any | types.Unset = dataclasses.field(default=types.Unset()) @@ -1964,6 +1967,7 @@ class ComputedVar(Var, property): deps: Optional[List[Union[str, Var]]] = None, auto_deps: bool = True, interval: Optional[Union[int, datetime.timedelta]] = None, + backend: bool | None = None, **kwargs, ): """Initialize a ComputedVar. @@ -1975,11 +1979,16 @@ class ComputedVar(Var, property): deps: Explicit var dependencies to track. auto_deps: Whether var dependencies should be auto-determined. interval: Interval at which the computed var should be updated. + backend: Whether the computed var is a backend var. **kwargs: additional attributes to set on the instance Raises: TypeError: If the computed var dependencies are not Var instances or var names. """ + if backend is None: + backend = fget.__name__.startswith("_") + self._backend = backend + self._initial_value = initial_value self._cache = cache if isinstance(interval, int): @@ -2023,6 +2032,7 @@ class ComputedVar(Var, property): deps=kwargs.get("deps", self._static_deps), auto_deps=kwargs.get("auto_deps", self._auto_deps), interval=kwargs.get("interval", self._update_interval), + backend=kwargs.get("backend", self._backend), _var_name=kwargs.get("_var_name", self._var_name), _var_type=kwargs.get("_var_type", self._var_type), _var_is_local=kwargs.get("_var_is_local", self._var_is_local), @@ -2233,6 +2243,7 @@ def computed_var( deps: Optional[List[Union[str, Var]]] = None, auto_deps: bool = True, interval: Optional[Union[datetime.timedelta, int]] = None, + backend: bool | None = None, _deprecated_cached_var: bool = False, **kwargs, ) -> ComputedVar | Callable[[Callable[[BaseState], Any]], ComputedVar]: @@ -2245,6 +2256,7 @@ def computed_var( deps: Explicit var dependencies to track. auto_deps: Whether var dependencies should be auto-determined. interval: Interval at which the computed var should be updated. + backend: Whether the computed var is a backend var. _deprecated_cached_var: Indicate usage of deprecated cached_var partial function. **kwargs: additional attributes to set on the instance @@ -2280,6 +2292,7 @@ def computed_var( deps=deps, auto_deps=auto_deps, interval=interval, + backend=backend, **kwargs, ) diff --git a/tests/test_state.py b/tests/test_state.py index 2ac080931..dbdfc5836 100644 --- a/tests/test_state.py +++ b/tests/test_state.py @@ -925,6 +925,15 @@ class InterdependentState(BaseState): """ return self._v2 * 2 + @rx.var(cache=True, backend=True) + def v2x2_backend(self) -> int: + """Depends on backend var _v2. + + Returns: + backend var _v2 multiplied by 2 + """ + return self._v2 * 2 + @rx.var(cache=True) def v1x2x2(self) -> int: """Depends on ComputedVar v1x2. @@ -1002,11 +1011,11 @@ def test_dirty_computed_var_from_backend_var( Args: interdependent_state: A state with varying Var dependencies. """ + assert InterdependentState._v3._backend is True interdependent_state._v2 = 2 assert interdependent_state.get_delta() == { interdependent_state.get_full_name(): {"v2x2": 4, "v3x2": 4}, } - assert "_v3" in InterdependentState.backend_vars def test_per_state_backend_var(interdependent_state: InterdependentState) -> None: @@ -1295,6 +1304,15 @@ def test_computed_var_dependencies(): """ return self.v + @rx.var(cache=True, backend=True) + def comp_v_backend(self) -> int: + """Direct access backend var. + + Returns: + The value of self.v. + """ + return self.v + @rx.var(cache=True) def comp_v_via_property(self) -> int: """Access v via property. @@ -1345,7 +1363,11 @@ def test_computed_var_dependencies(): return [z in self._z for z in range(5)] cs = ComputedState() - assert cs._computed_var_dependencies["v"] == {"comp_v", "comp_v_via_property"} + assert cs._computed_var_dependencies["v"] == { + "comp_v", + "comp_v_backend", + "comp_v_via_property", + } assert cs._computed_var_dependencies["w"] == {"comp_w"} assert cs._computed_var_dependencies["x"] == {"comp_x"} assert cs._computed_var_dependencies["y"] == {"comp_y"} diff --git a/tests/utils/test_utils.py b/tests/utils/test_utils.py index 3cffb267f..e905ff587 100644 --- a/tests/utils/test_utils.py +++ b/tests/utils/test_utils.py @@ -1,7 +1,8 @@ import os import typing +from functools import cached_property from pathlib import Path -from typing import Any, ClassVar, List, Literal, Union +from typing import Any, ClassVar, List, Literal, Type, Union import pytest import typer @@ -161,6 +162,14 @@ def test_backend_variable_cls(): def _hidden_method(self): pass + @property + def _hidden_property(self): + pass + + @cached_property + def _cached_hidden_property(self): + pass + return TestBackendVariable @@ -173,10 +182,14 @@ def test_backend_variable_cls(): ("_hidden", True), ("not_hidden", False), ("__dundermethod__", False), + ("_hidden_property", False), + ("_cached_hidden_property", False), ], ) -def test_is_backend_variable(test_backend_variable_cls, input, output): - assert types.is_backend_variable(input, test_backend_variable_cls) == output +def test_is_backend_base_variable( + test_backend_variable_cls: Type[BaseState], input: str, output: bool +): + assert types.is_backend_base_variable(input, test_backend_variable_cls) == output @pytest.mark.parametrize(