Initial values for computed vars (#2670)
* initial values for computed vars draft * add tests, add computed_var overloads * fix darglint * pass initial to substates when calling dict * add tests for for child states * format black * allow None as initial value * rename runtime_only to raises_at_runtime * cleanup unused arguments of ComputedVars * refactor cached_var to be partial of computed_var
This commit is contained in:
parent
82e3be76cf
commit
93f402c773
@ -138,7 +138,7 @@ def compile_state(state: Type[BaseState]) -> dict:
|
||||
A dictionary of the compiled state.
|
||||
"""
|
||||
try:
|
||||
initial_state = state().dict()
|
||||
initial_state = state().dict(initial=True)
|
||||
except Exception as e:
|
||||
console.warn(
|
||||
f"Failed to compile initial state with computed vars, excluding them: {e}"
|
||||
|
@ -46,10 +46,10 @@ from reflex.event import (
|
||||
from reflex.utils import console, format, prerequisites, types
|
||||
from reflex.utils.exceptions import ImmutableStateError, LockExpiredError
|
||||
from reflex.utils.serializers import SerializedType, serialize, serializer
|
||||
from reflex.vars import BaseVar, ComputedVar, Var
|
||||
from reflex.vars import BaseVar, ComputedVar, Var, computed_var
|
||||
|
||||
Delta = Dict[str, Any]
|
||||
var = ComputedVar
|
||||
var = computed_var
|
||||
|
||||
|
||||
class HeaderData(Base):
|
||||
@ -1328,11 +1328,14 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
|
||||
return super().get_value(key.__wrapped__)
|
||||
return super().get_value(key)
|
||||
|
||||
def dict(self, include_computed: bool = True, **kwargs) -> dict[str, Any]:
|
||||
def dict(
|
||||
self, include_computed: bool = True, initial: bool = False, **kwargs
|
||||
) -> dict[str, Any]:
|
||||
"""Convert the object to a dictionary.
|
||||
|
||||
Args:
|
||||
include_computed: Whether to include computed vars.
|
||||
initial: Whether to get the initial value of computed vars.
|
||||
**kwargs: Kwargs to pass to the pydantic dict method.
|
||||
|
||||
Returns:
|
||||
@ -1348,21 +1351,29 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
|
||||
prop_name: self.get_value(getattr(self, prop_name))
|
||||
for prop_name in self.base_vars
|
||||
}
|
||||
computed_vars = (
|
||||
{
|
||||
if initial:
|
||||
computed_vars = {
|
||||
# Include initial computed vars.
|
||||
prop_name: cv._initial_value
|
||||
if isinstance(cv, ComputedVar)
|
||||
and not isinstance(cv._initial_value, types.Unset)
|
||||
else self.get_value(getattr(self, prop_name))
|
||||
for prop_name, cv in self.computed_vars.items()
|
||||
}
|
||||
elif include_computed:
|
||||
computed_vars = {
|
||||
# Include the computed vars.
|
||||
prop_name: self.get_value(getattr(self, prop_name))
|
||||
for prop_name in self.computed_vars
|
||||
}
|
||||
if include_computed
|
||||
else {}
|
||||
)
|
||||
else:
|
||||
computed_vars = {}
|
||||
variables = {**base_vars, **computed_vars}
|
||||
d = {
|
||||
self.get_full_name(): {k: variables[k] for k in sorted(variables)},
|
||||
}
|
||||
for substate_d in [
|
||||
v.dict(include_computed=include_computed, **kwargs)
|
||||
v.dict(include_computed=include_computed, initial=initial, **kwargs)
|
||||
for v in self.substates.values()
|
||||
]:
|
||||
d.update(substate_d)
|
||||
|
@ -43,6 +43,29 @@ StateIterVar = Union[list, set, tuple]
|
||||
ArgsSpec = Callable
|
||||
|
||||
|
||||
class Unset:
|
||||
"""A class to represent an unset value.
|
||||
|
||||
This is used to differentiate between a value that is not set and a value that is set to None.
|
||||
"""
|
||||
|
||||
def __repr__(self) -> str:
|
||||
"""Return the string representation of the class.
|
||||
|
||||
Returns:
|
||||
The string representation of the class.
|
||||
"""
|
||||
return "Unset"
|
||||
|
||||
def __bool__(self) -> bool:
|
||||
"""Return False when the class is used in a boolean context.
|
||||
|
||||
Returns:
|
||||
False
|
||||
"""
|
||||
return False
|
||||
|
||||
|
||||
def is_generic_alias(cls: GenericType) -> bool:
|
||||
"""Check whether the class is a generic alias.
|
||||
|
||||
|
@ -4,6 +4,7 @@ from __future__ import annotations
|
||||
import contextlib
|
||||
import dataclasses
|
||||
import dis
|
||||
import functools
|
||||
import inspect
|
||||
import json
|
||||
import random
|
||||
@ -1802,24 +1803,26 @@ class ComputedVar(Var, property):
|
||||
# Whether to track dependencies and cache computed values
|
||||
_cache: bool = dataclasses.field(default=False)
|
||||
|
||||
_initial_value: Any | types.Unset = dataclasses.field(default_factory=types.Unset)
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
fget: Callable[[BaseState], Any],
|
||||
fset: Callable[[BaseState, Any], None] | None = None,
|
||||
fdel: Callable[[BaseState], Any] | None = None,
|
||||
doc: str | None = None,
|
||||
initial_value: Any | types.Unset = types.Unset(),
|
||||
cache: bool = False,
|
||||
**kwargs,
|
||||
):
|
||||
"""Initialize a ComputedVar.
|
||||
|
||||
Args:
|
||||
fget: The getter function.
|
||||
fset: The setter function.
|
||||
fdel: The deleter function.
|
||||
doc: The docstring.
|
||||
initial_value: The initial value of the computed var.
|
||||
cache: Whether to cache the computed value.
|
||||
**kwargs: additional attributes to set on the instance
|
||||
"""
|
||||
property.__init__(self, fget, fset, fdel, doc)
|
||||
self._initial_value = initial_value
|
||||
self._cache = cache
|
||||
property.__init__(self, fget)
|
||||
kwargs["_var_name"] = kwargs.pop("_var_name", fget.__name__)
|
||||
kwargs["_var_type"] = kwargs.pop("_var_type", self._determine_var_type())
|
||||
BaseVar.__init__(self, **kwargs) # type: ignore
|
||||
@ -1960,21 +1963,39 @@ class ComputedVar(Var, property):
|
||||
return Any
|
||||
|
||||
|
||||
def cached_var(fget: Callable[[Any], Any]) -> ComputedVar:
|
||||
"""A field with computed getter that tracks other state dependencies.
|
||||
|
||||
The cached_var will only be recalculated when other state vars that it
|
||||
depends on are modified.
|
||||
def computed_var(
|
||||
fget: Callable[[BaseState], Any] | None = None,
|
||||
initial_value: Any | None = None,
|
||||
cache: bool = False,
|
||||
**kwargs,
|
||||
) -> ComputedVar | Callable[[Callable[[BaseState], Any]], ComputedVar]:
|
||||
"""A ComputedVar decorator with or without kwargs.
|
||||
|
||||
Args:
|
||||
fget: the function that calculates the variable value.
|
||||
fget: The getter function.
|
||||
initial_value: The initial value of the computed var.
|
||||
cache: Whether to cache the computed value.
|
||||
**kwargs: additional attributes to set on the instance
|
||||
|
||||
Returns:
|
||||
ComputedVar that is recomputed when dependencies change.
|
||||
A ComputedVar instance.
|
||||
"""
|
||||
cvar = ComputedVar(fget=fget)
|
||||
cvar._cache = True
|
||||
return cvar
|
||||
if fget is not None:
|
||||
return ComputedVar(fget=fget, cache=cache)
|
||||
|
||||
def wrapper(fget):
|
||||
return ComputedVar(
|
||||
fget=fget,
|
||||
initial_value=initial_value,
|
||||
cache=cache,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
# Partial function of computed_var with cache=True
|
||||
cached_var = functools.partial(computed_var, cache=True)
|
||||
|
||||
|
||||
class CallableVar(BaseVar):
|
||||
|
@ -144,14 +144,19 @@ class ComputedVar(Var):
|
||||
def __init__(
|
||||
self,
|
||||
fget: Callable[[BaseState], Any],
|
||||
fset: Callable[[BaseState, Any], None] | None = None,
|
||||
fdel: Callable[[BaseState], Any] | None = None,
|
||||
doc: str | None = None,
|
||||
**kwargs,
|
||||
) -> None: ...
|
||||
@overload
|
||||
def __init__(self, func) -> None: ...
|
||||
|
||||
@overload
|
||||
def computed_var(
|
||||
fget: Callable[[BaseState], Any] | None = None,
|
||||
initial_value: Any | None = None,
|
||||
**kwargs,
|
||||
) -> Callable[[Callable[[Any], Any]], ComputedVar]: ...
|
||||
@overload
|
||||
def computed_var(fget: Callable[[Any], Any]) -> ComputedVar: ...
|
||||
def cached_var(fget: Callable[[Any], Any]) -> ComputedVar: ...
|
||||
|
||||
class CallableVar(BaseVar):
|
||||
|
@ -9,8 +9,8 @@ from reflex.base import Base
|
||||
from reflex.state import BaseState
|
||||
from reflex.vars import (
|
||||
BaseVar,
|
||||
ComputedVar,
|
||||
Var,
|
||||
computed_var,
|
||||
)
|
||||
|
||||
test_vars = [
|
||||
@ -46,7 +46,7 @@ def ParentState(TestObj):
|
||||
foo: int
|
||||
bar: int
|
||||
|
||||
@ComputedVar
|
||||
@computed_var
|
||||
def var_without_annotation(self):
|
||||
return TestObj
|
||||
|
||||
@ -56,7 +56,7 @@ def ParentState(TestObj):
|
||||
@pytest.fixture
|
||||
def ChildState(ParentState, TestObj):
|
||||
class ChildState(ParentState):
|
||||
@ComputedVar
|
||||
@computed_var
|
||||
def var_without_annotation(self):
|
||||
return TestObj
|
||||
|
||||
@ -66,7 +66,7 @@ def ChildState(ParentState, TestObj):
|
||||
@pytest.fixture
|
||||
def GrandChildState(ChildState, TestObj):
|
||||
class GrandChildState(ChildState):
|
||||
@ComputedVar
|
||||
@computed_var
|
||||
def var_without_annotation(self):
|
||||
return TestObj
|
||||
|
||||
@ -76,7 +76,7 @@ def GrandChildState(ChildState, TestObj):
|
||||
@pytest.fixture
|
||||
def StateWithAnyVar(TestObj):
|
||||
class StateWithAnyVar(BaseState):
|
||||
@ComputedVar
|
||||
@computed_var
|
||||
def var_without_annotation(self) -> typing.Any:
|
||||
return TestObj
|
||||
|
||||
@ -86,7 +86,7 @@ def StateWithAnyVar(TestObj):
|
||||
@pytest.fixture
|
||||
def StateWithCorrectVarAnnotation():
|
||||
class StateWithCorrectVarAnnotation(BaseState):
|
||||
@ComputedVar
|
||||
@computed_var
|
||||
def var_with_annotation(self) -> str:
|
||||
return "Correct annotation"
|
||||
|
||||
@ -96,13 +96,53 @@ def StateWithCorrectVarAnnotation():
|
||||
@pytest.fixture
|
||||
def StateWithWrongVarAnnotation(TestObj):
|
||||
class StateWithWrongVarAnnotation(BaseState):
|
||||
@ComputedVar
|
||||
@computed_var
|
||||
def var_with_annotation(self) -> str:
|
||||
return TestObj
|
||||
|
||||
return StateWithWrongVarAnnotation
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def StateWithInitialComputedVar():
|
||||
class StateWithInitialComputedVar(BaseState):
|
||||
@computed_var(initial_value="Initial value")
|
||||
def var_with_initial_value(self) -> str:
|
||||
return "Runtime value"
|
||||
|
||||
return StateWithInitialComputedVar
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def ChildWithInitialComputedVar(StateWithInitialComputedVar):
|
||||
class ChildWithInitialComputedVar(StateWithInitialComputedVar):
|
||||
@computed_var(initial_value="Initial value")
|
||||
def var_with_initial_value_child(self) -> str:
|
||||
return "Runtime value"
|
||||
|
||||
return ChildWithInitialComputedVar
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def StateWithRuntimeOnlyVar():
|
||||
class StateWithRuntimeOnlyVar(BaseState):
|
||||
@computed_var(initial_value=None)
|
||||
def var_raises_at_runtime(self) -> str:
|
||||
raise ValueError("So nicht, mein Freund")
|
||||
|
||||
return StateWithRuntimeOnlyVar
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def ChildWithRuntimeOnlyVar(StateWithRuntimeOnlyVar):
|
||||
class ChildWithRuntimeOnlyVar(StateWithRuntimeOnlyVar):
|
||||
@computed_var(initial_value="Initial value")
|
||||
def var_raises_at_runtime_child(self) -> str:
|
||||
raise ValueError("So nicht, mein Freund")
|
||||
|
||||
return ChildWithRuntimeOnlyVar
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"prop,expected",
|
||||
zip(
|
||||
@ -731,6 +771,65 @@ def test_computed_var_with_annotation_error(request, fixture, full_name):
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"fixture,var_name,expected_initial,expected_runtime,raises_at_runtime",
|
||||
[
|
||||
(
|
||||
"StateWithInitialComputedVar",
|
||||
"var_with_initial_value",
|
||||
"Initial value",
|
||||
"Runtime value",
|
||||
False,
|
||||
),
|
||||
(
|
||||
"ChildWithInitialComputedVar",
|
||||
"var_with_initial_value_child",
|
||||
"Initial value",
|
||||
"Runtime value",
|
||||
False,
|
||||
),
|
||||
(
|
||||
"StateWithRuntimeOnlyVar",
|
||||
"var_raises_at_runtime",
|
||||
None,
|
||||
None,
|
||||
True,
|
||||
),
|
||||
(
|
||||
"ChildWithRuntimeOnlyVar",
|
||||
"var_raises_at_runtime_child",
|
||||
"Initial value",
|
||||
None,
|
||||
True,
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_state_with_initial_computed_var(
|
||||
request, fixture, var_name, expected_initial, expected_runtime, raises_at_runtime
|
||||
):
|
||||
"""Test that the initial and runtime values of a computed var are correct.
|
||||
|
||||
Args:
|
||||
request: Fixture Request.
|
||||
fixture: The state fixture.
|
||||
var_name: The name of the computed var.
|
||||
expected_initial: The expected initial value of the computed var.
|
||||
expected_runtime: The expected runtime value of the computed var.
|
||||
raises_at_runtime: Whether the computed var is runtime only.
|
||||
"""
|
||||
state = request.getfixturevalue(fixture)()
|
||||
state_name = state.get_full_name()
|
||||
initial_dict = state.dict(initial=True)[state_name]
|
||||
assert initial_dict[var_name] == expected_initial
|
||||
|
||||
if raises_at_runtime:
|
||||
with pytest.raises(ValueError):
|
||||
state.dict()[state_name][var_name]
|
||||
else:
|
||||
runtime_dict = state.dict()[state_name]
|
||||
assert runtime_dict[var_name] == expected_runtime
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"out, expected",
|
||||
[
|
||||
|
Loading…
Reference in New Issue
Block a user