From 02dd4a1313f698784a2284082f8214941162adcd Mon Sep 17 00:00:00 2001 From: Benedikt Bartscher Date: Thu, 5 Dec 2024 15:41:26 +0100 Subject: [PATCH 1/9] wip --- reflex/state.py | 6 ++++++ reflex/vars/base.py | 1 + 2 files changed, 7 insertions(+) diff --git a/reflex/state.py b/reflex/state.py index 55f29cf45..118dbc81b 100644 --- a/reflex/state.py +++ b/reflex/state.py @@ -1288,6 +1288,9 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow): return if name in self.backend_vars: + if self._backend_vars.get(name) == value: + return + print(f"Setting {name} to {value}.") self._backend_vars.__setitem__(name, value) self.dirty_vars.add(name) self._mark_dirty() @@ -1323,6 +1326,9 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow): ) # Set the attribute. + current_value = getattr(self, name, None) + if current_value == value: + return super().__setattr__(name, value) # Add the var to the dirty list. diff --git a/reflex/vars/base.py b/reflex/vars/base.py index 200f693de..b61b987e6 100644 --- a/reflex/vars/base.py +++ b/reflex/vars/base.py @@ -2164,6 +2164,7 @@ class ComputedVar(Var[RETURN_TYPE]): Args: instance: the state instance that needs to recompute the value. """ + print(f"Marking {self._js_expr} as dirty") with contextlib.suppress(AttributeError): delattr(instance, self._cache_attr) From 421ed98ebd7068ce41a2a78a16a9be84a7d63ae6 Mon Sep 17 00:00:00 2001 From: Benedikt Bartscher Date: Thu, 5 Dec 2024 16:06:36 +0100 Subject: [PATCH 2/9] detect unchanged computed vars as well --- reflex/state.py | 9 +++++---- reflex/vars/base.py | 11 +++++++++-- 2 files changed, 14 insertions(+), 6 deletions(-) diff --git a/reflex/state.py b/reflex/state.py index 118dbc81b..4c6e8fec6 100644 --- a/reflex/state.py +++ b/reflex/state.py @@ -1290,7 +1290,6 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow): if name in self.backend_vars: if self._backend_vars.get(name) == value: return - print(f"Setting {name} to {value}.") self._backend_vars.__setitem__(name, value) self.dirty_vars.add(name) self._mark_dirty() @@ -1830,11 +1829,13 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow): while dirty_vars: calc_vars, dirty_vars = dirty_vars, set() for cvar in self._dirty_computed_vars(from_vars=calc_vars): - self.dirty_vars.add(cvar) - dirty_vars.add(cvar) actual_var = self.computed_vars.get(cvar) if actual_var is not None: - actual_var.mark_dirty(instance=self) + changed = actual_var.mark_dirty(instance=self) + if not changed: + continue + self.dirty_vars.add(cvar) + dirty_vars.add(cvar) def _expired_computed_vars(self) -> set[str]: """Determine ComputedVars that need to be recalculated based on the expiration time. diff --git a/reflex/vars/base.py b/reflex/vars/base.py index b61b987e6..8444a29d0 100644 --- a/reflex/vars/base.py +++ b/reflex/vars/base.py @@ -2158,15 +2158,22 @@ class ComputedVar(Var[RETURN_TYPE]): self_is_top_of_stack = False return d - def mark_dirty(self, instance) -> None: + def mark_dirty(self, instance: BaseState) -> bool: """Mark this ComputedVar as dirty. Args: instance: the state instance that needs to recompute the value. + + Returns: + True if the value was marked dirty (has changed), False otherwise. """ - print(f"Marking {self._js_expr} as dirty") + cached_value = getattr(instance, self._cache_attr, None) + new_value = self.fget(instance) + if cached_value == new_value: + return False with contextlib.suppress(AttributeError): delattr(instance, self._cache_attr) + return True def _determine_var_type(self) -> Type: """Get the type of the var. From 61e503c4d13f772bb13e42e946df509c2663db9f Mon Sep 17 00:00:00 2001 From: Benedikt Bartscher Date: Thu, 5 Dec 2024 19:52:17 +0100 Subject: [PATCH 3/9] only implement changed check for ComputedVar, put it in has_changed method --- reflex/state.py | 11 ++++------- reflex/vars/base.py | 19 ++++++++++++------- 2 files changed, 16 insertions(+), 14 deletions(-) diff --git a/reflex/state.py b/reflex/state.py index 4c6e8fec6..b2f9f5e01 100644 --- a/reflex/state.py +++ b/reflex/state.py @@ -1288,8 +1288,6 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow): return if name in self.backend_vars: - if self._backend_vars.get(name) == value: - return self._backend_vars.__setitem__(name, value) self.dirty_vars.add(name) self._mark_dirty() @@ -1325,9 +1323,6 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow): ) # Set the attribute. - current_value = getattr(self, name, None) - if current_value == value: - return super().__setattr__(name, value) # Add the var to the dirty list. @@ -1831,8 +1826,10 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow): for cvar in self._dirty_computed_vars(from_vars=calc_vars): actual_var = self.computed_vars.get(cvar) if actual_var is not None: - changed = actual_var.mark_dirty(instance=self) - if not changed: + if actual_var.has_changed(instance=self): + actual_var.mark_dirty(instance=self) + else: + # var has not changed, do not mark as dirty continue self.dirty_vars.add(cvar) dirty_vars.add(cvar) diff --git a/reflex/vars/base.py b/reflex/vars/base.py index 8444a29d0..1c14dc9e1 100644 --- a/reflex/vars/base.py +++ b/reflex/vars/base.py @@ -2158,22 +2158,27 @@ class ComputedVar(Var[RETURN_TYPE]): self_is_top_of_stack = False return d - def mark_dirty(self, instance: BaseState) -> bool: + def mark_dirty(self, instance: BaseState) -> None: """Mark this ComputedVar as dirty. + Args: + instance: the state instance that needs to recompute the value. + """ + with contextlib.suppress(AttributeError): + delattr(instance, self._cache_attr) + + def has_changed(self, instance: BaseState) -> bool: + """Check if the ComputedVar value has changed. + Args: instance: the state instance that needs to recompute the value. Returns: - True if the value was marked dirty (has changed), False otherwise. + True if the value has changed, False otherwise. """ cached_value = getattr(instance, self._cache_attr, None) new_value = self.fget(instance) - if cached_value == new_value: - return False - with contextlib.suppress(AttributeError): - delattr(instance, self._cache_attr) - return True + return cached_value != new_value def _determine_var_type(self) -> Type: """Get the type of the var. From 8ceccc61405949bd813a458cec447fba5dd9ad03 Mon Sep 17 00:00:00 2001 From: Benedikt Bartscher Date: Fri, 6 Dec 2024 01:35:23 +0100 Subject: [PATCH 4/9] proper computed var change detection --- reflex/state.py | 17 +++++++++-------- reflex/utils/types.py | 1 + reflex/vars/base.py | 24 +++++++++++++++++++++--- 3 files changed, 31 insertions(+), 11 deletions(-) diff --git a/reflex/state.py b/reflex/state.py index b2f9f5e01..835f2782f 100644 --- a/reflex/state.py +++ b/reflex/state.py @@ -370,6 +370,9 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow): # A special event handler for setting base vars. setvar: ClassVar[EventHandler] + # Track if computed vars have changed since last serialization + _changed_computed_vars: Set[str] = set() + def __init__( self, parent_state: BaseState | None = None, @@ -1825,14 +1828,11 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow): calc_vars, dirty_vars = dirty_vars, set() for cvar in self._dirty_computed_vars(from_vars=calc_vars): actual_var = self.computed_vars.get(cvar) - if actual_var is not None: - if actual_var.has_changed(instance=self): - actual_var.mark_dirty(instance=self) - else: - # var has not changed, do not mark as dirty - continue - self.dirty_vars.add(cvar) - dirty_vars.add(cvar) + assert actual_var is not None + if actual_var.has_changed(instance=self): + actual_var.mark_dirty(instance=self) + self.dirty_vars.add(cvar) + dirty_vars.add(cvar) def _expired_computed_vars(self) -> set[str]: """Determine ComputedVars that need to be recalculated based on the expiration time. @@ -2112,6 +2112,7 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow): state["__dict__"]["parent_state"] = None state["__dict__"]["substates"] = {} state["__dict__"].pop("_was_touched", None) + state["__dict__"].pop("_changed_computed_vars", None) # Remove all inherited vars. for inherited_var_name in self.inherited_vars: state["__dict__"].pop(inherited_var_name, None) diff --git a/reflex/utils/types.py b/reflex/utils/types.py index 0c39eacc4..28499c378 100644 --- a/reflex/utils/types.py +++ b/reflex/utils/types.py @@ -122,6 +122,7 @@ RESERVED_BACKEND_VAR_NAMES = { "_abc_impl", "_backend_vars", "_was_touched", + "_changed_computed_vars", } if sys.version_info >= (3, 11): diff --git a/reflex/vars/base.py b/reflex/vars/base.py index 1c14dc9e1..2891dded4 100644 --- a/reflex/vars/base.py +++ b/reflex/vars/base.py @@ -2027,8 +2027,16 @@ class ComputedVar(Var[RETURN_TYPE]): else: # handle caching if not hasattr(instance, self._cache_attr) or self.needs_update(instance): + # Get the new value. + new_value = self.fget(instance) + # Get the current cached value. + cached_value = getattr(instance, self._cache_attr, None) + # Check if the new value is different from the cached value. + if new_value == cached_value: + return new_value + instance._changed_computed_vars.add(self._js_expr) # Set cache attr on state instance. - setattr(instance, self._cache_attr, self.fget(instance)) + setattr(instance, self._cache_attr, new_value) # Ensure the computed var gets serialized to redis. instance._was_touched = True # Set the last updated timestamp on the state instance. @@ -2176,9 +2184,19 @@ class ComputedVar(Var[RETURN_TYPE]): Returns: True if the value has changed, False otherwise. """ - cached_value = getattr(instance, self._cache_attr, None) + if not self._cache: + return True + if self._js_expr in instance._changed_computed_vars: + return True + if not hasattr(instance, self._cache_attr): + return True + cached_value = getattr(instance, self._cache_attr) new_value = self.fget(instance) - return cached_value != new_value + has_changed = cached_value != new_value + if has_changed: + instance._changed_computed_vars.add(self._js_expr) + setattr(instance, self._cache_attr, new_value) + return has_changed def _determine_var_type(self) -> Type: """Get the type of the var. From 1a3a2309290c7773347d7ef6191f4281135c5091 Mon Sep 17 00:00:00 2001 From: Benedikt Bartscher Date: Sun, 8 Dec 2024 22:55:58 +0100 Subject: [PATCH 5/9] track which computed vars have already been computed --- reflex/state.py | 21 ++++++++ reflex/utils/types.py | 2 + reflex/vars/base.py | 105 +++++++++++++++++++++++++++----------- tests/units/test_state.py | 27 ++++++++++ 4 files changed, 124 insertions(+), 31 deletions(-) diff --git a/reflex/state.py b/reflex/state.py index 835f2782f..b391e9bd5 100644 --- a/reflex/state.py +++ b/reflex/state.py @@ -373,6 +373,9 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow): # Track if computed vars have changed since last serialization _changed_computed_vars: Set[str] = set() + # Track which computed vars have already been computed + _ready_computed_vars: Set[str] = set() + def __init__( self, parent_state: BaseState | None = None, @@ -2113,11 +2116,27 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow): state["__dict__"]["substates"] = {} state["__dict__"].pop("_was_touched", None) state["__dict__"].pop("_changed_computed_vars", None) + state["__dict__"].pop("_ready_computed_vars", None) + state["__fields_set__"].discard("_changed_computed_vars") + state["__fields_set__"].discard("_ready_computed_vars") # Remove all inherited vars. for inherited_var_name in self.inherited_vars: state["__dict__"].pop(inherited_var_name, None) return state + def __setstate__(self, state): + """Set the state from redis deserialization. + + This method is called by pickle to deserialize the object. + + Args: + state: The state dict for deserialization. + """ + super().__setstate__(state) + self._was_touched = False + self._changed_computed_vars = set() + self._ready_computed_vars = set() + def _check_state_size( self, pickle_state_size: int, @@ -3088,6 +3107,8 @@ class StateManagerDisk(StateManager): root_state = self.states.get(client_token) if root_state is not None: # Retrieved state from memory. + root_state._changed_computed_vars = set() + root_state._ready_computed_vars = set() return root_state # Deserialize root state from disk. diff --git a/reflex/utils/types.py b/reflex/utils/types.py index 28499c378..d25cd235b 100644 --- a/reflex/utils/types.py +++ b/reflex/utils/types.py @@ -123,6 +123,8 @@ RESERVED_BACKEND_VAR_NAMES = { "_backend_vars", "_was_touched", "_changed_computed_vars", + "_ready_computed_vars", + "_not_persisted_var_cache", } if sys.version_info >= (3, 11): diff --git a/reflex/vars/base.py b/reflex/vars/base.py index 2891dded4..56b0e6a21 100644 --- a/reflex/vars/base.py +++ b/reflex/vars/base.py @@ -2022,26 +2022,27 @@ class ComputedVar(Var[RETURN_TYPE]): existing_var=self, ) - if not self._cache: - value = self.fget(instance) - else: - # handle caching - if not hasattr(instance, self._cache_attr) or self.needs_update(instance): - # Get the new value. - new_value = self.fget(instance) - # Get the current cached value. - cached_value = getattr(instance, self._cache_attr, None) - # Check if the new value is different from the cached value. - if new_value == cached_value: - return new_value - instance._changed_computed_vars.add(self._js_expr) - # Set cache attr on state instance. - setattr(instance, self._cache_attr, new_value) - # Ensure the computed var gets serialized to redis. - instance._was_touched = True - # Set the last updated timestamp on the state instance. - setattr(instance, self._last_updated_attr, datetime.datetime.now()) - value = getattr(instance, self._cache_attr) + value = self.get_value(instance) + # if not self._cache: + # value = self.fget(instance) + # else: + # # handle caching + # if not self.has_changed(instance): + # # Get the new value. + # new_value = self.fget(instance) + # # Get the current cached value. + # cached_value = getattr(instance, self._cache_attr, None) + # # Check if the new value is different from the cached value. + # if new_value == cached_value: + # return new_value + # instance._changed_computed_vars.add(self._js_expr) + # # Set cache attr on state instance. + # setattr(instance, self._cache_attr, new_value) + # # Ensure the computed var gets serialized to redis. + # instance._was_touched = True + # # Set the last updated timestamp on the state instance. + # setattr(instance, self._last_updated_attr, datetime.datetime.now()) + # value = getattr(instance, self._cache_attr) if not _isinstance(value, self._var_type): console.deprecate( @@ -2172,8 +2173,56 @@ class ComputedVar(Var[RETURN_TYPE]): Args: instance: the state instance that needs to recompute the value. """ - with contextlib.suppress(AttributeError): - delattr(instance, self._cache_attr) + instance._ready_computed_vars.discard(self._js_expr) + + def already_computed(self, instance: BaseState) -> bool: + """Check if the ComputedVar has already been computed. + + Args: + instance: the state instance that needs to recompute the value. + + Returns: + True if the ComputedVar has already been computed, False otherwise. + """ + if self.needs_update(instance): + return False + return self._js_expr in instance._ready_computed_vars + + def get_value(self, instance: BaseState) -> RETURN_TYPE: + """Get the value of the ComputedVar. + + Args: + instance: the state instance that needs to recompute the value. + + Returns: + The value of the ComputedVar. + """ + if not self._cache: + instance._was_touched = True + new = self.fget(instance) + return new + + has_cache = hasattr(instance, self._cache_attr) + + if self.already_computed(instance): + if has_cache: + return getattr(instance, self._cache_attr) + else: + assert not isinstance(self._initial_value, types.Unset) + return self._initial_value + + cache_value = getattr(instance, self._cache_attr, None) + instance._ready_computed_vars.add(self._js_expr) + setattr(instance, self._last_updated_attr, datetime.datetime.now()) + new_value = self.fget(instance) + # NOTE: does not store initial_value in redis to save space/time + if (has_cache and cache_value != new_value) or ( + not has_cache and new_value != self._initial_value + ): + instance._changed_computed_vars.add(self._js_expr) + instance._was_touched = True + setattr(instance, self._cache_attr, new_value) + return new_value def has_changed(self, instance: BaseState) -> bool: """Check if the ComputedVar value has changed. @@ -2188,15 +2237,9 @@ class ComputedVar(Var[RETURN_TYPE]): return True if self._js_expr in instance._changed_computed_vars: return True - if not hasattr(instance, self._cache_attr): - return True - cached_value = getattr(instance, self._cache_attr) - new_value = self.fget(instance) - has_changed = cached_value != new_value - if has_changed: - instance._changed_computed_vars.add(self._js_expr) - setattr(instance, self._cache_attr, new_value) - return has_changed + if not self.already_computed(instance): + self.get_value(instance) + return self._js_expr in instance._changed_computed_vars def _determine_var_type(self) -> Type: """Get the type of the var. diff --git a/tests/units/test_state.py b/tests/units/test_state.py index 8e61b8dae..6cef39eca 100644 --- a/tests/units/test_state.py +++ b/tests/units/test_state.py @@ -3437,6 +3437,33 @@ def test_fallback_pickle(): assert len(pk3) == 0 +def test_pickle(): + class PickleState(BaseState): + pass + + state = PickleState(_reflex_internal_init=True) # type: ignore + + # test computed var cache is persisted + setattr(state, "__cvcached", 1) + state = PickleState._deserialize(state._serialize()) + assert getattr(state, "__cvcached", None) == 1 + + # test ready computed vars set is not persisted + state._ready_computed_vars = {"foo"} + state = PickleState._deserialize(state._serialize()) + assert not state._ready_computed_vars + + # test that changed computed vars set is not persisted + state._changed_computed_vars = {"foo"} + state = PickleState._deserialize(state._serialize()) + assert not state._changed_computed_vars + + # test was_touched is not persisted + state._was_touched = True + state = PickleState._deserialize(state._serialize()) + assert not state._was_touched + + def test_typed_state() -> None: class TypedState(rx.State): field: rx.Field[str] = rx.field("") From cc1b1a78188ac768b5f44a7cf6fc87a48029739e Mon Sep 17 00:00:00 2001 From: Benedikt Bartscher Date: Sun, 8 Dec 2024 23:08:35 +0100 Subject: [PATCH 6/9] simplify logic, store initial values for now --- reflex/vars/base.py | 13 +++---------- 1 file changed, 3 insertions(+), 10 deletions(-) diff --git a/reflex/vars/base.py b/reflex/vars/base.py index 56b0e6a21..901bb197a 100644 --- a/reflex/vars/base.py +++ b/reflex/vars/base.py @@ -2204,21 +2204,14 @@ class ComputedVar(Var[RETURN_TYPE]): has_cache = hasattr(instance, self._cache_attr) - if self.already_computed(instance): - if has_cache: - return getattr(instance, self._cache_attr) - else: - assert not isinstance(self._initial_value, types.Unset) - return self._initial_value + if self.already_computed(instance) and has_cache: + return getattr(instance, self._cache_attr) cache_value = getattr(instance, self._cache_attr, None) instance._ready_computed_vars.add(self._js_expr) setattr(instance, self._last_updated_attr, datetime.datetime.now()) new_value = self.fget(instance) - # NOTE: does not store initial_value in redis to save space/time - if (has_cache and cache_value != new_value) or ( - not has_cache and new_value != self._initial_value - ): + if cache_value != new_value: instance._changed_computed_vars.add(self._js_expr) instance._was_touched = True setattr(instance, self._cache_attr, new_value) From fa3d338d7057b88c807e5237308b79cb6dbaefe8 Mon Sep 17 00:00:00 2001 From: Benedikt Bartscher Date: Sun, 8 Dec 2024 23:09:17 +0100 Subject: [PATCH 7/9] cleanup old code --- reflex/vars/base.py | 20 -------------------- 1 file changed, 20 deletions(-) diff --git a/reflex/vars/base.py b/reflex/vars/base.py index 901bb197a..00599cae4 100644 --- a/reflex/vars/base.py +++ b/reflex/vars/base.py @@ -2023,26 +2023,6 @@ class ComputedVar(Var[RETURN_TYPE]): ) value = self.get_value(instance) - # if not self._cache: - # value = self.fget(instance) - # else: - # # handle caching - # if not self.has_changed(instance): - # # Get the new value. - # new_value = self.fget(instance) - # # Get the current cached value. - # cached_value = getattr(instance, self._cache_attr, None) - # # Check if the new value is different from the cached value. - # if new_value == cached_value: - # return new_value - # instance._changed_computed_vars.add(self._js_expr) - # # Set cache attr on state instance. - # setattr(instance, self._cache_attr, new_value) - # # Ensure the computed var gets serialized to redis. - # instance._was_touched = True - # # Set the last updated timestamp on the state instance. - # setattr(instance, self._last_updated_attr, datetime.datetime.now()) - # value = getattr(instance, self._cache_attr) if not _isinstance(value, self._var_type): console.deprecate( From b4ed43588fb163d047e955fc65b77bb3286a085a Mon Sep 17 00:00:00 2001 From: Benedikt Bartscher Date: Mon, 9 Dec 2024 20:53:21 +0100 Subject: [PATCH 8/9] cleanup --- reflex/utils/types.py | 1 - 1 file changed, 1 deletion(-) diff --git a/reflex/utils/types.py b/reflex/utils/types.py index d25cd235b..404ac0e10 100644 --- a/reflex/utils/types.py +++ b/reflex/utils/types.py @@ -124,7 +124,6 @@ RESERVED_BACKEND_VAR_NAMES = { "_was_touched", "_changed_computed_vars", "_ready_computed_vars", - "_not_persisted_var_cache", } if sys.version_info >= (3, 11): From cc6edf164a1ac31124c6b9ab07ecb25361af4730 Mon Sep 17 00:00:00 2001 From: Benedikt Bartscher Date: Tue, 17 Dec 2024 21:49:38 +0100 Subject: [PATCH 9/9] fix: disable prime cache for computed vars to avoid side effects --- reflex/vars/base.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/reflex/vars/base.py b/reflex/vars/base.py index 13c5d555a..0101833ba 100644 --- a/reflex/vars/base.py +++ b/reflex/vars/base.py @@ -2210,8 +2210,7 @@ class ComputedVar(Var[RETURN_TYPE]): return True if self._js_expr in instance._changed_computed_vars: return True - if not self.already_computed(instance): - self.get_value(instance) + # TODO: prime the cache if it's not already? creates side effects and breaks order of computed var execution return self._js_expr in instance._changed_computed_vars def _determine_var_type(self) -> Type: