Cache ComputedVar (#917)
This commit is contained in:
parent
bad2363506
commit
c344a5c0d7
@ -74,12 +74,6 @@ class State(Base, ABC, extra=pydantic.Extra.allow):
|
||||
# Mapping of var name to set of computed variables that depend on it
|
||||
computed_var_dependencies: Dict[str, Set[str]] = {}
|
||||
|
||||
# Whether to track accessed vars.
|
||||
track_vars: bool = False
|
||||
|
||||
# The current set of accessed vars during tracking.
|
||||
tracked_vars: Set[str] = set()
|
||||
|
||||
def __init__(self, *args, parent_state: Optional[State] = None, **kwargs):
|
||||
"""Initialize the state.
|
||||
|
||||
@ -102,22 +96,15 @@ class State(Base, ABC, extra=pydantic.Extra.allow):
|
||||
fn.__qualname__ = event_handler.fn.__qualname__ # type: ignore
|
||||
setattr(self, name, fn)
|
||||
|
||||
# Initialize the mutable fields.
|
||||
self._init_mutable_fields()
|
||||
|
||||
# Initialize computed vars dependencies.
|
||||
self.computed_var_dependencies = defaultdict(set)
|
||||
for cvar in self.computed_vars:
|
||||
self.tracked_vars = set()
|
||||
|
||||
# Enable tracking and get the computed var.
|
||||
self.track_vars = True
|
||||
self.__getattribute__(cvar)
|
||||
self.track_vars = False
|
||||
|
||||
for cvar_name, cvar in self.computed_vars.items():
|
||||
# Add the dependencies.
|
||||
for var in self.tracked_vars:
|
||||
self.computed_var_dependencies[var].add(cvar)
|
||||
for var in cvar.deps():
|
||||
self.computed_var_dependencies[var].add(cvar_name)
|
||||
|
||||
# Initialize the mutable fields.
|
||||
self._init_mutable_fields()
|
||||
|
||||
def _init_mutable_fields(self):
|
||||
"""Initialize mutable fields.
|
||||
@ -199,7 +186,6 @@ class State(Base, ABC, extra=pydantic.Extra.allow):
|
||||
**cls.base_vars,
|
||||
**cls.computed_vars,
|
||||
}
|
||||
cls.computed_var_dependencies = {}
|
||||
cls.event_handlers = {}
|
||||
|
||||
# Setup the base vars at the class level.
|
||||
@ -233,8 +219,6 @@ class State(Base, ABC, extra=pydantic.Extra.allow):
|
||||
"dirty_substates",
|
||||
"router_data",
|
||||
"computed_var_dependencies",
|
||||
"track_vars",
|
||||
"tracked_vars",
|
||||
}
|
||||
|
||||
@classmethod
|
||||
@ -508,8 +492,6 @@ class State(Base, ABC, extra=pydantic.Extra.allow):
|
||||
|
||||
If the var is inherited, get the var from the parent state.
|
||||
|
||||
If the Var is a dependent of a ComputedVar, track this status in computed_var_dependencies.
|
||||
|
||||
Args:
|
||||
name: The name of the var.
|
||||
|
||||
@ -520,17 +502,6 @@ class State(Base, ABC, extra=pydantic.Extra.allow):
|
||||
if not super().__getattribute__("__dict__"):
|
||||
return super().__getattribute__(name)
|
||||
|
||||
# Check if tracking is enabled.
|
||||
if super().__getattribute__("track_vars"):
|
||||
# Get the non-computed vars.
|
||||
all_vars = {
|
||||
**super().__getattribute__("vars"),
|
||||
**super().__getattribute__("backend_vars"),
|
||||
}
|
||||
# Add the var to the tracked vars.
|
||||
if name in all_vars:
|
||||
super().__getattribute__("tracked_vars").add(name)
|
||||
|
||||
inherited_vars = {
|
||||
**super().__getattribute__("inherited_vars"),
|
||||
**super().__getattribute__("inherited_backend_vars"),
|
||||
@ -676,55 +647,58 @@ class State(Base, ABC, extra=pydantic.Extra.allow):
|
||||
# Return the state update.
|
||||
return StateUpdate(delta=delta, events=events)
|
||||
|
||||
def _dirty_computed_vars(
|
||||
self, from_vars: Optional[Set[str]] = None, check: bool = False
|
||||
) -> Set[str]:
|
||||
"""Get ComputedVars that need to be recomputed based on dirty_vars.
|
||||
def _mark_dirty_computed_vars(self) -> None:
|
||||
"""Mark ComputedVars that need to be recalculated based on dirty_vars."""
|
||||
dirty_vars = self.dirty_vars
|
||||
while dirty_vars:
|
||||
calc_vars, dirty_vars = dirty_vars, set()
|
||||
for cvar in self._dirty_computed_vars(from_vars=calc_vars):
|
||||
self.dirty_vars.add(cvar)
|
||||
dirty_vars.add(cvar)
|
||||
actual_var = self.computed_vars.get(cvar)
|
||||
if actual_var:
|
||||
actual_var.mark_dirty(instance=self)
|
||||
|
||||
def _dirty_computed_vars(self, from_vars: Optional[Set[str]] = None) -> Set[str]:
|
||||
"""Determine ComputedVars that need to be recalculated based on the given vars.
|
||||
|
||||
Args:
|
||||
from_vars: find ComputedVar that depend on this set of vars. If unspecified, will use the dirty_vars.
|
||||
check: Whether to perform the check.
|
||||
|
||||
Returns:
|
||||
Set of computed vars to include in the delta.
|
||||
"""
|
||||
# If checking is disabled, return all computed vars.
|
||||
if not check:
|
||||
return set(self.computed_vars)
|
||||
|
||||
# Return only the computed vars that depend on the dirty vars.
|
||||
return set(
|
||||
cvar
|
||||
for dirty_var in from_vars or self.dirty_vars
|
||||
for cvar in self.computed_vars
|
||||
if cvar in self.computed_var_dependencies.get(dirty_var, set())
|
||||
for cvar in self.computed_var_dependencies[dirty_var]
|
||||
)
|
||||
|
||||
def get_delta(self, check: bool = False) -> Delta:
|
||||
def get_delta(self) -> Delta:
|
||||
"""Get the delta for the state.
|
||||
|
||||
Args:
|
||||
check: Whether to check for dirty computed vars.
|
||||
|
||||
Returns:
|
||||
The delta for the state.
|
||||
"""
|
||||
delta = {}
|
||||
|
||||
# Return the dirty vars, as well as computed vars depending on dirty vars.
|
||||
subdelta = {
|
||||
prop: getattr(self, prop)
|
||||
for prop in self.dirty_vars | self._dirty_computed_vars(check=check)
|
||||
if not types.is_backend_variable(prop)
|
||||
}
|
||||
if len(subdelta) > 0:
|
||||
delta[self.get_full_name()] = subdelta
|
||||
|
||||
# Recursively find the substate deltas.
|
||||
substates = self.substates
|
||||
for substate in self.dirty_substates:
|
||||
delta.update(substates[substate].get_delta())
|
||||
|
||||
# Return the dirty vars and dependent computed vars
|
||||
delta_vars = self.dirty_vars.intersection(self.base_vars).union(
|
||||
self._dirty_computed_vars()
|
||||
)
|
||||
subdelta = {
|
||||
prop: getattr(self, prop)
|
||||
for prop in delta_vars
|
||||
if not types.is_backend_variable(prop)
|
||||
}
|
||||
if len(subdelta) > 0:
|
||||
delta[self.get_full_name()] = subdelta
|
||||
|
||||
# Format the delta.
|
||||
delta = format.format_state(delta)
|
||||
|
||||
@ -737,6 +711,10 @@ class State(Base, ABC, extra=pydantic.Extra.allow):
|
||||
self.parent_state.dirty_substates.add(self.get_name())
|
||||
self.parent_state.mark_dirty()
|
||||
|
||||
# have to mark computed vars dirty to allow access to newly computed
|
||||
# values within the same ComputedVar function
|
||||
self._mark_dirty_computed_vars()
|
||||
|
||||
def clean(self):
|
||||
"""Reset the dirty vars."""
|
||||
# Recursively clean the substates.
|
||||
|
@ -1,10 +1,13 @@
|
||||
"""Define a state var."""
|
||||
from __future__ import annotations
|
||||
|
||||
import contextlib
|
||||
import dis
|
||||
import json
|
||||
import random
|
||||
import string
|
||||
from abc import ABC
|
||||
from types import FunctionType
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Any,
|
||||
@ -12,9 +15,11 @@ from typing import (
|
||||
Dict,
|
||||
List,
|
||||
Optional,
|
||||
Set,
|
||||
Type,
|
||||
Union,
|
||||
_GenericAlias, # type: ignore
|
||||
cast,
|
||||
get_type_hints,
|
||||
)
|
||||
|
||||
@ -801,6 +806,84 @@ class ComputedVar(property, Var):
|
||||
assert self.fget is not None, "Var must have a getter."
|
||||
return self.fget.__name__
|
||||
|
||||
@property
|
||||
def cache_attr(self) -> str:
|
||||
"""Get the attribute used to cache the value on the instance.
|
||||
|
||||
Returns:
|
||||
An attribute name.
|
||||
"""
|
||||
return f"__cached_{self.name}"
|
||||
|
||||
def __get__(self, instance, owner):
|
||||
"""Get the ComputedVar value.
|
||||
|
||||
If the value is already cached on the instance, return the cached value.
|
||||
|
||||
If this ComputedVar doesn't know what type of object it is attached to, then save
|
||||
a reference as self.__objclass__.
|
||||
|
||||
Args:
|
||||
instance: the instance of the class accessing this computed var.
|
||||
owner: the class that this descriptor is attached to.
|
||||
|
||||
Returns:
|
||||
The value of the var for the given instance.
|
||||
"""
|
||||
if not hasattr(self, "__objclass__"):
|
||||
self.__objclass__ = owner
|
||||
|
||||
if instance is None:
|
||||
return super().__get__(instance, owner)
|
||||
|
||||
# handle caching
|
||||
if not hasattr(instance, self.cache_attr):
|
||||
setattr(instance, self.cache_attr, super().__get__(instance, owner))
|
||||
return getattr(instance, self.cache_attr)
|
||||
|
||||
def deps(self, obj: Optional[FunctionType] = None) -> Set[str]:
|
||||
"""Determine var dependencies of this ComputedVar.
|
||||
|
||||
Save references to attributes accessed on "self". Recursively called
|
||||
when the function makes a method call on "self".
|
||||
|
||||
Args:
|
||||
obj: the object to disassemble (defaults to the fget function).
|
||||
|
||||
Returns:
|
||||
A set of variable names accessed by the given obj.
|
||||
"""
|
||||
d = set()
|
||||
if obj is None:
|
||||
if self.fget is not None:
|
||||
obj = cast(FunctionType, self.fget)
|
||||
else:
|
||||
return set()
|
||||
if not obj.__code__.co_varnames:
|
||||
# cannot reference self if method takes no args
|
||||
return set()
|
||||
self_name = obj.__code__.co_varnames[0]
|
||||
self_is_top_of_stack = False
|
||||
for instruction in dis.get_instructions(obj):
|
||||
if instruction.opname == "LOAD_FAST" and instruction.argval == self_name:
|
||||
self_is_top_of_stack = True
|
||||
continue
|
||||
if self_is_top_of_stack and instruction.opname == "LOAD_ATTR":
|
||||
d.add(instruction.argval)
|
||||
elif self_is_top_of_stack and instruction.opname == "LOAD_METHOD":
|
||||
d.update(self.deps(obj=getattr(self.__objclass__, instruction.argval)))
|
||||
self_is_top_of_stack = False
|
||||
return d
|
||||
|
||||
def mark_dirty(self, instance) -> None:
|
||||
"""Mark this ComputedVar as dirty.
|
||||
|
||||
Args:
|
||||
instance: the state instance that needs to recompute the value.
|
||||
"""
|
||||
with contextlib.suppress(AttributeError):
|
||||
delattr(instance, self.cache_attr)
|
||||
|
||||
@property
|
||||
def type_(self):
|
||||
"""Get the type of the var.
|
||||
|
@ -485,11 +485,11 @@ def test_set_dirty_var(test_state):
|
||||
|
||||
# Setting a var should mark it as dirty.
|
||||
test_state.num1 = 1
|
||||
assert test_state.dirty_vars == {"num1"}
|
||||
assert test_state.dirty_vars == {"num1", "sum"}
|
||||
|
||||
# Setting another var should mark it as dirty.
|
||||
test_state.num2 = 2
|
||||
assert test_state.dirty_vars == {"num1", "num2"}
|
||||
assert test_state.dirty_vars == {"num1", "num2", "sum"}
|
||||
|
||||
# Cleaning the state should remove all dirty vars.
|
||||
test_state.clean()
|
||||
@ -578,7 +578,7 @@ async def test_process_event_simple(test_state):
|
||||
assert test_state.num1 == 69
|
||||
|
||||
# The delta should contain the changes, including computed vars.
|
||||
assert update.delta == {"test_state": {"num1": 69, "sum": 72.14, "upper": ""}}
|
||||
assert update.delta == {"test_state": {"num1": 69, "sum": 72.14}}
|
||||
assert update.events == []
|
||||
|
||||
|
||||
@ -601,7 +601,6 @@ async def test_process_event_substate(test_state, child_state, grandchild_state)
|
||||
assert child_state.value == "HI"
|
||||
assert child_state.count == 24
|
||||
assert update.delta == {
|
||||
"test_state": {"sum": 3.14, "upper": ""},
|
||||
"test_state.child_state": {"value": "HI", "count": 24},
|
||||
}
|
||||
test_state.clean()
|
||||
@ -616,7 +615,6 @@ async def test_process_event_substate(test_state, child_state, grandchild_state)
|
||||
update = await test_state._process(event)
|
||||
assert grandchild_state.value2 == "new"
|
||||
assert update.delta == {
|
||||
"test_state": {"sum": 3.14, "upper": ""},
|
||||
"test_state.child_state.grandchild_state": {"value2": "new"},
|
||||
}
|
||||
|
||||
@ -791,7 +789,7 @@ def test_not_dirty_computed_var_from_var(interdependent_state):
|
||||
interdependent_state: A state with varying Var dependencies.
|
||||
"""
|
||||
interdependent_state.x = 5
|
||||
assert interdependent_state.get_delta(check=True) == {
|
||||
assert interdependent_state.get_delta() == {
|
||||
interdependent_state.get_full_name(): {"x": 5},
|
||||
}
|
||||
|
||||
@ -806,7 +804,7 @@ def test_dirty_computed_var_from_var(interdependent_state):
|
||||
interdependent_state: A state with varying Var dependencies.
|
||||
"""
|
||||
interdependent_state.v1 = 1
|
||||
assert interdependent_state.get_delta(check=True) == {
|
||||
assert interdependent_state.get_delta() == {
|
||||
interdependent_state.get_full_name(): {"v1": 1, "v1x2": 2, "v1x2x2": 4},
|
||||
}
|
||||
|
||||
@ -818,7 +816,7 @@ def test_dirty_computed_var_from_backend_var(interdependent_state):
|
||||
interdependent_state: A state with varying Var dependencies.
|
||||
"""
|
||||
interdependent_state._v2 = 2
|
||||
assert interdependent_state.get_delta(check=True) == {
|
||||
assert interdependent_state.get_delta() == {
|
||||
interdependent_state.get_full_name(): {"v2x2": 4},
|
||||
}
|
||||
|
||||
@ -860,6 +858,7 @@ def test_conditional_computed_vars():
|
||||
assert ms._dirty_computed_vars(from_vars={"flag"}) == {"rendered_var"}
|
||||
assert ms._dirty_computed_vars(from_vars={"t2"}) == {"rendered_var"}
|
||||
assert ms._dirty_computed_vars(from_vars={"t1"}) == {"rendered_var"}
|
||||
assert ms.computed_vars["rendered_var"].deps() == {"flag", "t1", "t2"}
|
||||
|
||||
|
||||
def test_event_handlers_convert_to_fns(test_state, child_state):
|
||||
@ -896,3 +895,29 @@ def test_event_handlers_call_other_handlers():
|
||||
ms = MainState()
|
||||
ms.set_v2(1)
|
||||
assert ms.v == 1
|
||||
|
||||
|
||||
def test_computed_var_cached():
|
||||
"""Test that a ComputedVar doesn't recalculate when accessed."""
|
||||
comp_v_calls = 0
|
||||
|
||||
class ComputedState(State):
|
||||
v: int = 0
|
||||
|
||||
@ComputedVar
|
||||
def comp_v(self) -> int:
|
||||
nonlocal comp_v_calls
|
||||
comp_v_calls += 1
|
||||
return self.v
|
||||
|
||||
cs = ComputedState()
|
||||
assert cs.dict()["v"] == 0
|
||||
assert comp_v_calls == 1
|
||||
assert cs.dict()["comp_v"] == 0
|
||||
assert comp_v_calls == 1
|
||||
assert cs.comp_v == 0
|
||||
assert comp_v_calls == 1
|
||||
cs.v = 1
|
||||
assert comp_v_calls == 1
|
||||
assert cs.comp_v == 1
|
||||
assert comp_v_calls == 2
|
||||
|
Loading…
Reference in New Issue
Block a user