@pc.cached_var: explicit opt-in for ComputedVar tracking (#1000)

This commit is contained in:
Masen Furer 2023-05-11 17:47:54 -07:00 committed by GitHub
parent 6c60295ba1
commit 0491852a45
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 182 additions and 63 deletions

View File

@ -31,3 +31,4 @@ from .state import ComputedVar as var
from .state import State as State from .state import State as State
from .style import toggle_color_mode as toggle_color_mode from .style import toggle_color_mode as toggle_color_mode
from .vars import Var as Var from .vars import Var as Var
from .vars import cached_var as cached_var

View File

@ -688,22 +688,29 @@ class State(Base, ABC, extra=pydantic.Extra.allow):
# Return the state update. # Return the state update.
return StateUpdate(delta=delta, events=events) return StateUpdate(delta=delta, events=events)
def _always_dirty_computed_vars(self) -> Set[str]:
"""The set of ComputedVars that always need to be recalculated.
Returns:
Set of all ComputedVar in this state where cache=False
"""
return set(
cvar_name
for cvar_name, cvar in self.computed_vars.items()
if not cvar.cache
)
def _mark_dirty_computed_vars(self) -> None: def _mark_dirty_computed_vars(self) -> None:
"""Mark ComputedVars that need to be recalculated based on dirty_vars.""" """Mark ComputedVars that need to be recalculated based on dirty_vars."""
# Mark all ComputedVars as dirty. dirty_vars = self.dirty_vars
for cvar in self.computed_vars.values(): while dirty_vars:
cvar.mark_dirty(instance=self) calc_vars, dirty_vars = dirty_vars, set()
for cvar in self._dirty_computed_vars(from_vars=calc_vars):
# TODO: Uncomment the actual implementation below. self.dirty_vars.add(cvar)
# dirty_vars = self.dirty_vars dirty_vars.add(cvar)
# while dirty_vars: actual_var = self.computed_vars.get(cvar)
# calc_vars, dirty_vars = dirty_vars, set() if actual_var:
# for cvar in self._dirty_computed_vars(from_vars=calc_vars): actual_var.mark_dirty(instance=self)
# self.dirty_vars.add(cvar)
# dirty_vars.add(cvar)
# actual_var = self.computed_vars.get(cvar)
# if actual_var:
# actual_var.mark_dirty(instance=self)
def _dirty_computed_vars(self, from_vars: Optional[Set[str]] = None) -> Set[str]: def _dirty_computed_vars(self, from_vars: Optional[Set[str]] = None) -> Set[str]:
"""Determine ComputedVars that need to be recalculated based on the given vars. """Determine ComputedVars that need to be recalculated based on the given vars.
@ -714,13 +721,11 @@ class State(Base, ABC, extra=pydantic.Extra.allow):
Returns: Returns:
Set of computed vars to include in the delta. Set of computed vars to include in the delta.
""" """
return set(self.computed_vars) return set(
# TODO: Uncomment the actual implementation below. cvar
# return set( for dirty_var in from_vars or self.dirty_vars
# cvar for cvar in self.computed_var_dependencies[dirty_var]
# for dirty_var in from_vars or self.dirty_vars )
# for cvar in self.computed_var_dependencies[dirty_var]
# )
def get_delta(self) -> Delta: def get_delta(self) -> Delta:
"""Get the delta for the state. """Get the delta for the state.
@ -730,11 +735,18 @@ class State(Base, ABC, extra=pydantic.Extra.allow):
""" """
delta = {} delta = {}
# Return the dirty vars and dependent computed vars # Apply dirty variables down into substates
self._mark_dirty_computed_vars() self.dirty_vars.update(self._always_dirty_computed_vars())
delta_vars = self.dirty_vars.intersection(self.base_vars).union( self.mark_dirty()
self._dirty_computed_vars()
# Return the dirty vars for this instance, any cached/dependent computed vars,
# and always dirty computed vars (cache=False)
delta_vars = (
self.dirty_vars.intersection(self.base_vars)
.union(self._dirty_computed_vars())
.union(self._always_dirty_computed_vars())
) )
subdelta = { subdelta = {
prop: getattr(self, prop) prop: getattr(self, prop)
for prop in delta_vars for prop in delta_vars
@ -797,6 +809,12 @@ class State(Base, ABC, extra=pydantic.Extra.allow):
Returns: Returns:
The object as a dictionary. The object as a dictionary.
""" """
if include_computed:
# Apply dirty variables down into substates to allow never-cached ComputedVar to
# trigger recalculation of dependent vars
self.dirty_vars.update(self._always_dirty_computed_vars())
self.mark_dirty()
base_vars = { base_vars = {
prop_name: self.get_value(getattr(self, prop_name)) prop_name: self.get_value(getattr(self, prop_name))
for prop_name in self.base_vars for prop_name in self.base_vars

View File

@ -801,6 +801,9 @@ class BaseVar(Var, Base):
class ComputedVar(Var, property): class ComputedVar(Var, property):
"""A field with computed getters.""" """A field with computed getters."""
# Whether to track dependencies and cache computed values
cache: bool = False
@property @property
def name(self) -> str: def name(self) -> str:
"""Get the name of the var. """Get the name of the var.
@ -832,7 +835,7 @@ class ComputedVar(Var, property):
Returns: Returns:
The value of the var for the given instance. The value of the var for the given instance.
""" """
if instance is None: if instance is None or not self.cache:
return super().__get__(instance, owner) return super().__get__(instance, owner)
# handle caching # handle caching
@ -906,6 +909,23 @@ class ComputedVar(Var, property):
return Any return Any
def cached_var(fget: Callable[[Any], Any]) -> ComputedVar:
"""A field with computed getter that tracks other state dependencies.
The cached_var will only be recalculated when other state vars that it
depends on are modified.
Args:
fget: the function that calculates the variable value.
Returns:
ComputedVar that is recomputed when dependencies change.
"""
cvar = ComputedVar(fget=fget)
cvar.cache = True
return cvar
class PCList(list): class PCList(list):
"""A custom list that pynecone can detect its mutation.""" """A custom list that pynecone can detect its mutation."""

View File

@ -3,6 +3,7 @@ from typing import Dict, List
import pytest import pytest
from plotly.graph_objects import Figure from plotly.graph_objects import Figure
import pynecone as pc
from pynecone.base import Base from pynecone.base import Base
from pynecone.constants import IS_HYDRATED, RouteVar from pynecone.constants import IS_HYDRATED, RouteVar
from pynecone.event import Event, EventHandler from pynecone.event import Event, EventHandler
@ -484,13 +485,11 @@ def test_set_dirty_var(test_state):
# Setting a var should mark it as dirty. # Setting a var should mark it as dirty.
test_state.num1 = 1 test_state.num1 = 1
# assert test_state.dirty_vars == {"num1", "sum"} assert test_state.dirty_vars == {"num1", "sum"}
assert test_state.dirty_vars == {"num1"}
# Setting another var should mark it as dirty. # Setting another var should mark it as dirty.
test_state.num2 = 2 test_state.num2 = 2
# assert test_state.dirty_vars == {"num1", "num2", "sum"} assert test_state.dirty_vars == {"num1", "num2", "sum"}
assert test_state.dirty_vars == {"num1", "num2"}
# Cleaning the state should remove all dirty vars. # Cleaning the state should remove all dirty vars.
test_state.clean() test_state.clean()
@ -746,7 +745,7 @@ class InterdependentState(State):
v1: int = 0 v1: int = 0
_v2: int = 1 _v2: int = 1
@ComputedVar @pc.cached_var
def v1x2(self) -> int: def v1x2(self) -> int:
"""Depends on var v1. """Depends on var v1.
@ -755,7 +754,7 @@ class InterdependentState(State):
""" """
return self.v1 * 2 return self.v1 * 2
@ComputedVar @pc.cached_var
def v2x2(self) -> int: def v2x2(self) -> int:
"""Depends on backend var _v2. """Depends on backend var _v2.
@ -764,7 +763,7 @@ class InterdependentState(State):
""" """
return self._v2 * 2 return self._v2 * 2
@ComputedVar @pc.cached_var
def v1x2x2(self) -> int: def v1x2x2(self) -> int:
"""Depends on ComputedVar v1x2. """Depends on ComputedVar v1x2.
@ -786,43 +785,43 @@ def interdependent_state() -> State:
return s return s
# def test_not_dirty_computed_var_from_var(interdependent_state): def test_not_dirty_computed_var_from_var(interdependent_state):
# """Set Var that no ComputedVar depends on, expect no recalculation. """Set Var that no ComputedVar depends on, expect no recalculation.
# Args: Args:
# interdependent_state: A state with varying Var dependencies. interdependent_state: A state with varying Var dependencies.
# """ """
# interdependent_state.x = 5 interdependent_state.x = 5
# assert interdependent_state.get_delta() == { assert interdependent_state.get_delta() == {
# interdependent_state.get_full_name(): {"x": 5}, interdependent_state.get_full_name(): {"x": 5},
# } }
# def test_dirty_computed_var_from_var(interdependent_state): def test_dirty_computed_var_from_var(interdependent_state):
# """Set Var that ComputedVar depends on, expect recalculation. """Set Var that ComputedVar depends on, expect recalculation.
# The other ComputedVar depends on the changed ComputedVar and should also be The other ComputedVar depends on the changed ComputedVar and should also be
# recalculated. No other ComputedVars should be recalculated. recalculated. No other ComputedVars should be recalculated.
# Args: Args:
# interdependent_state: A state with varying Var dependencies. interdependent_state: A state with varying Var dependencies.
# """ """
# interdependent_state.v1 = 1 interdependent_state.v1 = 1
# assert interdependent_state.get_delta() == { assert interdependent_state.get_delta() == {
# interdependent_state.get_full_name(): {"v1": 1, "v1x2": 2, "v1x2x2": 4}, interdependent_state.get_full_name(): {"v1": 1, "v1x2": 2, "v1x2x2": 4},
# } }
# def test_dirty_computed_var_from_backend_var(interdependent_state): def test_dirty_computed_var_from_backend_var(interdependent_state):
# """Set backend var that ComputedVar depends on, expect recalculation. """Set backend var that ComputedVar depends on, expect recalculation.
# Args: Args:
# interdependent_state: A state with varying Var dependencies. interdependent_state: A state with varying Var dependencies.
# """ """
# interdependent_state._v2 = 2 interdependent_state._v2 = 2
# assert interdependent_state.get_delta() == { assert interdependent_state.get_delta() == {
# interdependent_state.get_full_name(): {"v2x2": 4}, interdependent_state.get_full_name(): {"v2x2": 4},
# } }
def test_per_state_backend_var(interdependent_state): def test_per_state_backend_var(interdependent_state):
@ -932,7 +931,7 @@ def test_computed_var_cached():
class ComputedState(State): class ComputedState(State):
v: int = 0 v: int = 0
@ComputedVar @pc.cached_var
def comp_v(self) -> int: def comp_v(self) -> int:
nonlocal comp_v_calls nonlocal comp_v_calls
comp_v_calls += 1 comp_v_calls += 1
@ -949,3 +948,84 @@ def test_computed_var_cached():
assert comp_v_calls == 1 assert comp_v_calls == 1
assert cs.comp_v == 1 assert cs.comp_v == 1
assert comp_v_calls == 2 assert comp_v_calls == 2
def test_computed_var_cached_depends_on_non_cached():
"""Test that a cached_var is recalculated if it depends on non-cached ComputedVar."""
class ComputedState(State):
v: int = 0
@pc.var
def no_cache_v(self) -> int:
return self.v
@pc.cached_var
def dep_v(self) -> int:
return self.no_cache_v
@pc.cached_var
def comp_v(self) -> int:
return self.v
cs = ComputedState()
assert cs.dirty_vars == set()
assert cs.get_delta() == {cs.get_name(): {"no_cache_v": 0, "dep_v": 0}}
cs.clean()
assert cs.dirty_vars == set()
assert cs.get_delta() == {cs.get_name(): {"no_cache_v": 0, "dep_v": 0}}
cs.clean()
assert cs.dirty_vars == set()
cs.v = 1
assert cs.dirty_vars == {"v", "comp_v", "dep_v", "no_cache_v"}
assert cs.get_delta() == {
cs.get_name(): {"v": 1, "no_cache_v": 1, "dep_v": 1, "comp_v": 1}
}
cs.clean()
assert cs.dirty_vars == set()
assert cs.get_delta() == {cs.get_name(): {"no_cache_v": 1, "dep_v": 1}}
cs.clean()
assert cs.dirty_vars == set()
assert cs.get_delta() == {cs.get_name(): {"no_cache_v": 1, "dep_v": 1}}
cs.clean()
assert cs.dirty_vars == set()
def test_computed_var_depends_on_parent_non_cached():
"""Child state cached_var that depends on parent state un cached var is always recalculated."""
counter = 0
class ParentState(State):
@pc.var
def no_cache_v(self) -> int:
nonlocal counter
counter += 1
return counter
class ChildState(ParentState):
@pc.cached_var
def dep_v(self) -> int:
return self.no_cache_v
ps = ParentState()
cs = ps.substates[ChildState.get_name()]
assert ps.dirty_vars == set()
assert cs.dirty_vars == set()
assert ps.dict() == {
cs.get_name(): {"dep_v": 2},
"no_cache_v": 1,
IS_HYDRATED: False,
}
assert ps.dict() == {
cs.get_name(): {"dep_v": 4},
"no_cache_v": 3,
IS_HYDRATED: False,
}
assert ps.dict() == {
cs.get_name(): {"dep_v": 6},
"no_cache_v": 5,
IS_HYDRATED: False,
}
assert counter == 6