Speed up computed var dependency tracking (#864)

This commit is contained in:
Nikhil Rao 2023-04-25 13:56:24 -07:00 committed by GitHub
parent 2d7c2bcc5e
commit f019e0e55a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 81 additions and 65 deletions

View File

@ -1,3 +1,3 @@
{ {
"version": "0.1.21" "version": "0.1.25"
} }

View File

@ -3,9 +3,9 @@ from __future__ import annotations
import asyncio import asyncio
import functools import functools
import inspect
import traceback import traceback
from abc import ABC from abc import ABC
from collections import defaultdict
from typing import ( from typing import (
Any, Any,
Callable, Callable,
@ -15,7 +15,6 @@ from typing import (
Optional, Optional,
Sequence, Sequence,
Set, Set,
Tuple,
Type, Type,
Union, Union,
) )
@ -54,9 +53,6 @@ class State(Base, ABC, extra=pydantic.Extra.allow):
# Backend vars inherited # Backend vars inherited
inherited_backend_vars: ClassVar[Dict[str, Any]] = {} inherited_backend_vars: ClassVar[Dict[str, Any]] = {}
# Mapping of var name to set of computed variables that depend on it
computed_var_dependencies: ClassVar[Dict[str, Set[str]]] = {}
# The event handlers. # The event handlers.
event_handlers: ClassVar[Dict[str, EventHandler]] = {} event_handlers: ClassVar[Dict[str, EventHandler]] = {}
@ -75,18 +71,29 @@ class State(Base, ABC, extra=pydantic.Extra.allow):
# The routing path that triggered the state # The routing path that triggered the state
router_data: Dict[str, Any] = {} router_data: Dict[str, Any] = {}
def __init__(self, *args, **kwargs): # Mapping of var name to set of computed variables that depend on it
computed_var_dependencies: Dict[str, Set[str]] = {}
# Whether to track accessed vars.
track_vars: bool = False
# The current set of accessed vars during tracking.
tracked_vars: Set[str] = set()
def __init__(self, *args, parent_state: Optional[State] = None, **kwargs):
"""Initialize the state. """Initialize the state.
Args: Args:
*args: The args to pass to the Pydantic init method. *args: The args to pass to the Pydantic init method.
parent_state: The parent state.
**kwargs: The kwargs to pass to the Pydantic init method. **kwargs: The kwargs to pass to the Pydantic init method.
""" """
kwargs["parent_state"] = parent_state
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
# Setup the substates. # Setup the substates.
for substate in self.get_substates(): for substate in self.get_substates():
self.substates[substate.get_name()] = substate().set(parent_state=self) self.substates[substate.get_name()] = substate(parent_state=self)
# Convert the event handlers to functions. # Convert the event handlers to functions.
for name, event_handler in self.event_handlers.items(): for name, event_handler in self.event_handlers.items():
@ -95,6 +102,20 @@ class State(Base, ABC, extra=pydantic.Extra.allow):
# Initialize the mutable fields. # Initialize the mutable fields.
self._init_mutable_fields() self._init_mutable_fields()
# Initialize computed vars dependencies.
self.computed_var_dependencies = defaultdict(set)
for cvar in self.computed_vars:
self.tracked_vars = set()
# Enable tracking and get the computed var.
self.track_vars = True
self.__getattribute__(cvar)
self.track_vars = False
# Add the dependencies.
for var in self.tracked_vars:
self.computed_var_dependencies[var].add(cvar)
def _init_mutable_fields(self): def _init_mutable_fields(self):
"""Initialize mutable fields. """Initialize mutable fields.
@ -160,17 +181,10 @@ class State(Base, ABC, extra=pydantic.Extra.allow):
cls.backend_vars = {**cls.inherited_backend_vars, **cls.new_backend_vars} cls.backend_vars = {**cls.inherited_backend_vars, **cls.new_backend_vars}
# Set the base and computed vars. # Set the base and computed vars.
skip_vars = set(cls.inherited_vars) | {
"parent_state",
"substates",
"dirty_vars",
"dirty_substates",
"router_data",
}
cls.base_vars = { cls.base_vars = {
f.name: BaseVar(name=f.name, type_=f.outer_type_).set_state(cls) f.name: BaseVar(name=f.name, type_=f.outer_type_).set_state(cls)
for f in cls.get_fields().values() for f in cls.get_fields().values()
if f.name not in skip_vars if f.name not in cls.get_skip_vars()
} }
cls.computed_vars = { cls.computed_vars = {
v.name: v.set_state(cls) v.name: v.set_state(cls)
@ -202,6 +216,24 @@ class State(Base, ABC, extra=pydantic.Extra.allow):
cls.event_handlers[name] = handler cls.event_handlers[name] = handler
setattr(cls, name, handler) setattr(cls, name, handler)
@classmethod
def get_skip_vars(cls) -> Set[str]:
"""Get the vars to skip when serializing.
Returns:
The vars to skip when serializing.
"""
return set(cls.inherited_vars) | {
"parent_state",
"substates",
"dirty_vars",
"dirty_substates",
"router_data",
"computed_var_dependencies",
"track_vars",
"tracked_vars",
}
@classmethod @classmethod
@functools.lru_cache() @functools.lru_cache()
def get_parent_state(cls) -> Optional[Type[State]]: def get_parent_state(cls) -> Optional[Type[State]]:
@ -481,20 +513,21 @@ class State(Base, ABC, extra=pydantic.Extra.allow):
Returns: Returns:
The value of the var. The value of the var.
""" """
vars = { # If the state hasn't been initialized yet, return the default value.
**super().__getattribute__("vars"), if not super().__getattribute__("__dict__"):
**super().__getattribute__("backend_vars"), return super().__getattribute__(name)
}
if name in vars: # Check if tracking is enabled.
parent_frame, parent_frame_locals = _get_previous_recursive_frame_info() if super().__getattribute__("track_vars"):
if parent_frame is not None: # Get the non-computed vars.
computed_vars = super().__getattribute__("computed_vars") all_vars = {
requesting_attribute_name = parent_frame_locals.get("name") **super().__getattribute__("vars"),
if requesting_attribute_name in computed_vars: **super().__getattribute__("backend_vars"),
# Keep track of any ComputedVar that depends on this Var }
super().__getattribute__("computed_var_dependencies").setdefault( # Add the var to the tracked vars.
name, set() if name in all_vars:
).add(requesting_attribute_name) super().__getattribute__("tracked_vars").add(name)
inherited_vars = { inherited_vars = {
**super().__getattribute__("inherited_vars"), **super().__getattribute__("inherited_vars"),
**super().__getattribute__("inherited_backend_vars"), **super().__getattribute__("inherited_backend_vars"),
@ -649,18 +682,12 @@ class State(Base, ABC, extra=pydantic.Extra.allow):
Returns: Returns:
Set of computed vars to include in the delta. Set of computed vars to include in the delta.
""" """
dirty_computed_vars = set( return set(
cvar cvar
for dirty_var in from_vars or self.dirty_vars for dirty_var in from_vars or self.dirty_vars
for cvar in self.computed_vars for cvar in self.computed_vars
if cvar in self.computed_var_dependencies.get(dirty_var, set()) if cvar in self.computed_var_dependencies.get(dirty_var, set())
) )
if dirty_computed_vars:
# recursive call to catch computed vars that depend on computed vars
return dirty_computed_vars | self._dirty_computed_vars(
from_vars=dirty_computed_vars
)
return dirty_computed_vars
def get_delta(self) -> Delta: def get_delta(self) -> Delta:
"""Get the delta for the state. """Get the delta for the state.
@ -844,24 +871,3 @@ def _convert_mutable_datatypes(
field_value, reassign_field=reassign_field, field_name=field_name field_value, reassign_field=reassign_field, field_name=field_name
) )
return field_value return field_value
def _get_previous_recursive_frame_info() -> (
Tuple[Optional[inspect.FrameInfo], Dict[str, Any]]
):
"""Find the previous frame of the same function that calls this helper.
For example, if this function is called from `State.__getattribute__`
(parent frame), then the returned frame will be the next earliest call
of the same function.
Returns:
Tuple of (frame_info, local_vars)
If no previous recursive frame is found up the stack, the frame info will be None.
"""
_this_frame, parent_frame, *prev_frames = inspect.stack()
for frame in prev_frames:
if frame.frame.f_code == parent_frame.frame.f_code:
return frame, frame.frame.f_locals
return None, {}

View File

@ -155,13 +155,7 @@ def test_base_class_vars(test_state):
cls = type(test_state) cls = type(test_state)
for field in fields: for field in fields:
if field in ( if field in test_state.get_skip_vars():
"parent_state",
"substates",
"dirty_vars",
"dirty_substates",
"router_data",
):
continue continue
prop = getattr(cls, field) prop = getattr(cls, field)
assert isinstance(prop, BaseVar) assert isinstance(prop, BaseVar)
@ -819,3 +813,19 @@ def test_dirty_computed_var_from_backend_var(interdependent_state):
assert interdependent_state.get_delta() == { assert interdependent_state.get_delta() == {
interdependent_state.get_full_name(): {"v2x2": 4}, interdependent_state.get_full_name(): {"v2x2": 4},
} }
def test_child_state():
class MainState(State):
v: int = 2
class ChildState(MainState):
@ComputedVar
def rendered_var(self):
return self.v
ms = MainState()
cs = ms.substates[ChildState.get_name()]
assert ms.v == 2
assert cs.v == 2
assert cs.rendered_var == 2