Initial values for computed vars (#2670)
* initial values for computed vars draft * add tests, add computed_var overloads * fix darglint * pass initial to substates when calling dict * add tests for for child states * format black * allow None as initial value * rename runtime_only to raises_at_runtime * cleanup unused arguments of ComputedVars * refactor cached_var to be partial of computed_var
This commit is contained in:
parent
82e3be76cf
commit
93f402c773
@ -138,7 +138,7 @@ def compile_state(state: Type[BaseState]) -> dict:
|
|||||||
A dictionary of the compiled state.
|
A dictionary of the compiled state.
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
initial_state = state().dict()
|
initial_state = state().dict(initial=True)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
console.warn(
|
console.warn(
|
||||||
f"Failed to compile initial state with computed vars, excluding them: {e}"
|
f"Failed to compile initial state with computed vars, excluding them: {e}"
|
||||||
|
@ -46,10 +46,10 @@ from reflex.event import (
|
|||||||
from reflex.utils import console, format, prerequisites, types
|
from reflex.utils import console, format, prerequisites, types
|
||||||
from reflex.utils.exceptions import ImmutableStateError, LockExpiredError
|
from reflex.utils.exceptions import ImmutableStateError, LockExpiredError
|
||||||
from reflex.utils.serializers import SerializedType, serialize, serializer
|
from reflex.utils.serializers import SerializedType, serialize, serializer
|
||||||
from reflex.vars import BaseVar, ComputedVar, Var
|
from reflex.vars import BaseVar, ComputedVar, Var, computed_var
|
||||||
|
|
||||||
Delta = Dict[str, Any]
|
Delta = Dict[str, Any]
|
||||||
var = ComputedVar
|
var = computed_var
|
||||||
|
|
||||||
|
|
||||||
class HeaderData(Base):
|
class HeaderData(Base):
|
||||||
@ -1328,11 +1328,14 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
|
|||||||
return super().get_value(key.__wrapped__)
|
return super().get_value(key.__wrapped__)
|
||||||
return super().get_value(key)
|
return super().get_value(key)
|
||||||
|
|
||||||
def dict(self, include_computed: bool = True, **kwargs) -> dict[str, Any]:
|
def dict(
|
||||||
|
self, include_computed: bool = True, initial: bool = False, **kwargs
|
||||||
|
) -> dict[str, Any]:
|
||||||
"""Convert the object to a dictionary.
|
"""Convert the object to a dictionary.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
include_computed: Whether to include computed vars.
|
include_computed: Whether to include computed vars.
|
||||||
|
initial: Whether to get the initial value of computed vars.
|
||||||
**kwargs: Kwargs to pass to the pydantic dict method.
|
**kwargs: Kwargs to pass to the pydantic dict method.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
@ -1348,21 +1351,29 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
|
|||||||
prop_name: self.get_value(getattr(self, prop_name))
|
prop_name: self.get_value(getattr(self, prop_name))
|
||||||
for prop_name in self.base_vars
|
for prop_name in self.base_vars
|
||||||
}
|
}
|
||||||
computed_vars = (
|
if initial:
|
||||||
{
|
computed_vars = {
|
||||||
|
# Include initial computed vars.
|
||||||
|
prop_name: cv._initial_value
|
||||||
|
if isinstance(cv, ComputedVar)
|
||||||
|
and not isinstance(cv._initial_value, types.Unset)
|
||||||
|
else self.get_value(getattr(self, prop_name))
|
||||||
|
for prop_name, cv in self.computed_vars.items()
|
||||||
|
}
|
||||||
|
elif include_computed:
|
||||||
|
computed_vars = {
|
||||||
# Include the computed vars.
|
# Include the computed vars.
|
||||||
prop_name: self.get_value(getattr(self, prop_name))
|
prop_name: self.get_value(getattr(self, prop_name))
|
||||||
for prop_name in self.computed_vars
|
for prop_name in self.computed_vars
|
||||||
}
|
}
|
||||||
if include_computed
|
else:
|
||||||
else {}
|
computed_vars = {}
|
||||||
)
|
|
||||||
variables = {**base_vars, **computed_vars}
|
variables = {**base_vars, **computed_vars}
|
||||||
d = {
|
d = {
|
||||||
self.get_full_name(): {k: variables[k] for k in sorted(variables)},
|
self.get_full_name(): {k: variables[k] for k in sorted(variables)},
|
||||||
}
|
}
|
||||||
for substate_d in [
|
for substate_d in [
|
||||||
v.dict(include_computed=include_computed, **kwargs)
|
v.dict(include_computed=include_computed, initial=initial, **kwargs)
|
||||||
for v in self.substates.values()
|
for v in self.substates.values()
|
||||||
]:
|
]:
|
||||||
d.update(substate_d)
|
d.update(substate_d)
|
||||||
|
@ -43,6 +43,29 @@ StateIterVar = Union[list, set, tuple]
|
|||||||
ArgsSpec = Callable
|
ArgsSpec = Callable
|
||||||
|
|
||||||
|
|
||||||
|
class Unset:
|
||||||
|
"""A class to represent an unset value.
|
||||||
|
|
||||||
|
This is used to differentiate between a value that is not set and a value that is set to None.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __repr__(self) -> str:
|
||||||
|
"""Return the string representation of the class.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The string representation of the class.
|
||||||
|
"""
|
||||||
|
return "Unset"
|
||||||
|
|
||||||
|
def __bool__(self) -> bool:
|
||||||
|
"""Return False when the class is used in a boolean context.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
False
|
||||||
|
"""
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
def is_generic_alias(cls: GenericType) -> bool:
|
def is_generic_alias(cls: GenericType) -> bool:
|
||||||
"""Check whether the class is a generic alias.
|
"""Check whether the class is a generic alias.
|
||||||
|
|
||||||
|
@ -4,6 +4,7 @@ from __future__ import annotations
|
|||||||
import contextlib
|
import contextlib
|
||||||
import dataclasses
|
import dataclasses
|
||||||
import dis
|
import dis
|
||||||
|
import functools
|
||||||
import inspect
|
import inspect
|
||||||
import json
|
import json
|
||||||
import random
|
import random
|
||||||
@ -1802,24 +1803,26 @@ class ComputedVar(Var, property):
|
|||||||
# Whether to track dependencies and cache computed values
|
# Whether to track dependencies and cache computed values
|
||||||
_cache: bool = dataclasses.field(default=False)
|
_cache: bool = dataclasses.field(default=False)
|
||||||
|
|
||||||
|
_initial_value: Any | types.Unset = dataclasses.field(default_factory=types.Unset)
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
fget: Callable[[BaseState], Any],
|
fget: Callable[[BaseState], Any],
|
||||||
fset: Callable[[BaseState, Any], None] | None = None,
|
initial_value: Any | types.Unset = types.Unset(),
|
||||||
fdel: Callable[[BaseState], Any] | None = None,
|
cache: bool = False,
|
||||||
doc: str | None = None,
|
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
"""Initialize a ComputedVar.
|
"""Initialize a ComputedVar.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
fget: The getter function.
|
fget: The getter function.
|
||||||
fset: The setter function.
|
initial_value: The initial value of the computed var.
|
||||||
fdel: The deleter function.
|
cache: Whether to cache the computed value.
|
||||||
doc: The docstring.
|
|
||||||
**kwargs: additional attributes to set on the instance
|
**kwargs: additional attributes to set on the instance
|
||||||
"""
|
"""
|
||||||
property.__init__(self, fget, fset, fdel, doc)
|
self._initial_value = initial_value
|
||||||
|
self._cache = cache
|
||||||
|
property.__init__(self, fget)
|
||||||
kwargs["_var_name"] = kwargs.pop("_var_name", fget.__name__)
|
kwargs["_var_name"] = kwargs.pop("_var_name", fget.__name__)
|
||||||
kwargs["_var_type"] = kwargs.pop("_var_type", self._determine_var_type())
|
kwargs["_var_type"] = kwargs.pop("_var_type", self._determine_var_type())
|
||||||
BaseVar.__init__(self, **kwargs) # type: ignore
|
BaseVar.__init__(self, **kwargs) # type: ignore
|
||||||
@ -1960,21 +1963,39 @@ class ComputedVar(Var, property):
|
|||||||
return Any
|
return Any
|
||||||
|
|
||||||
|
|
||||||
def cached_var(fget: Callable[[Any], Any]) -> ComputedVar:
|
def computed_var(
|
||||||
"""A field with computed getter that tracks other state dependencies.
|
fget: Callable[[BaseState], Any] | None = None,
|
||||||
|
initial_value: Any | None = None,
|
||||||
The cached_var will only be recalculated when other state vars that it
|
cache: bool = False,
|
||||||
depends on are modified.
|
**kwargs,
|
||||||
|
) -> ComputedVar | Callable[[Callable[[BaseState], Any]], ComputedVar]:
|
||||||
|
"""A ComputedVar decorator with or without kwargs.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
fget: the function that calculates the variable value.
|
fget: The getter function.
|
||||||
|
initial_value: The initial value of the computed var.
|
||||||
|
cache: Whether to cache the computed value.
|
||||||
|
**kwargs: additional attributes to set on the instance
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
ComputedVar that is recomputed when dependencies change.
|
A ComputedVar instance.
|
||||||
"""
|
"""
|
||||||
cvar = ComputedVar(fget=fget)
|
if fget is not None:
|
||||||
cvar._cache = True
|
return ComputedVar(fget=fget, cache=cache)
|
||||||
return cvar
|
|
||||||
|
def wrapper(fget):
|
||||||
|
return ComputedVar(
|
||||||
|
fget=fget,
|
||||||
|
initial_value=initial_value,
|
||||||
|
cache=cache,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
return wrapper
|
||||||
|
|
||||||
|
|
||||||
|
# Partial function of computed_var with cache=True
|
||||||
|
cached_var = functools.partial(computed_var, cache=True)
|
||||||
|
|
||||||
|
|
||||||
class CallableVar(BaseVar):
|
class CallableVar(BaseVar):
|
||||||
|
@ -144,14 +144,19 @@ class ComputedVar(Var):
|
|||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
fget: Callable[[BaseState], Any],
|
fget: Callable[[BaseState], Any],
|
||||||
fset: Callable[[BaseState, Any], None] | None = None,
|
|
||||||
fdel: Callable[[BaseState], Any] | None = None,
|
|
||||||
doc: str | None = None,
|
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> None: ...
|
) -> None: ...
|
||||||
@overload
|
@overload
|
||||||
def __init__(self, func) -> None: ...
|
def __init__(self, func) -> None: ...
|
||||||
|
|
||||||
|
@overload
|
||||||
|
def computed_var(
|
||||||
|
fget: Callable[[BaseState], Any] | None = None,
|
||||||
|
initial_value: Any | None = None,
|
||||||
|
**kwargs,
|
||||||
|
) -> Callable[[Callable[[Any], Any]], ComputedVar]: ...
|
||||||
|
@overload
|
||||||
|
def computed_var(fget: Callable[[Any], Any]) -> ComputedVar: ...
|
||||||
def cached_var(fget: Callable[[Any], Any]) -> ComputedVar: ...
|
def cached_var(fget: Callable[[Any], Any]) -> ComputedVar: ...
|
||||||
|
|
||||||
class CallableVar(BaseVar):
|
class CallableVar(BaseVar):
|
||||||
|
@ -9,8 +9,8 @@ from reflex.base import Base
|
|||||||
from reflex.state import BaseState
|
from reflex.state import BaseState
|
||||||
from reflex.vars import (
|
from reflex.vars import (
|
||||||
BaseVar,
|
BaseVar,
|
||||||
ComputedVar,
|
|
||||||
Var,
|
Var,
|
||||||
|
computed_var,
|
||||||
)
|
)
|
||||||
|
|
||||||
test_vars = [
|
test_vars = [
|
||||||
@ -46,7 +46,7 @@ def ParentState(TestObj):
|
|||||||
foo: int
|
foo: int
|
||||||
bar: int
|
bar: int
|
||||||
|
|
||||||
@ComputedVar
|
@computed_var
|
||||||
def var_without_annotation(self):
|
def var_without_annotation(self):
|
||||||
return TestObj
|
return TestObj
|
||||||
|
|
||||||
@ -56,7 +56,7 @@ def ParentState(TestObj):
|
|||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def ChildState(ParentState, TestObj):
|
def ChildState(ParentState, TestObj):
|
||||||
class ChildState(ParentState):
|
class ChildState(ParentState):
|
||||||
@ComputedVar
|
@computed_var
|
||||||
def var_without_annotation(self):
|
def var_without_annotation(self):
|
||||||
return TestObj
|
return TestObj
|
||||||
|
|
||||||
@ -66,7 +66,7 @@ def ChildState(ParentState, TestObj):
|
|||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def GrandChildState(ChildState, TestObj):
|
def GrandChildState(ChildState, TestObj):
|
||||||
class GrandChildState(ChildState):
|
class GrandChildState(ChildState):
|
||||||
@ComputedVar
|
@computed_var
|
||||||
def var_without_annotation(self):
|
def var_without_annotation(self):
|
||||||
return TestObj
|
return TestObj
|
||||||
|
|
||||||
@ -76,7 +76,7 @@ def GrandChildState(ChildState, TestObj):
|
|||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def StateWithAnyVar(TestObj):
|
def StateWithAnyVar(TestObj):
|
||||||
class StateWithAnyVar(BaseState):
|
class StateWithAnyVar(BaseState):
|
||||||
@ComputedVar
|
@computed_var
|
||||||
def var_without_annotation(self) -> typing.Any:
|
def var_without_annotation(self) -> typing.Any:
|
||||||
return TestObj
|
return TestObj
|
||||||
|
|
||||||
@ -86,7 +86,7 @@ def StateWithAnyVar(TestObj):
|
|||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def StateWithCorrectVarAnnotation():
|
def StateWithCorrectVarAnnotation():
|
||||||
class StateWithCorrectVarAnnotation(BaseState):
|
class StateWithCorrectVarAnnotation(BaseState):
|
||||||
@ComputedVar
|
@computed_var
|
||||||
def var_with_annotation(self) -> str:
|
def var_with_annotation(self) -> str:
|
||||||
return "Correct annotation"
|
return "Correct annotation"
|
||||||
|
|
||||||
@ -96,13 +96,53 @@ def StateWithCorrectVarAnnotation():
|
|||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def StateWithWrongVarAnnotation(TestObj):
|
def StateWithWrongVarAnnotation(TestObj):
|
||||||
class StateWithWrongVarAnnotation(BaseState):
|
class StateWithWrongVarAnnotation(BaseState):
|
||||||
@ComputedVar
|
@computed_var
|
||||||
def var_with_annotation(self) -> str:
|
def var_with_annotation(self) -> str:
|
||||||
return TestObj
|
return TestObj
|
||||||
|
|
||||||
return StateWithWrongVarAnnotation
|
return StateWithWrongVarAnnotation
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def StateWithInitialComputedVar():
|
||||||
|
class StateWithInitialComputedVar(BaseState):
|
||||||
|
@computed_var(initial_value="Initial value")
|
||||||
|
def var_with_initial_value(self) -> str:
|
||||||
|
return "Runtime value"
|
||||||
|
|
||||||
|
return StateWithInitialComputedVar
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def ChildWithInitialComputedVar(StateWithInitialComputedVar):
|
||||||
|
class ChildWithInitialComputedVar(StateWithInitialComputedVar):
|
||||||
|
@computed_var(initial_value="Initial value")
|
||||||
|
def var_with_initial_value_child(self) -> str:
|
||||||
|
return "Runtime value"
|
||||||
|
|
||||||
|
return ChildWithInitialComputedVar
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def StateWithRuntimeOnlyVar():
|
||||||
|
class StateWithRuntimeOnlyVar(BaseState):
|
||||||
|
@computed_var(initial_value=None)
|
||||||
|
def var_raises_at_runtime(self) -> str:
|
||||||
|
raise ValueError("So nicht, mein Freund")
|
||||||
|
|
||||||
|
return StateWithRuntimeOnlyVar
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def ChildWithRuntimeOnlyVar(StateWithRuntimeOnlyVar):
|
||||||
|
class ChildWithRuntimeOnlyVar(StateWithRuntimeOnlyVar):
|
||||||
|
@computed_var(initial_value="Initial value")
|
||||||
|
def var_raises_at_runtime_child(self) -> str:
|
||||||
|
raise ValueError("So nicht, mein Freund")
|
||||||
|
|
||||||
|
return ChildWithRuntimeOnlyVar
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"prop,expected",
|
"prop,expected",
|
||||||
zip(
|
zip(
|
||||||
@ -731,6 +771,65 @@ def test_computed_var_with_annotation_error(request, fixture, full_name):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"fixture,var_name,expected_initial,expected_runtime,raises_at_runtime",
|
||||||
|
[
|
||||||
|
(
|
||||||
|
"StateWithInitialComputedVar",
|
||||||
|
"var_with_initial_value",
|
||||||
|
"Initial value",
|
||||||
|
"Runtime value",
|
||||||
|
False,
|
||||||
|
),
|
||||||
|
(
|
||||||
|
"ChildWithInitialComputedVar",
|
||||||
|
"var_with_initial_value_child",
|
||||||
|
"Initial value",
|
||||||
|
"Runtime value",
|
||||||
|
False,
|
||||||
|
),
|
||||||
|
(
|
||||||
|
"StateWithRuntimeOnlyVar",
|
||||||
|
"var_raises_at_runtime",
|
||||||
|
None,
|
||||||
|
None,
|
||||||
|
True,
|
||||||
|
),
|
||||||
|
(
|
||||||
|
"ChildWithRuntimeOnlyVar",
|
||||||
|
"var_raises_at_runtime_child",
|
||||||
|
"Initial value",
|
||||||
|
None,
|
||||||
|
True,
|
||||||
|
),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_state_with_initial_computed_var(
|
||||||
|
request, fixture, var_name, expected_initial, expected_runtime, raises_at_runtime
|
||||||
|
):
|
||||||
|
"""Test that the initial and runtime values of a computed var are correct.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
request: Fixture Request.
|
||||||
|
fixture: The state fixture.
|
||||||
|
var_name: The name of the computed var.
|
||||||
|
expected_initial: The expected initial value of the computed var.
|
||||||
|
expected_runtime: The expected runtime value of the computed var.
|
||||||
|
raises_at_runtime: Whether the computed var is runtime only.
|
||||||
|
"""
|
||||||
|
state = request.getfixturevalue(fixture)()
|
||||||
|
state_name = state.get_full_name()
|
||||||
|
initial_dict = state.dict(initial=True)[state_name]
|
||||||
|
assert initial_dict[var_name] == expected_initial
|
||||||
|
|
||||||
|
if raises_at_runtime:
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
state.dict()[state_name][var_name]
|
||||||
|
else:
|
||||||
|
runtime_dict = state.dict()[state_name]
|
||||||
|
assert runtime_dict[var_name] == expected_runtime
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"out, expected",
|
"out, expected",
|
||||||
[
|
[
|
||||||
|
Loading…
Reference in New Issue
Block a user