Only update ComputedVar when dependent vars change (#840)
This commit is contained in:
parent
3be43bdab1
commit
b4755b8123
@ -3,6 +3,7 @@ from __future__ import annotations
|
|||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
import functools
|
import functools
|
||||||
|
import inspect
|
||||||
import traceback
|
import traceback
|
||||||
from abc import ABC
|
from abc import ABC
|
||||||
from typing import (
|
from typing import (
|
||||||
@ -14,6 +15,7 @@ from typing import (
|
|||||||
Optional,
|
Optional,
|
||||||
Sequence,
|
Sequence,
|
||||||
Set,
|
Set,
|
||||||
|
Tuple,
|
||||||
Type,
|
Type,
|
||||||
Union,
|
Union,
|
||||||
)
|
)
|
||||||
@ -51,6 +53,9 @@ class State(Base, ABC):
|
|||||||
# Backend vars inherited
|
# Backend vars inherited
|
||||||
inherited_backend_vars: ClassVar[Dict[str, Any]] = {}
|
inherited_backend_vars: ClassVar[Dict[str, Any]] = {}
|
||||||
|
|
||||||
|
# Mapping of var name to set of computed variables that depend on it
|
||||||
|
computed_var_dependencies: ClassVar[Dict[str, Set[str]]] = {}
|
||||||
|
|
||||||
# The event handlers.
|
# The event handlers.
|
||||||
event_handlers: ClassVar[Dict[str, EventHandler]] = {}
|
event_handlers: ClassVar[Dict[str, EventHandler]] = {}
|
||||||
|
|
||||||
@ -171,6 +176,7 @@ class State(Base, ABC):
|
|||||||
**cls.base_vars,
|
**cls.base_vars,
|
||||||
**cls.computed_vars,
|
**cls.computed_vars,
|
||||||
}
|
}
|
||||||
|
cls.computed_var_dependencies = {}
|
||||||
|
|
||||||
# Setup the base vars at the class level.
|
# Setup the base vars at the class level.
|
||||||
for prop in cls.base_vars.values():
|
for prop in cls.base_vars.values():
|
||||||
@ -472,12 +478,28 @@ class State(Base, ABC):
|
|||||||
|
|
||||||
If the var is inherited, get the var from the parent state.
|
If the var is inherited, get the var from the parent state.
|
||||||
|
|
||||||
|
If the Var is a dependent of a ComputedVar, track this status in computed_var_dependencies.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
name: The name of the var.
|
name: The name of the var.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
The value of the var.
|
The value of the var.
|
||||||
"""
|
"""
|
||||||
|
vars = {
|
||||||
|
**super().__getattribute__("vars"),
|
||||||
|
**super().__getattribute__("backend_vars"),
|
||||||
|
}
|
||||||
|
if name in vars:
|
||||||
|
parent_frame, parent_frame_locals = _get_previous_recursive_frame_info()
|
||||||
|
if parent_frame is not None:
|
||||||
|
computed_vars = super().__getattribute__("computed_vars")
|
||||||
|
requesting_attribute_name = parent_frame_locals.get("name")
|
||||||
|
if requesting_attribute_name in computed_vars:
|
||||||
|
# Keep track of any ComputedVar that depends on this Var
|
||||||
|
super().__getattribute__("computed_var_dependencies").setdefault(
|
||||||
|
name, set()
|
||||||
|
).add(requesting_attribute_name)
|
||||||
inherited_vars = {
|
inherited_vars = {
|
||||||
**super().__getattribute__("inherited_vars"),
|
**super().__getattribute__("inherited_vars"),
|
||||||
**super().__getattribute__("inherited_backend_vars"),
|
**super().__getattribute__("inherited_backend_vars"),
|
||||||
@ -505,6 +527,7 @@ class State(Base, ABC):
|
|||||||
|
|
||||||
if types.is_backend_variable(name):
|
if types.is_backend_variable(name):
|
||||||
self.backend_vars.__setitem__(name, value)
|
self.backend_vars.__setitem__(name, value)
|
||||||
|
self.dirty_vars.add(name)
|
||||||
self.mark_dirty()
|
self.mark_dirty()
|
||||||
return
|
return
|
||||||
|
|
||||||
@ -622,6 +645,28 @@ class State(Base, ABC):
|
|||||||
# Return the state update.
|
# Return the state update.
|
||||||
return StateUpdate(delta=delta, events=events)
|
return StateUpdate(delta=delta, events=events)
|
||||||
|
|
||||||
|
def _dirty_computed_vars(self, from_vars: Optional[Set[str]] = None) -> Set[str]:
|
||||||
|
"""Get ComputedVars that need to be recomputed based on dirty_vars.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
from_vars: find ComputedVar that depend on this set of vars. If unspecified, will use the dirty_vars.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Set of computed vars to include in the delta.
|
||||||
|
"""
|
||||||
|
dirty_computed_vars = set(
|
||||||
|
cvar
|
||||||
|
for dirty_var in from_vars or self.dirty_vars
|
||||||
|
for cvar in self.computed_vars
|
||||||
|
if cvar in self.computed_var_dependencies.get(dirty_var, set())
|
||||||
|
)
|
||||||
|
if dirty_computed_vars:
|
||||||
|
# recursive call to catch computed vars that depend on computed vars
|
||||||
|
return dirty_computed_vars | self._dirty_computed_vars(
|
||||||
|
from_vars=dirty_computed_vars
|
||||||
|
)
|
||||||
|
return dirty_computed_vars
|
||||||
|
|
||||||
def get_delta(self) -> Delta:
|
def get_delta(self) -> Delta:
|
||||||
"""Get the delta for the state.
|
"""Get the delta for the state.
|
||||||
|
|
||||||
@ -630,10 +675,11 @@ class State(Base, ABC):
|
|||||||
"""
|
"""
|
||||||
delta = {}
|
delta = {}
|
||||||
|
|
||||||
# Return the dirty vars, as well as all computed vars.
|
# Return the dirty vars, as well as computed vars depending on dirty vars.
|
||||||
subdelta = {
|
subdelta = {
|
||||||
prop: getattr(self, prop)
|
prop: getattr(self, prop)
|
||||||
for prop in self.dirty_vars | self.computed_vars.keys()
|
for prop in self.dirty_vars | self._dirty_computed_vars()
|
||||||
|
if not types.is_backend_variable(prop)
|
||||||
}
|
}
|
||||||
if len(subdelta) > 0:
|
if len(subdelta) > 0:
|
||||||
delta[self.get_full_name()] = subdelta
|
delta[self.get_full_name()] = subdelta
|
||||||
@ -803,3 +849,24 @@ def _convert_mutable_datatypes(
|
|||||||
field_value, reassign_field=reassign_field, field_name=field_name
|
field_value, reassign_field=reassign_field, field_name=field_name
|
||||||
)
|
)
|
||||||
return field_value
|
return field_value
|
||||||
|
|
||||||
|
|
||||||
|
def _get_previous_recursive_frame_info() -> (
|
||||||
|
Tuple[Optional[inspect.FrameInfo], Dict[str, Any]]
|
||||||
|
):
|
||||||
|
"""Find the previous frame of the same function that calls this helper.
|
||||||
|
|
||||||
|
For example, if this function is called from `State.__getattribute__`
|
||||||
|
(parent frame), then the returned frame will be the next earliest call
|
||||||
|
of the same function.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple of (frame_info, local_vars)
|
||||||
|
|
||||||
|
If no previous recursive frame is found up the stack, the frame info will be None.
|
||||||
|
"""
|
||||||
|
_this_frame, parent_frame, *prev_frames = inspect.stack()
|
||||||
|
for frame in prev_frames:
|
||||||
|
if frame.frame.f_code == parent_frame.frame.f_code:
|
||||||
|
return frame, frame.frame.f_locals
|
||||||
|
return None, {}
|
||||||
|
@ -582,7 +582,7 @@ async def test_process_event_simple(test_state):
|
|||||||
assert test_state.num1 == 69
|
assert test_state.num1 == 69
|
||||||
|
|
||||||
# The delta should contain the changes, including computed vars.
|
# The delta should contain the changes, including computed vars.
|
||||||
assert update.delta == {"test_state": {"num1": 69, "sum": 72.14, "upper": ""}}
|
assert update.delta == {"test_state": {"num1": 69, "sum": 72.14}}
|
||||||
assert update.events == []
|
assert update.events == []
|
||||||
|
|
||||||
|
|
||||||
@ -606,7 +606,6 @@ async def test_process_event_substate(test_state, child_state, grandchild_state)
|
|||||||
assert child_state.count == 24
|
assert child_state.count == 24
|
||||||
assert update.delta == {
|
assert update.delta == {
|
||||||
"test_state.child_state": {"value": "HI", "count": 24},
|
"test_state.child_state": {"value": "HI", "count": 24},
|
||||||
"test_state": {"sum": 3.14, "upper": ""},
|
|
||||||
}
|
}
|
||||||
test_state.clean()
|
test_state.clean()
|
||||||
|
|
||||||
@ -621,7 +620,6 @@ async def test_process_event_substate(test_state, child_state, grandchild_state)
|
|||||||
assert grandchild_state.value2 == "new"
|
assert grandchild_state.value2 == "new"
|
||||||
assert update.delta == {
|
assert update.delta == {
|
||||||
"test_state.child_state.grandchild_state": {"value2": "new"},
|
"test_state.child_state.grandchild_state": {"value2": "new"},
|
||||||
"test_state": {"sum": 3.14, "upper": ""},
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@ -724,3 +722,98 @@ def test_add_var(test_state):
|
|||||||
test_state.add_var("dynamic_dict", Dict[str, int], {"k1": 5, "k2": 10})
|
test_state.add_var("dynamic_dict", Dict[str, int], {"k1": 5, "k2": 10})
|
||||||
assert test_state.dynamic_dict == {"k1": 5, "k2": 10}
|
assert test_state.dynamic_dict == {"k1": 5, "k2": 10}
|
||||||
assert test_state.dynamic_dict == {"k1": 5, "k2": 10}
|
assert test_state.dynamic_dict == {"k1": 5, "k2": 10}
|
||||||
|
|
||||||
|
|
||||||
|
class InterdependentState(State):
|
||||||
|
"""A state with 3 vars and 3 computed vars.
|
||||||
|
|
||||||
|
x: a variable that no computed var depends on
|
||||||
|
v1: a varable that one computed var directly depeneds on
|
||||||
|
_v2: a backend variable that one computed var directly depends on
|
||||||
|
|
||||||
|
v1x2: a computed var that depends on v1
|
||||||
|
v2x2: a computed var that depends on backend var _v2
|
||||||
|
v1x2x2: a computed var that depends on computed var v1x2
|
||||||
|
"""
|
||||||
|
|
||||||
|
x: int = 0
|
||||||
|
v1: int = 0
|
||||||
|
_v2: int = 1
|
||||||
|
|
||||||
|
@ComputedVar
|
||||||
|
def v1x2(self) -> int:
|
||||||
|
"""depends on var v1.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Var v1 multiplied by 2
|
||||||
|
"""
|
||||||
|
return self.v1 * 2
|
||||||
|
|
||||||
|
@ComputedVar
|
||||||
|
def v2x2(self) -> int:
|
||||||
|
"""depends on backend var _v2.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
backend var _v2 multiplied by 2
|
||||||
|
"""
|
||||||
|
return self._v2 * 2
|
||||||
|
|
||||||
|
@ComputedVar
|
||||||
|
def v1x2x2(self) -> int:
|
||||||
|
"""depends on ComputedVar v1x2.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
ComputedVar v1x2 multiplied by 2
|
||||||
|
"""
|
||||||
|
return self.v1x2 * 2
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def interdependent_state() -> State:
|
||||||
|
"""A state with varying dependency between vars.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
instance of InterdependentState
|
||||||
|
"""
|
||||||
|
s = InterdependentState()
|
||||||
|
s.dict() # prime initial relationships by accessing all ComputedVars
|
||||||
|
return s
|
||||||
|
|
||||||
|
|
||||||
|
def test_not_dirty_computed_var_from_var(interdependent_state):
|
||||||
|
"""Set Var that no ComputedVar depends on, expect no recalculation.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
interdependent_state: A state with varying Var dependencies.
|
||||||
|
"""
|
||||||
|
interdependent_state.x = 5
|
||||||
|
assert interdependent_state.get_delta() == {
|
||||||
|
interdependent_state.get_full_name(): {"x": 5},
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def test_dirty_computed_var_from_var(interdependent_state):
|
||||||
|
"""Set Var that ComputedVar depends on, expect recalculation.
|
||||||
|
|
||||||
|
The other ComputedVar depends on the changed ComputedVar and should also be
|
||||||
|
recalculated. No other ComputedVars should be recalculated.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
interdependent_state: A state with varying Var dependencies.
|
||||||
|
"""
|
||||||
|
interdependent_state.v1 = 1
|
||||||
|
assert interdependent_state.get_delta() == {
|
||||||
|
interdependent_state.get_full_name(): {"v1": 1, "v1x2": 2, "v1x2x2": 4},
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def test_dirty_computed_var_from_backend_var(interdependent_state):
|
||||||
|
"""Set backend var that ComputedVar depends on, expect recalculation.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
interdependent_state: A state with varying Var dependencies.
|
||||||
|
"""
|
||||||
|
interdependent_state._v2 = 2
|
||||||
|
assert interdependent_state.get_delta() == {
|
||||||
|
interdependent_state.get_full_name(): {"v2x2": 4},
|
||||||
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user