track which computed vars have already been computed

This commit is contained in:
Benedikt Bartscher 2024-12-08 22:55:58 +01:00
parent f5987ea652
commit 1a3a230929
No known key found for this signature in database
4 changed files with 124 additions and 31 deletions

View File

@ -373,6 +373,9 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
# Track if computed vars have changed since last serialization # Track if computed vars have changed since last serialization
_changed_computed_vars: Set[str] = set() _changed_computed_vars: Set[str] = set()
# Track which computed vars have already been computed
_ready_computed_vars: Set[str] = set()
def __init__( def __init__(
self, self,
parent_state: BaseState | None = None, parent_state: BaseState | None = None,
@ -2113,11 +2116,27 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
state["__dict__"]["substates"] = {} state["__dict__"]["substates"] = {}
state["__dict__"].pop("_was_touched", None) state["__dict__"].pop("_was_touched", None)
state["__dict__"].pop("_changed_computed_vars", None) state["__dict__"].pop("_changed_computed_vars", None)
state["__dict__"].pop("_ready_computed_vars", None)
state["__fields_set__"].discard("_changed_computed_vars")
state["__fields_set__"].discard("_ready_computed_vars")
# Remove all inherited vars. # Remove all inherited vars.
for inherited_var_name in self.inherited_vars: for inherited_var_name in self.inherited_vars:
state["__dict__"].pop(inherited_var_name, None) state["__dict__"].pop(inherited_var_name, None)
return state return state
def __setstate__(self, state):
"""Set the state from redis deserialization.
This method is called by pickle to deserialize the object.
Args:
state: The state dict for deserialization.
"""
super().__setstate__(state)
self._was_touched = False
self._changed_computed_vars = set()
self._ready_computed_vars = set()
def _check_state_size( def _check_state_size(
self, self,
pickle_state_size: int, pickle_state_size: int,
@ -3088,6 +3107,8 @@ class StateManagerDisk(StateManager):
root_state = self.states.get(client_token) root_state = self.states.get(client_token)
if root_state is not None: if root_state is not None:
# Retrieved state from memory. # Retrieved state from memory.
root_state._changed_computed_vars = set()
root_state._ready_computed_vars = set()
return root_state return root_state
# Deserialize root state from disk. # Deserialize root state from disk.

View File

@ -123,6 +123,8 @@ RESERVED_BACKEND_VAR_NAMES = {
"_backend_vars", "_backend_vars",
"_was_touched", "_was_touched",
"_changed_computed_vars", "_changed_computed_vars",
"_ready_computed_vars",
"_not_persisted_var_cache",
} }
if sys.version_info >= (3, 11): if sys.version_info >= (3, 11):

View File

@ -2022,26 +2022,27 @@ class ComputedVar(Var[RETURN_TYPE]):
existing_var=self, existing_var=self,
) )
if not self._cache: value = self.get_value(instance)
value = self.fget(instance) # if not self._cache:
else: # value = self.fget(instance)
# handle caching # else:
if not hasattr(instance, self._cache_attr) or self.needs_update(instance): # # handle caching
# Get the new value. # if not self.has_changed(instance):
new_value = self.fget(instance) # # Get the new value.
# Get the current cached value. # new_value = self.fget(instance)
cached_value = getattr(instance, self._cache_attr, None) # # Get the current cached value.
# Check if the new value is different from the cached value. # cached_value = getattr(instance, self._cache_attr, None)
if new_value == cached_value: # # Check if the new value is different from the cached value.
return new_value # if new_value == cached_value:
instance._changed_computed_vars.add(self._js_expr) # return new_value
# Set cache attr on state instance. # instance._changed_computed_vars.add(self._js_expr)
setattr(instance, self._cache_attr, new_value) # # Set cache attr on state instance.
# Ensure the computed var gets serialized to redis. # setattr(instance, self._cache_attr, new_value)
instance._was_touched = True # # Ensure the computed var gets serialized to redis.
# Set the last updated timestamp on the state instance. # instance._was_touched = True
setattr(instance, self._last_updated_attr, datetime.datetime.now()) # # Set the last updated timestamp on the state instance.
value = getattr(instance, self._cache_attr) # setattr(instance, self._last_updated_attr, datetime.datetime.now())
# value = getattr(instance, self._cache_attr)
if not _isinstance(value, self._var_type): if not _isinstance(value, self._var_type):
console.deprecate( console.deprecate(
@ -2172,8 +2173,56 @@ class ComputedVar(Var[RETURN_TYPE]):
Args: Args:
instance: the state instance that needs to recompute the value. instance: the state instance that needs to recompute the value.
""" """
with contextlib.suppress(AttributeError): instance._ready_computed_vars.discard(self._js_expr)
delattr(instance, self._cache_attr)
def already_computed(self, instance: BaseState) -> bool:
"""Check if the ComputedVar has already been computed.
Args:
instance: the state instance that needs to recompute the value.
Returns:
True if the ComputedVar has already been computed, False otherwise.
"""
if self.needs_update(instance):
return False
return self._js_expr in instance._ready_computed_vars
def get_value(self, instance: BaseState) -> RETURN_TYPE:
"""Get the value of the ComputedVar.
Args:
instance: the state instance that needs to recompute the value.
Returns:
The value of the ComputedVar.
"""
if not self._cache:
instance._was_touched = True
new = self.fget(instance)
return new
has_cache = hasattr(instance, self._cache_attr)
if self.already_computed(instance):
if has_cache:
return getattr(instance, self._cache_attr)
else:
assert not isinstance(self._initial_value, types.Unset)
return self._initial_value
cache_value = getattr(instance, self._cache_attr, None)
instance._ready_computed_vars.add(self._js_expr)
setattr(instance, self._last_updated_attr, datetime.datetime.now())
new_value = self.fget(instance)
# NOTE: does not store initial_value in redis to save space/time
if (has_cache and cache_value != new_value) or (
not has_cache and new_value != self._initial_value
):
instance._changed_computed_vars.add(self._js_expr)
instance._was_touched = True
setattr(instance, self._cache_attr, new_value)
return new_value
def has_changed(self, instance: BaseState) -> bool: def has_changed(self, instance: BaseState) -> bool:
"""Check if the ComputedVar value has changed. """Check if the ComputedVar value has changed.
@ -2188,15 +2237,9 @@ class ComputedVar(Var[RETURN_TYPE]):
return True return True
if self._js_expr in instance._changed_computed_vars: if self._js_expr in instance._changed_computed_vars:
return True return True
if not hasattr(instance, self._cache_attr): if not self.already_computed(instance):
return True self.get_value(instance)
cached_value = getattr(instance, self._cache_attr) return self._js_expr in instance._changed_computed_vars
new_value = self.fget(instance)
has_changed = cached_value != new_value
if has_changed:
instance._changed_computed_vars.add(self._js_expr)
setattr(instance, self._cache_attr, new_value)
return has_changed
def _determine_var_type(self) -> Type: def _determine_var_type(self) -> Type:
"""Get the type of the var. """Get the type of the var.

View File

@ -3437,6 +3437,33 @@ def test_fallback_pickle():
assert len(pk3) == 0 assert len(pk3) == 0
def test_pickle():
class PickleState(BaseState):
pass
state = PickleState(_reflex_internal_init=True) # type: ignore
# test computed var cache is persisted
setattr(state, "__cvcached", 1)
state = PickleState._deserialize(state._serialize())
assert getattr(state, "__cvcached", None) == 1
# test ready computed vars set is not persisted
state._ready_computed_vars = {"foo"}
state = PickleState._deserialize(state._serialize())
assert not state._ready_computed_vars
# test that changed computed vars set is not persisted
state._changed_computed_vars = {"foo"}
state = PickleState._deserialize(state._serialize())
assert not state._changed_computed_vars
# test was_touched is not persisted
state._was_touched = True
state = PickleState._deserialize(state._serialize())
assert not state._was_touched
def test_typed_state() -> None: def test_typed_state() -> None:
class TypedState(rx.State): class TypedState(rx.State):
field: rx.Field[str] = rx.field("") field: rx.Field[str] = rx.field("")