From 48c504666eef1b954f92c110eb239befdad90aae Mon Sep 17 00:00:00 2001 From: benedikt-bartscher <31854409+benedikt-bartscher@users.noreply.github.com> Date: Thu, 16 May 2024 00:01:14 +0200 Subject: [PATCH] properly replace ComputedVars (#3254) * properly replace ComputedVars * provide dummy override decorator for older python versions * adjust pyi * fix darglint * cleanup var_data for computed vars inherited from mixin --------- Co-authored-by: Masen Furer --- reflex/state.py | 4 +++- reflex/utils/types.py | 16 ++++++++++++++++ reflex/vars.py | 40 ++++++++++++++++++++++++++++++++++------ reflex/vars.pyi | 1 + 4 files changed, 54 insertions(+), 7 deletions(-) diff --git a/reflex/state.py b/reflex/state.py index 287f70073..86a222b66 100644 --- a/reflex/state.py +++ b/reflex/state.py @@ -550,7 +550,9 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow): for name, value in mixin.__dict__.items(): if isinstance(value, ComputedVar): fget = cls._copy_fn(value.fget) - newcv = ComputedVar(fget=fget, _var_name=value._var_name) + newcv = value._replace(fget=fget) + # cleanup refs to mixin cls in var_data + newcv._var_data = None newcv._var_set_state(cls) setattr(cls, name, newcv) cls.computed_vars[newcv._var_name] = newcv diff --git a/reflex/utils/types.py b/reflex/utils/types.py index f75e20dcc..6dd120e3d 100644 --- a/reflex/utils/types.py +++ b/reflex/utils/types.py @@ -44,6 +44,22 @@ from reflex import constants from reflex.base import Base from reflex.utils import console, serializers +if sys.version_info >= (3, 12): + from typing import override +else: + + def override(func: Callable) -> Callable: + """Fallback for @override decorator. + + Args: + func: The function to decorate. + + Returns: + The unmodified function. + """ + return func + + # Potential GenericAlias types for isinstance checks. GenericAliasTypes = [_GenericAlias] diff --git a/reflex/vars.py b/reflex/vars.py index cccef95eb..be6aa7eb8 100644 --- a/reflex/vars.py +++ b/reflex/vars.py @@ -39,10 +39,12 @@ from reflex.utils.exceptions import VarAttributeError, VarTypeError, VarValueErr # This module used to export ImportVar itself, so we still import it for export here from reflex.utils.imports import ImportDict, ImportVar +from reflex.utils.types import override if TYPE_CHECKING: from reflex.state import BaseState + # Set of unique variable names. USED_VARIABLES = set() @@ -832,19 +834,19 @@ class Var: if invoke_fn: # invoke the function on left operand. operation_name = ( - f"{left_operand_full_name}.{fn}({right_operand_full_name})" - ) # type: ignore + f"{left_operand_full_name}.{fn}({right_operand_full_name})" # type: ignore + ) else: # pass the operands as arguments to the function. operation_name = ( - f"{left_operand_full_name} {op} {right_operand_full_name}" - ) # type: ignore + f"{left_operand_full_name} {op} {right_operand_full_name}" # type: ignore + ) operation_name = f"{fn}({operation_name})" else: # apply operator to operands (left operand right_operand) operation_name = ( - f"{left_operand_full_name} {op} {right_operand_full_name}" - ) # type: ignore + f"{left_operand_full_name} {op} {right_operand_full_name}" # type: ignore + ) operation_name = format.wrap(operation_name, "(") else: # apply operator to left operand ( left_operand) @@ -1882,6 +1884,32 @@ class ComputedVar(Var, property): kwargs["_var_type"] = kwargs.pop("_var_type", self._determine_var_type()) BaseVar.__init__(self, **kwargs) # type: ignore + @override + def _replace(self, merge_var_data=None, **kwargs: Any) -> ComputedVar: + """Replace the attributes of the ComputedVar. + + Args: + merge_var_data: VarData to merge into the existing VarData. + **kwargs: Var fields to update. + + Returns: + The new ComputedVar instance. + """ + return ComputedVar( + fget=kwargs.get("fget", self.fget), + initial_value=kwargs.get("initial_value", self._initial_value), + cache=kwargs.get("cache", self._cache), + _var_name=kwargs.get("_var_name", self._var_name), + _var_type=kwargs.get("_var_type", self._var_type), + _var_is_local=kwargs.get("_var_is_local", self._var_is_local), + _var_is_string=kwargs.get("_var_is_string", self._var_is_string), + _var_full_name_needs_state_prefix=kwargs.get( + "_var_full_name_needs_state_prefix", + self._var_full_name_needs_state_prefix, + ), + _var_data=VarData.merge(self._var_data, merge_var_data), + ) + @property def _cache_attr(self) -> str: """Get the attribute used to cache the value on the instance. diff --git a/reflex/vars.pyi b/reflex/vars.pyi index 8251563f8..169e2d919 100644 --- a/reflex/vars.pyi +++ b/reflex/vars.pyi @@ -139,6 +139,7 @@ class ComputedVar(Var): def _cache_attr(self) -> str: ... def __get__(self, instance, owner): ... def _deps(self, objclass: Type, obj: Optional[FunctionType] = ...) -> Set[str]: ... + def _replace(self, merge_var_data=None, **kwargs: Any) -> ComputedVar: ... def mark_dirty(self, instance) -> None: ... def _determine_var_type(self) -> Type: ... @overload