diff --git a/pynecone/.templates/web/pynecone.json b/pynecone/.templates/web/pynecone.json index 7a17a2ddf..89f1f2322 100644 --- a/pynecone/.templates/web/pynecone.json +++ b/pynecone/.templates/web/pynecone.json @@ -1,3 +1,3 @@ { - "version": "0.1.21" + "version": "0.1.25" } diff --git a/pynecone/state.py b/pynecone/state.py index 5c8885599..7e88bcba3 100644 --- a/pynecone/state.py +++ b/pynecone/state.py @@ -3,9 +3,9 @@ from __future__ import annotations import asyncio import functools -import inspect import traceback from abc import ABC +from collections import defaultdict from typing import ( Any, Callable, @@ -15,7 +15,6 @@ from typing import ( Optional, Sequence, Set, - Tuple, Type, Union, ) @@ -54,9 +53,6 @@ class State(Base, ABC, extra=pydantic.Extra.allow): # Backend vars inherited inherited_backend_vars: ClassVar[Dict[str, Any]] = {} - # Mapping of var name to set of computed variables that depend on it - computed_var_dependencies: ClassVar[Dict[str, Set[str]]] = {} - # The event handlers. event_handlers: ClassVar[Dict[str, EventHandler]] = {} @@ -75,18 +71,29 @@ class State(Base, ABC, extra=pydantic.Extra.allow): # The routing path that triggered the state router_data: Dict[str, Any] = {} - def __init__(self, *args, **kwargs): + # Mapping of var name to set of computed variables that depend on it + computed_var_dependencies: Dict[str, Set[str]] = {} + + # Whether to track accessed vars. + track_vars: bool = False + + # The current set of accessed vars during tracking. + tracked_vars: Set[str] = set() + + def __init__(self, *args, parent_state: Optional[State] = None, **kwargs): """Initialize the state. Args: *args: The args to pass to the Pydantic init method. + parent_state: The parent state. **kwargs: The kwargs to pass to the Pydantic init method. """ + kwargs["parent_state"] = parent_state super().__init__(*args, **kwargs) # Setup the substates. for substate in self.get_substates(): - self.substates[substate.get_name()] = substate().set(parent_state=self) + self.substates[substate.get_name()] = substate(parent_state=self) # Convert the event handlers to functions. for name, event_handler in self.event_handlers.items(): @@ -95,6 +102,20 @@ class State(Base, ABC, extra=pydantic.Extra.allow): # Initialize the mutable fields. self._init_mutable_fields() + # Initialize computed vars dependencies. + self.computed_var_dependencies = defaultdict(set) + for cvar in self.computed_vars: + self.tracked_vars = set() + + # Enable tracking and get the computed var. + self.track_vars = True + self.__getattribute__(cvar) + self.track_vars = False + + # Add the dependencies. + for var in self.tracked_vars: + self.computed_var_dependencies[var].add(cvar) + def _init_mutable_fields(self): """Initialize mutable fields. @@ -160,17 +181,10 @@ class State(Base, ABC, extra=pydantic.Extra.allow): cls.backend_vars = {**cls.inherited_backend_vars, **cls.new_backend_vars} # Set the base and computed vars. - skip_vars = set(cls.inherited_vars) | { - "parent_state", - "substates", - "dirty_vars", - "dirty_substates", - "router_data", - } cls.base_vars = { f.name: BaseVar(name=f.name, type_=f.outer_type_).set_state(cls) for f in cls.get_fields().values() - if f.name not in skip_vars + if f.name not in cls.get_skip_vars() } cls.computed_vars = { v.name: v.set_state(cls) @@ -202,6 +216,24 @@ class State(Base, ABC, extra=pydantic.Extra.allow): cls.event_handlers[name] = handler setattr(cls, name, handler) + @classmethod + def get_skip_vars(cls) -> Set[str]: + """Get the vars to skip when serializing. + + Returns: + The vars to skip when serializing. + """ + return set(cls.inherited_vars) | { + "parent_state", + "substates", + "dirty_vars", + "dirty_substates", + "router_data", + "computed_var_dependencies", + "track_vars", + "tracked_vars", + } + @classmethod @functools.lru_cache() def get_parent_state(cls) -> Optional[Type[State]]: @@ -481,20 +513,21 @@ class State(Base, ABC, extra=pydantic.Extra.allow): Returns: The value of the var. """ - vars = { - **super().__getattribute__("vars"), - **super().__getattribute__("backend_vars"), - } - if name in vars: - parent_frame, parent_frame_locals = _get_previous_recursive_frame_info() - if parent_frame is not None: - computed_vars = super().__getattribute__("computed_vars") - requesting_attribute_name = parent_frame_locals.get("name") - if requesting_attribute_name in computed_vars: - # Keep track of any ComputedVar that depends on this Var - super().__getattribute__("computed_var_dependencies").setdefault( - name, set() - ).add(requesting_attribute_name) + # If the state hasn't been initialized yet, return the default value. + if not super().__getattribute__("__dict__"): + return super().__getattribute__(name) + + # Check if tracking is enabled. + if super().__getattribute__("track_vars"): + # Get the non-computed vars. + all_vars = { + **super().__getattribute__("vars"), + **super().__getattribute__("backend_vars"), + } + # Add the var to the tracked vars. + if name in all_vars: + super().__getattribute__("tracked_vars").add(name) + inherited_vars = { **super().__getattribute__("inherited_vars"), **super().__getattribute__("inherited_backend_vars"), @@ -649,18 +682,12 @@ class State(Base, ABC, extra=pydantic.Extra.allow): Returns: Set of computed vars to include in the delta. """ - dirty_computed_vars = set( + return set( cvar for dirty_var in from_vars or self.dirty_vars for cvar in self.computed_vars if cvar in self.computed_var_dependencies.get(dirty_var, set()) ) - if dirty_computed_vars: - # recursive call to catch computed vars that depend on computed vars - return dirty_computed_vars | self._dirty_computed_vars( - from_vars=dirty_computed_vars - ) - return dirty_computed_vars def get_delta(self) -> Delta: """Get the delta for the state. @@ -844,24 +871,3 @@ def _convert_mutable_datatypes( field_value, reassign_field=reassign_field, field_name=field_name ) return field_value - - -def _get_previous_recursive_frame_info() -> ( - Tuple[Optional[inspect.FrameInfo], Dict[str, Any]] -): - """Find the previous frame of the same function that calls this helper. - - For example, if this function is called from `State.__getattribute__` - (parent frame), then the returned frame will be the next earliest call - of the same function. - - Returns: - Tuple of (frame_info, local_vars) - - If no previous recursive frame is found up the stack, the frame info will be None. - """ - _this_frame, parent_frame, *prev_frames = inspect.stack() - for frame in prev_frames: - if frame.frame.f_code == parent_frame.frame.f_code: - return frame, frame.frame.f_locals - return None, {} diff --git a/tests/test_state.py b/tests/test_state.py index ba65e2f08..fad15b2d2 100644 --- a/tests/test_state.py +++ b/tests/test_state.py @@ -155,13 +155,7 @@ def test_base_class_vars(test_state): cls = type(test_state) for field in fields: - if field in ( - "parent_state", - "substates", - "dirty_vars", - "dirty_substates", - "router_data", - ): + if field in test_state.get_skip_vars(): continue prop = getattr(cls, field) assert isinstance(prop, BaseVar) @@ -819,3 +813,19 @@ def test_dirty_computed_var_from_backend_var(interdependent_state): assert interdependent_state.get_delta() == { interdependent_state.get_full_name(): {"v2x2": 4}, } + + +def test_child_state(): + class MainState(State): + v: int = 2 + + class ChildState(MainState): + @ComputedVar + def rendered_var(self): + return self.v + + ms = MainState() + cs = ms.substates[ChildState.get_name()] + assert ms.v == 2 + assert cs.v == 2 + assert cs.rendered_var == 2