Initial values for computed vars (#2670)

* initial values for computed vars draft

* add tests, add computed_var overloads

* fix darglint

* pass initial to substates when calling dict

* add tests for for child states

* format black

* allow None as initial value

* rename runtime_only to raises_at_runtime

* cleanup unused arguments of ComputedVars

* refactor cached_var to be partial of computed_var
This commit is contained in:
benedikt-bartscher 2024-02-24 22:45:07 +01:00 committed by GitHub
parent 82e3be76cf
commit 93f402c773
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 196 additions and 37 deletions

View File

@ -138,7 +138,7 @@ def compile_state(state: Type[BaseState]) -> dict:
A dictionary of the compiled state.
"""
try:
initial_state = state().dict()
initial_state = state().dict(initial=True)
except Exception as e:
console.warn(
f"Failed to compile initial state with computed vars, excluding them: {e}"

View File

@ -46,10 +46,10 @@ from reflex.event import (
from reflex.utils import console, format, prerequisites, types
from reflex.utils.exceptions import ImmutableStateError, LockExpiredError
from reflex.utils.serializers import SerializedType, serialize, serializer
from reflex.vars import BaseVar, ComputedVar, Var
from reflex.vars import BaseVar, ComputedVar, Var, computed_var
Delta = Dict[str, Any]
var = ComputedVar
var = computed_var
class HeaderData(Base):
@ -1328,11 +1328,14 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
return super().get_value(key.__wrapped__)
return super().get_value(key)
def dict(self, include_computed: bool = True, **kwargs) -> dict[str, Any]:
def dict(
self, include_computed: bool = True, initial: bool = False, **kwargs
) -> dict[str, Any]:
"""Convert the object to a dictionary.
Args:
include_computed: Whether to include computed vars.
initial: Whether to get the initial value of computed vars.
**kwargs: Kwargs to pass to the pydantic dict method.
Returns:
@ -1348,21 +1351,29 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
prop_name: self.get_value(getattr(self, prop_name))
for prop_name in self.base_vars
}
computed_vars = (
{
if initial:
computed_vars = {
# Include initial computed vars.
prop_name: cv._initial_value
if isinstance(cv, ComputedVar)
and not isinstance(cv._initial_value, types.Unset)
else self.get_value(getattr(self, prop_name))
for prop_name, cv in self.computed_vars.items()
}
elif include_computed:
computed_vars = {
# Include the computed vars.
prop_name: self.get_value(getattr(self, prop_name))
for prop_name in self.computed_vars
}
if include_computed
else {}
)
else:
computed_vars = {}
variables = {**base_vars, **computed_vars}
d = {
self.get_full_name(): {k: variables[k] for k in sorted(variables)},
}
for substate_d in [
v.dict(include_computed=include_computed, **kwargs)
v.dict(include_computed=include_computed, initial=initial, **kwargs)
for v in self.substates.values()
]:
d.update(substate_d)

View File

@ -43,6 +43,29 @@ StateIterVar = Union[list, set, tuple]
ArgsSpec = Callable
class Unset:
"""A class to represent an unset value.
This is used to differentiate between a value that is not set and a value that is set to None.
"""
def __repr__(self) -> str:
"""Return the string representation of the class.
Returns:
The string representation of the class.
"""
return "Unset"
def __bool__(self) -> bool:
"""Return False when the class is used in a boolean context.
Returns:
False
"""
return False
def is_generic_alias(cls: GenericType) -> bool:
"""Check whether the class is a generic alias.

View File

@ -4,6 +4,7 @@ from __future__ import annotations
import contextlib
import dataclasses
import dis
import functools
import inspect
import json
import random
@ -1802,24 +1803,26 @@ class ComputedVar(Var, property):
# Whether to track dependencies and cache computed values
_cache: bool = dataclasses.field(default=False)
_initial_value: Any | types.Unset = dataclasses.field(default_factory=types.Unset)
def __init__(
self,
fget: Callable[[BaseState], Any],
fset: Callable[[BaseState, Any], None] | None = None,
fdel: Callable[[BaseState], Any] | None = None,
doc: str | None = None,
initial_value: Any | types.Unset = types.Unset(),
cache: bool = False,
**kwargs,
):
"""Initialize a ComputedVar.
Args:
fget: The getter function.
fset: The setter function.
fdel: The deleter function.
doc: The docstring.
initial_value: The initial value of the computed var.
cache: Whether to cache the computed value.
**kwargs: additional attributes to set on the instance
"""
property.__init__(self, fget, fset, fdel, doc)
self._initial_value = initial_value
self._cache = cache
property.__init__(self, fget)
kwargs["_var_name"] = kwargs.pop("_var_name", fget.__name__)
kwargs["_var_type"] = kwargs.pop("_var_type", self._determine_var_type())
BaseVar.__init__(self, **kwargs) # type: ignore
@ -1960,21 +1963,39 @@ class ComputedVar(Var, property):
return Any
def cached_var(fget: Callable[[Any], Any]) -> ComputedVar:
"""A field with computed getter that tracks other state dependencies.
The cached_var will only be recalculated when other state vars that it
depends on are modified.
def computed_var(
fget: Callable[[BaseState], Any] | None = None,
initial_value: Any | None = None,
cache: bool = False,
**kwargs,
) -> ComputedVar | Callable[[Callable[[BaseState], Any]], ComputedVar]:
"""A ComputedVar decorator with or without kwargs.
Args:
fget: the function that calculates the variable value.
fget: The getter function.
initial_value: The initial value of the computed var.
cache: Whether to cache the computed value.
**kwargs: additional attributes to set on the instance
Returns:
ComputedVar that is recomputed when dependencies change.
A ComputedVar instance.
"""
cvar = ComputedVar(fget=fget)
cvar._cache = True
return cvar
if fget is not None:
return ComputedVar(fget=fget, cache=cache)
def wrapper(fget):
return ComputedVar(
fget=fget,
initial_value=initial_value,
cache=cache,
**kwargs,
)
return wrapper
# Partial function of computed_var with cache=True
cached_var = functools.partial(computed_var, cache=True)
class CallableVar(BaseVar):

View File

@ -144,14 +144,19 @@ class ComputedVar(Var):
def __init__(
self,
fget: Callable[[BaseState], Any],
fset: Callable[[BaseState, Any], None] | None = None,
fdel: Callable[[BaseState], Any] | None = None,
doc: str | None = None,
**kwargs,
) -> None: ...
@overload
def __init__(self, func) -> None: ...
@overload
def computed_var(
fget: Callable[[BaseState], Any] | None = None,
initial_value: Any | None = None,
**kwargs,
) -> Callable[[Callable[[Any], Any]], ComputedVar]: ...
@overload
def computed_var(fget: Callable[[Any], Any]) -> ComputedVar: ...
def cached_var(fget: Callable[[Any], Any]) -> ComputedVar: ...
class CallableVar(BaseVar):

View File

@ -9,8 +9,8 @@ from reflex.base import Base
from reflex.state import BaseState
from reflex.vars import (
BaseVar,
ComputedVar,
Var,
computed_var,
)
test_vars = [
@ -46,7 +46,7 @@ def ParentState(TestObj):
foo: int
bar: int
@ComputedVar
@computed_var
def var_without_annotation(self):
return TestObj
@ -56,7 +56,7 @@ def ParentState(TestObj):
@pytest.fixture
def ChildState(ParentState, TestObj):
class ChildState(ParentState):
@ComputedVar
@computed_var
def var_without_annotation(self):
return TestObj
@ -66,7 +66,7 @@ def ChildState(ParentState, TestObj):
@pytest.fixture
def GrandChildState(ChildState, TestObj):
class GrandChildState(ChildState):
@ComputedVar
@computed_var
def var_without_annotation(self):
return TestObj
@ -76,7 +76,7 @@ def GrandChildState(ChildState, TestObj):
@pytest.fixture
def StateWithAnyVar(TestObj):
class StateWithAnyVar(BaseState):
@ComputedVar
@computed_var
def var_without_annotation(self) -> typing.Any:
return TestObj
@ -86,7 +86,7 @@ def StateWithAnyVar(TestObj):
@pytest.fixture
def StateWithCorrectVarAnnotation():
class StateWithCorrectVarAnnotation(BaseState):
@ComputedVar
@computed_var
def var_with_annotation(self) -> str:
return "Correct annotation"
@ -96,13 +96,53 @@ def StateWithCorrectVarAnnotation():
@pytest.fixture
def StateWithWrongVarAnnotation(TestObj):
class StateWithWrongVarAnnotation(BaseState):
@ComputedVar
@computed_var
def var_with_annotation(self) -> str:
return TestObj
return StateWithWrongVarAnnotation
@pytest.fixture
def StateWithInitialComputedVar():
class StateWithInitialComputedVar(BaseState):
@computed_var(initial_value="Initial value")
def var_with_initial_value(self) -> str:
return "Runtime value"
return StateWithInitialComputedVar
@pytest.fixture
def ChildWithInitialComputedVar(StateWithInitialComputedVar):
class ChildWithInitialComputedVar(StateWithInitialComputedVar):
@computed_var(initial_value="Initial value")
def var_with_initial_value_child(self) -> str:
return "Runtime value"
return ChildWithInitialComputedVar
@pytest.fixture
def StateWithRuntimeOnlyVar():
class StateWithRuntimeOnlyVar(BaseState):
@computed_var(initial_value=None)
def var_raises_at_runtime(self) -> str:
raise ValueError("So nicht, mein Freund")
return StateWithRuntimeOnlyVar
@pytest.fixture
def ChildWithRuntimeOnlyVar(StateWithRuntimeOnlyVar):
class ChildWithRuntimeOnlyVar(StateWithRuntimeOnlyVar):
@computed_var(initial_value="Initial value")
def var_raises_at_runtime_child(self) -> str:
raise ValueError("So nicht, mein Freund")
return ChildWithRuntimeOnlyVar
@pytest.mark.parametrize(
"prop,expected",
zip(
@ -731,6 +771,65 @@ def test_computed_var_with_annotation_error(request, fixture, full_name):
)
@pytest.mark.parametrize(
"fixture,var_name,expected_initial,expected_runtime,raises_at_runtime",
[
(
"StateWithInitialComputedVar",
"var_with_initial_value",
"Initial value",
"Runtime value",
False,
),
(
"ChildWithInitialComputedVar",
"var_with_initial_value_child",
"Initial value",
"Runtime value",
False,
),
(
"StateWithRuntimeOnlyVar",
"var_raises_at_runtime",
None,
None,
True,
),
(
"ChildWithRuntimeOnlyVar",
"var_raises_at_runtime_child",
"Initial value",
None,
True,
),
],
)
def test_state_with_initial_computed_var(
request, fixture, var_name, expected_initial, expected_runtime, raises_at_runtime
):
"""Test that the initial and runtime values of a computed var are correct.
Args:
request: Fixture Request.
fixture: The state fixture.
var_name: The name of the computed var.
expected_initial: The expected initial value of the computed var.
expected_runtime: The expected runtime value of the computed var.
raises_at_runtime: Whether the computed var is runtime only.
"""
state = request.getfixturevalue(fixture)()
state_name = state.get_full_name()
initial_dict = state.dict(initial=True)[state_name]
assert initial_dict[var_name] == expected_initial
if raises_at_runtime:
with pytest.raises(ValueError):
state.dict()[state_name][var_name]
else:
runtime_dict = state.dict()[state_name]
assert runtime_dict[var_name] == expected_runtime
@pytest.mark.parametrize(
"out, expected",
[