diff --git a/reflex/state.py b/reflex/state.py index b2f9f5e01..835f2782f 100644 --- a/reflex/state.py +++ b/reflex/state.py @@ -370,6 +370,9 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow): # A special event handler for setting base vars. setvar: ClassVar[EventHandler] + # Track if computed vars have changed since last serialization + _changed_computed_vars: Set[str] = set() + def __init__( self, parent_state: BaseState | None = None, @@ -1825,14 +1828,11 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow): calc_vars, dirty_vars = dirty_vars, set() for cvar in self._dirty_computed_vars(from_vars=calc_vars): actual_var = self.computed_vars.get(cvar) - if actual_var is not None: - if actual_var.has_changed(instance=self): - actual_var.mark_dirty(instance=self) - else: - # var has not changed, do not mark as dirty - continue - self.dirty_vars.add(cvar) - dirty_vars.add(cvar) + assert actual_var is not None + if actual_var.has_changed(instance=self): + actual_var.mark_dirty(instance=self) + self.dirty_vars.add(cvar) + dirty_vars.add(cvar) def _expired_computed_vars(self) -> set[str]: """Determine ComputedVars that need to be recalculated based on the expiration time. @@ -2112,6 +2112,7 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow): state["__dict__"]["parent_state"] = None state["__dict__"]["substates"] = {} state["__dict__"].pop("_was_touched", None) + state["__dict__"].pop("_changed_computed_vars", None) # Remove all inherited vars. for inherited_var_name in self.inherited_vars: state["__dict__"].pop(inherited_var_name, None) diff --git a/reflex/utils/types.py b/reflex/utils/types.py index 0c39eacc4..28499c378 100644 --- a/reflex/utils/types.py +++ b/reflex/utils/types.py @@ -122,6 +122,7 @@ RESERVED_BACKEND_VAR_NAMES = { "_abc_impl", "_backend_vars", "_was_touched", + "_changed_computed_vars", } if sys.version_info >= (3, 11): diff --git a/reflex/vars/base.py b/reflex/vars/base.py index 1c14dc9e1..2891dded4 100644 --- a/reflex/vars/base.py +++ b/reflex/vars/base.py @@ -2027,8 +2027,16 @@ class ComputedVar(Var[RETURN_TYPE]): else: # handle caching if not hasattr(instance, self._cache_attr) or self.needs_update(instance): + # Get the new value. + new_value = self.fget(instance) + # Get the current cached value. + cached_value = getattr(instance, self._cache_attr, None) + # Check if the new value is different from the cached value. + if new_value == cached_value: + return new_value + instance._changed_computed_vars.add(self._js_expr) # Set cache attr on state instance. - setattr(instance, self._cache_attr, self.fget(instance)) + setattr(instance, self._cache_attr, new_value) # Ensure the computed var gets serialized to redis. instance._was_touched = True # Set the last updated timestamp on the state instance. @@ -2176,9 +2184,19 @@ class ComputedVar(Var[RETURN_TYPE]): Returns: True if the value has changed, False otherwise. """ - cached_value = getattr(instance, self._cache_attr, None) + if not self._cache: + return True + if self._js_expr in instance._changed_computed_vars: + return True + if not hasattr(instance, self._cache_attr): + return True + cached_value = getattr(instance, self._cache_attr) new_value = self.fget(instance) - return cached_value != new_value + has_changed = cached_value != new_value + if has_changed: + instance._changed_computed_vars.add(self._js_expr) + setattr(instance, self._cache_attr, new_value) + return has_changed def _determine_var_type(self) -> Type: """Get the type of the var.