Merge cc6edf164a
into 2c3257d4ea
This commit is contained in:
commit
71071008b7
@ -398,6 +398,12 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
|
||||
# A special event handler for setting base vars.
|
||||
setvar: ClassVar[EventHandler]
|
||||
|
||||
# Track if computed vars have changed since last serialization
|
||||
_changed_computed_vars: Set[str] = set()
|
||||
|
||||
# Track which computed vars have already been computed
|
||||
_ready_computed_vars: Set[str] = set()
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
parent_state: BaseState | None = None,
|
||||
@ -1903,11 +1909,12 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
|
||||
while dirty_vars:
|
||||
calc_vars, dirty_vars = dirty_vars, set()
|
||||
for cvar in self._dirty_computed_vars(from_vars=calc_vars):
|
||||
self.dirty_vars.add(cvar)
|
||||
dirty_vars.add(cvar)
|
||||
actual_var = self.computed_vars.get(cvar)
|
||||
if actual_var is not None:
|
||||
assert actual_var is not None
|
||||
if actual_var.has_changed(instance=self):
|
||||
actual_var.mark_dirty(instance=self)
|
||||
self.dirty_vars.add(cvar)
|
||||
dirty_vars.add(cvar)
|
||||
|
||||
def _expired_computed_vars(self) -> set[str]:
|
||||
"""Determine ComputedVars that need to be recalculated based on the expiration time.
|
||||
@ -2187,6 +2194,10 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
|
||||
state["__dict__"].pop("parent_state", None)
|
||||
state["__dict__"].pop("substates", None)
|
||||
state["__dict__"].pop("_was_touched", None)
|
||||
state["__dict__"].pop("_changed_computed_vars", None)
|
||||
state["__dict__"].pop("_ready_computed_vars", None)
|
||||
state["__fields_set__"].discard("_changed_computed_vars")
|
||||
state["__fields_set__"].discard("_ready_computed_vars")
|
||||
# Remove all inherited vars.
|
||||
for inherited_var_name in self.inherited_vars:
|
||||
state["__dict__"].pop(inherited_var_name, None)
|
||||
@ -2203,6 +2214,9 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
|
||||
state["__dict__"]["parent_state"] = None
|
||||
state["__dict__"]["substates"] = {}
|
||||
super().__setstate__(state)
|
||||
self._was_touched = False
|
||||
self._changed_computed_vars = set()
|
||||
self._ready_computed_vars = set()
|
||||
|
||||
def _check_state_size(
|
||||
self,
|
||||
@ -3190,6 +3204,8 @@ class StateManagerDisk(StateManager):
|
||||
root_state = self.states.get(client_token)
|
||||
if root_state is not None:
|
||||
# Retrieved state from memory.
|
||||
root_state._changed_computed_vars = set()
|
||||
root_state._ready_computed_vars = set()
|
||||
return root_state
|
||||
|
||||
# Deserialize root state from disk.
|
||||
|
@ -125,6 +125,8 @@ RESERVED_BACKEND_VAR_NAMES = {
|
||||
"_abc_impl",
|
||||
"_backend_vars",
|
||||
"_was_touched",
|
||||
"_changed_computed_vars",
|
||||
"_ready_computed_vars",
|
||||
}
|
||||
|
||||
if sys.version_info >= (3, 11):
|
||||
|
@ -2097,18 +2097,7 @@ class ComputedVar(Var[RETURN_TYPE]):
|
||||
existing_var=self,
|
||||
)
|
||||
|
||||
if not self._cache:
|
||||
value = self.fget(instance)
|
||||
else:
|
||||
# handle caching
|
||||
if not hasattr(instance, self._cache_attr) or self.needs_update(instance):
|
||||
# Set cache attr on state instance.
|
||||
setattr(instance, self._cache_attr, self.fget(instance))
|
||||
# Ensure the computed var gets serialized to redis.
|
||||
instance._was_touched = True
|
||||
# Set the last updated timestamp on the state instance.
|
||||
setattr(instance, self._last_updated_attr, datetime.datetime.now())
|
||||
value = getattr(instance, self._cache_attr)
|
||||
value = self.get_value(instance)
|
||||
|
||||
if not _isinstance(value, self._var_type):
|
||||
console.error(
|
||||
@ -2236,8 +2225,65 @@ class ComputedVar(Var[RETURN_TYPE]):
|
||||
Args:
|
||||
instance: the state instance that needs to recompute the value.
|
||||
"""
|
||||
with contextlib.suppress(AttributeError):
|
||||
delattr(instance, self._cache_attr)
|
||||
instance._ready_computed_vars.discard(self._js_expr)
|
||||
|
||||
def already_computed(self, instance: BaseState) -> bool:
|
||||
"""Check if the ComputedVar has already been computed.
|
||||
|
||||
Args:
|
||||
instance: the state instance that needs to recompute the value.
|
||||
|
||||
Returns:
|
||||
True if the ComputedVar has already been computed, False otherwise.
|
||||
"""
|
||||
if self.needs_update(instance):
|
||||
return False
|
||||
return self._js_expr in instance._ready_computed_vars
|
||||
|
||||
def get_value(self, instance: BaseState) -> RETURN_TYPE:
|
||||
"""Get the value of the ComputedVar.
|
||||
|
||||
Args:
|
||||
instance: the state instance that needs to recompute the value.
|
||||
|
||||
Returns:
|
||||
The value of the ComputedVar.
|
||||
"""
|
||||
if not self._cache:
|
||||
instance._was_touched = True
|
||||
new = self.fget(instance)
|
||||
return new
|
||||
|
||||
has_cache = hasattr(instance, self._cache_attr)
|
||||
|
||||
if self.already_computed(instance) and has_cache:
|
||||
return getattr(instance, self._cache_attr)
|
||||
|
||||
cache_value = getattr(instance, self._cache_attr, None)
|
||||
instance._ready_computed_vars.add(self._js_expr)
|
||||
setattr(instance, self._last_updated_attr, datetime.datetime.now())
|
||||
new_value = self.fget(instance)
|
||||
if cache_value != new_value:
|
||||
instance._changed_computed_vars.add(self._js_expr)
|
||||
instance._was_touched = True
|
||||
setattr(instance, self._cache_attr, new_value)
|
||||
return new_value
|
||||
|
||||
def has_changed(self, instance: BaseState) -> bool:
|
||||
"""Check if the ComputedVar value has changed.
|
||||
|
||||
Args:
|
||||
instance: the state instance that needs to recompute the value.
|
||||
|
||||
Returns:
|
||||
True if the value has changed, False otherwise.
|
||||
"""
|
||||
if not self._cache:
|
||||
return True
|
||||
if self._js_expr in instance._changed_computed_vars:
|
||||
return True
|
||||
# TODO: prime the cache if it's not already? creates side effects and breaks order of computed var execution
|
||||
return self._js_expr in instance._changed_computed_vars
|
||||
|
||||
def _determine_var_type(self) -> Type:
|
||||
"""Get the type of the var.
|
||||
|
@ -3569,6 +3569,33 @@ def test_fallback_pickle():
|
||||
_ = state3._serialize()
|
||||
|
||||
|
||||
def test_pickle():
|
||||
class PickleState(BaseState):
|
||||
pass
|
||||
|
||||
state = PickleState(_reflex_internal_init=True) # type: ignore
|
||||
|
||||
# test computed var cache is persisted
|
||||
setattr(state, "__cvcached", 1)
|
||||
state = PickleState._deserialize(state._serialize())
|
||||
assert getattr(state, "__cvcached", None) == 1
|
||||
|
||||
# test ready computed vars set is not persisted
|
||||
state._ready_computed_vars = {"foo"}
|
||||
state = PickleState._deserialize(state._serialize())
|
||||
assert not state._ready_computed_vars
|
||||
|
||||
# test that changed computed vars set is not persisted
|
||||
state._changed_computed_vars = {"foo"}
|
||||
state = PickleState._deserialize(state._serialize())
|
||||
assert not state._changed_computed_vars
|
||||
|
||||
# test was_touched is not persisted
|
||||
state._was_touched = True
|
||||
state = PickleState._deserialize(state._serialize())
|
||||
assert not state._was_touched
|
||||
|
||||
|
||||
def test_typed_state() -> None:
|
||||
class TypedState(rx.State):
|
||||
field: rx.Field[str] = rx.field("")
|
||||
|
Loading…
Reference in New Issue
Block a user