Cache ComputedVar (#917)
This commit is contained in:
parent
bad2363506
commit
c344a5c0d7
@ -74,12 +74,6 @@ class State(Base, ABC, extra=pydantic.Extra.allow):
|
|||||||
# Mapping of var name to set of computed variables that depend on it
|
# Mapping of var name to set of computed variables that depend on it
|
||||||
computed_var_dependencies: Dict[str, Set[str]] = {}
|
computed_var_dependencies: Dict[str, Set[str]] = {}
|
||||||
|
|
||||||
# Whether to track accessed vars.
|
|
||||||
track_vars: bool = False
|
|
||||||
|
|
||||||
# The current set of accessed vars during tracking.
|
|
||||||
tracked_vars: Set[str] = set()
|
|
||||||
|
|
||||||
def __init__(self, *args, parent_state: Optional[State] = None, **kwargs):
|
def __init__(self, *args, parent_state: Optional[State] = None, **kwargs):
|
||||||
"""Initialize the state.
|
"""Initialize the state.
|
||||||
|
|
||||||
@ -102,22 +96,15 @@ class State(Base, ABC, extra=pydantic.Extra.allow):
|
|||||||
fn.__qualname__ = event_handler.fn.__qualname__ # type: ignore
|
fn.__qualname__ = event_handler.fn.__qualname__ # type: ignore
|
||||||
setattr(self, name, fn)
|
setattr(self, name, fn)
|
||||||
|
|
||||||
# Initialize the mutable fields.
|
|
||||||
self._init_mutable_fields()
|
|
||||||
|
|
||||||
# Initialize computed vars dependencies.
|
# Initialize computed vars dependencies.
|
||||||
self.computed_var_dependencies = defaultdict(set)
|
self.computed_var_dependencies = defaultdict(set)
|
||||||
for cvar in self.computed_vars:
|
for cvar_name, cvar in self.computed_vars.items():
|
||||||
self.tracked_vars = set()
|
|
||||||
|
|
||||||
# Enable tracking and get the computed var.
|
|
||||||
self.track_vars = True
|
|
||||||
self.__getattribute__(cvar)
|
|
||||||
self.track_vars = False
|
|
||||||
|
|
||||||
# Add the dependencies.
|
# Add the dependencies.
|
||||||
for var in self.tracked_vars:
|
for var in cvar.deps():
|
||||||
self.computed_var_dependencies[var].add(cvar)
|
self.computed_var_dependencies[var].add(cvar_name)
|
||||||
|
|
||||||
|
# Initialize the mutable fields.
|
||||||
|
self._init_mutable_fields()
|
||||||
|
|
||||||
def _init_mutable_fields(self):
|
def _init_mutable_fields(self):
|
||||||
"""Initialize mutable fields.
|
"""Initialize mutable fields.
|
||||||
@ -199,7 +186,6 @@ class State(Base, ABC, extra=pydantic.Extra.allow):
|
|||||||
**cls.base_vars,
|
**cls.base_vars,
|
||||||
**cls.computed_vars,
|
**cls.computed_vars,
|
||||||
}
|
}
|
||||||
cls.computed_var_dependencies = {}
|
|
||||||
cls.event_handlers = {}
|
cls.event_handlers = {}
|
||||||
|
|
||||||
# Setup the base vars at the class level.
|
# Setup the base vars at the class level.
|
||||||
@ -233,8 +219,6 @@ class State(Base, ABC, extra=pydantic.Extra.allow):
|
|||||||
"dirty_substates",
|
"dirty_substates",
|
||||||
"router_data",
|
"router_data",
|
||||||
"computed_var_dependencies",
|
"computed_var_dependencies",
|
||||||
"track_vars",
|
|
||||||
"tracked_vars",
|
|
||||||
}
|
}
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@ -508,8 +492,6 @@ class State(Base, ABC, extra=pydantic.Extra.allow):
|
|||||||
|
|
||||||
If the var is inherited, get the var from the parent state.
|
If the var is inherited, get the var from the parent state.
|
||||||
|
|
||||||
If the Var is a dependent of a ComputedVar, track this status in computed_var_dependencies.
|
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
name: The name of the var.
|
name: The name of the var.
|
||||||
|
|
||||||
@ -520,17 +502,6 @@ class State(Base, ABC, extra=pydantic.Extra.allow):
|
|||||||
if not super().__getattribute__("__dict__"):
|
if not super().__getattribute__("__dict__"):
|
||||||
return super().__getattribute__(name)
|
return super().__getattribute__(name)
|
||||||
|
|
||||||
# Check if tracking is enabled.
|
|
||||||
if super().__getattribute__("track_vars"):
|
|
||||||
# Get the non-computed vars.
|
|
||||||
all_vars = {
|
|
||||||
**super().__getattribute__("vars"),
|
|
||||||
**super().__getattribute__("backend_vars"),
|
|
||||||
}
|
|
||||||
# Add the var to the tracked vars.
|
|
||||||
if name in all_vars:
|
|
||||||
super().__getattribute__("tracked_vars").add(name)
|
|
||||||
|
|
||||||
inherited_vars = {
|
inherited_vars = {
|
||||||
**super().__getattribute__("inherited_vars"),
|
**super().__getattribute__("inherited_vars"),
|
||||||
**super().__getattribute__("inherited_backend_vars"),
|
**super().__getattribute__("inherited_backend_vars"),
|
||||||
@ -676,55 +647,58 @@ class State(Base, ABC, extra=pydantic.Extra.allow):
|
|||||||
# Return the state update.
|
# Return the state update.
|
||||||
return StateUpdate(delta=delta, events=events)
|
return StateUpdate(delta=delta, events=events)
|
||||||
|
|
||||||
def _dirty_computed_vars(
|
def _mark_dirty_computed_vars(self) -> None:
|
||||||
self, from_vars: Optional[Set[str]] = None, check: bool = False
|
"""Mark ComputedVars that need to be recalculated based on dirty_vars."""
|
||||||
) -> Set[str]:
|
dirty_vars = self.dirty_vars
|
||||||
"""Get ComputedVars that need to be recomputed based on dirty_vars.
|
while dirty_vars:
|
||||||
|
calc_vars, dirty_vars = dirty_vars, set()
|
||||||
|
for cvar in self._dirty_computed_vars(from_vars=calc_vars):
|
||||||
|
self.dirty_vars.add(cvar)
|
||||||
|
dirty_vars.add(cvar)
|
||||||
|
actual_var = self.computed_vars.get(cvar)
|
||||||
|
if actual_var:
|
||||||
|
actual_var.mark_dirty(instance=self)
|
||||||
|
|
||||||
|
def _dirty_computed_vars(self, from_vars: Optional[Set[str]] = None) -> Set[str]:
|
||||||
|
"""Determine ComputedVars that need to be recalculated based on the given vars.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
from_vars: find ComputedVar that depend on this set of vars. If unspecified, will use the dirty_vars.
|
from_vars: find ComputedVar that depend on this set of vars. If unspecified, will use the dirty_vars.
|
||||||
check: Whether to perform the check.
|
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Set of computed vars to include in the delta.
|
Set of computed vars to include in the delta.
|
||||||
"""
|
"""
|
||||||
# If checking is disabled, return all computed vars.
|
|
||||||
if not check:
|
|
||||||
return set(self.computed_vars)
|
|
||||||
|
|
||||||
# Return only the computed vars that depend on the dirty vars.
|
|
||||||
return set(
|
return set(
|
||||||
cvar
|
cvar
|
||||||
for dirty_var in from_vars or self.dirty_vars
|
for dirty_var in from_vars or self.dirty_vars
|
||||||
for cvar in self.computed_vars
|
for cvar in self.computed_var_dependencies[dirty_var]
|
||||||
if cvar in self.computed_var_dependencies.get(dirty_var, set())
|
|
||||||
)
|
)
|
||||||
|
|
||||||
def get_delta(self, check: bool = False) -> Delta:
|
def get_delta(self) -> Delta:
|
||||||
"""Get the delta for the state.
|
"""Get the delta for the state.
|
||||||
|
|
||||||
Args:
|
|
||||||
check: Whether to check for dirty computed vars.
|
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
The delta for the state.
|
The delta for the state.
|
||||||
"""
|
"""
|
||||||
delta = {}
|
delta = {}
|
||||||
|
|
||||||
# Return the dirty vars, as well as computed vars depending on dirty vars.
|
|
||||||
subdelta = {
|
|
||||||
prop: getattr(self, prop)
|
|
||||||
for prop in self.dirty_vars | self._dirty_computed_vars(check=check)
|
|
||||||
if not types.is_backend_variable(prop)
|
|
||||||
}
|
|
||||||
if len(subdelta) > 0:
|
|
||||||
delta[self.get_full_name()] = subdelta
|
|
||||||
|
|
||||||
# Recursively find the substate deltas.
|
# Recursively find the substate deltas.
|
||||||
substates = self.substates
|
substates = self.substates
|
||||||
for substate in self.dirty_substates:
|
for substate in self.dirty_substates:
|
||||||
delta.update(substates[substate].get_delta())
|
delta.update(substates[substate].get_delta())
|
||||||
|
|
||||||
|
# Return the dirty vars and dependent computed vars
|
||||||
|
delta_vars = self.dirty_vars.intersection(self.base_vars).union(
|
||||||
|
self._dirty_computed_vars()
|
||||||
|
)
|
||||||
|
subdelta = {
|
||||||
|
prop: getattr(self, prop)
|
||||||
|
for prop in delta_vars
|
||||||
|
if not types.is_backend_variable(prop)
|
||||||
|
}
|
||||||
|
if len(subdelta) > 0:
|
||||||
|
delta[self.get_full_name()] = subdelta
|
||||||
|
|
||||||
# Format the delta.
|
# Format the delta.
|
||||||
delta = format.format_state(delta)
|
delta = format.format_state(delta)
|
||||||
|
|
||||||
@ -737,6 +711,10 @@ class State(Base, ABC, extra=pydantic.Extra.allow):
|
|||||||
self.parent_state.dirty_substates.add(self.get_name())
|
self.parent_state.dirty_substates.add(self.get_name())
|
||||||
self.parent_state.mark_dirty()
|
self.parent_state.mark_dirty()
|
||||||
|
|
||||||
|
# have to mark computed vars dirty to allow access to newly computed
|
||||||
|
# values within the same ComputedVar function
|
||||||
|
self._mark_dirty_computed_vars()
|
||||||
|
|
||||||
def clean(self):
|
def clean(self):
|
||||||
"""Reset the dirty vars."""
|
"""Reset the dirty vars."""
|
||||||
# Recursively clean the substates.
|
# Recursively clean the substates.
|
||||||
|
@ -1,10 +1,13 @@
|
|||||||
"""Define a state var."""
|
"""Define a state var."""
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import contextlib
|
||||||
|
import dis
|
||||||
import json
|
import json
|
||||||
import random
|
import random
|
||||||
import string
|
import string
|
||||||
from abc import ABC
|
from abc import ABC
|
||||||
|
from types import FunctionType
|
||||||
from typing import (
|
from typing import (
|
||||||
TYPE_CHECKING,
|
TYPE_CHECKING,
|
||||||
Any,
|
Any,
|
||||||
@ -12,9 +15,11 @@ from typing import (
|
|||||||
Dict,
|
Dict,
|
||||||
List,
|
List,
|
||||||
Optional,
|
Optional,
|
||||||
|
Set,
|
||||||
Type,
|
Type,
|
||||||
Union,
|
Union,
|
||||||
_GenericAlias, # type: ignore
|
_GenericAlias, # type: ignore
|
||||||
|
cast,
|
||||||
get_type_hints,
|
get_type_hints,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -801,6 +806,84 @@ class ComputedVar(property, Var):
|
|||||||
assert self.fget is not None, "Var must have a getter."
|
assert self.fget is not None, "Var must have a getter."
|
||||||
return self.fget.__name__
|
return self.fget.__name__
|
||||||
|
|
||||||
|
@property
|
||||||
|
def cache_attr(self) -> str:
|
||||||
|
"""Get the attribute used to cache the value on the instance.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
An attribute name.
|
||||||
|
"""
|
||||||
|
return f"__cached_{self.name}"
|
||||||
|
|
||||||
|
def __get__(self, instance, owner):
|
||||||
|
"""Get the ComputedVar value.
|
||||||
|
|
||||||
|
If the value is already cached on the instance, return the cached value.
|
||||||
|
|
||||||
|
If this ComputedVar doesn't know what type of object it is attached to, then save
|
||||||
|
a reference as self.__objclass__.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
instance: the instance of the class accessing this computed var.
|
||||||
|
owner: the class that this descriptor is attached to.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The value of the var for the given instance.
|
||||||
|
"""
|
||||||
|
if not hasattr(self, "__objclass__"):
|
||||||
|
self.__objclass__ = owner
|
||||||
|
|
||||||
|
if instance is None:
|
||||||
|
return super().__get__(instance, owner)
|
||||||
|
|
||||||
|
# handle caching
|
||||||
|
if not hasattr(instance, self.cache_attr):
|
||||||
|
setattr(instance, self.cache_attr, super().__get__(instance, owner))
|
||||||
|
return getattr(instance, self.cache_attr)
|
||||||
|
|
||||||
|
def deps(self, obj: Optional[FunctionType] = None) -> Set[str]:
|
||||||
|
"""Determine var dependencies of this ComputedVar.
|
||||||
|
|
||||||
|
Save references to attributes accessed on "self". Recursively called
|
||||||
|
when the function makes a method call on "self".
|
||||||
|
|
||||||
|
Args:
|
||||||
|
obj: the object to disassemble (defaults to the fget function).
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A set of variable names accessed by the given obj.
|
||||||
|
"""
|
||||||
|
d = set()
|
||||||
|
if obj is None:
|
||||||
|
if self.fget is not None:
|
||||||
|
obj = cast(FunctionType, self.fget)
|
||||||
|
else:
|
||||||
|
return set()
|
||||||
|
if not obj.__code__.co_varnames:
|
||||||
|
# cannot reference self if method takes no args
|
||||||
|
return set()
|
||||||
|
self_name = obj.__code__.co_varnames[0]
|
||||||
|
self_is_top_of_stack = False
|
||||||
|
for instruction in dis.get_instructions(obj):
|
||||||
|
if instruction.opname == "LOAD_FAST" and instruction.argval == self_name:
|
||||||
|
self_is_top_of_stack = True
|
||||||
|
continue
|
||||||
|
if self_is_top_of_stack and instruction.opname == "LOAD_ATTR":
|
||||||
|
d.add(instruction.argval)
|
||||||
|
elif self_is_top_of_stack and instruction.opname == "LOAD_METHOD":
|
||||||
|
d.update(self.deps(obj=getattr(self.__objclass__, instruction.argval)))
|
||||||
|
self_is_top_of_stack = False
|
||||||
|
return d
|
||||||
|
|
||||||
|
def mark_dirty(self, instance) -> None:
|
||||||
|
"""Mark this ComputedVar as dirty.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
instance: the state instance that needs to recompute the value.
|
||||||
|
"""
|
||||||
|
with contextlib.suppress(AttributeError):
|
||||||
|
delattr(instance, self.cache_attr)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def type_(self):
|
def type_(self):
|
||||||
"""Get the type of the var.
|
"""Get the type of the var.
|
||||||
|
@ -485,11 +485,11 @@ def test_set_dirty_var(test_state):
|
|||||||
|
|
||||||
# Setting a var should mark it as dirty.
|
# Setting a var should mark it as dirty.
|
||||||
test_state.num1 = 1
|
test_state.num1 = 1
|
||||||
assert test_state.dirty_vars == {"num1"}
|
assert test_state.dirty_vars == {"num1", "sum"}
|
||||||
|
|
||||||
# Setting another var should mark it as dirty.
|
# Setting another var should mark it as dirty.
|
||||||
test_state.num2 = 2
|
test_state.num2 = 2
|
||||||
assert test_state.dirty_vars == {"num1", "num2"}
|
assert test_state.dirty_vars == {"num1", "num2", "sum"}
|
||||||
|
|
||||||
# Cleaning the state should remove all dirty vars.
|
# Cleaning the state should remove all dirty vars.
|
||||||
test_state.clean()
|
test_state.clean()
|
||||||
@ -578,7 +578,7 @@ async def test_process_event_simple(test_state):
|
|||||||
assert test_state.num1 == 69
|
assert test_state.num1 == 69
|
||||||
|
|
||||||
# The delta should contain the changes, including computed vars.
|
# The delta should contain the changes, including computed vars.
|
||||||
assert update.delta == {"test_state": {"num1": 69, "sum": 72.14, "upper": ""}}
|
assert update.delta == {"test_state": {"num1": 69, "sum": 72.14}}
|
||||||
assert update.events == []
|
assert update.events == []
|
||||||
|
|
||||||
|
|
||||||
@ -601,7 +601,6 @@ async def test_process_event_substate(test_state, child_state, grandchild_state)
|
|||||||
assert child_state.value == "HI"
|
assert child_state.value == "HI"
|
||||||
assert child_state.count == 24
|
assert child_state.count == 24
|
||||||
assert update.delta == {
|
assert update.delta == {
|
||||||
"test_state": {"sum": 3.14, "upper": ""},
|
|
||||||
"test_state.child_state": {"value": "HI", "count": 24},
|
"test_state.child_state": {"value": "HI", "count": 24},
|
||||||
}
|
}
|
||||||
test_state.clean()
|
test_state.clean()
|
||||||
@ -616,7 +615,6 @@ async def test_process_event_substate(test_state, child_state, grandchild_state)
|
|||||||
update = await test_state._process(event)
|
update = await test_state._process(event)
|
||||||
assert grandchild_state.value2 == "new"
|
assert grandchild_state.value2 == "new"
|
||||||
assert update.delta == {
|
assert update.delta == {
|
||||||
"test_state": {"sum": 3.14, "upper": ""},
|
|
||||||
"test_state.child_state.grandchild_state": {"value2": "new"},
|
"test_state.child_state.grandchild_state": {"value2": "new"},
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -791,7 +789,7 @@ def test_not_dirty_computed_var_from_var(interdependent_state):
|
|||||||
interdependent_state: A state with varying Var dependencies.
|
interdependent_state: A state with varying Var dependencies.
|
||||||
"""
|
"""
|
||||||
interdependent_state.x = 5
|
interdependent_state.x = 5
|
||||||
assert interdependent_state.get_delta(check=True) == {
|
assert interdependent_state.get_delta() == {
|
||||||
interdependent_state.get_full_name(): {"x": 5},
|
interdependent_state.get_full_name(): {"x": 5},
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -806,7 +804,7 @@ def test_dirty_computed_var_from_var(interdependent_state):
|
|||||||
interdependent_state: A state with varying Var dependencies.
|
interdependent_state: A state with varying Var dependencies.
|
||||||
"""
|
"""
|
||||||
interdependent_state.v1 = 1
|
interdependent_state.v1 = 1
|
||||||
assert interdependent_state.get_delta(check=True) == {
|
assert interdependent_state.get_delta() == {
|
||||||
interdependent_state.get_full_name(): {"v1": 1, "v1x2": 2, "v1x2x2": 4},
|
interdependent_state.get_full_name(): {"v1": 1, "v1x2": 2, "v1x2x2": 4},
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -818,7 +816,7 @@ def test_dirty_computed_var_from_backend_var(interdependent_state):
|
|||||||
interdependent_state: A state with varying Var dependencies.
|
interdependent_state: A state with varying Var dependencies.
|
||||||
"""
|
"""
|
||||||
interdependent_state._v2 = 2
|
interdependent_state._v2 = 2
|
||||||
assert interdependent_state.get_delta(check=True) == {
|
assert interdependent_state.get_delta() == {
|
||||||
interdependent_state.get_full_name(): {"v2x2": 4},
|
interdependent_state.get_full_name(): {"v2x2": 4},
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -860,6 +858,7 @@ def test_conditional_computed_vars():
|
|||||||
assert ms._dirty_computed_vars(from_vars={"flag"}) == {"rendered_var"}
|
assert ms._dirty_computed_vars(from_vars={"flag"}) == {"rendered_var"}
|
||||||
assert ms._dirty_computed_vars(from_vars={"t2"}) == {"rendered_var"}
|
assert ms._dirty_computed_vars(from_vars={"t2"}) == {"rendered_var"}
|
||||||
assert ms._dirty_computed_vars(from_vars={"t1"}) == {"rendered_var"}
|
assert ms._dirty_computed_vars(from_vars={"t1"}) == {"rendered_var"}
|
||||||
|
assert ms.computed_vars["rendered_var"].deps() == {"flag", "t1", "t2"}
|
||||||
|
|
||||||
|
|
||||||
def test_event_handlers_convert_to_fns(test_state, child_state):
|
def test_event_handlers_convert_to_fns(test_state, child_state):
|
||||||
@ -896,3 +895,29 @@ def test_event_handlers_call_other_handlers():
|
|||||||
ms = MainState()
|
ms = MainState()
|
||||||
ms.set_v2(1)
|
ms.set_v2(1)
|
||||||
assert ms.v == 1
|
assert ms.v == 1
|
||||||
|
|
||||||
|
|
||||||
|
def test_computed_var_cached():
|
||||||
|
"""Test that a ComputedVar doesn't recalculate when accessed."""
|
||||||
|
comp_v_calls = 0
|
||||||
|
|
||||||
|
class ComputedState(State):
|
||||||
|
v: int = 0
|
||||||
|
|
||||||
|
@ComputedVar
|
||||||
|
def comp_v(self) -> int:
|
||||||
|
nonlocal comp_v_calls
|
||||||
|
comp_v_calls += 1
|
||||||
|
return self.v
|
||||||
|
|
||||||
|
cs = ComputedState()
|
||||||
|
assert cs.dict()["v"] == 0
|
||||||
|
assert comp_v_calls == 1
|
||||||
|
assert cs.dict()["comp_v"] == 0
|
||||||
|
assert comp_v_calls == 1
|
||||||
|
assert cs.comp_v == 0
|
||||||
|
assert comp_v_calls == 1
|
||||||
|
cs.v = 1
|
||||||
|
assert comp_v_calls == 1
|
||||||
|
assert cs.comp_v == 1
|
||||||
|
assert comp_v_calls == 2
|
||||||
|
Loading…
Reference in New Issue
Block a user