properly replace ComputedVars (#3254)

* properly replace ComputedVars

* provide dummy override decorator for older python versions

* adjust pyi

* fix darglint

* cleanup var_data for computed vars inherited from mixin

---------

Co-authored-by: Masen Furer <m_github@0x26.net>
This commit is contained in:
benedikt-bartscher 2024-05-16 00:01:14 +02:00 committed by GitHub
parent c5f32db756
commit 48c504666e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 54 additions and 7 deletions

View File

@ -550,7 +550,9 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
for name, value in mixin.__dict__.items():
if isinstance(value, ComputedVar):
fget = cls._copy_fn(value.fget)
newcv = ComputedVar(fget=fget, _var_name=value._var_name)
newcv = value._replace(fget=fget)
# cleanup refs to mixin cls in var_data
newcv._var_data = None
newcv._var_set_state(cls)
setattr(cls, name, newcv)
cls.computed_vars[newcv._var_name] = newcv

View File

@ -44,6 +44,22 @@ from reflex import constants
from reflex.base import Base
from reflex.utils import console, serializers
if sys.version_info >= (3, 12):
from typing import override
else:
def override(func: Callable) -> Callable:
"""Fallback for @override decorator.
Args:
func: The function to decorate.
Returns:
The unmodified function.
"""
return func
# Potential GenericAlias types for isinstance checks.
GenericAliasTypes = [_GenericAlias]

View File

@ -39,10 +39,12 @@ from reflex.utils.exceptions import VarAttributeError, VarTypeError, VarValueErr
# This module used to export ImportVar itself, so we still import it for export here
from reflex.utils.imports import ImportDict, ImportVar
from reflex.utils.types import override
if TYPE_CHECKING:
from reflex.state import BaseState
# Set of unique variable names.
USED_VARIABLES = set()
@ -832,19 +834,19 @@ class Var:
if invoke_fn:
# invoke the function on left operand.
operation_name = (
f"{left_operand_full_name}.{fn}({right_operand_full_name})"
) # type: ignore
f"{left_operand_full_name}.{fn}({right_operand_full_name})" # type: ignore
)
else:
# pass the operands as arguments to the function.
operation_name = (
f"{left_operand_full_name} {op} {right_operand_full_name}"
) # type: ignore
f"{left_operand_full_name} {op} {right_operand_full_name}" # type: ignore
)
operation_name = f"{fn}({operation_name})"
else:
# apply operator to operands (left operand <operator> right_operand)
operation_name = (
f"{left_operand_full_name} {op} {right_operand_full_name}"
) # type: ignore
f"{left_operand_full_name} {op} {right_operand_full_name}" # type: ignore
)
operation_name = format.wrap(operation_name, "(")
else:
# apply operator to left operand (<operator> left_operand)
@ -1882,6 +1884,32 @@ class ComputedVar(Var, property):
kwargs["_var_type"] = kwargs.pop("_var_type", self._determine_var_type())
BaseVar.__init__(self, **kwargs) # type: ignore
@override
def _replace(self, merge_var_data=None, **kwargs: Any) -> ComputedVar:
"""Replace the attributes of the ComputedVar.
Args:
merge_var_data: VarData to merge into the existing VarData.
**kwargs: Var fields to update.
Returns:
The new ComputedVar instance.
"""
return ComputedVar(
fget=kwargs.get("fget", self.fget),
initial_value=kwargs.get("initial_value", self._initial_value),
cache=kwargs.get("cache", self._cache),
_var_name=kwargs.get("_var_name", self._var_name),
_var_type=kwargs.get("_var_type", self._var_type),
_var_is_local=kwargs.get("_var_is_local", self._var_is_local),
_var_is_string=kwargs.get("_var_is_string", self._var_is_string),
_var_full_name_needs_state_prefix=kwargs.get(
"_var_full_name_needs_state_prefix",
self._var_full_name_needs_state_prefix,
),
_var_data=VarData.merge(self._var_data, merge_var_data),
)
@property
def _cache_attr(self) -> str:
"""Get the attribute used to cache the value on the instance.

View File

@ -139,6 +139,7 @@ class ComputedVar(Var):
def _cache_attr(self) -> str: ...
def __get__(self, instance, owner): ...
def _deps(self, objclass: Type, obj: Optional[FunctionType] = ...) -> Set[str]: ...
def _replace(self, merge_var_data=None, **kwargs: Any) -> ComputedVar: ...
def mark_dirty(self, instance) -> None: ...
def _determine_var_type(self) -> Type: ...
@overload