DRAFT PR - Added code for computed backend vars (#2540)
* added code for computed backend vars * fixed formatting issues * fix small bug * fixes ruff issue * fixed black issue * augment test for backend computed var --------- Co-authored-by: Masen Furer <m_github@0x26.net>
This commit is contained in:
parent
cc678e8648
commit
0a18eaa28b
@ -294,7 +294,13 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
|
|||||||
self._init_event_handlers()
|
self._init_event_handlers()
|
||||||
|
|
||||||
# Create a fresh copy of the backend variables for this instance
|
# Create a fresh copy of the backend variables for this instance
|
||||||
self._backend_vars = copy.deepcopy(self.backend_vars)
|
self._backend_vars = copy.deepcopy(
|
||||||
|
{
|
||||||
|
name: item
|
||||||
|
for name, item in self.backend_vars.items()
|
||||||
|
if name not in self.computed_vars
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
def _init_event_handlers(self, state: BaseState | None = None):
|
def _init_event_handlers(self, state: BaseState | None = None):
|
||||||
"""Initialize event handlers.
|
"""Initialize event handlers.
|
||||||
@ -330,6 +336,21 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
|
|||||||
"""
|
"""
|
||||||
return f"{self.__class__.__name__}({self.dict()})"
|
return f"{self.__class__.__name__}({self.dict()})"
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _get_computed_vars(cls) -> list[ComputedVar]:
|
||||||
|
"""Helper function to get all computed vars of a instance.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A list of computed vars.
|
||||||
|
"""
|
||||||
|
return [
|
||||||
|
v
|
||||||
|
for mixin in cls.__mro__
|
||||||
|
if mixin is cls or not issubclass(mixin, (BaseState, ABC))
|
||||||
|
for v in mixin.__dict__.values()
|
||||||
|
if isinstance(v, ComputedVar)
|
||||||
|
]
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def __init_subclass__(cls, **kwargs):
|
def __init_subclass__(cls, **kwargs):
|
||||||
"""Do some magic for the subclass initialization.
|
"""Do some magic for the subclass initialization.
|
||||||
@ -376,6 +397,9 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
|
|||||||
# Track this new subclass in the parent state's subclasses set.
|
# Track this new subclass in the parent state's subclasses set.
|
||||||
parent_state.class_subclasses.add(cls)
|
parent_state.class_subclasses.add(cls)
|
||||||
|
|
||||||
|
# Get computed vars.
|
||||||
|
computed_vars = cls._get_computed_vars()
|
||||||
|
|
||||||
new_backend_vars = {
|
new_backend_vars = {
|
||||||
name: value
|
name: value
|
||||||
for name, value in cls.__dict__.items()
|
for name, value in cls.__dict__.items()
|
||||||
@ -383,9 +407,22 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
|
|||||||
and name not in RESERVED_BACKEND_VAR_NAMES
|
and name not in RESERVED_BACKEND_VAR_NAMES
|
||||||
and name not in cls.inherited_backend_vars
|
and name not in cls.inherited_backend_vars
|
||||||
and not isinstance(value, FunctionType)
|
and not isinstance(value, FunctionType)
|
||||||
|
and not isinstance(value, ComputedVar)
|
||||||
}
|
}
|
||||||
|
|
||||||
cls.backend_vars = {**cls.inherited_backend_vars, **new_backend_vars}
|
# Get backend computed vars
|
||||||
|
backend_computed_vars = {
|
||||||
|
v._var_name: v._var_set_state(cls)
|
||||||
|
for v in computed_vars
|
||||||
|
if types.is_backend_variable(v._var_name, cls)
|
||||||
|
and v._var_name not in cls.inherited_backend_vars
|
||||||
|
}
|
||||||
|
|
||||||
|
cls.backend_vars = {
|
||||||
|
**cls.inherited_backend_vars,
|
||||||
|
**new_backend_vars,
|
||||||
|
**backend_computed_vars,
|
||||||
|
}
|
||||||
|
|
||||||
# Set the base and computed vars.
|
# Set the base and computed vars.
|
||||||
cls.base_vars = {
|
cls.base_vars = {
|
||||||
@ -395,11 +432,7 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
|
|||||||
for f in cls.get_fields().values()
|
for f in cls.get_fields().values()
|
||||||
if f.name not in cls.get_skip_vars()
|
if f.name not in cls.get_skip_vars()
|
||||||
}
|
}
|
||||||
cls.computed_vars = {
|
cls.computed_vars = {v._var_name: v._var_set_state(cls) for v in computed_vars}
|
||||||
v._var_name: v._var_set_state(cls)
|
|
||||||
for v in cls.__dict__.values()
|
|
||||||
if isinstance(v, ComputedVar)
|
|
||||||
}
|
|
||||||
cls.vars = {
|
cls.vars = {
|
||||||
**cls.inherited_vars,
|
**cls.inherited_vars,
|
||||||
**cls.base_vars,
|
**cls.base_vars,
|
||||||
|
@ -955,6 +955,24 @@ class InterdependentState(BaseState):
|
|||||||
"""
|
"""
|
||||||
return self.v1x2 * 2 # type: ignore
|
return self.v1x2 * 2 # type: ignore
|
||||||
|
|
||||||
|
@rx.cached_var
|
||||||
|
def _v3(self) -> int:
|
||||||
|
"""Depends on backend var _v2.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The value of the backend variable.
|
||||||
|
"""
|
||||||
|
return self._v2
|
||||||
|
|
||||||
|
@rx.cached_var
|
||||||
|
def v3x2(self) -> int:
|
||||||
|
"""Depends on ComputedVar _v3.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
ComputedVar _v3 multiplied by 2
|
||||||
|
"""
|
||||||
|
return self._v3 * 2
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def interdependent_state() -> BaseState:
|
def interdependent_state() -> BaseState:
|
||||||
@ -1003,8 +1021,9 @@ def test_dirty_computed_var_from_backend_var(interdependent_state):
|
|||||||
"""
|
"""
|
||||||
interdependent_state._v2 = 2
|
interdependent_state._v2 = 2
|
||||||
assert interdependent_state.get_delta() == {
|
assert interdependent_state.get_delta() == {
|
||||||
interdependent_state.get_full_name(): {"v2x2": 4},
|
interdependent_state.get_full_name(): {"v2x2": 4, "v3x2": 4},
|
||||||
}
|
}
|
||||||
|
assert "_v3" in InterdependentState.backend_vars
|
||||||
|
|
||||||
|
|
||||||
def test_per_state_backend_var(interdependent_state):
|
def test_per_state_backend_var(interdependent_state):
|
||||||
|
Loading…
Reference in New Issue
Block a user