From 0a18eaa28b7294c797f955c8a93ea1ad48ef19dd Mon Sep 17 00:00:00 2001 From: wassaf shahzad Date: Thu, 29 Feb 2024 22:00:41 +0100 Subject: [PATCH] DRAFT PR - Added code for computed backend vars (#2540) * added code for computed backend vars * fixed formatting issues * fix small bug * fixes ruff issue * fixed black issue * augment test for backend computed var --------- Co-authored-by: Masen Furer --- reflex/state.py | 47 ++++++++++++++++++++++++++++++++++++++------- tests/test_state.py | 21 +++++++++++++++++++- 2 files changed, 60 insertions(+), 8 deletions(-) diff --git a/reflex/state.py b/reflex/state.py index 8e1f40fa4..19cb755a0 100644 --- a/reflex/state.py +++ b/reflex/state.py @@ -294,7 +294,13 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow): self._init_event_handlers() # Create a fresh copy of the backend variables for this instance - self._backend_vars = copy.deepcopy(self.backend_vars) + self._backend_vars = copy.deepcopy( + { + name: item + for name, item in self.backend_vars.items() + if name not in self.computed_vars + } + ) def _init_event_handlers(self, state: BaseState | None = None): """Initialize event handlers. @@ -330,6 +336,21 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow): """ return f"{self.__class__.__name__}({self.dict()})" + @classmethod + def _get_computed_vars(cls) -> list[ComputedVar]: + """Helper function to get all computed vars of a instance. + + Returns: + A list of computed vars. + """ + return [ + v + for mixin in cls.__mro__ + if mixin is cls or not issubclass(mixin, (BaseState, ABC)) + for v in mixin.__dict__.values() + if isinstance(v, ComputedVar) + ] + @classmethod def __init_subclass__(cls, **kwargs): """Do some magic for the subclass initialization. @@ -376,6 +397,9 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow): # Track this new subclass in the parent state's subclasses set. parent_state.class_subclasses.add(cls) + # Get computed vars. + computed_vars = cls._get_computed_vars() + new_backend_vars = { name: value for name, value in cls.__dict__.items() @@ -383,9 +407,22 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow): and name not in RESERVED_BACKEND_VAR_NAMES and name not in cls.inherited_backend_vars and not isinstance(value, FunctionType) + and not isinstance(value, ComputedVar) } - cls.backend_vars = {**cls.inherited_backend_vars, **new_backend_vars} + # Get backend computed vars + backend_computed_vars = { + v._var_name: v._var_set_state(cls) + for v in computed_vars + if types.is_backend_variable(v._var_name, cls) + and v._var_name not in cls.inherited_backend_vars + } + + cls.backend_vars = { + **cls.inherited_backend_vars, + **new_backend_vars, + **backend_computed_vars, + } # Set the base and computed vars. cls.base_vars = { @@ -395,11 +432,7 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow): for f in cls.get_fields().values() if f.name not in cls.get_skip_vars() } - cls.computed_vars = { - v._var_name: v._var_set_state(cls) - for v in cls.__dict__.values() - if isinstance(v, ComputedVar) - } + cls.computed_vars = {v._var_name: v._var_set_state(cls) for v in computed_vars} cls.vars = { **cls.inherited_vars, **cls.base_vars, diff --git a/tests/test_state.py b/tests/test_state.py index c1c692a31..3e97c82db 100644 --- a/tests/test_state.py +++ b/tests/test_state.py @@ -955,6 +955,24 @@ class InterdependentState(BaseState): """ return self.v1x2 * 2 # type: ignore + @rx.cached_var + def _v3(self) -> int: + """Depends on backend var _v2. + + Returns: + The value of the backend variable. + """ + return self._v2 + + @rx.cached_var + def v3x2(self) -> int: + """Depends on ComputedVar _v3. + + Returns: + ComputedVar _v3 multiplied by 2 + """ + return self._v3 * 2 + @pytest.fixture def interdependent_state() -> BaseState: @@ -1003,8 +1021,9 @@ def test_dirty_computed_var_from_backend_var(interdependent_state): """ interdependent_state._v2 = 2 assert interdependent_state.get_delta() == { - interdependent_state.get_full_name(): {"v2x2": 4}, + interdependent_state.get_full_name(): {"v2x2": 4, "v3x2": 4}, } + assert "_v3" in InterdependentState.backend_vars def test_per_state_backend_var(interdependent_state):