Only update ComputedVar when dependent vars change (#840)

This commit is contained in:
Masen Furer 2023-04-23 17:48:44 -07:00 committed by GitHub
parent 3be43bdab1
commit b4755b8123
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 165 additions and 5 deletions

View File

@ -3,6 +3,7 @@ from __future__ import annotations
import asyncio
import functools
import inspect
import traceback
from abc import ABC
from typing import (
@ -14,6 +15,7 @@ from typing import (
Optional,
Sequence,
Set,
Tuple,
Type,
Union,
)
@ -51,6 +53,9 @@ class State(Base, ABC):
# Backend vars inherited
inherited_backend_vars: ClassVar[Dict[str, Any]] = {}
# Mapping of var name to set of computed variables that depend on it
computed_var_dependencies: ClassVar[Dict[str, Set[str]]] = {}
# The event handlers.
event_handlers: ClassVar[Dict[str, EventHandler]] = {}
@ -171,6 +176,7 @@ class State(Base, ABC):
**cls.base_vars,
**cls.computed_vars,
}
cls.computed_var_dependencies = {}
# Setup the base vars at the class level.
for prop in cls.base_vars.values():
@ -472,12 +478,28 @@ class State(Base, ABC):
If the var is inherited, get the var from the parent state.
If the Var is a dependent of a ComputedVar, track this status in computed_var_dependencies.
Args:
name: The name of the var.
Returns:
The value of the var.
"""
vars = {
**super().__getattribute__("vars"),
**super().__getattribute__("backend_vars"),
}
if name in vars:
parent_frame, parent_frame_locals = _get_previous_recursive_frame_info()
if parent_frame is not None:
computed_vars = super().__getattribute__("computed_vars")
requesting_attribute_name = parent_frame_locals.get("name")
if requesting_attribute_name in computed_vars:
# Keep track of any ComputedVar that depends on this Var
super().__getattribute__("computed_var_dependencies").setdefault(
name, set()
).add(requesting_attribute_name)
inherited_vars = {
**super().__getattribute__("inherited_vars"),
**super().__getattribute__("inherited_backend_vars"),
@ -505,6 +527,7 @@ class State(Base, ABC):
if types.is_backend_variable(name):
self.backend_vars.__setitem__(name, value)
self.dirty_vars.add(name)
self.mark_dirty()
return
@ -622,6 +645,28 @@ class State(Base, ABC):
# Return the state update.
return StateUpdate(delta=delta, events=events)
def _dirty_computed_vars(self, from_vars: Optional[Set[str]] = None) -> Set[str]:
"""Get ComputedVars that need to be recomputed based on dirty_vars.
Args:
from_vars: find ComputedVar that depend on this set of vars. If unspecified, will use the dirty_vars.
Returns:
Set of computed vars to include in the delta.
"""
dirty_computed_vars = set(
cvar
for dirty_var in from_vars or self.dirty_vars
for cvar in self.computed_vars
if cvar in self.computed_var_dependencies.get(dirty_var, set())
)
if dirty_computed_vars:
# recursive call to catch computed vars that depend on computed vars
return dirty_computed_vars | self._dirty_computed_vars(
from_vars=dirty_computed_vars
)
return dirty_computed_vars
def get_delta(self) -> Delta:
"""Get the delta for the state.
@ -630,10 +675,11 @@ class State(Base, ABC):
"""
delta = {}
# Return the dirty vars, as well as all computed vars.
# Return the dirty vars, as well as computed vars depending on dirty vars.
subdelta = {
prop: getattr(self, prop)
for prop in self.dirty_vars | self.computed_vars.keys()
for prop in self.dirty_vars | self._dirty_computed_vars()
if not types.is_backend_variable(prop)
}
if len(subdelta) > 0:
delta[self.get_full_name()] = subdelta
@ -803,3 +849,24 @@ def _convert_mutable_datatypes(
field_value, reassign_field=reassign_field, field_name=field_name
)
return field_value
def _get_previous_recursive_frame_info() -> (
Tuple[Optional[inspect.FrameInfo], Dict[str, Any]]
):
"""Find the previous frame of the same function that calls this helper.
For example, if this function is called from `State.__getattribute__`
(parent frame), then the returned frame will be the next earliest call
of the same function.
Returns:
Tuple of (frame_info, local_vars)
If no previous recursive frame is found up the stack, the frame info will be None.
"""
_this_frame, parent_frame, *prev_frames = inspect.stack()
for frame in prev_frames:
if frame.frame.f_code == parent_frame.frame.f_code:
return frame, frame.frame.f_locals
return None, {}

View File

@ -582,7 +582,7 @@ async def test_process_event_simple(test_state):
assert test_state.num1 == 69
# The delta should contain the changes, including computed vars.
assert update.delta == {"test_state": {"num1": 69, "sum": 72.14, "upper": ""}}
assert update.delta == {"test_state": {"num1": 69, "sum": 72.14}}
assert update.events == []
@ -606,7 +606,6 @@ async def test_process_event_substate(test_state, child_state, grandchild_state)
assert child_state.count == 24
assert update.delta == {
"test_state.child_state": {"value": "HI", "count": 24},
"test_state": {"sum": 3.14, "upper": ""},
}
test_state.clean()
@ -621,7 +620,6 @@ async def test_process_event_substate(test_state, child_state, grandchild_state)
assert grandchild_state.value2 == "new"
assert update.delta == {
"test_state.child_state.grandchild_state": {"value2": "new"},
"test_state": {"sum": 3.14, "upper": ""},
}
@ -724,3 +722,98 @@ def test_add_var(test_state):
test_state.add_var("dynamic_dict", Dict[str, int], {"k1": 5, "k2": 10})
assert test_state.dynamic_dict == {"k1": 5, "k2": 10}
assert test_state.dynamic_dict == {"k1": 5, "k2": 10}
class InterdependentState(State):
"""A state with 3 vars and 3 computed vars.
x: a variable that no computed var depends on
v1: a varable that one computed var directly depeneds on
_v2: a backend variable that one computed var directly depends on
v1x2: a computed var that depends on v1
v2x2: a computed var that depends on backend var _v2
v1x2x2: a computed var that depends on computed var v1x2
"""
x: int = 0
v1: int = 0
_v2: int = 1
@ComputedVar
def v1x2(self) -> int:
"""depends on var v1.
Returns:
Var v1 multiplied by 2
"""
return self.v1 * 2
@ComputedVar
def v2x2(self) -> int:
"""depends on backend var _v2.
Returns:
backend var _v2 multiplied by 2
"""
return self._v2 * 2
@ComputedVar
def v1x2x2(self) -> int:
"""depends on ComputedVar v1x2.
Returns:
ComputedVar v1x2 multiplied by 2
"""
return self.v1x2 * 2
@pytest.fixture
def interdependent_state() -> State:
"""A state with varying dependency between vars.
Returns:
instance of InterdependentState
"""
s = InterdependentState()
s.dict() # prime initial relationships by accessing all ComputedVars
return s
def test_not_dirty_computed_var_from_var(interdependent_state):
"""Set Var that no ComputedVar depends on, expect no recalculation.
Args:
interdependent_state: A state with varying Var dependencies.
"""
interdependent_state.x = 5
assert interdependent_state.get_delta() == {
interdependent_state.get_full_name(): {"x": 5},
}
def test_dirty_computed_var_from_var(interdependent_state):
"""Set Var that ComputedVar depends on, expect recalculation.
The other ComputedVar depends on the changed ComputedVar and should also be
recalculated. No other ComputedVars should be recalculated.
Args:
interdependent_state: A state with varying Var dependencies.
"""
interdependent_state.v1 = 1
assert interdependent_state.get_delta() == {
interdependent_state.get_full_name(): {"v1": 1, "v1x2": 2, "v1x2x2": 4},
}
def test_dirty_computed_var_from_backend_var(interdependent_state):
"""Set backend var that ComputedVar depends on, expect recalculation.
Args:
interdependent_state: A state with varying Var dependencies.
"""
interdependent_state._v2 = 2
assert interdependent_state.get_delta() == {
interdependent_state.get_full_name(): {"v2x2": 4},
}