track which computed vars have already been computed
This commit is contained in:
parent
f5987ea652
commit
1a3a230929
@ -373,6 +373,9 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
|
|||||||
# Track if computed vars have changed since last serialization
|
# Track if computed vars have changed since last serialization
|
||||||
_changed_computed_vars: Set[str] = set()
|
_changed_computed_vars: Set[str] = set()
|
||||||
|
|
||||||
|
# Track which computed vars have already been computed
|
||||||
|
_ready_computed_vars: Set[str] = set()
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
parent_state: BaseState | None = None,
|
parent_state: BaseState | None = None,
|
||||||
@ -2113,11 +2116,27 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
|
|||||||
state["__dict__"]["substates"] = {}
|
state["__dict__"]["substates"] = {}
|
||||||
state["__dict__"].pop("_was_touched", None)
|
state["__dict__"].pop("_was_touched", None)
|
||||||
state["__dict__"].pop("_changed_computed_vars", None)
|
state["__dict__"].pop("_changed_computed_vars", None)
|
||||||
|
state["__dict__"].pop("_ready_computed_vars", None)
|
||||||
|
state["__fields_set__"].discard("_changed_computed_vars")
|
||||||
|
state["__fields_set__"].discard("_ready_computed_vars")
|
||||||
# Remove all inherited vars.
|
# Remove all inherited vars.
|
||||||
for inherited_var_name in self.inherited_vars:
|
for inherited_var_name in self.inherited_vars:
|
||||||
state["__dict__"].pop(inherited_var_name, None)
|
state["__dict__"].pop(inherited_var_name, None)
|
||||||
return state
|
return state
|
||||||
|
|
||||||
|
def __setstate__(self, state):
|
||||||
|
"""Set the state from redis deserialization.
|
||||||
|
|
||||||
|
This method is called by pickle to deserialize the object.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
state: The state dict for deserialization.
|
||||||
|
"""
|
||||||
|
super().__setstate__(state)
|
||||||
|
self._was_touched = False
|
||||||
|
self._changed_computed_vars = set()
|
||||||
|
self._ready_computed_vars = set()
|
||||||
|
|
||||||
def _check_state_size(
|
def _check_state_size(
|
||||||
self,
|
self,
|
||||||
pickle_state_size: int,
|
pickle_state_size: int,
|
||||||
@ -3088,6 +3107,8 @@ class StateManagerDisk(StateManager):
|
|||||||
root_state = self.states.get(client_token)
|
root_state = self.states.get(client_token)
|
||||||
if root_state is not None:
|
if root_state is not None:
|
||||||
# Retrieved state from memory.
|
# Retrieved state from memory.
|
||||||
|
root_state._changed_computed_vars = set()
|
||||||
|
root_state._ready_computed_vars = set()
|
||||||
return root_state
|
return root_state
|
||||||
|
|
||||||
# Deserialize root state from disk.
|
# Deserialize root state from disk.
|
||||||
|
@ -123,6 +123,8 @@ RESERVED_BACKEND_VAR_NAMES = {
|
|||||||
"_backend_vars",
|
"_backend_vars",
|
||||||
"_was_touched",
|
"_was_touched",
|
||||||
"_changed_computed_vars",
|
"_changed_computed_vars",
|
||||||
|
"_ready_computed_vars",
|
||||||
|
"_not_persisted_var_cache",
|
||||||
}
|
}
|
||||||
|
|
||||||
if sys.version_info >= (3, 11):
|
if sys.version_info >= (3, 11):
|
||||||
|
@ -2022,26 +2022,27 @@ class ComputedVar(Var[RETURN_TYPE]):
|
|||||||
existing_var=self,
|
existing_var=self,
|
||||||
)
|
)
|
||||||
|
|
||||||
if not self._cache:
|
value = self.get_value(instance)
|
||||||
value = self.fget(instance)
|
# if not self._cache:
|
||||||
else:
|
# value = self.fget(instance)
|
||||||
# handle caching
|
# else:
|
||||||
if not hasattr(instance, self._cache_attr) or self.needs_update(instance):
|
# # handle caching
|
||||||
# Get the new value.
|
# if not self.has_changed(instance):
|
||||||
new_value = self.fget(instance)
|
# # Get the new value.
|
||||||
# Get the current cached value.
|
# new_value = self.fget(instance)
|
||||||
cached_value = getattr(instance, self._cache_attr, None)
|
# # Get the current cached value.
|
||||||
# Check if the new value is different from the cached value.
|
# cached_value = getattr(instance, self._cache_attr, None)
|
||||||
if new_value == cached_value:
|
# # Check if the new value is different from the cached value.
|
||||||
return new_value
|
# if new_value == cached_value:
|
||||||
instance._changed_computed_vars.add(self._js_expr)
|
# return new_value
|
||||||
# Set cache attr on state instance.
|
# instance._changed_computed_vars.add(self._js_expr)
|
||||||
setattr(instance, self._cache_attr, new_value)
|
# # Set cache attr on state instance.
|
||||||
# Ensure the computed var gets serialized to redis.
|
# setattr(instance, self._cache_attr, new_value)
|
||||||
instance._was_touched = True
|
# # Ensure the computed var gets serialized to redis.
|
||||||
# Set the last updated timestamp on the state instance.
|
# instance._was_touched = True
|
||||||
setattr(instance, self._last_updated_attr, datetime.datetime.now())
|
# # Set the last updated timestamp on the state instance.
|
||||||
value = getattr(instance, self._cache_attr)
|
# setattr(instance, self._last_updated_attr, datetime.datetime.now())
|
||||||
|
# value = getattr(instance, self._cache_attr)
|
||||||
|
|
||||||
if not _isinstance(value, self._var_type):
|
if not _isinstance(value, self._var_type):
|
||||||
console.deprecate(
|
console.deprecate(
|
||||||
@ -2172,8 +2173,56 @@ class ComputedVar(Var[RETURN_TYPE]):
|
|||||||
Args:
|
Args:
|
||||||
instance: the state instance that needs to recompute the value.
|
instance: the state instance that needs to recompute the value.
|
||||||
"""
|
"""
|
||||||
with contextlib.suppress(AttributeError):
|
instance._ready_computed_vars.discard(self._js_expr)
|
||||||
delattr(instance, self._cache_attr)
|
|
||||||
|
def already_computed(self, instance: BaseState) -> bool:
|
||||||
|
"""Check if the ComputedVar has already been computed.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
instance: the state instance that needs to recompute the value.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if the ComputedVar has already been computed, False otherwise.
|
||||||
|
"""
|
||||||
|
if self.needs_update(instance):
|
||||||
|
return False
|
||||||
|
return self._js_expr in instance._ready_computed_vars
|
||||||
|
|
||||||
|
def get_value(self, instance: BaseState) -> RETURN_TYPE:
|
||||||
|
"""Get the value of the ComputedVar.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
instance: the state instance that needs to recompute the value.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The value of the ComputedVar.
|
||||||
|
"""
|
||||||
|
if not self._cache:
|
||||||
|
instance._was_touched = True
|
||||||
|
new = self.fget(instance)
|
||||||
|
return new
|
||||||
|
|
||||||
|
has_cache = hasattr(instance, self._cache_attr)
|
||||||
|
|
||||||
|
if self.already_computed(instance):
|
||||||
|
if has_cache:
|
||||||
|
return getattr(instance, self._cache_attr)
|
||||||
|
else:
|
||||||
|
assert not isinstance(self._initial_value, types.Unset)
|
||||||
|
return self._initial_value
|
||||||
|
|
||||||
|
cache_value = getattr(instance, self._cache_attr, None)
|
||||||
|
instance._ready_computed_vars.add(self._js_expr)
|
||||||
|
setattr(instance, self._last_updated_attr, datetime.datetime.now())
|
||||||
|
new_value = self.fget(instance)
|
||||||
|
# NOTE: does not store initial_value in redis to save space/time
|
||||||
|
if (has_cache and cache_value != new_value) or (
|
||||||
|
not has_cache and new_value != self._initial_value
|
||||||
|
):
|
||||||
|
instance._changed_computed_vars.add(self._js_expr)
|
||||||
|
instance._was_touched = True
|
||||||
|
setattr(instance, self._cache_attr, new_value)
|
||||||
|
return new_value
|
||||||
|
|
||||||
def has_changed(self, instance: BaseState) -> bool:
|
def has_changed(self, instance: BaseState) -> bool:
|
||||||
"""Check if the ComputedVar value has changed.
|
"""Check if the ComputedVar value has changed.
|
||||||
@ -2188,15 +2237,9 @@ class ComputedVar(Var[RETURN_TYPE]):
|
|||||||
return True
|
return True
|
||||||
if self._js_expr in instance._changed_computed_vars:
|
if self._js_expr in instance._changed_computed_vars:
|
||||||
return True
|
return True
|
||||||
if not hasattr(instance, self._cache_attr):
|
if not self.already_computed(instance):
|
||||||
return True
|
self.get_value(instance)
|
||||||
cached_value = getattr(instance, self._cache_attr)
|
return self._js_expr in instance._changed_computed_vars
|
||||||
new_value = self.fget(instance)
|
|
||||||
has_changed = cached_value != new_value
|
|
||||||
if has_changed:
|
|
||||||
instance._changed_computed_vars.add(self._js_expr)
|
|
||||||
setattr(instance, self._cache_attr, new_value)
|
|
||||||
return has_changed
|
|
||||||
|
|
||||||
def _determine_var_type(self) -> Type:
|
def _determine_var_type(self) -> Type:
|
||||||
"""Get the type of the var.
|
"""Get the type of the var.
|
||||||
|
@ -3437,6 +3437,33 @@ def test_fallback_pickle():
|
|||||||
assert len(pk3) == 0
|
assert len(pk3) == 0
|
||||||
|
|
||||||
|
|
||||||
|
def test_pickle():
|
||||||
|
class PickleState(BaseState):
|
||||||
|
pass
|
||||||
|
|
||||||
|
state = PickleState(_reflex_internal_init=True) # type: ignore
|
||||||
|
|
||||||
|
# test computed var cache is persisted
|
||||||
|
setattr(state, "__cvcached", 1)
|
||||||
|
state = PickleState._deserialize(state._serialize())
|
||||||
|
assert getattr(state, "__cvcached", None) == 1
|
||||||
|
|
||||||
|
# test ready computed vars set is not persisted
|
||||||
|
state._ready_computed_vars = {"foo"}
|
||||||
|
state = PickleState._deserialize(state._serialize())
|
||||||
|
assert not state._ready_computed_vars
|
||||||
|
|
||||||
|
# test that changed computed vars set is not persisted
|
||||||
|
state._changed_computed_vars = {"foo"}
|
||||||
|
state = PickleState._deserialize(state._serialize())
|
||||||
|
assert not state._changed_computed_vars
|
||||||
|
|
||||||
|
# test was_touched is not persisted
|
||||||
|
state._was_touched = True
|
||||||
|
state = PickleState._deserialize(state._serialize())
|
||||||
|
assert not state._was_touched
|
||||||
|
|
||||||
|
|
||||||
def test_typed_state() -> None:
|
def test_typed_state() -> None:
|
||||||
class TypedState(rx.State):
|
class TypedState(rx.State):
|
||||||
field: rx.Field[str] = rx.field("")
|
field: rx.Field[str] = rx.field("")
|
||||||
|
Loading…
Reference in New Issue
Block a user