[REF-1035] Track ComputedVar dependency per class (#2067)
This commit is contained in:
parent
626357ed87
commit
ee87e62efa
159
reflex/state.py
159
reflex/state.py
@ -143,6 +143,15 @@ class RouterData(Base):
|
||||
self.page = PageData(router_data)
|
||||
|
||||
|
||||
RESERVED_BACKEND_VAR_NAMES = {
|
||||
"_backend_vars",
|
||||
"_computed_var_dependencies",
|
||||
"_substate_var_dependencies",
|
||||
"_always_dirty_computed_vars",
|
||||
"_always_dirty_substates",
|
||||
}
|
||||
|
||||
|
||||
class State(Base, ABC, extra=pydantic.Extra.allow):
|
||||
"""The state of the app."""
|
||||
|
||||
@ -167,6 +176,18 @@ class State(Base, ABC, extra=pydantic.Extra.allow):
|
||||
# The event handlers.
|
||||
event_handlers: ClassVar[Dict[str, EventHandler]] = {}
|
||||
|
||||
# Mapping of var name to set of computed variables that depend on it
|
||||
_computed_var_dependencies: ClassVar[Dict[str, Set[str]]] = {}
|
||||
|
||||
# Mapping of var name to set of substates that depend on it
|
||||
_substate_var_dependencies: ClassVar[Dict[str, Set[str]]] = {}
|
||||
|
||||
# Set of vars which always need to be recomputed
|
||||
_always_dirty_computed_vars: ClassVar[Set[str]] = set()
|
||||
|
||||
# Set of substates which always need to be recomputed
|
||||
_always_dirty_substates: ClassVar[Set[str]] = set()
|
||||
|
||||
# The parent state.
|
||||
parent_state: Optional[State] = None
|
||||
|
||||
@ -182,12 +203,6 @@ class State(Base, ABC, extra=pydantic.Extra.allow):
|
||||
# The routing path that triggered the state
|
||||
router_data: Dict[str, Any] = {}
|
||||
|
||||
# Mapping of var name to set of computed variables that depend on it
|
||||
computed_var_dependencies: Dict[str, Set[str]] = {}
|
||||
|
||||
# Mapping of var name to set of substates that depend on it
|
||||
substate_var_dependencies: Dict[str, Set[str]] = {}
|
||||
|
||||
# Per-instance copy of backend variable values
|
||||
_backend_vars: Dict[str, Any] = {}
|
||||
|
||||
@ -211,10 +226,6 @@ class State(Base, ABC, extra=pydantic.Extra.allow):
|
||||
kwargs["parent_state"] = parent_state
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
# initialize per-instance var dependency tracking
|
||||
self.computed_var_dependencies = defaultdict(set)
|
||||
self.substate_var_dependencies = defaultdict(set)
|
||||
|
||||
# Setup the substates.
|
||||
for substate in self.get_substates():
|
||||
substate_name = substate.get_name()
|
||||
@ -227,25 +238,6 @@ class State(Base, ABC, extra=pydantic.Extra.allow):
|
||||
# Convert the event handlers to functions.
|
||||
self._init_event_handlers()
|
||||
|
||||
# Initialize computed vars dependencies.
|
||||
inherited_vars = set(self.inherited_vars).union(
|
||||
set(self.inherited_backend_vars),
|
||||
)
|
||||
for cvar_name, cvar in self.computed_vars.items():
|
||||
# Add the dependencies.
|
||||
for var in cvar._deps(objclass=type(self)):
|
||||
self.computed_var_dependencies[var].add(cvar_name)
|
||||
if var in inherited_vars:
|
||||
# track that this substate depends on its parent for this var
|
||||
state_name = self.get_name()
|
||||
parent_state = self.parent_state
|
||||
while parent_state is not None and var in parent_state.vars:
|
||||
parent_state.substate_var_dependencies[var].add(state_name)
|
||||
state_name, parent_state = (
|
||||
parent_state.get_name(),
|
||||
parent_state.parent_state,
|
||||
)
|
||||
|
||||
# Create a fresh copy of the backend variables for this instance
|
||||
self._backend_vars = copy.deepcopy(self.backend_vars)
|
||||
|
||||
@ -347,6 +339,60 @@ class State(Base, ABC, extra=pydantic.Extra.allow):
|
||||
cls.event_handlers[name] = handler
|
||||
setattr(cls, name, handler)
|
||||
|
||||
cls._init_var_dependency_dicts()
|
||||
|
||||
@classmethod
|
||||
def _init_var_dependency_dicts(cls):
|
||||
"""Initialize the var dependency tracking dicts.
|
||||
|
||||
Allows the state to know which vars each ComputedVar depends on and
|
||||
whether a ComputedVar depends on a var in its parent state.
|
||||
|
||||
Additional updates tracking dicts for vars and substates that always
|
||||
need to be recomputed.
|
||||
"""
|
||||
# Initialize per-class var dependency tracking.
|
||||
cls._computed_var_dependencies = defaultdict(set)
|
||||
cls._substate_var_dependencies = defaultdict(set)
|
||||
|
||||
inherited_vars = set(cls.inherited_vars).union(
|
||||
set(cls.inherited_backend_vars),
|
||||
)
|
||||
for cvar_name, cvar in cls.computed_vars.items():
|
||||
# Add the dependencies.
|
||||
for var in cvar._deps(objclass=cls):
|
||||
cls._computed_var_dependencies[var].add(cvar_name)
|
||||
if var in inherited_vars:
|
||||
# track that this substate depends on its parent for this var
|
||||
state_name = cls.get_name()
|
||||
parent_state = cls.get_parent_state()
|
||||
while parent_state is not None and var in parent_state.vars:
|
||||
parent_state._substate_var_dependencies[var].add(state_name)
|
||||
state_name, parent_state = (
|
||||
parent_state.get_name(),
|
||||
parent_state.get_parent_state(),
|
||||
)
|
||||
|
||||
# ComputedVar with cache=False always need to be recomputed
|
||||
cls._always_dirty_computed_vars = set(
|
||||
cvar_name
|
||||
for cvar_name, cvar in cls.computed_vars.items()
|
||||
if not cvar._cache
|
||||
)
|
||||
|
||||
# Any substate containing a ComputedVar with cache=False always needs to be recomputed
|
||||
cls._always_dirty_substates = set()
|
||||
if cls._always_dirty_computed_vars:
|
||||
# Tell parent classes that this substate has always dirty computed vars
|
||||
state_name = cls.get_name()
|
||||
parent_state = cls.get_parent_state()
|
||||
while parent_state is not None:
|
||||
parent_state._always_dirty_substates.add(state_name)
|
||||
state_name, parent_state = (
|
||||
parent_state.get_name(),
|
||||
parent_state.get_parent_state(),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def _check_overridden_methods(cls):
|
||||
"""Check for shadow methods and raise error if any.
|
||||
@ -377,16 +423,17 @@ class State(Base, ABC, extra=pydantic.Extra.allow):
|
||||
Returns:
|
||||
The vars to skip when serializing.
|
||||
"""
|
||||
return set(cls.inherited_vars) | {
|
||||
"parent_state",
|
||||
"substates",
|
||||
"dirty_vars",
|
||||
"dirty_substates",
|
||||
"router_data",
|
||||
"computed_var_dependencies",
|
||||
"substate_var_dependencies",
|
||||
"_backend_vars",
|
||||
}
|
||||
return (
|
||||
set(cls.inherited_vars)
|
||||
| {
|
||||
"parent_state",
|
||||
"substates",
|
||||
"dirty_vars",
|
||||
"dirty_substates",
|
||||
"router_data",
|
||||
}
|
||||
| RESERVED_BACKEND_VAR_NAMES
|
||||
)
|
||||
|
||||
@classmethod
|
||||
@functools.lru_cache()
|
||||
@ -540,6 +587,9 @@ class State(Base, ABC, extra=pydantic.Extra.allow):
|
||||
for substate_class in cls.__subclasses__():
|
||||
substate_class.vars.setdefault(name, var)
|
||||
|
||||
# Reinitialize dependency tracking dicts.
|
||||
cls._init_var_dependency_dicts()
|
||||
|
||||
@classmethod
|
||||
def _set_var(cls, prop: BaseVar):
|
||||
"""Set the var as a class member.
|
||||
@ -749,6 +799,9 @@ class State(Base, ABC, extra=pydantic.Extra.allow):
|
||||
cls.vars[param] = cls.computed_vars[param] = func._var_set_state(cls) # type: ignore
|
||||
setattr(cls, param, func)
|
||||
|
||||
# Reinitialize dependency tracking dicts.
|
||||
cls._init_var_dependency_dicts()
|
||||
|
||||
def __getattribute__(self, name: str) -> Any:
|
||||
"""Get the state var.
|
||||
|
||||
@ -804,7 +857,7 @@ class State(Base, ABC, extra=pydantic.Extra.allow):
|
||||
setattr(self.parent_state, name, value)
|
||||
return
|
||||
|
||||
if types.is_backend_variable(name) and name != "_backend_vars":
|
||||
if types.is_backend_variable(name) and name not in RESERVED_BACKEND_VAR_NAMES:
|
||||
self._backend_vars.__setitem__(name, value)
|
||||
self.dirty_vars.add(name)
|
||||
self._mark_dirty()
|
||||
@ -814,7 +867,7 @@ class State(Base, ABC, extra=pydantic.Extra.allow):
|
||||
super().__setattr__(name, value)
|
||||
|
||||
# Add the var to the dirty list.
|
||||
if name in self.vars or name in self.computed_var_dependencies:
|
||||
if name in self.vars or name in self._computed_var_dependencies:
|
||||
self.dirty_vars.add(name)
|
||||
self._mark_dirty()
|
||||
|
||||
@ -1056,18 +1109,6 @@ class State(Base, ABC, extra=pydantic.Extra.allow):
|
||||
final=True,
|
||||
)
|
||||
|
||||
def _always_dirty_computed_vars(self) -> set[str]:
|
||||
"""The set of ComputedVars that always need to be recalculated.
|
||||
|
||||
Returns:
|
||||
Set of all ComputedVar in this state where cache=False
|
||||
"""
|
||||
return set(
|
||||
cvar_name
|
||||
for cvar_name, cvar in self.computed_vars.items()
|
||||
if not cvar._cache
|
||||
)
|
||||
|
||||
def _mark_dirty_computed_vars(self) -> None:
|
||||
"""Mark ComputedVars that need to be recalculated based on dirty_vars."""
|
||||
dirty_vars = self.dirty_vars
|
||||
@ -1092,7 +1133,7 @@ class State(Base, ABC, extra=pydantic.Extra.allow):
|
||||
return set(
|
||||
cvar
|
||||
for dirty_var in from_vars or self.dirty_vars
|
||||
for cvar in self.computed_var_dependencies[dirty_var]
|
||||
for cvar in self._computed_var_dependencies[dirty_var]
|
||||
)
|
||||
|
||||
def get_delta(self) -> Delta:
|
||||
@ -1104,7 +1145,7 @@ class State(Base, ABC, extra=pydantic.Extra.allow):
|
||||
delta = {}
|
||||
|
||||
# Apply dirty variables down into substates
|
||||
self.dirty_vars.update(self._always_dirty_computed_vars())
|
||||
self.dirty_vars.update(self._always_dirty_computed_vars)
|
||||
self._mark_dirty()
|
||||
|
||||
# Return the dirty vars for this instance, any cached/dependent computed vars,
|
||||
@ -1112,7 +1153,7 @@ class State(Base, ABC, extra=pydantic.Extra.allow):
|
||||
delta_vars = (
|
||||
self.dirty_vars.intersection(self.base_vars)
|
||||
.union(self._dirty_computed_vars())
|
||||
.union(self._always_dirty_computed_vars())
|
||||
.union(self._always_dirty_computed_vars)
|
||||
)
|
||||
|
||||
subdelta = {
|
||||
@ -1125,7 +1166,7 @@ class State(Base, ABC, extra=pydantic.Extra.allow):
|
||||
|
||||
# Recursively find the substate deltas.
|
||||
substates = self.substates
|
||||
for substate in self.dirty_substates:
|
||||
for substate in self.dirty_substates.union(self._always_dirty_substates):
|
||||
delta.update(substates[substate].get_delta())
|
||||
|
||||
# Format the delta.
|
||||
@ -1151,7 +1192,7 @@ class State(Base, ABC, extra=pydantic.Extra.allow):
|
||||
# Propagate dirty var / computed var status into substates
|
||||
substates = self.substates
|
||||
for var in self.dirty_vars:
|
||||
for substate_name in self.substate_var_dependencies[var]:
|
||||
for substate_name in self._substate_var_dependencies[var]:
|
||||
self.dirty_substates.add(substate_name)
|
||||
substate = substates[substate_name]
|
||||
substate.dirty_vars.add(var)
|
||||
@ -1195,7 +1236,7 @@ class State(Base, ABC, extra=pydantic.Extra.allow):
|
||||
if include_computed:
|
||||
# Apply dirty variables down into substates to allow never-cached ComputedVar to
|
||||
# trigger recalculation of dependent vars
|
||||
self.dirty_vars.update(self._always_dirty_computed_vars())
|
||||
self.dirty_vars.update(self._always_dirty_computed_vars)
|
||||
self._mark_dirty()
|
||||
|
||||
base_vars = {
|
||||
|
@ -257,7 +257,7 @@ def test_add_page_set_route_dynamic(index_page, windows_platform: bool):
|
||||
assert app.state.computed_vars["dynamic"]._deps(objclass=EmptyState) == {
|
||||
constants.ROUTER
|
||||
}
|
||||
assert constants.ROUTER in app.state().computed_var_dependencies
|
||||
assert constants.ROUTER in app.state()._computed_var_dependencies
|
||||
|
||||
|
||||
def test_add_page_set_route_nested(app: App, index_page, windows_platform: bool):
|
||||
@ -917,7 +917,7 @@ async def test_dynamic_route_var_route_change_completed_on_load(
|
||||
assert app.state.computed_vars[arg_name]._deps(objclass=DynamicState) == {
|
||||
constants.ROUTER
|
||||
}
|
||||
assert constants.ROUTER in app.state().computed_var_dependencies
|
||||
assert constants.ROUTER in app.state()._computed_var_dependencies
|
||||
|
||||
sid = "mock_sid"
|
||||
client_ip = "127.0.0.1"
|
||||
|
@ -1215,7 +1215,7 @@ def test_cached_var_depends_on_event_handler(use_partial: bool):
|
||||
assert isinstance(HandlerState.handler, EventHandler)
|
||||
|
||||
s = HandlerState()
|
||||
assert "cached_x_side_effect" in s.computed_var_dependencies["x"]
|
||||
assert "cached_x_side_effect" in s._computed_var_dependencies["x"]
|
||||
assert s.cached_x_side_effect == 1
|
||||
assert s.x == 43
|
||||
s.handler()
|
||||
@ -1283,11 +1283,11 @@ def test_computed_var_dependencies():
|
||||
return [z in self._z for z in range(5)]
|
||||
|
||||
cs = ComputedState()
|
||||
assert cs.computed_var_dependencies["v"] == {"comp_v"}
|
||||
assert cs.computed_var_dependencies["w"] == {"comp_w"}
|
||||
assert cs.computed_var_dependencies["x"] == {"comp_x"}
|
||||
assert cs.computed_var_dependencies["y"] == {"comp_y"}
|
||||
assert cs.computed_var_dependencies["_z"] == {"comp_z"}
|
||||
assert cs._computed_var_dependencies["v"] == {"comp_v"}
|
||||
assert cs._computed_var_dependencies["w"] == {"comp_w"}
|
||||
assert cs._computed_var_dependencies["x"] == {"comp_x"}
|
||||
assert cs._computed_var_dependencies["y"] == {"comp_y"}
|
||||
assert cs._computed_var_dependencies["_z"] == {"comp_z"}
|
||||
|
||||
|
||||
def test_backend_method():
|
||||
|
Loading…
Reference in New Issue
Block a user