@pc.cached_var: explicit opt-in for ComputedVar tracking (#1000)
This commit is contained in:
parent
6c60295ba1
commit
0491852a45
@ -31,3 +31,4 @@ from .state import ComputedVar as var
|
|||||||
from .state import State as State
|
from .state import State as State
|
||||||
from .style import toggle_color_mode as toggle_color_mode
|
from .style import toggle_color_mode as toggle_color_mode
|
||||||
from .vars import Var as Var
|
from .vars import Var as Var
|
||||||
|
from .vars import cached_var as cached_var
|
||||||
|
@ -688,22 +688,29 @@ class State(Base, ABC, extra=pydantic.Extra.allow):
|
|||||||
# Return the state update.
|
# Return the state update.
|
||||||
return StateUpdate(delta=delta, events=events)
|
return StateUpdate(delta=delta, events=events)
|
||||||
|
|
||||||
|
def _always_dirty_computed_vars(self) -> Set[str]:
|
||||||
|
"""The set of ComputedVars that always need to be recalculated.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Set of all ComputedVar in this state where cache=False
|
||||||
|
"""
|
||||||
|
return set(
|
||||||
|
cvar_name
|
||||||
|
for cvar_name, cvar in self.computed_vars.items()
|
||||||
|
if not cvar.cache
|
||||||
|
)
|
||||||
|
|
||||||
def _mark_dirty_computed_vars(self) -> None:
|
def _mark_dirty_computed_vars(self) -> None:
|
||||||
"""Mark ComputedVars that need to be recalculated based on dirty_vars."""
|
"""Mark ComputedVars that need to be recalculated based on dirty_vars."""
|
||||||
# Mark all ComputedVars as dirty.
|
dirty_vars = self.dirty_vars
|
||||||
for cvar in self.computed_vars.values():
|
while dirty_vars:
|
||||||
cvar.mark_dirty(instance=self)
|
calc_vars, dirty_vars = dirty_vars, set()
|
||||||
|
for cvar in self._dirty_computed_vars(from_vars=calc_vars):
|
||||||
# TODO: Uncomment the actual implementation below.
|
self.dirty_vars.add(cvar)
|
||||||
# dirty_vars = self.dirty_vars
|
dirty_vars.add(cvar)
|
||||||
# while dirty_vars:
|
actual_var = self.computed_vars.get(cvar)
|
||||||
# calc_vars, dirty_vars = dirty_vars, set()
|
if actual_var:
|
||||||
# for cvar in self._dirty_computed_vars(from_vars=calc_vars):
|
actual_var.mark_dirty(instance=self)
|
||||||
# self.dirty_vars.add(cvar)
|
|
||||||
# dirty_vars.add(cvar)
|
|
||||||
# actual_var = self.computed_vars.get(cvar)
|
|
||||||
# if actual_var:
|
|
||||||
# actual_var.mark_dirty(instance=self)
|
|
||||||
|
|
||||||
def _dirty_computed_vars(self, from_vars: Optional[Set[str]] = None) -> Set[str]:
|
def _dirty_computed_vars(self, from_vars: Optional[Set[str]] = None) -> Set[str]:
|
||||||
"""Determine ComputedVars that need to be recalculated based on the given vars.
|
"""Determine ComputedVars that need to be recalculated based on the given vars.
|
||||||
@ -714,13 +721,11 @@ class State(Base, ABC, extra=pydantic.Extra.allow):
|
|||||||
Returns:
|
Returns:
|
||||||
Set of computed vars to include in the delta.
|
Set of computed vars to include in the delta.
|
||||||
"""
|
"""
|
||||||
return set(self.computed_vars)
|
return set(
|
||||||
# TODO: Uncomment the actual implementation below.
|
cvar
|
||||||
# return set(
|
for dirty_var in from_vars or self.dirty_vars
|
||||||
# cvar
|
for cvar in self.computed_var_dependencies[dirty_var]
|
||||||
# for dirty_var in from_vars or self.dirty_vars
|
)
|
||||||
# for cvar in self.computed_var_dependencies[dirty_var]
|
|
||||||
# )
|
|
||||||
|
|
||||||
def get_delta(self) -> Delta:
|
def get_delta(self) -> Delta:
|
||||||
"""Get the delta for the state.
|
"""Get the delta for the state.
|
||||||
@ -730,11 +735,18 @@ class State(Base, ABC, extra=pydantic.Extra.allow):
|
|||||||
"""
|
"""
|
||||||
delta = {}
|
delta = {}
|
||||||
|
|
||||||
# Return the dirty vars and dependent computed vars
|
# Apply dirty variables down into substates
|
||||||
self._mark_dirty_computed_vars()
|
self.dirty_vars.update(self._always_dirty_computed_vars())
|
||||||
delta_vars = self.dirty_vars.intersection(self.base_vars).union(
|
self.mark_dirty()
|
||||||
self._dirty_computed_vars()
|
|
||||||
|
# Return the dirty vars for this instance, any cached/dependent computed vars,
|
||||||
|
# and always dirty computed vars (cache=False)
|
||||||
|
delta_vars = (
|
||||||
|
self.dirty_vars.intersection(self.base_vars)
|
||||||
|
.union(self._dirty_computed_vars())
|
||||||
|
.union(self._always_dirty_computed_vars())
|
||||||
)
|
)
|
||||||
|
|
||||||
subdelta = {
|
subdelta = {
|
||||||
prop: getattr(self, prop)
|
prop: getattr(self, prop)
|
||||||
for prop in delta_vars
|
for prop in delta_vars
|
||||||
@ -797,6 +809,12 @@ class State(Base, ABC, extra=pydantic.Extra.allow):
|
|||||||
Returns:
|
Returns:
|
||||||
The object as a dictionary.
|
The object as a dictionary.
|
||||||
"""
|
"""
|
||||||
|
if include_computed:
|
||||||
|
# Apply dirty variables down into substates to allow never-cached ComputedVar to
|
||||||
|
# trigger recalculation of dependent vars
|
||||||
|
self.dirty_vars.update(self._always_dirty_computed_vars())
|
||||||
|
self.mark_dirty()
|
||||||
|
|
||||||
base_vars = {
|
base_vars = {
|
||||||
prop_name: self.get_value(getattr(self, prop_name))
|
prop_name: self.get_value(getattr(self, prop_name))
|
||||||
for prop_name in self.base_vars
|
for prop_name in self.base_vars
|
||||||
|
@ -801,6 +801,9 @@ class BaseVar(Var, Base):
|
|||||||
class ComputedVar(Var, property):
|
class ComputedVar(Var, property):
|
||||||
"""A field with computed getters."""
|
"""A field with computed getters."""
|
||||||
|
|
||||||
|
# Whether to track dependencies and cache computed values
|
||||||
|
cache: bool = False
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def name(self) -> str:
|
def name(self) -> str:
|
||||||
"""Get the name of the var.
|
"""Get the name of the var.
|
||||||
@ -832,7 +835,7 @@ class ComputedVar(Var, property):
|
|||||||
Returns:
|
Returns:
|
||||||
The value of the var for the given instance.
|
The value of the var for the given instance.
|
||||||
"""
|
"""
|
||||||
if instance is None:
|
if instance is None or not self.cache:
|
||||||
return super().__get__(instance, owner)
|
return super().__get__(instance, owner)
|
||||||
|
|
||||||
# handle caching
|
# handle caching
|
||||||
@ -906,6 +909,23 @@ class ComputedVar(Var, property):
|
|||||||
return Any
|
return Any
|
||||||
|
|
||||||
|
|
||||||
|
def cached_var(fget: Callable[[Any], Any]) -> ComputedVar:
|
||||||
|
"""A field with computed getter that tracks other state dependencies.
|
||||||
|
|
||||||
|
The cached_var will only be recalculated when other state vars that it
|
||||||
|
depends on are modified.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
fget: the function that calculates the variable value.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
ComputedVar that is recomputed when dependencies change.
|
||||||
|
"""
|
||||||
|
cvar = ComputedVar(fget=fget)
|
||||||
|
cvar.cache = True
|
||||||
|
return cvar
|
||||||
|
|
||||||
|
|
||||||
class PCList(list):
|
class PCList(list):
|
||||||
"""A custom list that pynecone can detect its mutation."""
|
"""A custom list that pynecone can detect its mutation."""
|
||||||
|
|
||||||
|
@ -3,6 +3,7 @@ from typing import Dict, List
|
|||||||
import pytest
|
import pytest
|
||||||
from plotly.graph_objects import Figure
|
from plotly.graph_objects import Figure
|
||||||
|
|
||||||
|
import pynecone as pc
|
||||||
from pynecone.base import Base
|
from pynecone.base import Base
|
||||||
from pynecone.constants import IS_HYDRATED, RouteVar
|
from pynecone.constants import IS_HYDRATED, RouteVar
|
||||||
from pynecone.event import Event, EventHandler
|
from pynecone.event import Event, EventHandler
|
||||||
@ -484,13 +485,11 @@ def test_set_dirty_var(test_state):
|
|||||||
|
|
||||||
# Setting a var should mark it as dirty.
|
# Setting a var should mark it as dirty.
|
||||||
test_state.num1 = 1
|
test_state.num1 = 1
|
||||||
# assert test_state.dirty_vars == {"num1", "sum"}
|
assert test_state.dirty_vars == {"num1", "sum"}
|
||||||
assert test_state.dirty_vars == {"num1"}
|
|
||||||
|
|
||||||
# Setting another var should mark it as dirty.
|
# Setting another var should mark it as dirty.
|
||||||
test_state.num2 = 2
|
test_state.num2 = 2
|
||||||
# assert test_state.dirty_vars == {"num1", "num2", "sum"}
|
assert test_state.dirty_vars == {"num1", "num2", "sum"}
|
||||||
assert test_state.dirty_vars == {"num1", "num2"}
|
|
||||||
|
|
||||||
# Cleaning the state should remove all dirty vars.
|
# Cleaning the state should remove all dirty vars.
|
||||||
test_state.clean()
|
test_state.clean()
|
||||||
@ -746,7 +745,7 @@ class InterdependentState(State):
|
|||||||
v1: int = 0
|
v1: int = 0
|
||||||
_v2: int = 1
|
_v2: int = 1
|
||||||
|
|
||||||
@ComputedVar
|
@pc.cached_var
|
||||||
def v1x2(self) -> int:
|
def v1x2(self) -> int:
|
||||||
"""Depends on var v1.
|
"""Depends on var v1.
|
||||||
|
|
||||||
@ -755,7 +754,7 @@ class InterdependentState(State):
|
|||||||
"""
|
"""
|
||||||
return self.v1 * 2
|
return self.v1 * 2
|
||||||
|
|
||||||
@ComputedVar
|
@pc.cached_var
|
||||||
def v2x2(self) -> int:
|
def v2x2(self) -> int:
|
||||||
"""Depends on backend var _v2.
|
"""Depends on backend var _v2.
|
||||||
|
|
||||||
@ -764,7 +763,7 @@ class InterdependentState(State):
|
|||||||
"""
|
"""
|
||||||
return self._v2 * 2
|
return self._v2 * 2
|
||||||
|
|
||||||
@ComputedVar
|
@pc.cached_var
|
||||||
def v1x2x2(self) -> int:
|
def v1x2x2(self) -> int:
|
||||||
"""Depends on ComputedVar v1x2.
|
"""Depends on ComputedVar v1x2.
|
||||||
|
|
||||||
@ -786,43 +785,43 @@ def interdependent_state() -> State:
|
|||||||
return s
|
return s
|
||||||
|
|
||||||
|
|
||||||
# def test_not_dirty_computed_var_from_var(interdependent_state):
|
def test_not_dirty_computed_var_from_var(interdependent_state):
|
||||||
# """Set Var that no ComputedVar depends on, expect no recalculation.
|
"""Set Var that no ComputedVar depends on, expect no recalculation.
|
||||||
|
|
||||||
# Args:
|
Args:
|
||||||
# interdependent_state: A state with varying Var dependencies.
|
interdependent_state: A state with varying Var dependencies.
|
||||||
# """
|
"""
|
||||||
# interdependent_state.x = 5
|
interdependent_state.x = 5
|
||||||
# assert interdependent_state.get_delta() == {
|
assert interdependent_state.get_delta() == {
|
||||||
# interdependent_state.get_full_name(): {"x": 5},
|
interdependent_state.get_full_name(): {"x": 5},
|
||||||
# }
|
}
|
||||||
|
|
||||||
|
|
||||||
# def test_dirty_computed_var_from_var(interdependent_state):
|
def test_dirty_computed_var_from_var(interdependent_state):
|
||||||
# """Set Var that ComputedVar depends on, expect recalculation.
|
"""Set Var that ComputedVar depends on, expect recalculation.
|
||||||
|
|
||||||
# The other ComputedVar depends on the changed ComputedVar and should also be
|
The other ComputedVar depends on the changed ComputedVar and should also be
|
||||||
# recalculated. No other ComputedVars should be recalculated.
|
recalculated. No other ComputedVars should be recalculated.
|
||||||
|
|
||||||
# Args:
|
Args:
|
||||||
# interdependent_state: A state with varying Var dependencies.
|
interdependent_state: A state with varying Var dependencies.
|
||||||
# """
|
"""
|
||||||
# interdependent_state.v1 = 1
|
interdependent_state.v1 = 1
|
||||||
# assert interdependent_state.get_delta() == {
|
assert interdependent_state.get_delta() == {
|
||||||
# interdependent_state.get_full_name(): {"v1": 1, "v1x2": 2, "v1x2x2": 4},
|
interdependent_state.get_full_name(): {"v1": 1, "v1x2": 2, "v1x2x2": 4},
|
||||||
# }
|
}
|
||||||
|
|
||||||
|
|
||||||
# def test_dirty_computed_var_from_backend_var(interdependent_state):
|
def test_dirty_computed_var_from_backend_var(interdependent_state):
|
||||||
# """Set backend var that ComputedVar depends on, expect recalculation.
|
"""Set backend var that ComputedVar depends on, expect recalculation.
|
||||||
|
|
||||||
# Args:
|
Args:
|
||||||
# interdependent_state: A state with varying Var dependencies.
|
interdependent_state: A state with varying Var dependencies.
|
||||||
# """
|
"""
|
||||||
# interdependent_state._v2 = 2
|
interdependent_state._v2 = 2
|
||||||
# assert interdependent_state.get_delta() == {
|
assert interdependent_state.get_delta() == {
|
||||||
# interdependent_state.get_full_name(): {"v2x2": 4},
|
interdependent_state.get_full_name(): {"v2x2": 4},
|
||||||
# }
|
}
|
||||||
|
|
||||||
|
|
||||||
def test_per_state_backend_var(interdependent_state):
|
def test_per_state_backend_var(interdependent_state):
|
||||||
@ -932,7 +931,7 @@ def test_computed_var_cached():
|
|||||||
class ComputedState(State):
|
class ComputedState(State):
|
||||||
v: int = 0
|
v: int = 0
|
||||||
|
|
||||||
@ComputedVar
|
@pc.cached_var
|
||||||
def comp_v(self) -> int:
|
def comp_v(self) -> int:
|
||||||
nonlocal comp_v_calls
|
nonlocal comp_v_calls
|
||||||
comp_v_calls += 1
|
comp_v_calls += 1
|
||||||
@ -949,3 +948,84 @@ def test_computed_var_cached():
|
|||||||
assert comp_v_calls == 1
|
assert comp_v_calls == 1
|
||||||
assert cs.comp_v == 1
|
assert cs.comp_v == 1
|
||||||
assert comp_v_calls == 2
|
assert comp_v_calls == 2
|
||||||
|
|
||||||
|
|
||||||
|
def test_computed_var_cached_depends_on_non_cached():
|
||||||
|
"""Test that a cached_var is recalculated if it depends on non-cached ComputedVar."""
|
||||||
|
|
||||||
|
class ComputedState(State):
|
||||||
|
v: int = 0
|
||||||
|
|
||||||
|
@pc.var
|
||||||
|
def no_cache_v(self) -> int:
|
||||||
|
return self.v
|
||||||
|
|
||||||
|
@pc.cached_var
|
||||||
|
def dep_v(self) -> int:
|
||||||
|
return self.no_cache_v
|
||||||
|
|
||||||
|
@pc.cached_var
|
||||||
|
def comp_v(self) -> int:
|
||||||
|
return self.v
|
||||||
|
|
||||||
|
cs = ComputedState()
|
||||||
|
assert cs.dirty_vars == set()
|
||||||
|
assert cs.get_delta() == {cs.get_name(): {"no_cache_v": 0, "dep_v": 0}}
|
||||||
|
cs.clean()
|
||||||
|
assert cs.dirty_vars == set()
|
||||||
|
assert cs.get_delta() == {cs.get_name(): {"no_cache_v": 0, "dep_v": 0}}
|
||||||
|
cs.clean()
|
||||||
|
assert cs.dirty_vars == set()
|
||||||
|
cs.v = 1
|
||||||
|
assert cs.dirty_vars == {"v", "comp_v", "dep_v", "no_cache_v"}
|
||||||
|
assert cs.get_delta() == {
|
||||||
|
cs.get_name(): {"v": 1, "no_cache_v": 1, "dep_v": 1, "comp_v": 1}
|
||||||
|
}
|
||||||
|
cs.clean()
|
||||||
|
assert cs.dirty_vars == set()
|
||||||
|
assert cs.get_delta() == {cs.get_name(): {"no_cache_v": 1, "dep_v": 1}}
|
||||||
|
cs.clean()
|
||||||
|
assert cs.dirty_vars == set()
|
||||||
|
assert cs.get_delta() == {cs.get_name(): {"no_cache_v": 1, "dep_v": 1}}
|
||||||
|
cs.clean()
|
||||||
|
assert cs.dirty_vars == set()
|
||||||
|
|
||||||
|
|
||||||
|
def test_computed_var_depends_on_parent_non_cached():
|
||||||
|
"""Child state cached_var that depends on parent state un cached var is always recalculated."""
|
||||||
|
counter = 0
|
||||||
|
|
||||||
|
class ParentState(State):
|
||||||
|
@pc.var
|
||||||
|
def no_cache_v(self) -> int:
|
||||||
|
nonlocal counter
|
||||||
|
counter += 1
|
||||||
|
return counter
|
||||||
|
|
||||||
|
class ChildState(ParentState):
|
||||||
|
@pc.cached_var
|
||||||
|
def dep_v(self) -> int:
|
||||||
|
return self.no_cache_v
|
||||||
|
|
||||||
|
ps = ParentState()
|
||||||
|
cs = ps.substates[ChildState.get_name()]
|
||||||
|
|
||||||
|
assert ps.dirty_vars == set()
|
||||||
|
assert cs.dirty_vars == set()
|
||||||
|
|
||||||
|
assert ps.dict() == {
|
||||||
|
cs.get_name(): {"dep_v": 2},
|
||||||
|
"no_cache_v": 1,
|
||||||
|
IS_HYDRATED: False,
|
||||||
|
}
|
||||||
|
assert ps.dict() == {
|
||||||
|
cs.get_name(): {"dep_v": 4},
|
||||||
|
"no_cache_v": 3,
|
||||||
|
IS_HYDRATED: False,
|
||||||
|
}
|
||||||
|
assert ps.dict() == {
|
||||||
|
cs.get_name(): {"dep_v": 6},
|
||||||
|
"no_cache_v": 5,
|
||||||
|
IS_HYDRATED: False,
|
||||||
|
}
|
||||||
|
assert counter == 6
|
||||||
|
Loading…
Reference in New Issue
Block a user