From 0491852a4505be1f9cbb409a81fa4b59bc53a98f Mon Sep 17 00:00:00 2001 From: Masen Furer Date: Thu, 11 May 2023 17:47:54 -0700 Subject: [PATCH] @pc.cached_var: explicit opt-in for ComputedVar tracking (#1000) --- pynecone/__init__.py | 1 + pynecone/state.py | 68 ++++++++++++------- pynecone/vars.py | 22 ++++++- tests/test_state.py | 154 ++++++++++++++++++++++++++++++++----------- 4 files changed, 182 insertions(+), 63 deletions(-) diff --git a/pynecone/__init__.py b/pynecone/__init__.py index 800179272..0e6d655b3 100644 --- a/pynecone/__init__.py +++ b/pynecone/__init__.py @@ -31,3 +31,4 @@ from .state import ComputedVar as var from .state import State as State from .style import toggle_color_mode as toggle_color_mode from .vars import Var as Var +from .vars import cached_var as cached_var diff --git a/pynecone/state.py b/pynecone/state.py index 3f3ba1eab..429831e56 100644 --- a/pynecone/state.py +++ b/pynecone/state.py @@ -688,22 +688,29 @@ class State(Base, ABC, extra=pydantic.Extra.allow): # Return the state update. return StateUpdate(delta=delta, events=events) + def _always_dirty_computed_vars(self) -> Set[str]: + """The set of ComputedVars that always need to be recalculated. + + Returns: + Set of all ComputedVar in this state where cache=False + """ + return set( + cvar_name + for cvar_name, cvar in self.computed_vars.items() + if not cvar.cache + ) + def _mark_dirty_computed_vars(self) -> None: """Mark ComputedVars that need to be recalculated based on dirty_vars.""" - # Mark all ComputedVars as dirty. - for cvar in self.computed_vars.values(): - cvar.mark_dirty(instance=self) - - # TODO: Uncomment the actual implementation below. - # dirty_vars = self.dirty_vars - # while dirty_vars: - # calc_vars, dirty_vars = dirty_vars, set() - # for cvar in self._dirty_computed_vars(from_vars=calc_vars): - # self.dirty_vars.add(cvar) - # dirty_vars.add(cvar) - # actual_var = self.computed_vars.get(cvar) - # if actual_var: - # actual_var.mark_dirty(instance=self) + dirty_vars = self.dirty_vars + while dirty_vars: + calc_vars, dirty_vars = dirty_vars, set() + for cvar in self._dirty_computed_vars(from_vars=calc_vars): + self.dirty_vars.add(cvar) + dirty_vars.add(cvar) + actual_var = self.computed_vars.get(cvar) + if actual_var: + actual_var.mark_dirty(instance=self) def _dirty_computed_vars(self, from_vars: Optional[Set[str]] = None) -> Set[str]: """Determine ComputedVars that need to be recalculated based on the given vars. @@ -714,13 +721,11 @@ class State(Base, ABC, extra=pydantic.Extra.allow): Returns: Set of computed vars to include in the delta. """ - return set(self.computed_vars) - # TODO: Uncomment the actual implementation below. - # return set( - # cvar - # for dirty_var in from_vars or self.dirty_vars - # for cvar in self.computed_var_dependencies[dirty_var] - # ) + return set( + cvar + for dirty_var in from_vars or self.dirty_vars + for cvar in self.computed_var_dependencies[dirty_var] + ) def get_delta(self) -> Delta: """Get the delta for the state. @@ -730,11 +735,18 @@ class State(Base, ABC, extra=pydantic.Extra.allow): """ delta = {} - # Return the dirty vars and dependent computed vars - self._mark_dirty_computed_vars() - delta_vars = self.dirty_vars.intersection(self.base_vars).union( - self._dirty_computed_vars() + # Apply dirty variables down into substates + self.dirty_vars.update(self._always_dirty_computed_vars()) + self.mark_dirty() + + # Return the dirty vars for this instance, any cached/dependent computed vars, + # and always dirty computed vars (cache=False) + delta_vars = ( + self.dirty_vars.intersection(self.base_vars) + .union(self._dirty_computed_vars()) + .union(self._always_dirty_computed_vars()) ) + subdelta = { prop: getattr(self, prop) for prop in delta_vars @@ -797,6 +809,12 @@ class State(Base, ABC, extra=pydantic.Extra.allow): Returns: The object as a dictionary. """ + if include_computed: + # Apply dirty variables down into substates to allow never-cached ComputedVar to + # trigger recalculation of dependent vars + self.dirty_vars.update(self._always_dirty_computed_vars()) + self.mark_dirty() + base_vars = { prop_name: self.get_value(getattr(self, prop_name)) for prop_name in self.base_vars diff --git a/pynecone/vars.py b/pynecone/vars.py index 425f0b54a..87062e4e6 100644 --- a/pynecone/vars.py +++ b/pynecone/vars.py @@ -801,6 +801,9 @@ class BaseVar(Var, Base): class ComputedVar(Var, property): """A field with computed getters.""" + # Whether to track dependencies and cache computed values + cache: bool = False + @property def name(self) -> str: """Get the name of the var. @@ -832,7 +835,7 @@ class ComputedVar(Var, property): Returns: The value of the var for the given instance. """ - if instance is None: + if instance is None or not self.cache: return super().__get__(instance, owner) # handle caching @@ -906,6 +909,23 @@ class ComputedVar(Var, property): return Any +def cached_var(fget: Callable[[Any], Any]) -> ComputedVar: + """A field with computed getter that tracks other state dependencies. + + The cached_var will only be recalculated when other state vars that it + depends on are modified. + + Args: + fget: the function that calculates the variable value. + + Returns: + ComputedVar that is recomputed when dependencies change. + """ + cvar = ComputedVar(fget=fget) + cvar.cache = True + return cvar + + class PCList(list): """A custom list that pynecone can detect its mutation.""" diff --git a/tests/test_state.py b/tests/test_state.py index c49464511..44d9bddc1 100644 --- a/tests/test_state.py +++ b/tests/test_state.py @@ -3,6 +3,7 @@ from typing import Dict, List import pytest from plotly.graph_objects import Figure +import pynecone as pc from pynecone.base import Base from pynecone.constants import IS_HYDRATED, RouteVar from pynecone.event import Event, EventHandler @@ -484,13 +485,11 @@ def test_set_dirty_var(test_state): # Setting a var should mark it as dirty. test_state.num1 = 1 - # assert test_state.dirty_vars == {"num1", "sum"} - assert test_state.dirty_vars == {"num1"} + assert test_state.dirty_vars == {"num1", "sum"} # Setting another var should mark it as dirty. test_state.num2 = 2 - # assert test_state.dirty_vars == {"num1", "num2", "sum"} - assert test_state.dirty_vars == {"num1", "num2"} + assert test_state.dirty_vars == {"num1", "num2", "sum"} # Cleaning the state should remove all dirty vars. test_state.clean() @@ -746,7 +745,7 @@ class InterdependentState(State): v1: int = 0 _v2: int = 1 - @ComputedVar + @pc.cached_var def v1x2(self) -> int: """Depends on var v1. @@ -755,7 +754,7 @@ class InterdependentState(State): """ return self.v1 * 2 - @ComputedVar + @pc.cached_var def v2x2(self) -> int: """Depends on backend var _v2. @@ -764,7 +763,7 @@ class InterdependentState(State): """ return self._v2 * 2 - @ComputedVar + @pc.cached_var def v1x2x2(self) -> int: """Depends on ComputedVar v1x2. @@ -786,43 +785,43 @@ def interdependent_state() -> State: return s -# def test_not_dirty_computed_var_from_var(interdependent_state): -# """Set Var that no ComputedVar depends on, expect no recalculation. +def test_not_dirty_computed_var_from_var(interdependent_state): + """Set Var that no ComputedVar depends on, expect no recalculation. -# Args: -# interdependent_state: A state with varying Var dependencies. -# """ -# interdependent_state.x = 5 -# assert interdependent_state.get_delta() == { -# interdependent_state.get_full_name(): {"x": 5}, -# } + Args: + interdependent_state: A state with varying Var dependencies. + """ + interdependent_state.x = 5 + assert interdependent_state.get_delta() == { + interdependent_state.get_full_name(): {"x": 5}, + } -# def test_dirty_computed_var_from_var(interdependent_state): -# """Set Var that ComputedVar depends on, expect recalculation. +def test_dirty_computed_var_from_var(interdependent_state): + """Set Var that ComputedVar depends on, expect recalculation. -# The other ComputedVar depends on the changed ComputedVar and should also be -# recalculated. No other ComputedVars should be recalculated. + The other ComputedVar depends on the changed ComputedVar and should also be + recalculated. No other ComputedVars should be recalculated. -# Args: -# interdependent_state: A state with varying Var dependencies. -# """ -# interdependent_state.v1 = 1 -# assert interdependent_state.get_delta() == { -# interdependent_state.get_full_name(): {"v1": 1, "v1x2": 2, "v1x2x2": 4}, -# } + Args: + interdependent_state: A state with varying Var dependencies. + """ + interdependent_state.v1 = 1 + assert interdependent_state.get_delta() == { + interdependent_state.get_full_name(): {"v1": 1, "v1x2": 2, "v1x2x2": 4}, + } -# def test_dirty_computed_var_from_backend_var(interdependent_state): -# """Set backend var that ComputedVar depends on, expect recalculation. +def test_dirty_computed_var_from_backend_var(interdependent_state): + """Set backend var that ComputedVar depends on, expect recalculation. -# Args: -# interdependent_state: A state with varying Var dependencies. -# """ -# interdependent_state._v2 = 2 -# assert interdependent_state.get_delta() == { -# interdependent_state.get_full_name(): {"v2x2": 4}, -# } + Args: + interdependent_state: A state with varying Var dependencies. + """ + interdependent_state._v2 = 2 + assert interdependent_state.get_delta() == { + interdependent_state.get_full_name(): {"v2x2": 4}, + } def test_per_state_backend_var(interdependent_state): @@ -932,7 +931,7 @@ def test_computed_var_cached(): class ComputedState(State): v: int = 0 - @ComputedVar + @pc.cached_var def comp_v(self) -> int: nonlocal comp_v_calls comp_v_calls += 1 @@ -949,3 +948,84 @@ def test_computed_var_cached(): assert comp_v_calls == 1 assert cs.comp_v == 1 assert comp_v_calls == 2 + + +def test_computed_var_cached_depends_on_non_cached(): + """Test that a cached_var is recalculated if it depends on non-cached ComputedVar.""" + + class ComputedState(State): + v: int = 0 + + @pc.var + def no_cache_v(self) -> int: + return self.v + + @pc.cached_var + def dep_v(self) -> int: + return self.no_cache_v + + @pc.cached_var + def comp_v(self) -> int: + return self.v + + cs = ComputedState() + assert cs.dirty_vars == set() + assert cs.get_delta() == {cs.get_name(): {"no_cache_v": 0, "dep_v": 0}} + cs.clean() + assert cs.dirty_vars == set() + assert cs.get_delta() == {cs.get_name(): {"no_cache_v": 0, "dep_v": 0}} + cs.clean() + assert cs.dirty_vars == set() + cs.v = 1 + assert cs.dirty_vars == {"v", "comp_v", "dep_v", "no_cache_v"} + assert cs.get_delta() == { + cs.get_name(): {"v": 1, "no_cache_v": 1, "dep_v": 1, "comp_v": 1} + } + cs.clean() + assert cs.dirty_vars == set() + assert cs.get_delta() == {cs.get_name(): {"no_cache_v": 1, "dep_v": 1}} + cs.clean() + assert cs.dirty_vars == set() + assert cs.get_delta() == {cs.get_name(): {"no_cache_v": 1, "dep_v": 1}} + cs.clean() + assert cs.dirty_vars == set() + + +def test_computed_var_depends_on_parent_non_cached(): + """Child state cached_var that depends on parent state un cached var is always recalculated.""" + counter = 0 + + class ParentState(State): + @pc.var + def no_cache_v(self) -> int: + nonlocal counter + counter += 1 + return counter + + class ChildState(ParentState): + @pc.cached_var + def dep_v(self) -> int: + return self.no_cache_v + + ps = ParentState() + cs = ps.substates[ChildState.get_name()] + + assert ps.dirty_vars == set() + assert cs.dirty_vars == set() + + assert ps.dict() == { + cs.get_name(): {"dep_v": 2}, + "no_cache_v": 1, + IS_HYDRATED: False, + } + assert ps.dict() == { + cs.get_name(): {"dep_v": 4}, + "no_cache_v": 3, + IS_HYDRATED: False, + } + assert ps.dict() == { + cs.get_name(): {"dep_v": 6}, + "no_cache_v": 5, + IS_HYDRATED: False, + } + assert counter == 6