[REF-1035] Track ComputedVar dependency per class (#2067)

This commit is contained in:
Masen Furer 2023-11-27 18:17:53 -08:00 committed by GitHub
parent 626357ed87
commit ee87e62efa
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 108 additions and 67 deletions

View File

@ -143,6 +143,15 @@ class RouterData(Base):
self.page = PageData(router_data)
RESERVED_BACKEND_VAR_NAMES = {
"_backend_vars",
"_computed_var_dependencies",
"_substate_var_dependencies",
"_always_dirty_computed_vars",
"_always_dirty_substates",
}
class State(Base, ABC, extra=pydantic.Extra.allow):
"""The state of the app."""
@ -167,6 +176,18 @@ class State(Base, ABC, extra=pydantic.Extra.allow):
# The event handlers.
event_handlers: ClassVar[Dict[str, EventHandler]] = {}
# Mapping of var name to set of computed variables that depend on it
_computed_var_dependencies: ClassVar[Dict[str, Set[str]]] = {}
# Mapping of var name to set of substates that depend on it
_substate_var_dependencies: ClassVar[Dict[str, Set[str]]] = {}
# Set of vars which always need to be recomputed
_always_dirty_computed_vars: ClassVar[Set[str]] = set()
# Set of substates which always need to be recomputed
_always_dirty_substates: ClassVar[Set[str]] = set()
# The parent state.
parent_state: Optional[State] = None
@ -182,12 +203,6 @@ class State(Base, ABC, extra=pydantic.Extra.allow):
# The routing path that triggered the state
router_data: Dict[str, Any] = {}
# Mapping of var name to set of computed variables that depend on it
computed_var_dependencies: Dict[str, Set[str]] = {}
# Mapping of var name to set of substates that depend on it
substate_var_dependencies: Dict[str, Set[str]] = {}
# Per-instance copy of backend variable values
_backend_vars: Dict[str, Any] = {}
@ -211,10 +226,6 @@ class State(Base, ABC, extra=pydantic.Extra.allow):
kwargs["parent_state"] = parent_state
super().__init__(*args, **kwargs)
# initialize per-instance var dependency tracking
self.computed_var_dependencies = defaultdict(set)
self.substate_var_dependencies = defaultdict(set)
# Setup the substates.
for substate in self.get_substates():
substate_name = substate.get_name()
@ -227,25 +238,6 @@ class State(Base, ABC, extra=pydantic.Extra.allow):
# Convert the event handlers to functions.
self._init_event_handlers()
# Initialize computed vars dependencies.
inherited_vars = set(self.inherited_vars).union(
set(self.inherited_backend_vars),
)
for cvar_name, cvar in self.computed_vars.items():
# Add the dependencies.
for var in cvar._deps(objclass=type(self)):
self.computed_var_dependencies[var].add(cvar_name)
if var in inherited_vars:
# track that this substate depends on its parent for this var
state_name = self.get_name()
parent_state = self.parent_state
while parent_state is not None and var in parent_state.vars:
parent_state.substate_var_dependencies[var].add(state_name)
state_name, parent_state = (
parent_state.get_name(),
parent_state.parent_state,
)
# Create a fresh copy of the backend variables for this instance
self._backend_vars = copy.deepcopy(self.backend_vars)
@ -347,6 +339,60 @@ class State(Base, ABC, extra=pydantic.Extra.allow):
cls.event_handlers[name] = handler
setattr(cls, name, handler)
cls._init_var_dependency_dicts()
@classmethod
def _init_var_dependency_dicts(cls):
"""Initialize the var dependency tracking dicts.
Allows the state to know which vars each ComputedVar depends on and
whether a ComputedVar depends on a var in its parent state.
Additional updates tracking dicts for vars and substates that always
need to be recomputed.
"""
# Initialize per-class var dependency tracking.
cls._computed_var_dependencies = defaultdict(set)
cls._substate_var_dependencies = defaultdict(set)
inherited_vars = set(cls.inherited_vars).union(
set(cls.inherited_backend_vars),
)
for cvar_name, cvar in cls.computed_vars.items():
# Add the dependencies.
for var in cvar._deps(objclass=cls):
cls._computed_var_dependencies[var].add(cvar_name)
if var in inherited_vars:
# track that this substate depends on its parent for this var
state_name = cls.get_name()
parent_state = cls.get_parent_state()
while parent_state is not None and var in parent_state.vars:
parent_state._substate_var_dependencies[var].add(state_name)
state_name, parent_state = (
parent_state.get_name(),
parent_state.get_parent_state(),
)
# ComputedVar with cache=False always need to be recomputed
cls._always_dirty_computed_vars = set(
cvar_name
for cvar_name, cvar in cls.computed_vars.items()
if not cvar._cache
)
# Any substate containing a ComputedVar with cache=False always needs to be recomputed
cls._always_dirty_substates = set()
if cls._always_dirty_computed_vars:
# Tell parent classes that this substate has always dirty computed vars
state_name = cls.get_name()
parent_state = cls.get_parent_state()
while parent_state is not None:
parent_state._always_dirty_substates.add(state_name)
state_name, parent_state = (
parent_state.get_name(),
parent_state.get_parent_state(),
)
@classmethod
def _check_overridden_methods(cls):
"""Check for shadow methods and raise error if any.
@ -377,16 +423,17 @@ class State(Base, ABC, extra=pydantic.Extra.allow):
Returns:
The vars to skip when serializing.
"""
return set(cls.inherited_vars) | {
"parent_state",
"substates",
"dirty_vars",
"dirty_substates",
"router_data",
"computed_var_dependencies",
"substate_var_dependencies",
"_backend_vars",
}
return (
set(cls.inherited_vars)
| {
"parent_state",
"substates",
"dirty_vars",
"dirty_substates",
"router_data",
}
| RESERVED_BACKEND_VAR_NAMES
)
@classmethod
@functools.lru_cache()
@ -540,6 +587,9 @@ class State(Base, ABC, extra=pydantic.Extra.allow):
for substate_class in cls.__subclasses__():
substate_class.vars.setdefault(name, var)
# Reinitialize dependency tracking dicts.
cls._init_var_dependency_dicts()
@classmethod
def _set_var(cls, prop: BaseVar):
"""Set the var as a class member.
@ -749,6 +799,9 @@ class State(Base, ABC, extra=pydantic.Extra.allow):
cls.vars[param] = cls.computed_vars[param] = func._var_set_state(cls) # type: ignore
setattr(cls, param, func)
# Reinitialize dependency tracking dicts.
cls._init_var_dependency_dicts()
def __getattribute__(self, name: str) -> Any:
"""Get the state var.
@ -804,7 +857,7 @@ class State(Base, ABC, extra=pydantic.Extra.allow):
setattr(self.parent_state, name, value)
return
if types.is_backend_variable(name) and name != "_backend_vars":
if types.is_backend_variable(name) and name not in RESERVED_BACKEND_VAR_NAMES:
self._backend_vars.__setitem__(name, value)
self.dirty_vars.add(name)
self._mark_dirty()
@ -814,7 +867,7 @@ class State(Base, ABC, extra=pydantic.Extra.allow):
super().__setattr__(name, value)
# Add the var to the dirty list.
if name in self.vars or name in self.computed_var_dependencies:
if name in self.vars or name in self._computed_var_dependencies:
self.dirty_vars.add(name)
self._mark_dirty()
@ -1056,18 +1109,6 @@ class State(Base, ABC, extra=pydantic.Extra.allow):
final=True,
)
def _always_dirty_computed_vars(self) -> set[str]:
"""The set of ComputedVars that always need to be recalculated.
Returns:
Set of all ComputedVar in this state where cache=False
"""
return set(
cvar_name
for cvar_name, cvar in self.computed_vars.items()
if not cvar._cache
)
def _mark_dirty_computed_vars(self) -> None:
"""Mark ComputedVars that need to be recalculated based on dirty_vars."""
dirty_vars = self.dirty_vars
@ -1092,7 +1133,7 @@ class State(Base, ABC, extra=pydantic.Extra.allow):
return set(
cvar
for dirty_var in from_vars or self.dirty_vars
for cvar in self.computed_var_dependencies[dirty_var]
for cvar in self._computed_var_dependencies[dirty_var]
)
def get_delta(self) -> Delta:
@ -1104,7 +1145,7 @@ class State(Base, ABC, extra=pydantic.Extra.allow):
delta = {}
# Apply dirty variables down into substates
self.dirty_vars.update(self._always_dirty_computed_vars())
self.dirty_vars.update(self._always_dirty_computed_vars)
self._mark_dirty()
# Return the dirty vars for this instance, any cached/dependent computed vars,
@ -1112,7 +1153,7 @@ class State(Base, ABC, extra=pydantic.Extra.allow):
delta_vars = (
self.dirty_vars.intersection(self.base_vars)
.union(self._dirty_computed_vars())
.union(self._always_dirty_computed_vars())
.union(self._always_dirty_computed_vars)
)
subdelta = {
@ -1125,7 +1166,7 @@ class State(Base, ABC, extra=pydantic.Extra.allow):
# Recursively find the substate deltas.
substates = self.substates
for substate in self.dirty_substates:
for substate in self.dirty_substates.union(self._always_dirty_substates):
delta.update(substates[substate].get_delta())
# Format the delta.
@ -1151,7 +1192,7 @@ class State(Base, ABC, extra=pydantic.Extra.allow):
# Propagate dirty var / computed var status into substates
substates = self.substates
for var in self.dirty_vars:
for substate_name in self.substate_var_dependencies[var]:
for substate_name in self._substate_var_dependencies[var]:
self.dirty_substates.add(substate_name)
substate = substates[substate_name]
substate.dirty_vars.add(var)
@ -1195,7 +1236,7 @@ class State(Base, ABC, extra=pydantic.Extra.allow):
if include_computed:
# Apply dirty variables down into substates to allow never-cached ComputedVar to
# trigger recalculation of dependent vars
self.dirty_vars.update(self._always_dirty_computed_vars())
self.dirty_vars.update(self._always_dirty_computed_vars)
self._mark_dirty()
base_vars = {

View File

@ -257,7 +257,7 @@ def test_add_page_set_route_dynamic(index_page, windows_platform: bool):
assert app.state.computed_vars["dynamic"]._deps(objclass=EmptyState) == {
constants.ROUTER
}
assert constants.ROUTER in app.state().computed_var_dependencies
assert constants.ROUTER in app.state()._computed_var_dependencies
def test_add_page_set_route_nested(app: App, index_page, windows_platform: bool):
@ -917,7 +917,7 @@ async def test_dynamic_route_var_route_change_completed_on_load(
assert app.state.computed_vars[arg_name]._deps(objclass=DynamicState) == {
constants.ROUTER
}
assert constants.ROUTER in app.state().computed_var_dependencies
assert constants.ROUTER in app.state()._computed_var_dependencies
sid = "mock_sid"
client_ip = "127.0.0.1"

View File

@ -1215,7 +1215,7 @@ def test_cached_var_depends_on_event_handler(use_partial: bool):
assert isinstance(HandlerState.handler, EventHandler)
s = HandlerState()
assert "cached_x_side_effect" in s.computed_var_dependencies["x"]
assert "cached_x_side_effect" in s._computed_var_dependencies["x"]
assert s.cached_x_side_effect == 1
assert s.x == 43
s.handler()
@ -1283,11 +1283,11 @@ def test_computed_var_dependencies():
return [z in self._z for z in range(5)]
cs = ComputedState()
assert cs.computed_var_dependencies["v"] == {"comp_v"}
assert cs.computed_var_dependencies["w"] == {"comp_w"}
assert cs.computed_var_dependencies["x"] == {"comp_x"}
assert cs.computed_var_dependencies["y"] == {"comp_y"}
assert cs.computed_var_dependencies["_z"] == {"comp_z"}
assert cs._computed_var_dependencies["v"] == {"comp_v"}
assert cs._computed_var_dependencies["w"] == {"comp_w"}
assert cs._computed_var_dependencies["x"] == {"comp_x"}
assert cs._computed_var_dependencies["y"] == {"comp_y"}
assert cs._computed_var_dependencies["_z"] == {"comp_z"}
def test_backend_method():