Speed up computed var dependency tracking (#864)

This commit is contained in:
Nikhil Rao 2023-04-25 13:56:24 -07:00 committed by GitHub
parent 2d7c2bcc5e
commit f019e0e55a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 81 additions and 65 deletions

View File

@ -1,3 +1,3 @@
{
"version": "0.1.21"
"version": "0.1.25"
}

View File

@ -3,9 +3,9 @@ from __future__ import annotations
import asyncio
import functools
import inspect
import traceback
from abc import ABC
from collections import defaultdict
from typing import (
Any,
Callable,
@ -15,7 +15,6 @@ from typing import (
Optional,
Sequence,
Set,
Tuple,
Type,
Union,
)
@ -54,9 +53,6 @@ class State(Base, ABC, extra=pydantic.Extra.allow):
# Backend vars inherited
inherited_backend_vars: ClassVar[Dict[str, Any]] = {}
# Mapping of var name to set of computed variables that depend on it
computed_var_dependencies: ClassVar[Dict[str, Set[str]]] = {}
# The event handlers.
event_handlers: ClassVar[Dict[str, EventHandler]] = {}
@ -75,18 +71,29 @@ class State(Base, ABC, extra=pydantic.Extra.allow):
# The routing path that triggered the state
router_data: Dict[str, Any] = {}
def __init__(self, *args, **kwargs):
# Mapping of var name to set of computed variables that depend on it
computed_var_dependencies: Dict[str, Set[str]] = {}
# Whether to track accessed vars.
track_vars: bool = False
# The current set of accessed vars during tracking.
tracked_vars: Set[str] = set()
def __init__(self, *args, parent_state: Optional[State] = None, **kwargs):
"""Initialize the state.
Args:
*args: The args to pass to the Pydantic init method.
parent_state: The parent state.
**kwargs: The kwargs to pass to the Pydantic init method.
"""
kwargs["parent_state"] = parent_state
super().__init__(*args, **kwargs)
# Setup the substates.
for substate in self.get_substates():
self.substates[substate.get_name()] = substate().set(parent_state=self)
self.substates[substate.get_name()] = substate(parent_state=self)
# Convert the event handlers to functions.
for name, event_handler in self.event_handlers.items():
@ -95,6 +102,20 @@ class State(Base, ABC, extra=pydantic.Extra.allow):
# Initialize the mutable fields.
self._init_mutable_fields()
# Initialize computed vars dependencies.
self.computed_var_dependencies = defaultdict(set)
for cvar in self.computed_vars:
self.tracked_vars = set()
# Enable tracking and get the computed var.
self.track_vars = True
self.__getattribute__(cvar)
self.track_vars = False
# Add the dependencies.
for var in self.tracked_vars:
self.computed_var_dependencies[var].add(cvar)
def _init_mutable_fields(self):
"""Initialize mutable fields.
@ -160,17 +181,10 @@ class State(Base, ABC, extra=pydantic.Extra.allow):
cls.backend_vars = {**cls.inherited_backend_vars, **cls.new_backend_vars}
# Set the base and computed vars.
skip_vars = set(cls.inherited_vars) | {
"parent_state",
"substates",
"dirty_vars",
"dirty_substates",
"router_data",
}
cls.base_vars = {
f.name: BaseVar(name=f.name, type_=f.outer_type_).set_state(cls)
for f in cls.get_fields().values()
if f.name not in skip_vars
if f.name not in cls.get_skip_vars()
}
cls.computed_vars = {
v.name: v.set_state(cls)
@ -202,6 +216,24 @@ class State(Base, ABC, extra=pydantic.Extra.allow):
cls.event_handlers[name] = handler
setattr(cls, name, handler)
@classmethod
def get_skip_vars(cls) -> Set[str]:
"""Get the vars to skip when serializing.
Returns:
The vars to skip when serializing.
"""
return set(cls.inherited_vars) | {
"parent_state",
"substates",
"dirty_vars",
"dirty_substates",
"router_data",
"computed_var_dependencies",
"track_vars",
"tracked_vars",
}
@classmethod
@functools.lru_cache()
def get_parent_state(cls) -> Optional[Type[State]]:
@ -481,20 +513,21 @@ class State(Base, ABC, extra=pydantic.Extra.allow):
Returns:
The value of the var.
"""
vars = {
**super().__getattribute__("vars"),
**super().__getattribute__("backend_vars"),
}
if name in vars:
parent_frame, parent_frame_locals = _get_previous_recursive_frame_info()
if parent_frame is not None:
computed_vars = super().__getattribute__("computed_vars")
requesting_attribute_name = parent_frame_locals.get("name")
if requesting_attribute_name in computed_vars:
# Keep track of any ComputedVar that depends on this Var
super().__getattribute__("computed_var_dependencies").setdefault(
name, set()
).add(requesting_attribute_name)
# If the state hasn't been initialized yet, return the default value.
if not super().__getattribute__("__dict__"):
return super().__getattribute__(name)
# Check if tracking is enabled.
if super().__getattribute__("track_vars"):
# Get the non-computed vars.
all_vars = {
**super().__getattribute__("vars"),
**super().__getattribute__("backend_vars"),
}
# Add the var to the tracked vars.
if name in all_vars:
super().__getattribute__("tracked_vars").add(name)
inherited_vars = {
**super().__getattribute__("inherited_vars"),
**super().__getattribute__("inherited_backend_vars"),
@ -649,18 +682,12 @@ class State(Base, ABC, extra=pydantic.Extra.allow):
Returns:
Set of computed vars to include in the delta.
"""
dirty_computed_vars = set(
return set(
cvar
for dirty_var in from_vars or self.dirty_vars
for cvar in self.computed_vars
if cvar in self.computed_var_dependencies.get(dirty_var, set())
)
if dirty_computed_vars:
# recursive call to catch computed vars that depend on computed vars
return dirty_computed_vars | self._dirty_computed_vars(
from_vars=dirty_computed_vars
)
return dirty_computed_vars
def get_delta(self) -> Delta:
"""Get the delta for the state.
@ -844,24 +871,3 @@ def _convert_mutable_datatypes(
field_value, reassign_field=reassign_field, field_name=field_name
)
return field_value
def _get_previous_recursive_frame_info() -> (
Tuple[Optional[inspect.FrameInfo], Dict[str, Any]]
):
"""Find the previous frame of the same function that calls this helper.
For example, if this function is called from `State.__getattribute__`
(parent frame), then the returned frame will be the next earliest call
of the same function.
Returns:
Tuple of (frame_info, local_vars)
If no previous recursive frame is found up the stack, the frame info will be None.
"""
_this_frame, parent_frame, *prev_frames = inspect.stack()
for frame in prev_frames:
if frame.frame.f_code == parent_frame.frame.f_code:
return frame, frame.frame.f_locals
return None, {}

View File

@ -155,13 +155,7 @@ def test_base_class_vars(test_state):
cls = type(test_state)
for field in fields:
if field in (
"parent_state",
"substates",
"dirty_vars",
"dirty_substates",
"router_data",
):
if field in test_state.get_skip_vars():
continue
prop = getattr(cls, field)
assert isinstance(prop, BaseVar)
@ -819,3 +813,19 @@ def test_dirty_computed_var_from_backend_var(interdependent_state):
assert interdependent_state.get_delta() == {
interdependent_state.get_full_name(): {"v2x2": 4},
}
def test_child_state():
class MainState(State):
v: int = 2
class ChildState(MainState):
@ComputedVar
def rendered_var(self):
return self.v
ms = MainState()
cs = ms.substates[ChildState.get_name()]
assert ms.v == 2
assert cs.v == 2
assert cs.rendered_var == 2