DRAFT PR - Added code for computed backend vars (#2540)

* added code for computed backend vars

* fixed formatting issues

* fix small bug

* fixes ruff issue

* fixed black issue

* augment test for backend computed var

---------

Co-authored-by: Masen Furer <m_github@0x26.net>
This commit is contained in:
wassaf shahzad 2024-02-29 22:00:41 +01:00 committed by GitHub
parent cc678e8648
commit 0a18eaa28b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 60 additions and 8 deletions

View File

@ -294,7 +294,13 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
self._init_event_handlers()
# Create a fresh copy of the backend variables for this instance
self._backend_vars = copy.deepcopy(self.backend_vars)
self._backend_vars = copy.deepcopy(
{
name: item
for name, item in self.backend_vars.items()
if name not in self.computed_vars
}
)
def _init_event_handlers(self, state: BaseState | None = None):
"""Initialize event handlers.
@ -330,6 +336,21 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
"""
return f"{self.__class__.__name__}({self.dict()})"
@classmethod
def _get_computed_vars(cls) -> list[ComputedVar]:
"""Helper function to get all computed vars of a instance.
Returns:
A list of computed vars.
"""
return [
v
for mixin in cls.__mro__
if mixin is cls or not issubclass(mixin, (BaseState, ABC))
for v in mixin.__dict__.values()
if isinstance(v, ComputedVar)
]
@classmethod
def __init_subclass__(cls, **kwargs):
"""Do some magic for the subclass initialization.
@ -376,6 +397,9 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
# Track this new subclass in the parent state's subclasses set.
parent_state.class_subclasses.add(cls)
# Get computed vars.
computed_vars = cls._get_computed_vars()
new_backend_vars = {
name: value
for name, value in cls.__dict__.items()
@ -383,9 +407,22 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
and name not in RESERVED_BACKEND_VAR_NAMES
and name not in cls.inherited_backend_vars
and not isinstance(value, FunctionType)
and not isinstance(value, ComputedVar)
}
cls.backend_vars = {**cls.inherited_backend_vars, **new_backend_vars}
# Get backend computed vars
backend_computed_vars = {
v._var_name: v._var_set_state(cls)
for v in computed_vars
if types.is_backend_variable(v._var_name, cls)
and v._var_name not in cls.inherited_backend_vars
}
cls.backend_vars = {
**cls.inherited_backend_vars,
**new_backend_vars,
**backend_computed_vars,
}
# Set the base and computed vars.
cls.base_vars = {
@ -395,11 +432,7 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
for f in cls.get_fields().values()
if f.name not in cls.get_skip_vars()
}
cls.computed_vars = {
v._var_name: v._var_set_state(cls)
for v in cls.__dict__.values()
if isinstance(v, ComputedVar)
}
cls.computed_vars = {v._var_name: v._var_set_state(cls) for v in computed_vars}
cls.vars = {
**cls.inherited_vars,
**cls.base_vars,

View File

@ -955,6 +955,24 @@ class InterdependentState(BaseState):
"""
return self.v1x2 * 2 # type: ignore
@rx.cached_var
def _v3(self) -> int:
"""Depends on backend var _v2.
Returns:
The value of the backend variable.
"""
return self._v2
@rx.cached_var
def v3x2(self) -> int:
"""Depends on ComputedVar _v3.
Returns:
ComputedVar _v3 multiplied by 2
"""
return self._v3 * 2
@pytest.fixture
def interdependent_state() -> BaseState:
@ -1003,8 +1021,9 @@ def test_dirty_computed_var_from_backend_var(interdependent_state):
"""
interdependent_state._v2 = 2
assert interdependent_state.get_delta() == {
interdependent_state.get_full_name(): {"v2x2": 4},
interdependent_state.get_full_name(): {"v2x2": 4, "v3x2": 4},
}
assert "_v3" in InterdependentState.backend_vars
def test_per_state_backend_var(interdependent_state):