track substate ComputedVar that depends on parent Var (#852)
This commit is contained in:
parent
b5bc7e5d8c
commit
fabaa7be1c
@ -75,6 +75,9 @@ class State(Base, ABC, extra=pydantic.Extra.allow):
|
||||
# Mapping of var name to set of computed variables that depend on it
|
||||
computed_var_dependencies: Dict[str, Set[str]] = {}
|
||||
|
||||
# Mapping of var name to set of substates that depend on it
|
||||
substate_var_dependencies: Dict[str, Set[str]] = {}
|
||||
|
||||
# Per-instance copy of backend variable values
|
||||
_backend_vars: Dict[str, Any] = {}
|
||||
|
||||
@ -89,6 +92,10 @@ class State(Base, ABC, extra=pydantic.Extra.allow):
|
||||
kwargs["parent_state"] = parent_state
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
# initialize per-instance var dependency tracking
|
||||
self.computed_var_dependencies = defaultdict(set)
|
||||
self.substate_var_dependencies = defaultdict(set)
|
||||
|
||||
# Setup the substates.
|
||||
for substate in self.get_substates():
|
||||
self.substates[substate.get_name()] = substate(parent_state=self)
|
||||
@ -101,11 +108,23 @@ class State(Base, ABC, extra=pydantic.Extra.allow):
|
||||
setattr(self, name, fn)
|
||||
|
||||
# Initialize computed vars dependencies.
|
||||
self.computed_var_dependencies = defaultdict(set)
|
||||
inherited_vars = set(self.inherited_vars).union(
|
||||
set(self.inherited_backend_vars),
|
||||
)
|
||||
for cvar_name, cvar in self.computed_vars.items():
|
||||
# Add the dependencies.
|
||||
for var in cvar.deps(objclass=type(self)):
|
||||
self.computed_var_dependencies[var].add(cvar_name)
|
||||
if var in inherited_vars:
|
||||
# track that this substate depends on its parent for this var
|
||||
state_name = self.get_name()
|
||||
parent_state = self.parent_state
|
||||
while parent_state is not None and var in parent_state.vars:
|
||||
parent_state.substate_var_dependencies[var].add(state_name)
|
||||
state_name, parent_state = (
|
||||
parent_state.get_name(),
|
||||
parent_state.parent_state,
|
||||
)
|
||||
|
||||
# Initialize the mutable fields.
|
||||
self._init_mutable_fields()
|
||||
@ -226,6 +245,7 @@ class State(Base, ABC, extra=pydantic.Extra.allow):
|
||||
"dirty_substates",
|
||||
"router_data",
|
||||
"computed_var_dependencies",
|
||||
"substate_var_dependencies",
|
||||
"_backend_vars",
|
||||
}
|
||||
|
||||
@ -377,6 +397,10 @@ class State(Base, ABC, extra=pydantic.Extra.allow):
|
||||
cls.base_vars.update({name: var})
|
||||
cls.vars.update({name: var})
|
||||
|
||||
# let substates know about the new variable
|
||||
for substate_class in cls.__subclasses__():
|
||||
substate_class.vars.setdefault(name, var)
|
||||
|
||||
@classmethod
|
||||
def _set_var(cls, prop: BaseVar):
|
||||
"""Set the var as a class member.
|
||||
@ -698,11 +722,6 @@ class State(Base, ABC, extra=pydantic.Extra.allow):
|
||||
"""
|
||||
delta = {}
|
||||
|
||||
# Recursively find the substate deltas.
|
||||
substates = self.substates
|
||||
for substate in self.dirty_substates:
|
||||
delta.update(substates[substate].get_delta())
|
||||
|
||||
# Return the dirty vars and dependent computed vars
|
||||
self._mark_dirty_computed_vars()
|
||||
delta_vars = self.dirty_vars.intersection(self.base_vars).union(
|
||||
@ -716,6 +735,11 @@ class State(Base, ABC, extra=pydantic.Extra.allow):
|
||||
if len(subdelta) > 0:
|
||||
delta[self.get_full_name()] = subdelta
|
||||
|
||||
# Recursively find the substate deltas.
|
||||
substates = self.substates
|
||||
for substate in self.dirty_substates:
|
||||
delta.update(substates[substate].get_delta())
|
||||
|
||||
# Format the delta.
|
||||
delta = format.format_state(delta)
|
||||
|
||||
@ -724,7 +748,11 @@ class State(Base, ABC, extra=pydantic.Extra.allow):
|
||||
|
||||
def mark_dirty(self):
|
||||
"""Mark the substate and all parent states as dirty."""
|
||||
if self.parent_state is not None:
|
||||
state_name = self.get_name()
|
||||
if (
|
||||
self.parent_state is not None
|
||||
and state_name not in self.parent_state.dirty_substates
|
||||
):
|
||||
self.parent_state.dirty_substates.add(self.get_name())
|
||||
self.parent_state.mark_dirty()
|
||||
|
||||
@ -732,6 +760,15 @@ class State(Base, ABC, extra=pydantic.Extra.allow):
|
||||
# values within the same ComputedVar function
|
||||
self._mark_dirty_computed_vars()
|
||||
|
||||
# Propagate dirty var / computed var status into substates
|
||||
substates = self.substates
|
||||
for var in self.dirty_vars:
|
||||
for substate_name in self.substate_var_dependencies[var]:
|
||||
self.dirty_substates.add(substate_name)
|
||||
substate = substates[substate_name]
|
||||
substate.dirty_vars.add(var)
|
||||
substate.mark_dirty()
|
||||
|
||||
def clean(self):
|
||||
"""Reset the dirty vars."""
|
||||
# Recursively clean the substates.
|
||||
|
Loading…
Reference in New Issue
Block a user