diff --git a/reflex/compiler/utils.py b/reflex/compiler/utils.py index e408d012a..ac53e8b02 100644 --- a/reflex/compiler/utils.py +++ b/reflex/compiler/utils.py @@ -138,7 +138,7 @@ def compile_state(state: Type[BaseState]) -> dict: A dictionary of the compiled state. """ try: - initial_state = state().dict() + initial_state = state().dict(initial=True) except Exception as e: console.warn( f"Failed to compile initial state with computed vars, excluding them: {e}" diff --git a/reflex/state.py b/reflex/state.py index f6bf2161b..86ea96c56 100644 --- a/reflex/state.py +++ b/reflex/state.py @@ -46,10 +46,10 @@ from reflex.event import ( from reflex.utils import console, format, prerequisites, types from reflex.utils.exceptions import ImmutableStateError, LockExpiredError from reflex.utils.serializers import SerializedType, serialize, serializer -from reflex.vars import BaseVar, ComputedVar, Var +from reflex.vars import BaseVar, ComputedVar, Var, computed_var Delta = Dict[str, Any] -var = ComputedVar +var = computed_var class HeaderData(Base): @@ -1328,11 +1328,14 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow): return super().get_value(key.__wrapped__) return super().get_value(key) - def dict(self, include_computed: bool = True, **kwargs) -> dict[str, Any]: + def dict( + self, include_computed: bool = True, initial: bool = False, **kwargs + ) -> dict[str, Any]: """Convert the object to a dictionary. Args: include_computed: Whether to include computed vars. + initial: Whether to get the initial value of computed vars. **kwargs: Kwargs to pass to the pydantic dict method. Returns: @@ -1348,21 +1351,29 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow): prop_name: self.get_value(getattr(self, prop_name)) for prop_name in self.base_vars } - computed_vars = ( - { + if initial: + computed_vars = { + # Include initial computed vars. + prop_name: cv._initial_value + if isinstance(cv, ComputedVar) + and not isinstance(cv._initial_value, types.Unset) + else self.get_value(getattr(self, prop_name)) + for prop_name, cv in self.computed_vars.items() + } + elif include_computed: + computed_vars = { # Include the computed vars. prop_name: self.get_value(getattr(self, prop_name)) for prop_name in self.computed_vars } - if include_computed - else {} - ) + else: + computed_vars = {} variables = {**base_vars, **computed_vars} d = { self.get_full_name(): {k: variables[k] for k in sorted(variables)}, } for substate_d in [ - v.dict(include_computed=include_computed, **kwargs) + v.dict(include_computed=include_computed, initial=initial, **kwargs) for v in self.substates.values() ]: d.update(substate_d) diff --git a/reflex/utils/types.py b/reflex/utils/types.py index 4187d4434..125dc2615 100644 --- a/reflex/utils/types.py +++ b/reflex/utils/types.py @@ -43,6 +43,29 @@ StateIterVar = Union[list, set, tuple] ArgsSpec = Callable +class Unset: + """A class to represent an unset value. + + This is used to differentiate between a value that is not set and a value that is set to None. + """ + + def __repr__(self) -> str: + """Return the string representation of the class. + + Returns: + The string representation of the class. + """ + return "Unset" + + def __bool__(self) -> bool: + """Return False when the class is used in a boolean context. + + Returns: + False + """ + return False + + def is_generic_alias(cls: GenericType) -> bool: """Check whether the class is a generic alias. diff --git a/reflex/vars.py b/reflex/vars.py index 93f8363b9..ed01e8d33 100644 --- a/reflex/vars.py +++ b/reflex/vars.py @@ -4,6 +4,7 @@ from __future__ import annotations import contextlib import dataclasses import dis +import functools import inspect import json import random @@ -1802,24 +1803,26 @@ class ComputedVar(Var, property): # Whether to track dependencies and cache computed values _cache: bool = dataclasses.field(default=False) + _initial_value: Any | types.Unset = dataclasses.field(default_factory=types.Unset) + def __init__( self, fget: Callable[[BaseState], Any], - fset: Callable[[BaseState, Any], None] | None = None, - fdel: Callable[[BaseState], Any] | None = None, - doc: str | None = None, + initial_value: Any | types.Unset = types.Unset(), + cache: bool = False, **kwargs, ): """Initialize a ComputedVar. Args: fget: The getter function. - fset: The setter function. - fdel: The deleter function. - doc: The docstring. + initial_value: The initial value of the computed var. + cache: Whether to cache the computed value. **kwargs: additional attributes to set on the instance """ - property.__init__(self, fget, fset, fdel, doc) + self._initial_value = initial_value + self._cache = cache + property.__init__(self, fget) kwargs["_var_name"] = kwargs.pop("_var_name", fget.__name__) kwargs["_var_type"] = kwargs.pop("_var_type", self._determine_var_type()) BaseVar.__init__(self, **kwargs) # type: ignore @@ -1960,21 +1963,39 @@ class ComputedVar(Var, property): return Any -def cached_var(fget: Callable[[Any], Any]) -> ComputedVar: - """A field with computed getter that tracks other state dependencies. - - The cached_var will only be recalculated when other state vars that it - depends on are modified. +def computed_var( + fget: Callable[[BaseState], Any] | None = None, + initial_value: Any | None = None, + cache: bool = False, + **kwargs, +) -> ComputedVar | Callable[[Callable[[BaseState], Any]], ComputedVar]: + """A ComputedVar decorator with or without kwargs. Args: - fget: the function that calculates the variable value. + fget: The getter function. + initial_value: The initial value of the computed var. + cache: Whether to cache the computed value. + **kwargs: additional attributes to set on the instance Returns: - ComputedVar that is recomputed when dependencies change. + A ComputedVar instance. """ - cvar = ComputedVar(fget=fget) - cvar._cache = True - return cvar + if fget is not None: + return ComputedVar(fget=fget, cache=cache) + + def wrapper(fget): + return ComputedVar( + fget=fget, + initial_value=initial_value, + cache=cache, + **kwargs, + ) + + return wrapper + + +# Partial function of computed_var with cache=True +cached_var = functools.partial(computed_var, cache=True) class CallableVar(BaseVar): diff --git a/reflex/vars.pyi b/reflex/vars.pyi index 048c99f40..959a54f74 100644 --- a/reflex/vars.pyi +++ b/reflex/vars.pyi @@ -144,14 +144,19 @@ class ComputedVar(Var): def __init__( self, fget: Callable[[BaseState], Any], - fset: Callable[[BaseState, Any], None] | None = None, - fdel: Callable[[BaseState], Any] | None = None, - doc: str | None = None, **kwargs, ) -> None: ... @overload def __init__(self, func) -> None: ... +@overload +def computed_var( + fget: Callable[[BaseState], Any] | None = None, + initial_value: Any | None = None, + **kwargs, +) -> Callable[[Callable[[Any], Any]], ComputedVar]: ... +@overload +def computed_var(fget: Callable[[Any], Any]) -> ComputedVar: ... def cached_var(fget: Callable[[Any], Any]) -> ComputedVar: ... class CallableVar(BaseVar): diff --git a/tests/test_var.py b/tests/test_var.py index a4797532c..797d48ed7 100644 --- a/tests/test_var.py +++ b/tests/test_var.py @@ -9,8 +9,8 @@ from reflex.base import Base from reflex.state import BaseState from reflex.vars import ( BaseVar, - ComputedVar, Var, + computed_var, ) test_vars = [ @@ -46,7 +46,7 @@ def ParentState(TestObj): foo: int bar: int - @ComputedVar + @computed_var def var_without_annotation(self): return TestObj @@ -56,7 +56,7 @@ def ParentState(TestObj): @pytest.fixture def ChildState(ParentState, TestObj): class ChildState(ParentState): - @ComputedVar + @computed_var def var_without_annotation(self): return TestObj @@ -66,7 +66,7 @@ def ChildState(ParentState, TestObj): @pytest.fixture def GrandChildState(ChildState, TestObj): class GrandChildState(ChildState): - @ComputedVar + @computed_var def var_without_annotation(self): return TestObj @@ -76,7 +76,7 @@ def GrandChildState(ChildState, TestObj): @pytest.fixture def StateWithAnyVar(TestObj): class StateWithAnyVar(BaseState): - @ComputedVar + @computed_var def var_without_annotation(self) -> typing.Any: return TestObj @@ -86,7 +86,7 @@ def StateWithAnyVar(TestObj): @pytest.fixture def StateWithCorrectVarAnnotation(): class StateWithCorrectVarAnnotation(BaseState): - @ComputedVar + @computed_var def var_with_annotation(self) -> str: return "Correct annotation" @@ -96,13 +96,53 @@ def StateWithCorrectVarAnnotation(): @pytest.fixture def StateWithWrongVarAnnotation(TestObj): class StateWithWrongVarAnnotation(BaseState): - @ComputedVar + @computed_var def var_with_annotation(self) -> str: return TestObj return StateWithWrongVarAnnotation +@pytest.fixture +def StateWithInitialComputedVar(): + class StateWithInitialComputedVar(BaseState): + @computed_var(initial_value="Initial value") + def var_with_initial_value(self) -> str: + return "Runtime value" + + return StateWithInitialComputedVar + + +@pytest.fixture +def ChildWithInitialComputedVar(StateWithInitialComputedVar): + class ChildWithInitialComputedVar(StateWithInitialComputedVar): + @computed_var(initial_value="Initial value") + def var_with_initial_value_child(self) -> str: + return "Runtime value" + + return ChildWithInitialComputedVar + + +@pytest.fixture +def StateWithRuntimeOnlyVar(): + class StateWithRuntimeOnlyVar(BaseState): + @computed_var(initial_value=None) + def var_raises_at_runtime(self) -> str: + raise ValueError("So nicht, mein Freund") + + return StateWithRuntimeOnlyVar + + +@pytest.fixture +def ChildWithRuntimeOnlyVar(StateWithRuntimeOnlyVar): + class ChildWithRuntimeOnlyVar(StateWithRuntimeOnlyVar): + @computed_var(initial_value="Initial value") + def var_raises_at_runtime_child(self) -> str: + raise ValueError("So nicht, mein Freund") + + return ChildWithRuntimeOnlyVar + + @pytest.mark.parametrize( "prop,expected", zip( @@ -731,6 +771,65 @@ def test_computed_var_with_annotation_error(request, fixture, full_name): ) +@pytest.mark.parametrize( + "fixture,var_name,expected_initial,expected_runtime,raises_at_runtime", + [ + ( + "StateWithInitialComputedVar", + "var_with_initial_value", + "Initial value", + "Runtime value", + False, + ), + ( + "ChildWithInitialComputedVar", + "var_with_initial_value_child", + "Initial value", + "Runtime value", + False, + ), + ( + "StateWithRuntimeOnlyVar", + "var_raises_at_runtime", + None, + None, + True, + ), + ( + "ChildWithRuntimeOnlyVar", + "var_raises_at_runtime_child", + "Initial value", + None, + True, + ), + ], +) +def test_state_with_initial_computed_var( + request, fixture, var_name, expected_initial, expected_runtime, raises_at_runtime +): + """Test that the initial and runtime values of a computed var are correct. + + Args: + request: Fixture Request. + fixture: The state fixture. + var_name: The name of the computed var. + expected_initial: The expected initial value of the computed var. + expected_runtime: The expected runtime value of the computed var. + raises_at_runtime: Whether the computed var is runtime only. + """ + state = request.getfixturevalue(fixture)() + state_name = state.get_full_name() + initial_dict = state.dict(initial=True)[state_name] + assert initial_dict[var_name] == expected_initial + + if raises_at_runtime: + with pytest.raises(ValueError): + state.dict()[state_name][var_name] + else: + runtime_dict = state.dict()[state_name] + assert runtime_dict[var_name] == expected_runtime + + @pytest.mark.parametrize( "out, expected", [