Cache ComputedVar (#917)

This commit is contained in:
Masen Furer 2023-05-04 00:11:39 -07:00 committed by GitHub
parent bad2363506
commit c344a5c0d7
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 154 additions and 68 deletions

View File

@ -74,12 +74,6 @@ class State(Base, ABC, extra=pydantic.Extra.allow):
# Mapping of var name to set of computed variables that depend on it
computed_var_dependencies: Dict[str, Set[str]] = {}
# Whether to track accessed vars.
track_vars: bool = False
# The current set of accessed vars during tracking.
tracked_vars: Set[str] = set()
def __init__(self, *args, parent_state: Optional[State] = None, **kwargs):
"""Initialize the state.
@ -102,22 +96,15 @@ class State(Base, ABC, extra=pydantic.Extra.allow):
fn.__qualname__ = event_handler.fn.__qualname__ # type: ignore
setattr(self, name, fn)
# Initialize the mutable fields.
self._init_mutable_fields()
# Initialize computed vars dependencies.
self.computed_var_dependencies = defaultdict(set)
for cvar in self.computed_vars:
self.tracked_vars = set()
# Enable tracking and get the computed var.
self.track_vars = True
self.__getattribute__(cvar)
self.track_vars = False
for cvar_name, cvar in self.computed_vars.items():
# Add the dependencies.
for var in self.tracked_vars:
self.computed_var_dependencies[var].add(cvar)
for var in cvar.deps():
self.computed_var_dependencies[var].add(cvar_name)
# Initialize the mutable fields.
self._init_mutable_fields()
def _init_mutable_fields(self):
"""Initialize mutable fields.
@ -199,7 +186,6 @@ class State(Base, ABC, extra=pydantic.Extra.allow):
**cls.base_vars,
**cls.computed_vars,
}
cls.computed_var_dependencies = {}
cls.event_handlers = {}
# Setup the base vars at the class level.
@ -233,8 +219,6 @@ class State(Base, ABC, extra=pydantic.Extra.allow):
"dirty_substates",
"router_data",
"computed_var_dependencies",
"track_vars",
"tracked_vars",
}
@classmethod
@ -508,8 +492,6 @@ class State(Base, ABC, extra=pydantic.Extra.allow):
If the var is inherited, get the var from the parent state.
If the Var is a dependent of a ComputedVar, track this status in computed_var_dependencies.
Args:
name: The name of the var.
@ -520,17 +502,6 @@ class State(Base, ABC, extra=pydantic.Extra.allow):
if not super().__getattribute__("__dict__"):
return super().__getattribute__(name)
# Check if tracking is enabled.
if super().__getattribute__("track_vars"):
# Get the non-computed vars.
all_vars = {
**super().__getattribute__("vars"),
**super().__getattribute__("backend_vars"),
}
# Add the var to the tracked vars.
if name in all_vars:
super().__getattribute__("tracked_vars").add(name)
inherited_vars = {
**super().__getattribute__("inherited_vars"),
**super().__getattribute__("inherited_backend_vars"),
@ -676,55 +647,58 @@ class State(Base, ABC, extra=pydantic.Extra.allow):
# Return the state update.
return StateUpdate(delta=delta, events=events)
def _dirty_computed_vars(
self, from_vars: Optional[Set[str]] = None, check: bool = False
) -> Set[str]:
"""Get ComputedVars that need to be recomputed based on dirty_vars.
def _mark_dirty_computed_vars(self) -> None:
"""Mark ComputedVars that need to be recalculated based on dirty_vars."""
dirty_vars = self.dirty_vars
while dirty_vars:
calc_vars, dirty_vars = dirty_vars, set()
for cvar in self._dirty_computed_vars(from_vars=calc_vars):
self.dirty_vars.add(cvar)
dirty_vars.add(cvar)
actual_var = self.computed_vars.get(cvar)
if actual_var:
actual_var.mark_dirty(instance=self)
def _dirty_computed_vars(self, from_vars: Optional[Set[str]] = None) -> Set[str]:
"""Determine ComputedVars that need to be recalculated based on the given vars.
Args:
from_vars: find ComputedVar that depend on this set of vars. If unspecified, will use the dirty_vars.
check: Whether to perform the check.
Returns:
Set of computed vars to include in the delta.
"""
# If checking is disabled, return all computed vars.
if not check:
return set(self.computed_vars)
# Return only the computed vars that depend on the dirty vars.
return set(
cvar
for dirty_var in from_vars or self.dirty_vars
for cvar in self.computed_vars
if cvar in self.computed_var_dependencies.get(dirty_var, set())
for cvar in self.computed_var_dependencies[dirty_var]
)
def get_delta(self, check: bool = False) -> Delta:
def get_delta(self) -> Delta:
"""Get the delta for the state.
Args:
check: Whether to check for dirty computed vars.
Returns:
The delta for the state.
"""
delta = {}
# Return the dirty vars, as well as computed vars depending on dirty vars.
subdelta = {
prop: getattr(self, prop)
for prop in self.dirty_vars | self._dirty_computed_vars(check=check)
if not types.is_backend_variable(prop)
}
if len(subdelta) > 0:
delta[self.get_full_name()] = subdelta
# Recursively find the substate deltas.
substates = self.substates
for substate in self.dirty_substates:
delta.update(substates[substate].get_delta())
# Return the dirty vars and dependent computed vars
delta_vars = self.dirty_vars.intersection(self.base_vars).union(
self._dirty_computed_vars()
)
subdelta = {
prop: getattr(self, prop)
for prop in delta_vars
if not types.is_backend_variable(prop)
}
if len(subdelta) > 0:
delta[self.get_full_name()] = subdelta
# Format the delta.
delta = format.format_state(delta)
@ -737,6 +711,10 @@ class State(Base, ABC, extra=pydantic.Extra.allow):
self.parent_state.dirty_substates.add(self.get_name())
self.parent_state.mark_dirty()
# have to mark computed vars dirty to allow access to newly computed
# values within the same ComputedVar function
self._mark_dirty_computed_vars()
def clean(self):
"""Reset the dirty vars."""
# Recursively clean the substates.

View File

@ -1,10 +1,13 @@
"""Define a state var."""
from __future__ import annotations
import contextlib
import dis
import json
import random
import string
from abc import ABC
from types import FunctionType
from typing import (
TYPE_CHECKING,
Any,
@ -12,9 +15,11 @@ from typing import (
Dict,
List,
Optional,
Set,
Type,
Union,
_GenericAlias, # type: ignore
cast,
get_type_hints,
)
@ -801,6 +806,84 @@ class ComputedVar(property, Var):
assert self.fget is not None, "Var must have a getter."
return self.fget.__name__
@property
def cache_attr(self) -> str:
"""Get the attribute used to cache the value on the instance.
Returns:
An attribute name.
"""
return f"__cached_{self.name}"
def __get__(self, instance, owner):
"""Get the ComputedVar value.
If the value is already cached on the instance, return the cached value.
If this ComputedVar doesn't know what type of object it is attached to, then save
a reference as self.__objclass__.
Args:
instance: the instance of the class accessing this computed var.
owner: the class that this descriptor is attached to.
Returns:
The value of the var for the given instance.
"""
if not hasattr(self, "__objclass__"):
self.__objclass__ = owner
if instance is None:
return super().__get__(instance, owner)
# handle caching
if not hasattr(instance, self.cache_attr):
setattr(instance, self.cache_attr, super().__get__(instance, owner))
return getattr(instance, self.cache_attr)
def deps(self, obj: Optional[FunctionType] = None) -> Set[str]:
"""Determine var dependencies of this ComputedVar.
Save references to attributes accessed on "self". Recursively called
when the function makes a method call on "self".
Args:
obj: the object to disassemble (defaults to the fget function).
Returns:
A set of variable names accessed by the given obj.
"""
d = set()
if obj is None:
if self.fget is not None:
obj = cast(FunctionType, self.fget)
else:
return set()
if not obj.__code__.co_varnames:
# cannot reference self if method takes no args
return set()
self_name = obj.__code__.co_varnames[0]
self_is_top_of_stack = False
for instruction in dis.get_instructions(obj):
if instruction.opname == "LOAD_FAST" and instruction.argval == self_name:
self_is_top_of_stack = True
continue
if self_is_top_of_stack and instruction.opname == "LOAD_ATTR":
d.add(instruction.argval)
elif self_is_top_of_stack and instruction.opname == "LOAD_METHOD":
d.update(self.deps(obj=getattr(self.__objclass__, instruction.argval)))
self_is_top_of_stack = False
return d
def mark_dirty(self, instance) -> None:
"""Mark this ComputedVar as dirty.
Args:
instance: the state instance that needs to recompute the value.
"""
with contextlib.suppress(AttributeError):
delattr(instance, self.cache_attr)
@property
def type_(self):
"""Get the type of the var.

View File

@ -485,11 +485,11 @@ def test_set_dirty_var(test_state):
# Setting a var should mark it as dirty.
test_state.num1 = 1
assert test_state.dirty_vars == {"num1"}
assert test_state.dirty_vars == {"num1", "sum"}
# Setting another var should mark it as dirty.
test_state.num2 = 2
assert test_state.dirty_vars == {"num1", "num2"}
assert test_state.dirty_vars == {"num1", "num2", "sum"}
# Cleaning the state should remove all dirty vars.
test_state.clean()
@ -578,7 +578,7 @@ async def test_process_event_simple(test_state):
assert test_state.num1 == 69
# The delta should contain the changes, including computed vars.
assert update.delta == {"test_state": {"num1": 69, "sum": 72.14, "upper": ""}}
assert update.delta == {"test_state": {"num1": 69, "sum": 72.14}}
assert update.events == []
@ -601,7 +601,6 @@ async def test_process_event_substate(test_state, child_state, grandchild_state)
assert child_state.value == "HI"
assert child_state.count == 24
assert update.delta == {
"test_state": {"sum": 3.14, "upper": ""},
"test_state.child_state": {"value": "HI", "count": 24},
}
test_state.clean()
@ -616,7 +615,6 @@ async def test_process_event_substate(test_state, child_state, grandchild_state)
update = await test_state._process(event)
assert grandchild_state.value2 == "new"
assert update.delta == {
"test_state": {"sum": 3.14, "upper": ""},
"test_state.child_state.grandchild_state": {"value2": "new"},
}
@ -791,7 +789,7 @@ def test_not_dirty_computed_var_from_var(interdependent_state):
interdependent_state: A state with varying Var dependencies.
"""
interdependent_state.x = 5
assert interdependent_state.get_delta(check=True) == {
assert interdependent_state.get_delta() == {
interdependent_state.get_full_name(): {"x": 5},
}
@ -806,7 +804,7 @@ def test_dirty_computed_var_from_var(interdependent_state):
interdependent_state: A state with varying Var dependencies.
"""
interdependent_state.v1 = 1
assert interdependent_state.get_delta(check=True) == {
assert interdependent_state.get_delta() == {
interdependent_state.get_full_name(): {"v1": 1, "v1x2": 2, "v1x2x2": 4},
}
@ -818,7 +816,7 @@ def test_dirty_computed_var_from_backend_var(interdependent_state):
interdependent_state: A state with varying Var dependencies.
"""
interdependent_state._v2 = 2
assert interdependent_state.get_delta(check=True) == {
assert interdependent_state.get_delta() == {
interdependent_state.get_full_name(): {"v2x2": 4},
}
@ -860,6 +858,7 @@ def test_conditional_computed_vars():
assert ms._dirty_computed_vars(from_vars={"flag"}) == {"rendered_var"}
assert ms._dirty_computed_vars(from_vars={"t2"}) == {"rendered_var"}
assert ms._dirty_computed_vars(from_vars={"t1"}) == {"rendered_var"}
assert ms.computed_vars["rendered_var"].deps() == {"flag", "t1", "t2"}
def test_event_handlers_convert_to_fns(test_state, child_state):
@ -896,3 +895,29 @@ def test_event_handlers_call_other_handlers():
ms = MainState()
ms.set_v2(1)
assert ms.v == 1
def test_computed_var_cached():
"""Test that a ComputedVar doesn't recalculate when accessed."""
comp_v_calls = 0
class ComputedState(State):
v: int = 0
@ComputedVar
def comp_v(self) -> int:
nonlocal comp_v_calls
comp_v_calls += 1
return self.v
cs = ComputedState()
assert cs.dict()["v"] == 0
assert comp_v_calls == 1
assert cs.dict()["comp_v"] == 0
assert comp_v_calls == 1
assert cs.comp_v == 0
assert comp_v_calls == 1
cs.v = 1
assert comp_v_calls == 1
assert cs.comp_v == 1
assert comp_v_calls == 2