From b4755b81236d514d9716733cdd3d2891ece075a0 Mon Sep 17 00:00:00 2001 From: Masen Furer Date: Sun, 23 Apr 2023 17:48:44 -0700 Subject: [PATCH] Only update ComputedVar when dependent vars change (#840) --- pynecone/state.py | 71 +++++++++++++++++++++++++++++++- tests/test_state.py | 99 +++++++++++++++++++++++++++++++++++++++++++-- 2 files changed, 165 insertions(+), 5 deletions(-) diff --git a/pynecone/state.py b/pynecone/state.py index 03df6e194..5a20ccb3c 100644 --- a/pynecone/state.py +++ b/pynecone/state.py @@ -3,6 +3,7 @@ from __future__ import annotations import asyncio import functools +import inspect import traceback from abc import ABC from typing import ( @@ -14,6 +15,7 @@ from typing import ( Optional, Sequence, Set, + Tuple, Type, Union, ) @@ -51,6 +53,9 @@ class State(Base, ABC): # Backend vars inherited inherited_backend_vars: ClassVar[Dict[str, Any]] = {} + # Mapping of var name to set of computed variables that depend on it + computed_var_dependencies: ClassVar[Dict[str, Set[str]]] = {} + # The event handlers. event_handlers: ClassVar[Dict[str, EventHandler]] = {} @@ -171,6 +176,7 @@ class State(Base, ABC): **cls.base_vars, **cls.computed_vars, } + cls.computed_var_dependencies = {} # Setup the base vars at the class level. for prop in cls.base_vars.values(): @@ -472,12 +478,28 @@ class State(Base, ABC): If the var is inherited, get the var from the parent state. + If the Var is a dependent of a ComputedVar, track this status in computed_var_dependencies. + Args: name: The name of the var. Returns: The value of the var. """ + vars = { + **super().__getattribute__("vars"), + **super().__getattribute__("backend_vars"), + } + if name in vars: + parent_frame, parent_frame_locals = _get_previous_recursive_frame_info() + if parent_frame is not None: + computed_vars = super().__getattribute__("computed_vars") + requesting_attribute_name = parent_frame_locals.get("name") + if requesting_attribute_name in computed_vars: + # Keep track of any ComputedVar that depends on this Var + super().__getattribute__("computed_var_dependencies").setdefault( + name, set() + ).add(requesting_attribute_name) inherited_vars = { **super().__getattribute__("inherited_vars"), **super().__getattribute__("inherited_backend_vars"), @@ -505,6 +527,7 @@ class State(Base, ABC): if types.is_backend_variable(name): self.backend_vars.__setitem__(name, value) + self.dirty_vars.add(name) self.mark_dirty() return @@ -622,6 +645,28 @@ class State(Base, ABC): # Return the state update. return StateUpdate(delta=delta, events=events) + def _dirty_computed_vars(self, from_vars: Optional[Set[str]] = None) -> Set[str]: + """Get ComputedVars that need to be recomputed based on dirty_vars. + + Args: + from_vars: find ComputedVar that depend on this set of vars. If unspecified, will use the dirty_vars. + + Returns: + Set of computed vars to include in the delta. + """ + dirty_computed_vars = set( + cvar + for dirty_var in from_vars or self.dirty_vars + for cvar in self.computed_vars + if cvar in self.computed_var_dependencies.get(dirty_var, set()) + ) + if dirty_computed_vars: + # recursive call to catch computed vars that depend on computed vars + return dirty_computed_vars | self._dirty_computed_vars( + from_vars=dirty_computed_vars + ) + return dirty_computed_vars + def get_delta(self) -> Delta: """Get the delta for the state. @@ -630,10 +675,11 @@ class State(Base, ABC): """ delta = {} - # Return the dirty vars, as well as all computed vars. + # Return the dirty vars, as well as computed vars depending on dirty vars. subdelta = { prop: getattr(self, prop) - for prop in self.dirty_vars | self.computed_vars.keys() + for prop in self.dirty_vars | self._dirty_computed_vars() + if not types.is_backend_variable(prop) } if len(subdelta) > 0: delta[self.get_full_name()] = subdelta @@ -803,3 +849,24 @@ def _convert_mutable_datatypes( field_value, reassign_field=reassign_field, field_name=field_name ) return field_value + + +def _get_previous_recursive_frame_info() -> ( + Tuple[Optional[inspect.FrameInfo], Dict[str, Any]] +): + """Find the previous frame of the same function that calls this helper. + + For example, if this function is called from `State.__getattribute__` + (parent frame), then the returned frame will be the next earliest call + of the same function. + + Returns: + Tuple of (frame_info, local_vars) + + If no previous recursive frame is found up the stack, the frame info will be None. + """ + _this_frame, parent_frame, *prev_frames = inspect.stack() + for frame in prev_frames: + if frame.frame.f_code == parent_frame.frame.f_code: + return frame, frame.frame.f_locals + return None, {} diff --git a/tests/test_state.py b/tests/test_state.py index 14f4c71cb..fb146e17e 100644 --- a/tests/test_state.py +++ b/tests/test_state.py @@ -582,7 +582,7 @@ async def test_process_event_simple(test_state): assert test_state.num1 == 69 # The delta should contain the changes, including computed vars. - assert update.delta == {"test_state": {"num1": 69, "sum": 72.14, "upper": ""}} + assert update.delta == {"test_state": {"num1": 69, "sum": 72.14}} assert update.events == [] @@ -606,7 +606,6 @@ async def test_process_event_substate(test_state, child_state, grandchild_state) assert child_state.count == 24 assert update.delta == { "test_state.child_state": {"value": "HI", "count": 24}, - "test_state": {"sum": 3.14, "upper": ""}, } test_state.clean() @@ -621,7 +620,6 @@ async def test_process_event_substate(test_state, child_state, grandchild_state) assert grandchild_state.value2 == "new" assert update.delta == { "test_state.child_state.grandchild_state": {"value2": "new"}, - "test_state": {"sum": 3.14, "upper": ""}, } @@ -724,3 +722,98 @@ def test_add_var(test_state): test_state.add_var("dynamic_dict", Dict[str, int], {"k1": 5, "k2": 10}) assert test_state.dynamic_dict == {"k1": 5, "k2": 10} assert test_state.dynamic_dict == {"k1": 5, "k2": 10} + + +class InterdependentState(State): + """A state with 3 vars and 3 computed vars. + + x: a variable that no computed var depends on + v1: a varable that one computed var directly depeneds on + _v2: a backend variable that one computed var directly depends on + + v1x2: a computed var that depends on v1 + v2x2: a computed var that depends on backend var _v2 + v1x2x2: a computed var that depends on computed var v1x2 + """ + + x: int = 0 + v1: int = 0 + _v2: int = 1 + + @ComputedVar + def v1x2(self) -> int: + """depends on var v1. + + Returns: + Var v1 multiplied by 2 + """ + return self.v1 * 2 + + @ComputedVar + def v2x2(self) -> int: + """depends on backend var _v2. + + Returns: + backend var _v2 multiplied by 2 + """ + return self._v2 * 2 + + @ComputedVar + def v1x2x2(self) -> int: + """depends on ComputedVar v1x2. + + Returns: + ComputedVar v1x2 multiplied by 2 + """ + return self.v1x2 * 2 + + +@pytest.fixture +def interdependent_state() -> State: + """A state with varying dependency between vars. + + Returns: + instance of InterdependentState + """ + s = InterdependentState() + s.dict() # prime initial relationships by accessing all ComputedVars + return s + + +def test_not_dirty_computed_var_from_var(interdependent_state): + """Set Var that no ComputedVar depends on, expect no recalculation. + + Args: + interdependent_state: A state with varying Var dependencies. + """ + interdependent_state.x = 5 + assert interdependent_state.get_delta() == { + interdependent_state.get_full_name(): {"x": 5}, + } + + +def test_dirty_computed_var_from_var(interdependent_state): + """Set Var that ComputedVar depends on, expect recalculation. + + The other ComputedVar depends on the changed ComputedVar and should also be + recalculated. No other ComputedVars should be recalculated. + + Args: + interdependent_state: A state with varying Var dependencies. + """ + interdependent_state.v1 = 1 + assert interdependent_state.get_delta() == { + interdependent_state.get_full_name(): {"v1": 1, "v1x2": 2, "v1x2x2": 4}, + } + + +def test_dirty_computed_var_from_backend_var(interdependent_state): + """Set backend var that ComputedVar depends on, expect recalculation. + + Args: + interdependent_state: A state with varying Var dependencies. + """ + interdependent_state._v2 = 2 + assert interdependent_state.get_delta() == { + interdependent_state.get_full_name(): {"v2x2": 4}, + }