track substate ComputedVar that depends on parent Var (#852)
This commit is contained in:
parent
b5bc7e5d8c
commit
fabaa7be1c
@ -75,6 +75,9 @@ class State(Base, ABC, extra=pydantic.Extra.allow):
|
|||||||
# Mapping of var name to set of computed variables that depend on it
|
# Mapping of var name to set of computed variables that depend on it
|
||||||
computed_var_dependencies: Dict[str, Set[str]] = {}
|
computed_var_dependencies: Dict[str, Set[str]] = {}
|
||||||
|
|
||||||
|
# Mapping of var name to set of substates that depend on it
|
||||||
|
substate_var_dependencies: Dict[str, Set[str]] = {}
|
||||||
|
|
||||||
# Per-instance copy of backend variable values
|
# Per-instance copy of backend variable values
|
||||||
_backend_vars: Dict[str, Any] = {}
|
_backend_vars: Dict[str, Any] = {}
|
||||||
|
|
||||||
@ -89,6 +92,10 @@ class State(Base, ABC, extra=pydantic.Extra.allow):
|
|||||||
kwargs["parent_state"] = parent_state
|
kwargs["parent_state"] = parent_state
|
||||||
super().__init__(*args, **kwargs)
|
super().__init__(*args, **kwargs)
|
||||||
|
|
||||||
|
# initialize per-instance var dependency tracking
|
||||||
|
self.computed_var_dependencies = defaultdict(set)
|
||||||
|
self.substate_var_dependencies = defaultdict(set)
|
||||||
|
|
||||||
# Setup the substates.
|
# Setup the substates.
|
||||||
for substate in self.get_substates():
|
for substate in self.get_substates():
|
||||||
self.substates[substate.get_name()] = substate(parent_state=self)
|
self.substates[substate.get_name()] = substate(parent_state=self)
|
||||||
@ -101,11 +108,23 @@ class State(Base, ABC, extra=pydantic.Extra.allow):
|
|||||||
setattr(self, name, fn)
|
setattr(self, name, fn)
|
||||||
|
|
||||||
# Initialize computed vars dependencies.
|
# Initialize computed vars dependencies.
|
||||||
self.computed_var_dependencies = defaultdict(set)
|
inherited_vars = set(self.inherited_vars).union(
|
||||||
|
set(self.inherited_backend_vars),
|
||||||
|
)
|
||||||
for cvar_name, cvar in self.computed_vars.items():
|
for cvar_name, cvar in self.computed_vars.items():
|
||||||
# Add the dependencies.
|
# Add the dependencies.
|
||||||
for var in cvar.deps(objclass=type(self)):
|
for var in cvar.deps(objclass=type(self)):
|
||||||
self.computed_var_dependencies[var].add(cvar_name)
|
self.computed_var_dependencies[var].add(cvar_name)
|
||||||
|
if var in inherited_vars:
|
||||||
|
# track that this substate depends on its parent for this var
|
||||||
|
state_name = self.get_name()
|
||||||
|
parent_state = self.parent_state
|
||||||
|
while parent_state is not None and var in parent_state.vars:
|
||||||
|
parent_state.substate_var_dependencies[var].add(state_name)
|
||||||
|
state_name, parent_state = (
|
||||||
|
parent_state.get_name(),
|
||||||
|
parent_state.parent_state,
|
||||||
|
)
|
||||||
|
|
||||||
# Initialize the mutable fields.
|
# Initialize the mutable fields.
|
||||||
self._init_mutable_fields()
|
self._init_mutable_fields()
|
||||||
@ -226,6 +245,7 @@ class State(Base, ABC, extra=pydantic.Extra.allow):
|
|||||||
"dirty_substates",
|
"dirty_substates",
|
||||||
"router_data",
|
"router_data",
|
||||||
"computed_var_dependencies",
|
"computed_var_dependencies",
|
||||||
|
"substate_var_dependencies",
|
||||||
"_backend_vars",
|
"_backend_vars",
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -377,6 +397,10 @@ class State(Base, ABC, extra=pydantic.Extra.allow):
|
|||||||
cls.base_vars.update({name: var})
|
cls.base_vars.update({name: var})
|
||||||
cls.vars.update({name: var})
|
cls.vars.update({name: var})
|
||||||
|
|
||||||
|
# let substates know about the new variable
|
||||||
|
for substate_class in cls.__subclasses__():
|
||||||
|
substate_class.vars.setdefault(name, var)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def _set_var(cls, prop: BaseVar):
|
def _set_var(cls, prop: BaseVar):
|
||||||
"""Set the var as a class member.
|
"""Set the var as a class member.
|
||||||
@ -698,11 +722,6 @@ class State(Base, ABC, extra=pydantic.Extra.allow):
|
|||||||
"""
|
"""
|
||||||
delta = {}
|
delta = {}
|
||||||
|
|
||||||
# Recursively find the substate deltas.
|
|
||||||
substates = self.substates
|
|
||||||
for substate in self.dirty_substates:
|
|
||||||
delta.update(substates[substate].get_delta())
|
|
||||||
|
|
||||||
# Return the dirty vars and dependent computed vars
|
# Return the dirty vars and dependent computed vars
|
||||||
self._mark_dirty_computed_vars()
|
self._mark_dirty_computed_vars()
|
||||||
delta_vars = self.dirty_vars.intersection(self.base_vars).union(
|
delta_vars = self.dirty_vars.intersection(self.base_vars).union(
|
||||||
@ -716,6 +735,11 @@ class State(Base, ABC, extra=pydantic.Extra.allow):
|
|||||||
if len(subdelta) > 0:
|
if len(subdelta) > 0:
|
||||||
delta[self.get_full_name()] = subdelta
|
delta[self.get_full_name()] = subdelta
|
||||||
|
|
||||||
|
# Recursively find the substate deltas.
|
||||||
|
substates = self.substates
|
||||||
|
for substate in self.dirty_substates:
|
||||||
|
delta.update(substates[substate].get_delta())
|
||||||
|
|
||||||
# Format the delta.
|
# Format the delta.
|
||||||
delta = format.format_state(delta)
|
delta = format.format_state(delta)
|
||||||
|
|
||||||
@ -724,7 +748,11 @@ class State(Base, ABC, extra=pydantic.Extra.allow):
|
|||||||
|
|
||||||
def mark_dirty(self):
|
def mark_dirty(self):
|
||||||
"""Mark the substate and all parent states as dirty."""
|
"""Mark the substate and all parent states as dirty."""
|
||||||
if self.parent_state is not None:
|
state_name = self.get_name()
|
||||||
|
if (
|
||||||
|
self.parent_state is not None
|
||||||
|
and state_name not in self.parent_state.dirty_substates
|
||||||
|
):
|
||||||
self.parent_state.dirty_substates.add(self.get_name())
|
self.parent_state.dirty_substates.add(self.get_name())
|
||||||
self.parent_state.mark_dirty()
|
self.parent_state.mark_dirty()
|
||||||
|
|
||||||
@ -732,6 +760,15 @@ class State(Base, ABC, extra=pydantic.Extra.allow):
|
|||||||
# values within the same ComputedVar function
|
# values within the same ComputedVar function
|
||||||
self._mark_dirty_computed_vars()
|
self._mark_dirty_computed_vars()
|
||||||
|
|
||||||
|
# Propagate dirty var / computed var status into substates
|
||||||
|
substates = self.substates
|
||||||
|
for var in self.dirty_vars:
|
||||||
|
for substate_name in self.substate_var_dependencies[var]:
|
||||||
|
self.dirty_substates.add(substate_name)
|
||||||
|
substate = substates[substate_name]
|
||||||
|
substate.dirty_vars.add(var)
|
||||||
|
substate.mark_dirty()
|
||||||
|
|
||||||
def clean(self):
|
def clean(self):
|
||||||
"""Reset the dirty vars."""
|
"""Reset the dirty vars."""
|
||||||
# Recursively clean the substates.
|
# Recursively clean the substates.
|
||||||
|
Loading…
Reference in New Issue
Block a user