DRAFT PR - Added code for computed backend vars (#2540)
* added code for computed backend vars * fixed formatting issues * fix small bug * fixes ruff issue * fixed black issue * augment test for backend computed var --------- Co-authored-by: Masen Furer <m_github@0x26.net>
This commit is contained in:
parent
cc678e8648
commit
0a18eaa28b
@ -294,7 +294,13 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
|
||||
self._init_event_handlers()
|
||||
|
||||
# Create a fresh copy of the backend variables for this instance
|
||||
self._backend_vars = copy.deepcopy(self.backend_vars)
|
||||
self._backend_vars = copy.deepcopy(
|
||||
{
|
||||
name: item
|
||||
for name, item in self.backend_vars.items()
|
||||
if name not in self.computed_vars
|
||||
}
|
||||
)
|
||||
|
||||
def _init_event_handlers(self, state: BaseState | None = None):
|
||||
"""Initialize event handlers.
|
||||
@ -330,6 +336,21 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
|
||||
"""
|
||||
return f"{self.__class__.__name__}({self.dict()})"
|
||||
|
||||
@classmethod
|
||||
def _get_computed_vars(cls) -> list[ComputedVar]:
|
||||
"""Helper function to get all computed vars of a instance.
|
||||
|
||||
Returns:
|
||||
A list of computed vars.
|
||||
"""
|
||||
return [
|
||||
v
|
||||
for mixin in cls.__mro__
|
||||
if mixin is cls or not issubclass(mixin, (BaseState, ABC))
|
||||
for v in mixin.__dict__.values()
|
||||
if isinstance(v, ComputedVar)
|
||||
]
|
||||
|
||||
@classmethod
|
||||
def __init_subclass__(cls, **kwargs):
|
||||
"""Do some magic for the subclass initialization.
|
||||
@ -376,6 +397,9 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
|
||||
# Track this new subclass in the parent state's subclasses set.
|
||||
parent_state.class_subclasses.add(cls)
|
||||
|
||||
# Get computed vars.
|
||||
computed_vars = cls._get_computed_vars()
|
||||
|
||||
new_backend_vars = {
|
||||
name: value
|
||||
for name, value in cls.__dict__.items()
|
||||
@ -383,9 +407,22 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
|
||||
and name not in RESERVED_BACKEND_VAR_NAMES
|
||||
and name not in cls.inherited_backend_vars
|
||||
and not isinstance(value, FunctionType)
|
||||
and not isinstance(value, ComputedVar)
|
||||
}
|
||||
|
||||
cls.backend_vars = {**cls.inherited_backend_vars, **new_backend_vars}
|
||||
# Get backend computed vars
|
||||
backend_computed_vars = {
|
||||
v._var_name: v._var_set_state(cls)
|
||||
for v in computed_vars
|
||||
if types.is_backend_variable(v._var_name, cls)
|
||||
and v._var_name not in cls.inherited_backend_vars
|
||||
}
|
||||
|
||||
cls.backend_vars = {
|
||||
**cls.inherited_backend_vars,
|
||||
**new_backend_vars,
|
||||
**backend_computed_vars,
|
||||
}
|
||||
|
||||
# Set the base and computed vars.
|
||||
cls.base_vars = {
|
||||
@ -395,11 +432,7 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
|
||||
for f in cls.get_fields().values()
|
||||
if f.name not in cls.get_skip_vars()
|
||||
}
|
||||
cls.computed_vars = {
|
||||
v._var_name: v._var_set_state(cls)
|
||||
for v in cls.__dict__.values()
|
||||
if isinstance(v, ComputedVar)
|
||||
}
|
||||
cls.computed_vars = {v._var_name: v._var_set_state(cls) for v in computed_vars}
|
||||
cls.vars = {
|
||||
**cls.inherited_vars,
|
||||
**cls.base_vars,
|
||||
|
@ -955,6 +955,24 @@ class InterdependentState(BaseState):
|
||||
"""
|
||||
return self.v1x2 * 2 # type: ignore
|
||||
|
||||
@rx.cached_var
|
||||
def _v3(self) -> int:
|
||||
"""Depends on backend var _v2.
|
||||
|
||||
Returns:
|
||||
The value of the backend variable.
|
||||
"""
|
||||
return self._v2
|
||||
|
||||
@rx.cached_var
|
||||
def v3x2(self) -> int:
|
||||
"""Depends on ComputedVar _v3.
|
||||
|
||||
Returns:
|
||||
ComputedVar _v3 multiplied by 2
|
||||
"""
|
||||
return self._v3 * 2
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def interdependent_state() -> BaseState:
|
||||
@ -1003,8 +1021,9 @@ def test_dirty_computed_var_from_backend_var(interdependent_state):
|
||||
"""
|
||||
interdependent_state._v2 = 2
|
||||
assert interdependent_state.get_delta() == {
|
||||
interdependent_state.get_full_name(): {"v2x2": 4},
|
||||
interdependent_state.get_full_name(): {"v2x2": 4, "v3x2": 4},
|
||||
}
|
||||
assert "_v3" in InterdependentState.backend_vars
|
||||
|
||||
|
||||
def test_per_state_backend_var(interdependent_state):
|
||||
|
Loading…
Reference in New Issue
Block a user