From fabaa7be1c6f07bd759f7800f963489fe8c627e5 Mon Sep 17 00:00:00 2001 From: Masen Furer Date: Tue, 9 May 2023 16:12:24 -0700 Subject: [PATCH] track substate ComputedVar that depends on parent Var (#852) --- pynecone/state.py | 51 ++++++++++++++++++++++++++++++++++++++++------- 1 file changed, 44 insertions(+), 7 deletions(-) diff --git a/pynecone/state.py b/pynecone/state.py index 70a5010d1..62fca226a 100644 --- a/pynecone/state.py +++ b/pynecone/state.py @@ -75,6 +75,9 @@ class State(Base, ABC, extra=pydantic.Extra.allow): # Mapping of var name to set of computed variables that depend on it computed_var_dependencies: Dict[str, Set[str]] = {} + # Mapping of var name to set of substates that depend on it + substate_var_dependencies: Dict[str, Set[str]] = {} + # Per-instance copy of backend variable values _backend_vars: Dict[str, Any] = {} @@ -89,6 +92,10 @@ class State(Base, ABC, extra=pydantic.Extra.allow): kwargs["parent_state"] = parent_state super().__init__(*args, **kwargs) + # initialize per-instance var dependency tracking + self.computed_var_dependencies = defaultdict(set) + self.substate_var_dependencies = defaultdict(set) + # Setup the substates. for substate in self.get_substates(): self.substates[substate.get_name()] = substate(parent_state=self) @@ -101,11 +108,23 @@ class State(Base, ABC, extra=pydantic.Extra.allow): setattr(self, name, fn) # Initialize computed vars dependencies. - self.computed_var_dependencies = defaultdict(set) + inherited_vars = set(self.inherited_vars).union( + set(self.inherited_backend_vars), + ) for cvar_name, cvar in self.computed_vars.items(): # Add the dependencies. for var in cvar.deps(objclass=type(self)): self.computed_var_dependencies[var].add(cvar_name) + if var in inherited_vars: + # track that this substate depends on its parent for this var + state_name = self.get_name() + parent_state = self.parent_state + while parent_state is not None and var in parent_state.vars: + parent_state.substate_var_dependencies[var].add(state_name) + state_name, parent_state = ( + parent_state.get_name(), + parent_state.parent_state, + ) # Initialize the mutable fields. self._init_mutable_fields() @@ -226,6 +245,7 @@ class State(Base, ABC, extra=pydantic.Extra.allow): "dirty_substates", "router_data", "computed_var_dependencies", + "substate_var_dependencies", "_backend_vars", } @@ -377,6 +397,10 @@ class State(Base, ABC, extra=pydantic.Extra.allow): cls.base_vars.update({name: var}) cls.vars.update({name: var}) + # let substates know about the new variable + for substate_class in cls.__subclasses__(): + substate_class.vars.setdefault(name, var) + @classmethod def _set_var(cls, prop: BaseVar): """Set the var as a class member. @@ -698,11 +722,6 @@ class State(Base, ABC, extra=pydantic.Extra.allow): """ delta = {} - # Recursively find the substate deltas. - substates = self.substates - for substate in self.dirty_substates: - delta.update(substates[substate].get_delta()) - # Return the dirty vars and dependent computed vars self._mark_dirty_computed_vars() delta_vars = self.dirty_vars.intersection(self.base_vars).union( @@ -716,6 +735,11 @@ class State(Base, ABC, extra=pydantic.Extra.allow): if len(subdelta) > 0: delta[self.get_full_name()] = subdelta + # Recursively find the substate deltas. + substates = self.substates + for substate in self.dirty_substates: + delta.update(substates[substate].get_delta()) + # Format the delta. delta = format.format_state(delta) @@ -724,7 +748,11 @@ class State(Base, ABC, extra=pydantic.Extra.allow): def mark_dirty(self): """Mark the substate and all parent states as dirty.""" - if self.parent_state is not None: + state_name = self.get_name() + if ( + self.parent_state is not None + and state_name not in self.parent_state.dirty_substates + ): self.parent_state.dirty_substates.add(self.get_name()) self.parent_state.mark_dirty() @@ -732,6 +760,15 @@ class State(Base, ABC, extra=pydantic.Extra.allow): # values within the same ComputedVar function self._mark_dirty_computed_vars() + # Propagate dirty var / computed var status into substates + substates = self.substates + for var in self.dirty_vars: + for substate_name in self.substate_var_dependencies[var]: + self.dirty_substates.add(substate_name) + substate = substates[substate_name] + substate.dirty_vars.add(var) + substate.mark_dirty() + def clean(self): """Reset the dirty vars.""" # Recursively clean the substates.