diff --git a/pynecone/state.py b/pynecone/state.py index d49d89257..e022c764b 100644 --- a/pynecone/state.py +++ b/pynecone/state.py @@ -74,12 +74,6 @@ class State(Base, ABC, extra=pydantic.Extra.allow): # Mapping of var name to set of computed variables that depend on it computed_var_dependencies: Dict[str, Set[str]] = {} - # Whether to track accessed vars. - track_vars: bool = False - - # The current set of accessed vars during tracking. - tracked_vars: Set[str] = set() - def __init__(self, *args, parent_state: Optional[State] = None, **kwargs): """Initialize the state. @@ -102,22 +96,15 @@ class State(Base, ABC, extra=pydantic.Extra.allow): fn.__qualname__ = event_handler.fn.__qualname__ # type: ignore setattr(self, name, fn) - # Initialize the mutable fields. - self._init_mutable_fields() - # Initialize computed vars dependencies. self.computed_var_dependencies = defaultdict(set) - for cvar in self.computed_vars: - self.tracked_vars = set() - - # Enable tracking and get the computed var. - self.track_vars = True - self.__getattribute__(cvar) - self.track_vars = False - + for cvar_name, cvar in self.computed_vars.items(): # Add the dependencies. - for var in self.tracked_vars: - self.computed_var_dependencies[var].add(cvar) + for var in cvar.deps(): + self.computed_var_dependencies[var].add(cvar_name) + + # Initialize the mutable fields. + self._init_mutable_fields() def _init_mutable_fields(self): """Initialize mutable fields. @@ -199,7 +186,6 @@ class State(Base, ABC, extra=pydantic.Extra.allow): **cls.base_vars, **cls.computed_vars, } - cls.computed_var_dependencies = {} cls.event_handlers = {} # Setup the base vars at the class level. @@ -233,8 +219,6 @@ class State(Base, ABC, extra=pydantic.Extra.allow): "dirty_substates", "router_data", "computed_var_dependencies", - "track_vars", - "tracked_vars", } @classmethod @@ -508,8 +492,6 @@ class State(Base, ABC, extra=pydantic.Extra.allow): If the var is inherited, get the var from the parent state. - If the Var is a dependent of a ComputedVar, track this status in computed_var_dependencies. - Args: name: The name of the var. @@ -520,17 +502,6 @@ class State(Base, ABC, extra=pydantic.Extra.allow): if not super().__getattribute__("__dict__"): return super().__getattribute__(name) - # Check if tracking is enabled. - if super().__getattribute__("track_vars"): - # Get the non-computed vars. - all_vars = { - **super().__getattribute__("vars"), - **super().__getattribute__("backend_vars"), - } - # Add the var to the tracked vars. - if name in all_vars: - super().__getattribute__("tracked_vars").add(name) - inherited_vars = { **super().__getattribute__("inherited_vars"), **super().__getattribute__("inherited_backend_vars"), @@ -676,55 +647,58 @@ class State(Base, ABC, extra=pydantic.Extra.allow): # Return the state update. return StateUpdate(delta=delta, events=events) - def _dirty_computed_vars( - self, from_vars: Optional[Set[str]] = None, check: bool = False - ) -> Set[str]: - """Get ComputedVars that need to be recomputed based on dirty_vars. + def _mark_dirty_computed_vars(self) -> None: + """Mark ComputedVars that need to be recalculated based on dirty_vars.""" + dirty_vars = self.dirty_vars + while dirty_vars: + calc_vars, dirty_vars = dirty_vars, set() + for cvar in self._dirty_computed_vars(from_vars=calc_vars): + self.dirty_vars.add(cvar) + dirty_vars.add(cvar) + actual_var = self.computed_vars.get(cvar) + if actual_var: + actual_var.mark_dirty(instance=self) + + def _dirty_computed_vars(self, from_vars: Optional[Set[str]] = None) -> Set[str]: + """Determine ComputedVars that need to be recalculated based on the given vars. Args: from_vars: find ComputedVar that depend on this set of vars. If unspecified, will use the dirty_vars. - check: Whether to perform the check. Returns: Set of computed vars to include in the delta. """ - # If checking is disabled, return all computed vars. - if not check: - return set(self.computed_vars) - - # Return only the computed vars that depend on the dirty vars. return set( cvar for dirty_var in from_vars or self.dirty_vars - for cvar in self.computed_vars - if cvar in self.computed_var_dependencies.get(dirty_var, set()) + for cvar in self.computed_var_dependencies[dirty_var] ) - def get_delta(self, check: bool = False) -> Delta: + def get_delta(self) -> Delta: """Get the delta for the state. - Args: - check: Whether to check for dirty computed vars. - Returns: The delta for the state. """ delta = {} - # Return the dirty vars, as well as computed vars depending on dirty vars. - subdelta = { - prop: getattr(self, prop) - for prop in self.dirty_vars | self._dirty_computed_vars(check=check) - if not types.is_backend_variable(prop) - } - if len(subdelta) > 0: - delta[self.get_full_name()] = subdelta - # Recursively find the substate deltas. substates = self.substates for substate in self.dirty_substates: delta.update(substates[substate].get_delta()) + # Return the dirty vars and dependent computed vars + delta_vars = self.dirty_vars.intersection(self.base_vars).union( + self._dirty_computed_vars() + ) + subdelta = { + prop: getattr(self, prop) + for prop in delta_vars + if not types.is_backend_variable(prop) + } + if len(subdelta) > 0: + delta[self.get_full_name()] = subdelta + # Format the delta. delta = format.format_state(delta) @@ -737,6 +711,10 @@ class State(Base, ABC, extra=pydantic.Extra.allow): self.parent_state.dirty_substates.add(self.get_name()) self.parent_state.mark_dirty() + # have to mark computed vars dirty to allow access to newly computed + # values within the same ComputedVar function + self._mark_dirty_computed_vars() + def clean(self): """Reset the dirty vars.""" # Recursively clean the substates. diff --git a/pynecone/var.py b/pynecone/var.py index ceafc15fc..74ae1c80c 100644 --- a/pynecone/var.py +++ b/pynecone/var.py @@ -1,10 +1,13 @@ """Define a state var.""" from __future__ import annotations +import contextlib +import dis import json import random import string from abc import ABC +from types import FunctionType from typing import ( TYPE_CHECKING, Any, @@ -12,9 +15,11 @@ from typing import ( Dict, List, Optional, + Set, Type, Union, _GenericAlias, # type: ignore + cast, get_type_hints, ) @@ -801,6 +806,84 @@ class ComputedVar(property, Var): assert self.fget is not None, "Var must have a getter." return self.fget.__name__ + @property + def cache_attr(self) -> str: + """Get the attribute used to cache the value on the instance. + + Returns: + An attribute name. + """ + return f"__cached_{self.name}" + + def __get__(self, instance, owner): + """Get the ComputedVar value. + + If the value is already cached on the instance, return the cached value. + + If this ComputedVar doesn't know what type of object it is attached to, then save + a reference as self.__objclass__. + + Args: + instance: the instance of the class accessing this computed var. + owner: the class that this descriptor is attached to. + + Returns: + The value of the var for the given instance. + """ + if not hasattr(self, "__objclass__"): + self.__objclass__ = owner + + if instance is None: + return super().__get__(instance, owner) + + # handle caching + if not hasattr(instance, self.cache_attr): + setattr(instance, self.cache_attr, super().__get__(instance, owner)) + return getattr(instance, self.cache_attr) + + def deps(self, obj: Optional[FunctionType] = None) -> Set[str]: + """Determine var dependencies of this ComputedVar. + + Save references to attributes accessed on "self". Recursively called + when the function makes a method call on "self". + + Args: + obj: the object to disassemble (defaults to the fget function). + + Returns: + A set of variable names accessed by the given obj. + """ + d = set() + if obj is None: + if self.fget is not None: + obj = cast(FunctionType, self.fget) + else: + return set() + if not obj.__code__.co_varnames: + # cannot reference self if method takes no args + return set() + self_name = obj.__code__.co_varnames[0] + self_is_top_of_stack = False + for instruction in dis.get_instructions(obj): + if instruction.opname == "LOAD_FAST" and instruction.argval == self_name: + self_is_top_of_stack = True + continue + if self_is_top_of_stack and instruction.opname == "LOAD_ATTR": + d.add(instruction.argval) + elif self_is_top_of_stack and instruction.opname == "LOAD_METHOD": + d.update(self.deps(obj=getattr(self.__objclass__, instruction.argval))) + self_is_top_of_stack = False + return d + + def mark_dirty(self, instance) -> None: + """Mark this ComputedVar as dirty. + + Args: + instance: the state instance that needs to recompute the value. + """ + with contextlib.suppress(AttributeError): + delattr(instance, self.cache_attr) + @property def type_(self): """Get the type of the var. diff --git a/tests/test_state.py b/tests/test_state.py index 87508980c..260c650b5 100644 --- a/tests/test_state.py +++ b/tests/test_state.py @@ -485,11 +485,11 @@ def test_set_dirty_var(test_state): # Setting a var should mark it as dirty. test_state.num1 = 1 - assert test_state.dirty_vars == {"num1"} + assert test_state.dirty_vars == {"num1", "sum"} # Setting another var should mark it as dirty. test_state.num2 = 2 - assert test_state.dirty_vars == {"num1", "num2"} + assert test_state.dirty_vars == {"num1", "num2", "sum"} # Cleaning the state should remove all dirty vars. test_state.clean() @@ -578,7 +578,7 @@ async def test_process_event_simple(test_state): assert test_state.num1 == 69 # The delta should contain the changes, including computed vars. - assert update.delta == {"test_state": {"num1": 69, "sum": 72.14, "upper": ""}} + assert update.delta == {"test_state": {"num1": 69, "sum": 72.14}} assert update.events == [] @@ -601,7 +601,6 @@ async def test_process_event_substate(test_state, child_state, grandchild_state) assert child_state.value == "HI" assert child_state.count == 24 assert update.delta == { - "test_state": {"sum": 3.14, "upper": ""}, "test_state.child_state": {"value": "HI", "count": 24}, } test_state.clean() @@ -616,7 +615,6 @@ async def test_process_event_substate(test_state, child_state, grandchild_state) update = await test_state._process(event) assert grandchild_state.value2 == "new" assert update.delta == { - "test_state": {"sum": 3.14, "upper": ""}, "test_state.child_state.grandchild_state": {"value2": "new"}, } @@ -791,7 +789,7 @@ def test_not_dirty_computed_var_from_var(interdependent_state): interdependent_state: A state with varying Var dependencies. """ interdependent_state.x = 5 - assert interdependent_state.get_delta(check=True) == { + assert interdependent_state.get_delta() == { interdependent_state.get_full_name(): {"x": 5}, } @@ -806,7 +804,7 @@ def test_dirty_computed_var_from_var(interdependent_state): interdependent_state: A state with varying Var dependencies. """ interdependent_state.v1 = 1 - assert interdependent_state.get_delta(check=True) == { + assert interdependent_state.get_delta() == { interdependent_state.get_full_name(): {"v1": 1, "v1x2": 2, "v1x2x2": 4}, } @@ -818,7 +816,7 @@ def test_dirty_computed_var_from_backend_var(interdependent_state): interdependent_state: A state with varying Var dependencies. """ interdependent_state._v2 = 2 - assert interdependent_state.get_delta(check=True) == { + assert interdependent_state.get_delta() == { interdependent_state.get_full_name(): {"v2x2": 4}, } @@ -860,6 +858,7 @@ def test_conditional_computed_vars(): assert ms._dirty_computed_vars(from_vars={"flag"}) == {"rendered_var"} assert ms._dirty_computed_vars(from_vars={"t2"}) == {"rendered_var"} assert ms._dirty_computed_vars(from_vars={"t1"}) == {"rendered_var"} + assert ms.computed_vars["rendered_var"].deps() == {"flag", "t1", "t2"} def test_event_handlers_convert_to_fns(test_state, child_state): @@ -896,3 +895,29 @@ def test_event_handlers_call_other_handlers(): ms = MainState() ms.set_v2(1) assert ms.v == 1 + + +def test_computed_var_cached(): + """Test that a ComputedVar doesn't recalculate when accessed.""" + comp_v_calls = 0 + + class ComputedState(State): + v: int = 0 + + @ComputedVar + def comp_v(self) -> int: + nonlocal comp_v_calls + comp_v_calls += 1 + return self.v + + cs = ComputedState() + assert cs.dict()["v"] == 0 + assert comp_v_calls == 1 + assert cs.dict()["comp_v"] == 0 + assert comp_v_calls == 1 + assert cs.comp_v == 0 + assert comp_v_calls == 1 + cs.v = 1 + assert comp_v_calls == 1 + assert cs.comp_v == 1 + assert comp_v_calls == 2