Speed up computed var dependency tracking (#864)
This commit is contained in:
parent
2d7c2bcc5e
commit
f019e0e55a
@ -1,3 +1,3 @@
|
|||||||
{
|
{
|
||||||
"version": "0.1.21"
|
"version": "0.1.25"
|
||||||
}
|
}
|
||||||
|
@ -3,9 +3,9 @@ from __future__ import annotations
|
|||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
import functools
|
import functools
|
||||||
import inspect
|
|
||||||
import traceback
|
import traceback
|
||||||
from abc import ABC
|
from abc import ABC
|
||||||
|
from collections import defaultdict
|
||||||
from typing import (
|
from typing import (
|
||||||
Any,
|
Any,
|
||||||
Callable,
|
Callable,
|
||||||
@ -15,7 +15,6 @@ from typing import (
|
|||||||
Optional,
|
Optional,
|
||||||
Sequence,
|
Sequence,
|
||||||
Set,
|
Set,
|
||||||
Tuple,
|
|
||||||
Type,
|
Type,
|
||||||
Union,
|
Union,
|
||||||
)
|
)
|
||||||
@ -54,9 +53,6 @@ class State(Base, ABC, extra=pydantic.Extra.allow):
|
|||||||
# Backend vars inherited
|
# Backend vars inherited
|
||||||
inherited_backend_vars: ClassVar[Dict[str, Any]] = {}
|
inherited_backend_vars: ClassVar[Dict[str, Any]] = {}
|
||||||
|
|
||||||
# Mapping of var name to set of computed variables that depend on it
|
|
||||||
computed_var_dependencies: ClassVar[Dict[str, Set[str]]] = {}
|
|
||||||
|
|
||||||
# The event handlers.
|
# The event handlers.
|
||||||
event_handlers: ClassVar[Dict[str, EventHandler]] = {}
|
event_handlers: ClassVar[Dict[str, EventHandler]] = {}
|
||||||
|
|
||||||
@ -75,18 +71,29 @@ class State(Base, ABC, extra=pydantic.Extra.allow):
|
|||||||
# The routing path that triggered the state
|
# The routing path that triggered the state
|
||||||
router_data: Dict[str, Any] = {}
|
router_data: Dict[str, Any] = {}
|
||||||
|
|
||||||
def __init__(self, *args, **kwargs):
|
# Mapping of var name to set of computed variables that depend on it
|
||||||
|
computed_var_dependencies: Dict[str, Set[str]] = {}
|
||||||
|
|
||||||
|
# Whether to track accessed vars.
|
||||||
|
track_vars: bool = False
|
||||||
|
|
||||||
|
# The current set of accessed vars during tracking.
|
||||||
|
tracked_vars: Set[str] = set()
|
||||||
|
|
||||||
|
def __init__(self, *args, parent_state: Optional[State] = None, **kwargs):
|
||||||
"""Initialize the state.
|
"""Initialize the state.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
*args: The args to pass to the Pydantic init method.
|
*args: The args to pass to the Pydantic init method.
|
||||||
|
parent_state: The parent state.
|
||||||
**kwargs: The kwargs to pass to the Pydantic init method.
|
**kwargs: The kwargs to pass to the Pydantic init method.
|
||||||
"""
|
"""
|
||||||
|
kwargs["parent_state"] = parent_state
|
||||||
super().__init__(*args, **kwargs)
|
super().__init__(*args, **kwargs)
|
||||||
|
|
||||||
# Setup the substates.
|
# Setup the substates.
|
||||||
for substate in self.get_substates():
|
for substate in self.get_substates():
|
||||||
self.substates[substate.get_name()] = substate().set(parent_state=self)
|
self.substates[substate.get_name()] = substate(parent_state=self)
|
||||||
|
|
||||||
# Convert the event handlers to functions.
|
# Convert the event handlers to functions.
|
||||||
for name, event_handler in self.event_handlers.items():
|
for name, event_handler in self.event_handlers.items():
|
||||||
@ -95,6 +102,20 @@ class State(Base, ABC, extra=pydantic.Extra.allow):
|
|||||||
# Initialize the mutable fields.
|
# Initialize the mutable fields.
|
||||||
self._init_mutable_fields()
|
self._init_mutable_fields()
|
||||||
|
|
||||||
|
# Initialize computed vars dependencies.
|
||||||
|
self.computed_var_dependencies = defaultdict(set)
|
||||||
|
for cvar in self.computed_vars:
|
||||||
|
self.tracked_vars = set()
|
||||||
|
|
||||||
|
# Enable tracking and get the computed var.
|
||||||
|
self.track_vars = True
|
||||||
|
self.__getattribute__(cvar)
|
||||||
|
self.track_vars = False
|
||||||
|
|
||||||
|
# Add the dependencies.
|
||||||
|
for var in self.tracked_vars:
|
||||||
|
self.computed_var_dependencies[var].add(cvar)
|
||||||
|
|
||||||
def _init_mutable_fields(self):
|
def _init_mutable_fields(self):
|
||||||
"""Initialize mutable fields.
|
"""Initialize mutable fields.
|
||||||
|
|
||||||
@ -160,17 +181,10 @@ class State(Base, ABC, extra=pydantic.Extra.allow):
|
|||||||
cls.backend_vars = {**cls.inherited_backend_vars, **cls.new_backend_vars}
|
cls.backend_vars = {**cls.inherited_backend_vars, **cls.new_backend_vars}
|
||||||
|
|
||||||
# Set the base and computed vars.
|
# Set the base and computed vars.
|
||||||
skip_vars = set(cls.inherited_vars) | {
|
|
||||||
"parent_state",
|
|
||||||
"substates",
|
|
||||||
"dirty_vars",
|
|
||||||
"dirty_substates",
|
|
||||||
"router_data",
|
|
||||||
}
|
|
||||||
cls.base_vars = {
|
cls.base_vars = {
|
||||||
f.name: BaseVar(name=f.name, type_=f.outer_type_).set_state(cls)
|
f.name: BaseVar(name=f.name, type_=f.outer_type_).set_state(cls)
|
||||||
for f in cls.get_fields().values()
|
for f in cls.get_fields().values()
|
||||||
if f.name not in skip_vars
|
if f.name not in cls.get_skip_vars()
|
||||||
}
|
}
|
||||||
cls.computed_vars = {
|
cls.computed_vars = {
|
||||||
v.name: v.set_state(cls)
|
v.name: v.set_state(cls)
|
||||||
@ -202,6 +216,24 @@ class State(Base, ABC, extra=pydantic.Extra.allow):
|
|||||||
cls.event_handlers[name] = handler
|
cls.event_handlers[name] = handler
|
||||||
setattr(cls, name, handler)
|
setattr(cls, name, handler)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_skip_vars(cls) -> Set[str]:
|
||||||
|
"""Get the vars to skip when serializing.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The vars to skip when serializing.
|
||||||
|
"""
|
||||||
|
return set(cls.inherited_vars) | {
|
||||||
|
"parent_state",
|
||||||
|
"substates",
|
||||||
|
"dirty_vars",
|
||||||
|
"dirty_substates",
|
||||||
|
"router_data",
|
||||||
|
"computed_var_dependencies",
|
||||||
|
"track_vars",
|
||||||
|
"tracked_vars",
|
||||||
|
}
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@functools.lru_cache()
|
@functools.lru_cache()
|
||||||
def get_parent_state(cls) -> Optional[Type[State]]:
|
def get_parent_state(cls) -> Optional[Type[State]]:
|
||||||
@ -481,20 +513,21 @@ class State(Base, ABC, extra=pydantic.Extra.allow):
|
|||||||
Returns:
|
Returns:
|
||||||
The value of the var.
|
The value of the var.
|
||||||
"""
|
"""
|
||||||
vars = {
|
# If the state hasn't been initialized yet, return the default value.
|
||||||
**super().__getattribute__("vars"),
|
if not super().__getattribute__("__dict__"):
|
||||||
**super().__getattribute__("backend_vars"),
|
return super().__getattribute__(name)
|
||||||
}
|
|
||||||
if name in vars:
|
# Check if tracking is enabled.
|
||||||
parent_frame, parent_frame_locals = _get_previous_recursive_frame_info()
|
if super().__getattribute__("track_vars"):
|
||||||
if parent_frame is not None:
|
# Get the non-computed vars.
|
||||||
computed_vars = super().__getattribute__("computed_vars")
|
all_vars = {
|
||||||
requesting_attribute_name = parent_frame_locals.get("name")
|
**super().__getattribute__("vars"),
|
||||||
if requesting_attribute_name in computed_vars:
|
**super().__getattribute__("backend_vars"),
|
||||||
# Keep track of any ComputedVar that depends on this Var
|
}
|
||||||
super().__getattribute__("computed_var_dependencies").setdefault(
|
# Add the var to the tracked vars.
|
||||||
name, set()
|
if name in all_vars:
|
||||||
).add(requesting_attribute_name)
|
super().__getattribute__("tracked_vars").add(name)
|
||||||
|
|
||||||
inherited_vars = {
|
inherited_vars = {
|
||||||
**super().__getattribute__("inherited_vars"),
|
**super().__getattribute__("inherited_vars"),
|
||||||
**super().__getattribute__("inherited_backend_vars"),
|
**super().__getattribute__("inherited_backend_vars"),
|
||||||
@ -649,18 +682,12 @@ class State(Base, ABC, extra=pydantic.Extra.allow):
|
|||||||
Returns:
|
Returns:
|
||||||
Set of computed vars to include in the delta.
|
Set of computed vars to include in the delta.
|
||||||
"""
|
"""
|
||||||
dirty_computed_vars = set(
|
return set(
|
||||||
cvar
|
cvar
|
||||||
for dirty_var in from_vars or self.dirty_vars
|
for dirty_var in from_vars or self.dirty_vars
|
||||||
for cvar in self.computed_vars
|
for cvar in self.computed_vars
|
||||||
if cvar in self.computed_var_dependencies.get(dirty_var, set())
|
if cvar in self.computed_var_dependencies.get(dirty_var, set())
|
||||||
)
|
)
|
||||||
if dirty_computed_vars:
|
|
||||||
# recursive call to catch computed vars that depend on computed vars
|
|
||||||
return dirty_computed_vars | self._dirty_computed_vars(
|
|
||||||
from_vars=dirty_computed_vars
|
|
||||||
)
|
|
||||||
return dirty_computed_vars
|
|
||||||
|
|
||||||
def get_delta(self) -> Delta:
|
def get_delta(self) -> Delta:
|
||||||
"""Get the delta for the state.
|
"""Get the delta for the state.
|
||||||
@ -844,24 +871,3 @@ def _convert_mutable_datatypes(
|
|||||||
field_value, reassign_field=reassign_field, field_name=field_name
|
field_value, reassign_field=reassign_field, field_name=field_name
|
||||||
)
|
)
|
||||||
return field_value
|
return field_value
|
||||||
|
|
||||||
|
|
||||||
def _get_previous_recursive_frame_info() -> (
|
|
||||||
Tuple[Optional[inspect.FrameInfo], Dict[str, Any]]
|
|
||||||
):
|
|
||||||
"""Find the previous frame of the same function that calls this helper.
|
|
||||||
|
|
||||||
For example, if this function is called from `State.__getattribute__`
|
|
||||||
(parent frame), then the returned frame will be the next earliest call
|
|
||||||
of the same function.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Tuple of (frame_info, local_vars)
|
|
||||||
|
|
||||||
If no previous recursive frame is found up the stack, the frame info will be None.
|
|
||||||
"""
|
|
||||||
_this_frame, parent_frame, *prev_frames = inspect.stack()
|
|
||||||
for frame in prev_frames:
|
|
||||||
if frame.frame.f_code == parent_frame.frame.f_code:
|
|
||||||
return frame, frame.frame.f_locals
|
|
||||||
return None, {}
|
|
||||||
|
@ -155,13 +155,7 @@ def test_base_class_vars(test_state):
|
|||||||
cls = type(test_state)
|
cls = type(test_state)
|
||||||
|
|
||||||
for field in fields:
|
for field in fields:
|
||||||
if field in (
|
if field in test_state.get_skip_vars():
|
||||||
"parent_state",
|
|
||||||
"substates",
|
|
||||||
"dirty_vars",
|
|
||||||
"dirty_substates",
|
|
||||||
"router_data",
|
|
||||||
):
|
|
||||||
continue
|
continue
|
||||||
prop = getattr(cls, field)
|
prop = getattr(cls, field)
|
||||||
assert isinstance(prop, BaseVar)
|
assert isinstance(prop, BaseVar)
|
||||||
@ -819,3 +813,19 @@ def test_dirty_computed_var_from_backend_var(interdependent_state):
|
|||||||
assert interdependent_state.get_delta() == {
|
assert interdependent_state.get_delta() == {
|
||||||
interdependent_state.get_full_name(): {"v2x2": 4},
|
interdependent_state.get_full_name(): {"v2x2": 4},
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def test_child_state():
|
||||||
|
class MainState(State):
|
||||||
|
v: int = 2
|
||||||
|
|
||||||
|
class ChildState(MainState):
|
||||||
|
@ComputedVar
|
||||||
|
def rendered_var(self):
|
||||||
|
return self.v
|
||||||
|
|
||||||
|
ms = MainState()
|
||||||
|
cs = ms.substates[ChildState.get_name()]
|
||||||
|
assert ms.v == 2
|
||||||
|
assert cs.v == 2
|
||||||
|
assert cs.rendered_var == 2
|
||||||
|
Loading…
Reference in New Issue
Block a user