Initial values for computed vars (#2670)

* initial values for computed vars draft

* add tests, add computed_var overloads

* fix darglint

* pass initial to substates when calling dict

* add tests for for child states

* format black

* allow None as initial value

* rename runtime_only to raises_at_runtime

* cleanup unused arguments of ComputedVars

* refactor cached_var to be partial of computed_var
This commit is contained in:
benedikt-bartscher 2024-02-24 22:45:07 +01:00 committed by GitHub
parent 82e3be76cf
commit 93f402c773
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 196 additions and 37 deletions

View File

@ -138,7 +138,7 @@ def compile_state(state: Type[BaseState]) -> dict:
A dictionary of the compiled state. A dictionary of the compiled state.
""" """
try: try:
initial_state = state().dict() initial_state = state().dict(initial=True)
except Exception as e: except Exception as e:
console.warn( console.warn(
f"Failed to compile initial state with computed vars, excluding them: {e}" f"Failed to compile initial state with computed vars, excluding them: {e}"

View File

@ -46,10 +46,10 @@ from reflex.event import (
from reflex.utils import console, format, prerequisites, types from reflex.utils import console, format, prerequisites, types
from reflex.utils.exceptions import ImmutableStateError, LockExpiredError from reflex.utils.exceptions import ImmutableStateError, LockExpiredError
from reflex.utils.serializers import SerializedType, serialize, serializer from reflex.utils.serializers import SerializedType, serialize, serializer
from reflex.vars import BaseVar, ComputedVar, Var from reflex.vars import BaseVar, ComputedVar, Var, computed_var
Delta = Dict[str, Any] Delta = Dict[str, Any]
var = ComputedVar var = computed_var
class HeaderData(Base): class HeaderData(Base):
@ -1328,11 +1328,14 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
return super().get_value(key.__wrapped__) return super().get_value(key.__wrapped__)
return super().get_value(key) return super().get_value(key)
def dict(self, include_computed: bool = True, **kwargs) -> dict[str, Any]: def dict(
self, include_computed: bool = True, initial: bool = False, **kwargs
) -> dict[str, Any]:
"""Convert the object to a dictionary. """Convert the object to a dictionary.
Args: Args:
include_computed: Whether to include computed vars. include_computed: Whether to include computed vars.
initial: Whether to get the initial value of computed vars.
**kwargs: Kwargs to pass to the pydantic dict method. **kwargs: Kwargs to pass to the pydantic dict method.
Returns: Returns:
@ -1348,21 +1351,29 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
prop_name: self.get_value(getattr(self, prop_name)) prop_name: self.get_value(getattr(self, prop_name))
for prop_name in self.base_vars for prop_name in self.base_vars
} }
computed_vars = ( if initial:
{ computed_vars = {
# Include initial computed vars.
prop_name: cv._initial_value
if isinstance(cv, ComputedVar)
and not isinstance(cv._initial_value, types.Unset)
else self.get_value(getattr(self, prop_name))
for prop_name, cv in self.computed_vars.items()
}
elif include_computed:
computed_vars = {
# Include the computed vars. # Include the computed vars.
prop_name: self.get_value(getattr(self, prop_name)) prop_name: self.get_value(getattr(self, prop_name))
for prop_name in self.computed_vars for prop_name in self.computed_vars
} }
if include_computed else:
else {} computed_vars = {}
)
variables = {**base_vars, **computed_vars} variables = {**base_vars, **computed_vars}
d = { d = {
self.get_full_name(): {k: variables[k] for k in sorted(variables)}, self.get_full_name(): {k: variables[k] for k in sorted(variables)},
} }
for substate_d in [ for substate_d in [
v.dict(include_computed=include_computed, **kwargs) v.dict(include_computed=include_computed, initial=initial, **kwargs)
for v in self.substates.values() for v in self.substates.values()
]: ]:
d.update(substate_d) d.update(substate_d)

View File

@ -43,6 +43,29 @@ StateIterVar = Union[list, set, tuple]
ArgsSpec = Callable ArgsSpec = Callable
class Unset:
"""A class to represent an unset value.
This is used to differentiate between a value that is not set and a value that is set to None.
"""
def __repr__(self) -> str:
"""Return the string representation of the class.
Returns:
The string representation of the class.
"""
return "Unset"
def __bool__(self) -> bool:
"""Return False when the class is used in a boolean context.
Returns:
False
"""
return False
def is_generic_alias(cls: GenericType) -> bool: def is_generic_alias(cls: GenericType) -> bool:
"""Check whether the class is a generic alias. """Check whether the class is a generic alias.

View File

@ -4,6 +4,7 @@ from __future__ import annotations
import contextlib import contextlib
import dataclasses import dataclasses
import dis import dis
import functools
import inspect import inspect
import json import json
import random import random
@ -1802,24 +1803,26 @@ class ComputedVar(Var, property):
# Whether to track dependencies and cache computed values # Whether to track dependencies and cache computed values
_cache: bool = dataclasses.field(default=False) _cache: bool = dataclasses.field(default=False)
_initial_value: Any | types.Unset = dataclasses.field(default_factory=types.Unset)
def __init__( def __init__(
self, self,
fget: Callable[[BaseState], Any], fget: Callable[[BaseState], Any],
fset: Callable[[BaseState, Any], None] | None = None, initial_value: Any | types.Unset = types.Unset(),
fdel: Callable[[BaseState], Any] | None = None, cache: bool = False,
doc: str | None = None,
**kwargs, **kwargs,
): ):
"""Initialize a ComputedVar. """Initialize a ComputedVar.
Args: Args:
fget: The getter function. fget: The getter function.
fset: The setter function. initial_value: The initial value of the computed var.
fdel: The deleter function. cache: Whether to cache the computed value.
doc: The docstring.
**kwargs: additional attributes to set on the instance **kwargs: additional attributes to set on the instance
""" """
property.__init__(self, fget, fset, fdel, doc) self._initial_value = initial_value
self._cache = cache
property.__init__(self, fget)
kwargs["_var_name"] = kwargs.pop("_var_name", fget.__name__) kwargs["_var_name"] = kwargs.pop("_var_name", fget.__name__)
kwargs["_var_type"] = kwargs.pop("_var_type", self._determine_var_type()) kwargs["_var_type"] = kwargs.pop("_var_type", self._determine_var_type())
BaseVar.__init__(self, **kwargs) # type: ignore BaseVar.__init__(self, **kwargs) # type: ignore
@ -1960,21 +1963,39 @@ class ComputedVar(Var, property):
return Any return Any
def cached_var(fget: Callable[[Any], Any]) -> ComputedVar: def computed_var(
"""A field with computed getter that tracks other state dependencies. fget: Callable[[BaseState], Any] | None = None,
initial_value: Any | None = None,
The cached_var will only be recalculated when other state vars that it cache: bool = False,
depends on are modified. **kwargs,
) -> ComputedVar | Callable[[Callable[[BaseState], Any]], ComputedVar]:
"""A ComputedVar decorator with or without kwargs.
Args: Args:
fget: the function that calculates the variable value. fget: The getter function.
initial_value: The initial value of the computed var.
cache: Whether to cache the computed value.
**kwargs: additional attributes to set on the instance
Returns: Returns:
ComputedVar that is recomputed when dependencies change. A ComputedVar instance.
""" """
cvar = ComputedVar(fget=fget) if fget is not None:
cvar._cache = True return ComputedVar(fget=fget, cache=cache)
return cvar
def wrapper(fget):
return ComputedVar(
fget=fget,
initial_value=initial_value,
cache=cache,
**kwargs,
)
return wrapper
# Partial function of computed_var with cache=True
cached_var = functools.partial(computed_var, cache=True)
class CallableVar(BaseVar): class CallableVar(BaseVar):

View File

@ -144,14 +144,19 @@ class ComputedVar(Var):
def __init__( def __init__(
self, self,
fget: Callable[[BaseState], Any], fget: Callable[[BaseState], Any],
fset: Callable[[BaseState, Any], None] | None = None,
fdel: Callable[[BaseState], Any] | None = None,
doc: str | None = None,
**kwargs, **kwargs,
) -> None: ... ) -> None: ...
@overload @overload
def __init__(self, func) -> None: ... def __init__(self, func) -> None: ...
@overload
def computed_var(
fget: Callable[[BaseState], Any] | None = None,
initial_value: Any | None = None,
**kwargs,
) -> Callable[[Callable[[Any], Any]], ComputedVar]: ...
@overload
def computed_var(fget: Callable[[Any], Any]) -> ComputedVar: ...
def cached_var(fget: Callable[[Any], Any]) -> ComputedVar: ... def cached_var(fget: Callable[[Any], Any]) -> ComputedVar: ...
class CallableVar(BaseVar): class CallableVar(BaseVar):

View File

@ -9,8 +9,8 @@ from reflex.base import Base
from reflex.state import BaseState from reflex.state import BaseState
from reflex.vars import ( from reflex.vars import (
BaseVar, BaseVar,
ComputedVar,
Var, Var,
computed_var,
) )
test_vars = [ test_vars = [
@ -46,7 +46,7 @@ def ParentState(TestObj):
foo: int foo: int
bar: int bar: int
@ComputedVar @computed_var
def var_without_annotation(self): def var_without_annotation(self):
return TestObj return TestObj
@ -56,7 +56,7 @@ def ParentState(TestObj):
@pytest.fixture @pytest.fixture
def ChildState(ParentState, TestObj): def ChildState(ParentState, TestObj):
class ChildState(ParentState): class ChildState(ParentState):
@ComputedVar @computed_var
def var_without_annotation(self): def var_without_annotation(self):
return TestObj return TestObj
@ -66,7 +66,7 @@ def ChildState(ParentState, TestObj):
@pytest.fixture @pytest.fixture
def GrandChildState(ChildState, TestObj): def GrandChildState(ChildState, TestObj):
class GrandChildState(ChildState): class GrandChildState(ChildState):
@ComputedVar @computed_var
def var_without_annotation(self): def var_without_annotation(self):
return TestObj return TestObj
@ -76,7 +76,7 @@ def GrandChildState(ChildState, TestObj):
@pytest.fixture @pytest.fixture
def StateWithAnyVar(TestObj): def StateWithAnyVar(TestObj):
class StateWithAnyVar(BaseState): class StateWithAnyVar(BaseState):
@ComputedVar @computed_var
def var_without_annotation(self) -> typing.Any: def var_without_annotation(self) -> typing.Any:
return TestObj return TestObj
@ -86,7 +86,7 @@ def StateWithAnyVar(TestObj):
@pytest.fixture @pytest.fixture
def StateWithCorrectVarAnnotation(): def StateWithCorrectVarAnnotation():
class StateWithCorrectVarAnnotation(BaseState): class StateWithCorrectVarAnnotation(BaseState):
@ComputedVar @computed_var
def var_with_annotation(self) -> str: def var_with_annotation(self) -> str:
return "Correct annotation" return "Correct annotation"
@ -96,13 +96,53 @@ def StateWithCorrectVarAnnotation():
@pytest.fixture @pytest.fixture
def StateWithWrongVarAnnotation(TestObj): def StateWithWrongVarAnnotation(TestObj):
class StateWithWrongVarAnnotation(BaseState): class StateWithWrongVarAnnotation(BaseState):
@ComputedVar @computed_var
def var_with_annotation(self) -> str: def var_with_annotation(self) -> str:
return TestObj return TestObj
return StateWithWrongVarAnnotation return StateWithWrongVarAnnotation
@pytest.fixture
def StateWithInitialComputedVar():
class StateWithInitialComputedVar(BaseState):
@computed_var(initial_value="Initial value")
def var_with_initial_value(self) -> str:
return "Runtime value"
return StateWithInitialComputedVar
@pytest.fixture
def ChildWithInitialComputedVar(StateWithInitialComputedVar):
class ChildWithInitialComputedVar(StateWithInitialComputedVar):
@computed_var(initial_value="Initial value")
def var_with_initial_value_child(self) -> str:
return "Runtime value"
return ChildWithInitialComputedVar
@pytest.fixture
def StateWithRuntimeOnlyVar():
class StateWithRuntimeOnlyVar(BaseState):
@computed_var(initial_value=None)
def var_raises_at_runtime(self) -> str:
raise ValueError("So nicht, mein Freund")
return StateWithRuntimeOnlyVar
@pytest.fixture
def ChildWithRuntimeOnlyVar(StateWithRuntimeOnlyVar):
class ChildWithRuntimeOnlyVar(StateWithRuntimeOnlyVar):
@computed_var(initial_value="Initial value")
def var_raises_at_runtime_child(self) -> str:
raise ValueError("So nicht, mein Freund")
return ChildWithRuntimeOnlyVar
@pytest.mark.parametrize( @pytest.mark.parametrize(
"prop,expected", "prop,expected",
zip( zip(
@ -731,6 +771,65 @@ def test_computed_var_with_annotation_error(request, fixture, full_name):
) )
@pytest.mark.parametrize(
"fixture,var_name,expected_initial,expected_runtime,raises_at_runtime",
[
(
"StateWithInitialComputedVar",
"var_with_initial_value",
"Initial value",
"Runtime value",
False,
),
(
"ChildWithInitialComputedVar",
"var_with_initial_value_child",
"Initial value",
"Runtime value",
False,
),
(
"StateWithRuntimeOnlyVar",
"var_raises_at_runtime",
None,
None,
True,
),
(
"ChildWithRuntimeOnlyVar",
"var_raises_at_runtime_child",
"Initial value",
None,
True,
),
],
)
def test_state_with_initial_computed_var(
request, fixture, var_name, expected_initial, expected_runtime, raises_at_runtime
):
"""Test that the initial and runtime values of a computed var are correct.
Args:
request: Fixture Request.
fixture: The state fixture.
var_name: The name of the computed var.
expected_initial: The expected initial value of the computed var.
expected_runtime: The expected runtime value of the computed var.
raises_at_runtime: Whether the computed var is runtime only.
"""
state = request.getfixturevalue(fixture)()
state_name = state.get_full_name()
initial_dict = state.dict(initial=True)[state_name]
assert initial_dict[var_name] == expected_initial
if raises_at_runtime:
with pytest.raises(ValueError):
state.dict()[state_name][var_name]
else:
runtime_dict = state.dict()[state_name]
assert runtime_dict[var_name] == expected_runtime
@pytest.mark.parametrize( @pytest.mark.parametrize(
"out, expected", "out, expected",
[ [