Only update ComputedVar when dependent vars change (#840)
This commit is contained in:
parent
3be43bdab1
commit
b4755b8123
@ -3,6 +3,7 @@ from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import functools
|
||||
import inspect
|
||||
import traceback
|
||||
from abc import ABC
|
||||
from typing import (
|
||||
@ -14,6 +15,7 @@ from typing import (
|
||||
Optional,
|
||||
Sequence,
|
||||
Set,
|
||||
Tuple,
|
||||
Type,
|
||||
Union,
|
||||
)
|
||||
@ -51,6 +53,9 @@ class State(Base, ABC):
|
||||
# Backend vars inherited
|
||||
inherited_backend_vars: ClassVar[Dict[str, Any]] = {}
|
||||
|
||||
# Mapping of var name to set of computed variables that depend on it
|
||||
computed_var_dependencies: ClassVar[Dict[str, Set[str]]] = {}
|
||||
|
||||
# The event handlers.
|
||||
event_handlers: ClassVar[Dict[str, EventHandler]] = {}
|
||||
|
||||
@ -171,6 +176,7 @@ class State(Base, ABC):
|
||||
**cls.base_vars,
|
||||
**cls.computed_vars,
|
||||
}
|
||||
cls.computed_var_dependencies = {}
|
||||
|
||||
# Setup the base vars at the class level.
|
||||
for prop in cls.base_vars.values():
|
||||
@ -472,12 +478,28 @@ class State(Base, ABC):
|
||||
|
||||
If the var is inherited, get the var from the parent state.
|
||||
|
||||
If the Var is a dependent of a ComputedVar, track this status in computed_var_dependencies.
|
||||
|
||||
Args:
|
||||
name: The name of the var.
|
||||
|
||||
Returns:
|
||||
The value of the var.
|
||||
"""
|
||||
vars = {
|
||||
**super().__getattribute__("vars"),
|
||||
**super().__getattribute__("backend_vars"),
|
||||
}
|
||||
if name in vars:
|
||||
parent_frame, parent_frame_locals = _get_previous_recursive_frame_info()
|
||||
if parent_frame is not None:
|
||||
computed_vars = super().__getattribute__("computed_vars")
|
||||
requesting_attribute_name = parent_frame_locals.get("name")
|
||||
if requesting_attribute_name in computed_vars:
|
||||
# Keep track of any ComputedVar that depends on this Var
|
||||
super().__getattribute__("computed_var_dependencies").setdefault(
|
||||
name, set()
|
||||
).add(requesting_attribute_name)
|
||||
inherited_vars = {
|
||||
**super().__getattribute__("inherited_vars"),
|
||||
**super().__getattribute__("inherited_backend_vars"),
|
||||
@ -505,6 +527,7 @@ class State(Base, ABC):
|
||||
|
||||
if types.is_backend_variable(name):
|
||||
self.backend_vars.__setitem__(name, value)
|
||||
self.dirty_vars.add(name)
|
||||
self.mark_dirty()
|
||||
return
|
||||
|
||||
@ -622,6 +645,28 @@ class State(Base, ABC):
|
||||
# Return the state update.
|
||||
return StateUpdate(delta=delta, events=events)
|
||||
|
||||
def _dirty_computed_vars(self, from_vars: Optional[Set[str]] = None) -> Set[str]:
|
||||
"""Get ComputedVars that need to be recomputed based on dirty_vars.
|
||||
|
||||
Args:
|
||||
from_vars: find ComputedVar that depend on this set of vars. If unspecified, will use the dirty_vars.
|
||||
|
||||
Returns:
|
||||
Set of computed vars to include in the delta.
|
||||
"""
|
||||
dirty_computed_vars = set(
|
||||
cvar
|
||||
for dirty_var in from_vars or self.dirty_vars
|
||||
for cvar in self.computed_vars
|
||||
if cvar in self.computed_var_dependencies.get(dirty_var, set())
|
||||
)
|
||||
if dirty_computed_vars:
|
||||
# recursive call to catch computed vars that depend on computed vars
|
||||
return dirty_computed_vars | self._dirty_computed_vars(
|
||||
from_vars=dirty_computed_vars
|
||||
)
|
||||
return dirty_computed_vars
|
||||
|
||||
def get_delta(self) -> Delta:
|
||||
"""Get the delta for the state.
|
||||
|
||||
@ -630,10 +675,11 @@ class State(Base, ABC):
|
||||
"""
|
||||
delta = {}
|
||||
|
||||
# Return the dirty vars, as well as all computed vars.
|
||||
# Return the dirty vars, as well as computed vars depending on dirty vars.
|
||||
subdelta = {
|
||||
prop: getattr(self, prop)
|
||||
for prop in self.dirty_vars | self.computed_vars.keys()
|
||||
for prop in self.dirty_vars | self._dirty_computed_vars()
|
||||
if not types.is_backend_variable(prop)
|
||||
}
|
||||
if len(subdelta) > 0:
|
||||
delta[self.get_full_name()] = subdelta
|
||||
@ -803,3 +849,24 @@ def _convert_mutable_datatypes(
|
||||
field_value, reassign_field=reassign_field, field_name=field_name
|
||||
)
|
||||
return field_value
|
||||
|
||||
|
||||
def _get_previous_recursive_frame_info() -> (
|
||||
Tuple[Optional[inspect.FrameInfo], Dict[str, Any]]
|
||||
):
|
||||
"""Find the previous frame of the same function that calls this helper.
|
||||
|
||||
For example, if this function is called from `State.__getattribute__`
|
||||
(parent frame), then the returned frame will be the next earliest call
|
||||
of the same function.
|
||||
|
||||
Returns:
|
||||
Tuple of (frame_info, local_vars)
|
||||
|
||||
If no previous recursive frame is found up the stack, the frame info will be None.
|
||||
"""
|
||||
_this_frame, parent_frame, *prev_frames = inspect.stack()
|
||||
for frame in prev_frames:
|
||||
if frame.frame.f_code == parent_frame.frame.f_code:
|
||||
return frame, frame.frame.f_locals
|
||||
return None, {}
|
||||
|
@ -582,7 +582,7 @@ async def test_process_event_simple(test_state):
|
||||
assert test_state.num1 == 69
|
||||
|
||||
# The delta should contain the changes, including computed vars.
|
||||
assert update.delta == {"test_state": {"num1": 69, "sum": 72.14, "upper": ""}}
|
||||
assert update.delta == {"test_state": {"num1": 69, "sum": 72.14}}
|
||||
assert update.events == []
|
||||
|
||||
|
||||
@ -606,7 +606,6 @@ async def test_process_event_substate(test_state, child_state, grandchild_state)
|
||||
assert child_state.count == 24
|
||||
assert update.delta == {
|
||||
"test_state.child_state": {"value": "HI", "count": 24},
|
||||
"test_state": {"sum": 3.14, "upper": ""},
|
||||
}
|
||||
test_state.clean()
|
||||
|
||||
@ -621,7 +620,6 @@ async def test_process_event_substate(test_state, child_state, grandchild_state)
|
||||
assert grandchild_state.value2 == "new"
|
||||
assert update.delta == {
|
||||
"test_state.child_state.grandchild_state": {"value2": "new"},
|
||||
"test_state": {"sum": 3.14, "upper": ""},
|
||||
}
|
||||
|
||||
|
||||
@ -724,3 +722,98 @@ def test_add_var(test_state):
|
||||
test_state.add_var("dynamic_dict", Dict[str, int], {"k1": 5, "k2": 10})
|
||||
assert test_state.dynamic_dict == {"k1": 5, "k2": 10}
|
||||
assert test_state.dynamic_dict == {"k1": 5, "k2": 10}
|
||||
|
||||
|
||||
class InterdependentState(State):
|
||||
"""A state with 3 vars and 3 computed vars.
|
||||
|
||||
x: a variable that no computed var depends on
|
||||
v1: a varable that one computed var directly depeneds on
|
||||
_v2: a backend variable that one computed var directly depends on
|
||||
|
||||
v1x2: a computed var that depends on v1
|
||||
v2x2: a computed var that depends on backend var _v2
|
||||
v1x2x2: a computed var that depends on computed var v1x2
|
||||
"""
|
||||
|
||||
x: int = 0
|
||||
v1: int = 0
|
||||
_v2: int = 1
|
||||
|
||||
@ComputedVar
|
||||
def v1x2(self) -> int:
|
||||
"""depends on var v1.
|
||||
|
||||
Returns:
|
||||
Var v1 multiplied by 2
|
||||
"""
|
||||
return self.v1 * 2
|
||||
|
||||
@ComputedVar
|
||||
def v2x2(self) -> int:
|
||||
"""depends on backend var _v2.
|
||||
|
||||
Returns:
|
||||
backend var _v2 multiplied by 2
|
||||
"""
|
||||
return self._v2 * 2
|
||||
|
||||
@ComputedVar
|
||||
def v1x2x2(self) -> int:
|
||||
"""depends on ComputedVar v1x2.
|
||||
|
||||
Returns:
|
||||
ComputedVar v1x2 multiplied by 2
|
||||
"""
|
||||
return self.v1x2 * 2
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def interdependent_state() -> State:
|
||||
"""A state with varying dependency between vars.
|
||||
|
||||
Returns:
|
||||
instance of InterdependentState
|
||||
"""
|
||||
s = InterdependentState()
|
||||
s.dict() # prime initial relationships by accessing all ComputedVars
|
||||
return s
|
||||
|
||||
|
||||
def test_not_dirty_computed_var_from_var(interdependent_state):
|
||||
"""Set Var that no ComputedVar depends on, expect no recalculation.
|
||||
|
||||
Args:
|
||||
interdependent_state: A state with varying Var dependencies.
|
||||
"""
|
||||
interdependent_state.x = 5
|
||||
assert interdependent_state.get_delta() == {
|
||||
interdependent_state.get_full_name(): {"x": 5},
|
||||
}
|
||||
|
||||
|
||||
def test_dirty_computed_var_from_var(interdependent_state):
|
||||
"""Set Var that ComputedVar depends on, expect recalculation.
|
||||
|
||||
The other ComputedVar depends on the changed ComputedVar and should also be
|
||||
recalculated. No other ComputedVars should be recalculated.
|
||||
|
||||
Args:
|
||||
interdependent_state: A state with varying Var dependencies.
|
||||
"""
|
||||
interdependent_state.v1 = 1
|
||||
assert interdependent_state.get_delta() == {
|
||||
interdependent_state.get_full_name(): {"v1": 1, "v1x2": 2, "v1x2x2": 4},
|
||||
}
|
||||
|
||||
|
||||
def test_dirty_computed_var_from_backend_var(interdependent_state):
|
||||
"""Set backend var that ComputedVar depends on, expect recalculation.
|
||||
|
||||
Args:
|
||||
interdependent_state: A state with varying Var dependencies.
|
||||
"""
|
||||
interdependent_state._v2 = 2
|
||||
assert interdependent_state.get_delta() == {
|
||||
interdependent_state.get_full_name(): {"v2x2": 4},
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user