only implement changed check for ComputedVar, put it in has_changed method

This commit is contained in:
Benedikt Bartscher 2024-12-05 19:52:17 +01:00
parent 421ed98ebd
commit 61e503c4d1
No known key found for this signature in database
2 changed files with 16 additions and 14 deletions

View File

@ -1288,8 +1288,6 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
return return
if name in self.backend_vars: if name in self.backend_vars:
if self._backend_vars.get(name) == value:
return
self._backend_vars.__setitem__(name, value) self._backend_vars.__setitem__(name, value)
self.dirty_vars.add(name) self.dirty_vars.add(name)
self._mark_dirty() self._mark_dirty()
@ -1325,9 +1323,6 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
) )
# Set the attribute. # Set the attribute.
current_value = getattr(self, name, None)
if current_value == value:
return
super().__setattr__(name, value) super().__setattr__(name, value)
# Add the var to the dirty list. # Add the var to the dirty list.
@ -1831,8 +1826,10 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
for cvar in self._dirty_computed_vars(from_vars=calc_vars): for cvar in self._dirty_computed_vars(from_vars=calc_vars):
actual_var = self.computed_vars.get(cvar) actual_var = self.computed_vars.get(cvar)
if actual_var is not None: if actual_var is not None:
changed = actual_var.mark_dirty(instance=self) if actual_var.has_changed(instance=self):
if not changed: actual_var.mark_dirty(instance=self)
else:
# var has not changed, do not mark as dirty
continue continue
self.dirty_vars.add(cvar) self.dirty_vars.add(cvar)
dirty_vars.add(cvar) dirty_vars.add(cvar)

View File

@ -2158,22 +2158,27 @@ class ComputedVar(Var[RETURN_TYPE]):
self_is_top_of_stack = False self_is_top_of_stack = False
return d return d
def mark_dirty(self, instance: BaseState) -> bool: def mark_dirty(self, instance: BaseState) -> None:
"""Mark this ComputedVar as dirty. """Mark this ComputedVar as dirty.
Args:
instance: the state instance that needs to recompute the value.
"""
with contextlib.suppress(AttributeError):
delattr(instance, self._cache_attr)
def has_changed(self, instance: BaseState) -> bool:
"""Check if the ComputedVar value has changed.
Args: Args:
instance: the state instance that needs to recompute the value. instance: the state instance that needs to recompute the value.
Returns: Returns:
True if the value was marked dirty (has changed), False otherwise. True if the value has changed, False otherwise.
""" """
cached_value = getattr(instance, self._cache_attr, None) cached_value = getattr(instance, self._cache_attr, None)
new_value = self.fget(instance) new_value = self.fget(instance)
if cached_value == new_value: return cached_value != new_value
return False
with contextlib.suppress(AttributeError):
delattr(instance, self._cache_attr)
return True
def _determine_var_type(self) -> Type: def _determine_var_type(self) -> Type:
"""Get the type of the var. """Get the type of the var.