Temp remove computed var dependency checks (#972)

This commit is contained in:
Nikhil Rao 2023-05-08 18:00:03 -07:00 committed by GitHub
parent 3d3c974768
commit dc2dff9323
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 59 additions and 46 deletions

View File

@ -660,15 +660,20 @@ class State(Base, ABC, extra=pydantic.Extra.allow):
def _mark_dirty_computed_vars(self) -> None:
"""Mark ComputedVars that need to be recalculated based on dirty_vars."""
dirty_vars = self.dirty_vars
while dirty_vars:
calc_vars, dirty_vars = dirty_vars, set()
for cvar in self._dirty_computed_vars(from_vars=calc_vars):
self.dirty_vars.add(cvar)
dirty_vars.add(cvar)
actual_var = self.computed_vars.get(cvar)
if actual_var:
actual_var.mark_dirty(instance=self)
# Mark all ComputedVars as dirty.
for cvar in self.computed_vars.values():
cvar.mark_dirty(instance=self)
# TODO: Uncomment the actual implementation below.
# dirty_vars = self.dirty_vars
# while dirty_vars:
# calc_vars, dirty_vars = dirty_vars, set()
# for cvar in self._dirty_computed_vars(from_vars=calc_vars):
# self.dirty_vars.add(cvar)
# dirty_vars.add(cvar)
# actual_var = self.computed_vars.get(cvar)
# if actual_var:
# actual_var.mark_dirty(instance=self)
def _dirty_computed_vars(self, from_vars: Optional[Set[str]] = None) -> Set[str]:
"""Determine ComputedVars that need to be recalculated based on the given vars.
@ -679,11 +684,13 @@ class State(Base, ABC, extra=pydantic.Extra.allow):
Returns:
Set of computed vars to include in the delta.
"""
return set(
cvar
for dirty_var in from_vars or self.dirty_vars
for cvar in self.computed_var_dependencies[dirty_var]
)
return set(self.computed_vars)
# TODO: Uncomment the actual implementation below.
# return set(
# cvar
# for dirty_var in from_vars or self.dirty_vars
# for cvar in self.computed_var_dependencies[dirty_var]
# )
def get_delta(self) -> Delta:
"""Get the delta for the state.
@ -699,6 +706,7 @@ class State(Base, ABC, extra=pydantic.Extra.allow):
delta.update(substates[substate].get_delta())
# Return the dirty vars and dependent computed vars
self._mark_dirty_computed_vars()
delta_vars = self.dirty_vars.intersection(self.base_vars).union(
self._dirty_computed_vars()
)

View File

@ -484,11 +484,13 @@ def test_set_dirty_var(test_state):
# Setting a var should mark it as dirty.
test_state.num1 = 1
assert test_state.dirty_vars == {"num1", "sum"}
# assert test_state.dirty_vars == {"num1", "sum"}
assert test_state.dirty_vars == {"num1"}
# Setting another var should mark it as dirty.
test_state.num2 = 2
assert test_state.dirty_vars == {"num1", "num2", "sum"}
# assert test_state.dirty_vars == {"num1", "num2", "sum"}
assert test_state.dirty_vars == {"num1", "num2"}
# Cleaning the state should remove all dirty vars.
test_state.clean()
@ -577,7 +579,8 @@ async def test_process_event_simple(test_state):
assert test_state.num1 == 69
# The delta should contain the changes, including computed vars.
assert update.delta == {"test_state": {"num1": 69, "sum": 72.14}}
# assert update.delta == {"test_state": {"num1": 69, "sum": 72.14}}
assert update.delta == {"test_state": {"num1": 69, "sum": 72.14, "upper": ""}}
assert update.events == []
@ -600,6 +603,7 @@ async def test_process_event_substate(test_state, child_state, grandchild_state)
assert child_state.value == "HI"
assert child_state.count == 24
assert update.delta == {
"test_state": {"sum": 3.14, "upper": ""},
"test_state.child_state": {"value": "HI", "count": 24},
}
test_state.clean()
@ -614,6 +618,7 @@ async def test_process_event_substate(test_state, child_state, grandchild_state)
update = await test_state._process(event)
assert grandchild_state.value2 == "new"
assert update.delta == {
"test_state": {"sum": 3.14, "upper": ""},
"test_state.child_state.grandchild_state": {"value2": "new"},
}
@ -781,43 +786,43 @@ def interdependent_state() -> State:
return s
def test_not_dirty_computed_var_from_var(interdependent_state):
"""Set Var that no ComputedVar depends on, expect no recalculation.
# def test_not_dirty_computed_var_from_var(interdependent_state):
# """Set Var that no ComputedVar depends on, expect no recalculation.
Args:
interdependent_state: A state with varying Var dependencies.
"""
interdependent_state.x = 5
assert interdependent_state.get_delta() == {
interdependent_state.get_full_name(): {"x": 5},
}
# Args:
# interdependent_state: A state with varying Var dependencies.
# """
# interdependent_state.x = 5
# assert interdependent_state.get_delta() == {
# interdependent_state.get_full_name(): {"x": 5},
# }
def test_dirty_computed_var_from_var(interdependent_state):
"""Set Var that ComputedVar depends on, expect recalculation.
# def test_dirty_computed_var_from_var(interdependent_state):
# """Set Var that ComputedVar depends on, expect recalculation.
The other ComputedVar depends on the changed ComputedVar and should also be
recalculated. No other ComputedVars should be recalculated.
# The other ComputedVar depends on the changed ComputedVar and should also be
# recalculated. No other ComputedVars should be recalculated.
Args:
interdependent_state: A state with varying Var dependencies.
"""
interdependent_state.v1 = 1
assert interdependent_state.get_delta() == {
interdependent_state.get_full_name(): {"v1": 1, "v1x2": 2, "v1x2x2": 4},
}
# Args:
# interdependent_state: A state with varying Var dependencies.
# """
# interdependent_state.v1 = 1
# assert interdependent_state.get_delta() == {
# interdependent_state.get_full_name(): {"v1": 1, "v1x2": 2, "v1x2x2": 4},
# }
def test_dirty_computed_var_from_backend_var(interdependent_state):
"""Set backend var that ComputedVar depends on, expect recalculation.
# def test_dirty_computed_var_from_backend_var(interdependent_state):
# """Set backend var that ComputedVar depends on, expect recalculation.
Args:
interdependent_state: A state with varying Var dependencies.
"""
interdependent_state._v2 = 2
assert interdependent_state.get_delta() == {
interdependent_state.get_full_name(): {"v2x2": 4},
}
# Args:
# interdependent_state: A state with varying Var dependencies.
# """
# interdependent_state._v2 = 2
# assert interdependent_state.get_delta() == {
# interdependent_state.get_full_name(): {"v2x2": 4},
# }
def test_per_state_backend_var(interdependent_state):