@pc.cached_var: explicit opt-in for ComputedVar tracking (#1000)
This commit is contained in:
parent
6c60295ba1
commit
0491852a45
@ -31,3 +31,4 @@ from .state import ComputedVar as var
|
||||
from .state import State as State
|
||||
from .style import toggle_color_mode as toggle_color_mode
|
||||
from .vars import Var as Var
|
||||
from .vars import cached_var as cached_var
|
||||
|
@ -688,22 +688,29 @@ class State(Base, ABC, extra=pydantic.Extra.allow):
|
||||
# Return the state update.
|
||||
return StateUpdate(delta=delta, events=events)
|
||||
|
||||
def _always_dirty_computed_vars(self) -> Set[str]:
|
||||
"""The set of ComputedVars that always need to be recalculated.
|
||||
|
||||
Returns:
|
||||
Set of all ComputedVar in this state where cache=False
|
||||
"""
|
||||
return set(
|
||||
cvar_name
|
||||
for cvar_name, cvar in self.computed_vars.items()
|
||||
if not cvar.cache
|
||||
)
|
||||
|
||||
def _mark_dirty_computed_vars(self) -> None:
|
||||
"""Mark ComputedVars that need to be recalculated based on dirty_vars."""
|
||||
# Mark all ComputedVars as dirty.
|
||||
for cvar in self.computed_vars.values():
|
||||
cvar.mark_dirty(instance=self)
|
||||
|
||||
# TODO: Uncomment the actual implementation below.
|
||||
# dirty_vars = self.dirty_vars
|
||||
# while dirty_vars:
|
||||
# calc_vars, dirty_vars = dirty_vars, set()
|
||||
# for cvar in self._dirty_computed_vars(from_vars=calc_vars):
|
||||
# self.dirty_vars.add(cvar)
|
||||
# dirty_vars.add(cvar)
|
||||
# actual_var = self.computed_vars.get(cvar)
|
||||
# if actual_var:
|
||||
# actual_var.mark_dirty(instance=self)
|
||||
dirty_vars = self.dirty_vars
|
||||
while dirty_vars:
|
||||
calc_vars, dirty_vars = dirty_vars, set()
|
||||
for cvar in self._dirty_computed_vars(from_vars=calc_vars):
|
||||
self.dirty_vars.add(cvar)
|
||||
dirty_vars.add(cvar)
|
||||
actual_var = self.computed_vars.get(cvar)
|
||||
if actual_var:
|
||||
actual_var.mark_dirty(instance=self)
|
||||
|
||||
def _dirty_computed_vars(self, from_vars: Optional[Set[str]] = None) -> Set[str]:
|
||||
"""Determine ComputedVars that need to be recalculated based on the given vars.
|
||||
@ -714,13 +721,11 @@ class State(Base, ABC, extra=pydantic.Extra.allow):
|
||||
Returns:
|
||||
Set of computed vars to include in the delta.
|
||||
"""
|
||||
return set(self.computed_vars)
|
||||
# TODO: Uncomment the actual implementation below.
|
||||
# return set(
|
||||
# cvar
|
||||
# for dirty_var in from_vars or self.dirty_vars
|
||||
# for cvar in self.computed_var_dependencies[dirty_var]
|
||||
# )
|
||||
return set(
|
||||
cvar
|
||||
for dirty_var in from_vars or self.dirty_vars
|
||||
for cvar in self.computed_var_dependencies[dirty_var]
|
||||
)
|
||||
|
||||
def get_delta(self) -> Delta:
|
||||
"""Get the delta for the state.
|
||||
@ -730,11 +735,18 @@ class State(Base, ABC, extra=pydantic.Extra.allow):
|
||||
"""
|
||||
delta = {}
|
||||
|
||||
# Return the dirty vars and dependent computed vars
|
||||
self._mark_dirty_computed_vars()
|
||||
delta_vars = self.dirty_vars.intersection(self.base_vars).union(
|
||||
self._dirty_computed_vars()
|
||||
# Apply dirty variables down into substates
|
||||
self.dirty_vars.update(self._always_dirty_computed_vars())
|
||||
self.mark_dirty()
|
||||
|
||||
# Return the dirty vars for this instance, any cached/dependent computed vars,
|
||||
# and always dirty computed vars (cache=False)
|
||||
delta_vars = (
|
||||
self.dirty_vars.intersection(self.base_vars)
|
||||
.union(self._dirty_computed_vars())
|
||||
.union(self._always_dirty_computed_vars())
|
||||
)
|
||||
|
||||
subdelta = {
|
||||
prop: getattr(self, prop)
|
||||
for prop in delta_vars
|
||||
@ -797,6 +809,12 @@ class State(Base, ABC, extra=pydantic.Extra.allow):
|
||||
Returns:
|
||||
The object as a dictionary.
|
||||
"""
|
||||
if include_computed:
|
||||
# Apply dirty variables down into substates to allow never-cached ComputedVar to
|
||||
# trigger recalculation of dependent vars
|
||||
self.dirty_vars.update(self._always_dirty_computed_vars())
|
||||
self.mark_dirty()
|
||||
|
||||
base_vars = {
|
||||
prop_name: self.get_value(getattr(self, prop_name))
|
||||
for prop_name in self.base_vars
|
||||
|
@ -801,6 +801,9 @@ class BaseVar(Var, Base):
|
||||
class ComputedVar(Var, property):
|
||||
"""A field with computed getters."""
|
||||
|
||||
# Whether to track dependencies and cache computed values
|
||||
cache: bool = False
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
"""Get the name of the var.
|
||||
@ -832,7 +835,7 @@ class ComputedVar(Var, property):
|
||||
Returns:
|
||||
The value of the var for the given instance.
|
||||
"""
|
||||
if instance is None:
|
||||
if instance is None or not self.cache:
|
||||
return super().__get__(instance, owner)
|
||||
|
||||
# handle caching
|
||||
@ -906,6 +909,23 @@ class ComputedVar(Var, property):
|
||||
return Any
|
||||
|
||||
|
||||
def cached_var(fget: Callable[[Any], Any]) -> ComputedVar:
|
||||
"""A field with computed getter that tracks other state dependencies.
|
||||
|
||||
The cached_var will only be recalculated when other state vars that it
|
||||
depends on are modified.
|
||||
|
||||
Args:
|
||||
fget: the function that calculates the variable value.
|
||||
|
||||
Returns:
|
||||
ComputedVar that is recomputed when dependencies change.
|
||||
"""
|
||||
cvar = ComputedVar(fget=fget)
|
||||
cvar.cache = True
|
||||
return cvar
|
||||
|
||||
|
||||
class PCList(list):
|
||||
"""A custom list that pynecone can detect its mutation."""
|
||||
|
||||
|
@ -3,6 +3,7 @@ from typing import Dict, List
|
||||
import pytest
|
||||
from plotly.graph_objects import Figure
|
||||
|
||||
import pynecone as pc
|
||||
from pynecone.base import Base
|
||||
from pynecone.constants import IS_HYDRATED, RouteVar
|
||||
from pynecone.event import Event, EventHandler
|
||||
@ -484,13 +485,11 @@ def test_set_dirty_var(test_state):
|
||||
|
||||
# Setting a var should mark it as dirty.
|
||||
test_state.num1 = 1
|
||||
# assert test_state.dirty_vars == {"num1", "sum"}
|
||||
assert test_state.dirty_vars == {"num1"}
|
||||
assert test_state.dirty_vars == {"num1", "sum"}
|
||||
|
||||
# Setting another var should mark it as dirty.
|
||||
test_state.num2 = 2
|
||||
# assert test_state.dirty_vars == {"num1", "num2", "sum"}
|
||||
assert test_state.dirty_vars == {"num1", "num2"}
|
||||
assert test_state.dirty_vars == {"num1", "num2", "sum"}
|
||||
|
||||
# Cleaning the state should remove all dirty vars.
|
||||
test_state.clean()
|
||||
@ -746,7 +745,7 @@ class InterdependentState(State):
|
||||
v1: int = 0
|
||||
_v2: int = 1
|
||||
|
||||
@ComputedVar
|
||||
@pc.cached_var
|
||||
def v1x2(self) -> int:
|
||||
"""Depends on var v1.
|
||||
|
||||
@ -755,7 +754,7 @@ class InterdependentState(State):
|
||||
"""
|
||||
return self.v1 * 2
|
||||
|
||||
@ComputedVar
|
||||
@pc.cached_var
|
||||
def v2x2(self) -> int:
|
||||
"""Depends on backend var _v2.
|
||||
|
||||
@ -764,7 +763,7 @@ class InterdependentState(State):
|
||||
"""
|
||||
return self._v2 * 2
|
||||
|
||||
@ComputedVar
|
||||
@pc.cached_var
|
||||
def v1x2x2(self) -> int:
|
||||
"""Depends on ComputedVar v1x2.
|
||||
|
||||
@ -786,43 +785,43 @@ def interdependent_state() -> State:
|
||||
return s
|
||||
|
||||
|
||||
# def test_not_dirty_computed_var_from_var(interdependent_state):
|
||||
# """Set Var that no ComputedVar depends on, expect no recalculation.
|
||||
def test_not_dirty_computed_var_from_var(interdependent_state):
|
||||
"""Set Var that no ComputedVar depends on, expect no recalculation.
|
||||
|
||||
# Args:
|
||||
# interdependent_state: A state with varying Var dependencies.
|
||||
# """
|
||||
# interdependent_state.x = 5
|
||||
# assert interdependent_state.get_delta() == {
|
||||
# interdependent_state.get_full_name(): {"x": 5},
|
||||
# }
|
||||
Args:
|
||||
interdependent_state: A state with varying Var dependencies.
|
||||
"""
|
||||
interdependent_state.x = 5
|
||||
assert interdependent_state.get_delta() == {
|
||||
interdependent_state.get_full_name(): {"x": 5},
|
||||
}
|
||||
|
||||
|
||||
# def test_dirty_computed_var_from_var(interdependent_state):
|
||||
# """Set Var that ComputedVar depends on, expect recalculation.
|
||||
def test_dirty_computed_var_from_var(interdependent_state):
|
||||
"""Set Var that ComputedVar depends on, expect recalculation.
|
||||
|
||||
# The other ComputedVar depends on the changed ComputedVar and should also be
|
||||
# recalculated. No other ComputedVars should be recalculated.
|
||||
The other ComputedVar depends on the changed ComputedVar and should also be
|
||||
recalculated. No other ComputedVars should be recalculated.
|
||||
|
||||
# Args:
|
||||
# interdependent_state: A state with varying Var dependencies.
|
||||
# """
|
||||
# interdependent_state.v1 = 1
|
||||
# assert interdependent_state.get_delta() == {
|
||||
# interdependent_state.get_full_name(): {"v1": 1, "v1x2": 2, "v1x2x2": 4},
|
||||
# }
|
||||
Args:
|
||||
interdependent_state: A state with varying Var dependencies.
|
||||
"""
|
||||
interdependent_state.v1 = 1
|
||||
assert interdependent_state.get_delta() == {
|
||||
interdependent_state.get_full_name(): {"v1": 1, "v1x2": 2, "v1x2x2": 4},
|
||||
}
|
||||
|
||||
|
||||
# def test_dirty_computed_var_from_backend_var(interdependent_state):
|
||||
# """Set backend var that ComputedVar depends on, expect recalculation.
|
||||
def test_dirty_computed_var_from_backend_var(interdependent_state):
|
||||
"""Set backend var that ComputedVar depends on, expect recalculation.
|
||||
|
||||
# Args:
|
||||
# interdependent_state: A state with varying Var dependencies.
|
||||
# """
|
||||
# interdependent_state._v2 = 2
|
||||
# assert interdependent_state.get_delta() == {
|
||||
# interdependent_state.get_full_name(): {"v2x2": 4},
|
||||
# }
|
||||
Args:
|
||||
interdependent_state: A state with varying Var dependencies.
|
||||
"""
|
||||
interdependent_state._v2 = 2
|
||||
assert interdependent_state.get_delta() == {
|
||||
interdependent_state.get_full_name(): {"v2x2": 4},
|
||||
}
|
||||
|
||||
|
||||
def test_per_state_backend_var(interdependent_state):
|
||||
@ -932,7 +931,7 @@ def test_computed_var_cached():
|
||||
class ComputedState(State):
|
||||
v: int = 0
|
||||
|
||||
@ComputedVar
|
||||
@pc.cached_var
|
||||
def comp_v(self) -> int:
|
||||
nonlocal comp_v_calls
|
||||
comp_v_calls += 1
|
||||
@ -949,3 +948,84 @@ def test_computed_var_cached():
|
||||
assert comp_v_calls == 1
|
||||
assert cs.comp_v == 1
|
||||
assert comp_v_calls == 2
|
||||
|
||||
|
||||
def test_computed_var_cached_depends_on_non_cached():
|
||||
"""Test that a cached_var is recalculated if it depends on non-cached ComputedVar."""
|
||||
|
||||
class ComputedState(State):
|
||||
v: int = 0
|
||||
|
||||
@pc.var
|
||||
def no_cache_v(self) -> int:
|
||||
return self.v
|
||||
|
||||
@pc.cached_var
|
||||
def dep_v(self) -> int:
|
||||
return self.no_cache_v
|
||||
|
||||
@pc.cached_var
|
||||
def comp_v(self) -> int:
|
||||
return self.v
|
||||
|
||||
cs = ComputedState()
|
||||
assert cs.dirty_vars == set()
|
||||
assert cs.get_delta() == {cs.get_name(): {"no_cache_v": 0, "dep_v": 0}}
|
||||
cs.clean()
|
||||
assert cs.dirty_vars == set()
|
||||
assert cs.get_delta() == {cs.get_name(): {"no_cache_v": 0, "dep_v": 0}}
|
||||
cs.clean()
|
||||
assert cs.dirty_vars == set()
|
||||
cs.v = 1
|
||||
assert cs.dirty_vars == {"v", "comp_v", "dep_v", "no_cache_v"}
|
||||
assert cs.get_delta() == {
|
||||
cs.get_name(): {"v": 1, "no_cache_v": 1, "dep_v": 1, "comp_v": 1}
|
||||
}
|
||||
cs.clean()
|
||||
assert cs.dirty_vars == set()
|
||||
assert cs.get_delta() == {cs.get_name(): {"no_cache_v": 1, "dep_v": 1}}
|
||||
cs.clean()
|
||||
assert cs.dirty_vars == set()
|
||||
assert cs.get_delta() == {cs.get_name(): {"no_cache_v": 1, "dep_v": 1}}
|
||||
cs.clean()
|
||||
assert cs.dirty_vars == set()
|
||||
|
||||
|
||||
def test_computed_var_depends_on_parent_non_cached():
|
||||
"""Child state cached_var that depends on parent state un cached var is always recalculated."""
|
||||
counter = 0
|
||||
|
||||
class ParentState(State):
|
||||
@pc.var
|
||||
def no_cache_v(self) -> int:
|
||||
nonlocal counter
|
||||
counter += 1
|
||||
return counter
|
||||
|
||||
class ChildState(ParentState):
|
||||
@pc.cached_var
|
||||
def dep_v(self) -> int:
|
||||
return self.no_cache_v
|
||||
|
||||
ps = ParentState()
|
||||
cs = ps.substates[ChildState.get_name()]
|
||||
|
||||
assert ps.dirty_vars == set()
|
||||
assert cs.dirty_vars == set()
|
||||
|
||||
assert ps.dict() == {
|
||||
cs.get_name(): {"dep_v": 2},
|
||||
"no_cache_v": 1,
|
||||
IS_HYDRATED: False,
|
||||
}
|
||||
assert ps.dict() == {
|
||||
cs.get_name(): {"dep_v": 4},
|
||||
"no_cache_v": 3,
|
||||
IS_HYDRATED: False,
|
||||
}
|
||||
assert ps.dict() == {
|
||||
cs.get_name(): {"dep_v": 6},
|
||||
"no_cache_v": 5,
|
||||
IS_HYDRATED: False,
|
||||
}
|
||||
assert counter == 6
|
||||
|
Loading…
Reference in New Issue
Block a user