Speed up computed var dependency tracking (#864)
This commit is contained in:
parent
2d7c2bcc5e
commit
f019e0e55a
@ -1,3 +1,3 @@
|
||||
{
|
||||
"version": "0.1.21"
|
||||
"version": "0.1.25"
|
||||
}
|
||||
|
@ -3,9 +3,9 @@ from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import functools
|
||||
import inspect
|
||||
import traceback
|
||||
from abc import ABC
|
||||
from collections import defaultdict
|
||||
from typing import (
|
||||
Any,
|
||||
Callable,
|
||||
@ -15,7 +15,6 @@ from typing import (
|
||||
Optional,
|
||||
Sequence,
|
||||
Set,
|
||||
Tuple,
|
||||
Type,
|
||||
Union,
|
||||
)
|
||||
@ -54,9 +53,6 @@ class State(Base, ABC, extra=pydantic.Extra.allow):
|
||||
# Backend vars inherited
|
||||
inherited_backend_vars: ClassVar[Dict[str, Any]] = {}
|
||||
|
||||
# Mapping of var name to set of computed variables that depend on it
|
||||
computed_var_dependencies: ClassVar[Dict[str, Set[str]]] = {}
|
||||
|
||||
# The event handlers.
|
||||
event_handlers: ClassVar[Dict[str, EventHandler]] = {}
|
||||
|
||||
@ -75,18 +71,29 @@ class State(Base, ABC, extra=pydantic.Extra.allow):
|
||||
# The routing path that triggered the state
|
||||
router_data: Dict[str, Any] = {}
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
# Mapping of var name to set of computed variables that depend on it
|
||||
computed_var_dependencies: Dict[str, Set[str]] = {}
|
||||
|
||||
# Whether to track accessed vars.
|
||||
track_vars: bool = False
|
||||
|
||||
# The current set of accessed vars during tracking.
|
||||
tracked_vars: Set[str] = set()
|
||||
|
||||
def __init__(self, *args, parent_state: Optional[State] = None, **kwargs):
|
||||
"""Initialize the state.
|
||||
|
||||
Args:
|
||||
*args: The args to pass to the Pydantic init method.
|
||||
parent_state: The parent state.
|
||||
**kwargs: The kwargs to pass to the Pydantic init method.
|
||||
"""
|
||||
kwargs["parent_state"] = parent_state
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
# Setup the substates.
|
||||
for substate in self.get_substates():
|
||||
self.substates[substate.get_name()] = substate().set(parent_state=self)
|
||||
self.substates[substate.get_name()] = substate(parent_state=self)
|
||||
|
||||
# Convert the event handlers to functions.
|
||||
for name, event_handler in self.event_handlers.items():
|
||||
@ -95,6 +102,20 @@ class State(Base, ABC, extra=pydantic.Extra.allow):
|
||||
# Initialize the mutable fields.
|
||||
self._init_mutable_fields()
|
||||
|
||||
# Initialize computed vars dependencies.
|
||||
self.computed_var_dependencies = defaultdict(set)
|
||||
for cvar in self.computed_vars:
|
||||
self.tracked_vars = set()
|
||||
|
||||
# Enable tracking and get the computed var.
|
||||
self.track_vars = True
|
||||
self.__getattribute__(cvar)
|
||||
self.track_vars = False
|
||||
|
||||
# Add the dependencies.
|
||||
for var in self.tracked_vars:
|
||||
self.computed_var_dependencies[var].add(cvar)
|
||||
|
||||
def _init_mutable_fields(self):
|
||||
"""Initialize mutable fields.
|
||||
|
||||
@ -160,17 +181,10 @@ class State(Base, ABC, extra=pydantic.Extra.allow):
|
||||
cls.backend_vars = {**cls.inherited_backend_vars, **cls.new_backend_vars}
|
||||
|
||||
# Set the base and computed vars.
|
||||
skip_vars = set(cls.inherited_vars) | {
|
||||
"parent_state",
|
||||
"substates",
|
||||
"dirty_vars",
|
||||
"dirty_substates",
|
||||
"router_data",
|
||||
}
|
||||
cls.base_vars = {
|
||||
f.name: BaseVar(name=f.name, type_=f.outer_type_).set_state(cls)
|
||||
for f in cls.get_fields().values()
|
||||
if f.name not in skip_vars
|
||||
if f.name not in cls.get_skip_vars()
|
||||
}
|
||||
cls.computed_vars = {
|
||||
v.name: v.set_state(cls)
|
||||
@ -202,6 +216,24 @@ class State(Base, ABC, extra=pydantic.Extra.allow):
|
||||
cls.event_handlers[name] = handler
|
||||
setattr(cls, name, handler)
|
||||
|
||||
@classmethod
|
||||
def get_skip_vars(cls) -> Set[str]:
|
||||
"""Get the vars to skip when serializing.
|
||||
|
||||
Returns:
|
||||
The vars to skip when serializing.
|
||||
"""
|
||||
return set(cls.inherited_vars) | {
|
||||
"parent_state",
|
||||
"substates",
|
||||
"dirty_vars",
|
||||
"dirty_substates",
|
||||
"router_data",
|
||||
"computed_var_dependencies",
|
||||
"track_vars",
|
||||
"tracked_vars",
|
||||
}
|
||||
|
||||
@classmethod
|
||||
@functools.lru_cache()
|
||||
def get_parent_state(cls) -> Optional[Type[State]]:
|
||||
@ -481,20 +513,21 @@ class State(Base, ABC, extra=pydantic.Extra.allow):
|
||||
Returns:
|
||||
The value of the var.
|
||||
"""
|
||||
vars = {
|
||||
**super().__getattribute__("vars"),
|
||||
**super().__getattribute__("backend_vars"),
|
||||
}
|
||||
if name in vars:
|
||||
parent_frame, parent_frame_locals = _get_previous_recursive_frame_info()
|
||||
if parent_frame is not None:
|
||||
computed_vars = super().__getattribute__("computed_vars")
|
||||
requesting_attribute_name = parent_frame_locals.get("name")
|
||||
if requesting_attribute_name in computed_vars:
|
||||
# Keep track of any ComputedVar that depends on this Var
|
||||
super().__getattribute__("computed_var_dependencies").setdefault(
|
||||
name, set()
|
||||
).add(requesting_attribute_name)
|
||||
# If the state hasn't been initialized yet, return the default value.
|
||||
if not super().__getattribute__("__dict__"):
|
||||
return super().__getattribute__(name)
|
||||
|
||||
# Check if tracking is enabled.
|
||||
if super().__getattribute__("track_vars"):
|
||||
# Get the non-computed vars.
|
||||
all_vars = {
|
||||
**super().__getattribute__("vars"),
|
||||
**super().__getattribute__("backend_vars"),
|
||||
}
|
||||
# Add the var to the tracked vars.
|
||||
if name in all_vars:
|
||||
super().__getattribute__("tracked_vars").add(name)
|
||||
|
||||
inherited_vars = {
|
||||
**super().__getattribute__("inherited_vars"),
|
||||
**super().__getattribute__("inherited_backend_vars"),
|
||||
@ -649,18 +682,12 @@ class State(Base, ABC, extra=pydantic.Extra.allow):
|
||||
Returns:
|
||||
Set of computed vars to include in the delta.
|
||||
"""
|
||||
dirty_computed_vars = set(
|
||||
return set(
|
||||
cvar
|
||||
for dirty_var in from_vars or self.dirty_vars
|
||||
for cvar in self.computed_vars
|
||||
if cvar in self.computed_var_dependencies.get(dirty_var, set())
|
||||
)
|
||||
if dirty_computed_vars:
|
||||
# recursive call to catch computed vars that depend on computed vars
|
||||
return dirty_computed_vars | self._dirty_computed_vars(
|
||||
from_vars=dirty_computed_vars
|
||||
)
|
||||
return dirty_computed_vars
|
||||
|
||||
def get_delta(self) -> Delta:
|
||||
"""Get the delta for the state.
|
||||
@ -844,24 +871,3 @@ def _convert_mutable_datatypes(
|
||||
field_value, reassign_field=reassign_field, field_name=field_name
|
||||
)
|
||||
return field_value
|
||||
|
||||
|
||||
def _get_previous_recursive_frame_info() -> (
|
||||
Tuple[Optional[inspect.FrameInfo], Dict[str, Any]]
|
||||
):
|
||||
"""Find the previous frame of the same function that calls this helper.
|
||||
|
||||
For example, if this function is called from `State.__getattribute__`
|
||||
(parent frame), then the returned frame will be the next earliest call
|
||||
of the same function.
|
||||
|
||||
Returns:
|
||||
Tuple of (frame_info, local_vars)
|
||||
|
||||
If no previous recursive frame is found up the stack, the frame info will be None.
|
||||
"""
|
||||
_this_frame, parent_frame, *prev_frames = inspect.stack()
|
||||
for frame in prev_frames:
|
||||
if frame.frame.f_code == parent_frame.frame.f_code:
|
||||
return frame, frame.frame.f_locals
|
||||
return None, {}
|
||||
|
@ -155,13 +155,7 @@ def test_base_class_vars(test_state):
|
||||
cls = type(test_state)
|
||||
|
||||
for field in fields:
|
||||
if field in (
|
||||
"parent_state",
|
||||
"substates",
|
||||
"dirty_vars",
|
||||
"dirty_substates",
|
||||
"router_data",
|
||||
):
|
||||
if field in test_state.get_skip_vars():
|
||||
continue
|
||||
prop = getattr(cls, field)
|
||||
assert isinstance(prop, BaseVar)
|
||||
@ -819,3 +813,19 @@ def test_dirty_computed_var_from_backend_var(interdependent_state):
|
||||
assert interdependent_state.get_delta() == {
|
||||
interdependent_state.get_full_name(): {"v2x2": 4},
|
||||
}
|
||||
|
||||
|
||||
def test_child_state():
|
||||
class MainState(State):
|
||||
v: int = 2
|
||||
|
||||
class ChildState(MainState):
|
||||
@ComputedVar
|
||||
def rendered_var(self):
|
||||
return self.v
|
||||
|
||||
ms = MainState()
|
||||
cs = ms.substates[ChildState.get_name()]
|
||||
assert ms.v == 2
|
||||
assert cs.v == 2
|
||||
assert cs.rendered_var == 2
|
||||
|
Loading…
Reference in New Issue
Block a user