diff --git a/pynecone/state.py b/pynecone/state.py index 7e88bcba3..cded54416 100644 --- a/pynecone/state.py +++ b/pynecone/state.py @@ -673,15 +673,23 @@ class State(Base, ABC, extra=pydantic.Extra.allow): # Return the state update. return StateUpdate(delta=delta, events=events) - def _dirty_computed_vars(self, from_vars: Optional[Set[str]] = None) -> Set[str]: + def _dirty_computed_vars( + self, from_vars: Optional[Set[str]] = None, check: bool = False + ) -> Set[str]: """Get ComputedVars that need to be recomputed based on dirty_vars. Args: from_vars: find ComputedVar that depend on this set of vars. If unspecified, will use the dirty_vars. + check: Whether to perform the check. Returns: Set of computed vars to include in the delta. """ + # If checking is disabled, return all computed vars. + if not check: + return set(self.computed_vars) + + # Return only the computed vars that depend on the dirty vars. return set( cvar for dirty_var in from_vars or self.dirty_vars @@ -689,9 +697,12 @@ class State(Base, ABC, extra=pydantic.Extra.allow): if cvar in self.computed_var_dependencies.get(dirty_var, set()) ) - def get_delta(self) -> Delta: + def get_delta(self, check: bool = False) -> Delta: """Get the delta for the state. + Args: + check: Whether to check for dirty computed vars. + Returns: The delta for the state. """ @@ -700,7 +711,7 @@ class State(Base, ABC, extra=pydantic.Extra.allow): # Return the dirty vars, as well as computed vars depending on dirty vars. subdelta = { prop: getattr(self, prop) - for prop in self.dirty_vars | self._dirty_computed_vars() + for prop in self.dirty_vars | self._dirty_computed_vars(check=check) if not types.is_backend_variable(prop) } if len(subdelta) > 0: diff --git a/tests/test_state.py b/tests/test_state.py index fad15b2d2..4b1e43802 100644 --- a/tests/test_state.py +++ b/tests/test_state.py @@ -578,7 +578,7 @@ async def test_process_event_simple(test_state): assert test_state.num1 == 69 # The delta should contain the changes, including computed vars. - assert update.delta == {"test_state": {"num1": 69, "sum": 72.14}} + assert update.delta == {"test_state": {"num1": 69, "sum": 72.14, "upper": ""}} assert update.events == [] @@ -601,6 +601,7 @@ async def test_process_event_substate(test_state, child_state, grandchild_state) assert child_state.value == "HI" assert child_state.count == 24 assert update.delta == { + "test_state": {"sum": 3.14, "upper": ""}, "test_state.child_state": {"value": "HI", "count": 24}, } test_state.clean() @@ -615,6 +616,7 @@ async def test_process_event_substate(test_state, child_state, grandchild_state) update = await test_state._process(event) assert grandchild_state.value2 == "new" assert update.delta == { + "test_state": {"sum": 3.14, "upper": ""}, "test_state.child_state.grandchild_state": {"value2": "new"}, } @@ -783,7 +785,7 @@ def test_not_dirty_computed_var_from_var(interdependent_state): interdependent_state: A state with varying Var dependencies. """ interdependent_state.x = 5 - assert interdependent_state.get_delta() == { + assert interdependent_state.get_delta(check=True) == { interdependent_state.get_full_name(): {"x": 5}, } @@ -798,7 +800,7 @@ def test_dirty_computed_var_from_var(interdependent_state): interdependent_state: A state with varying Var dependencies. """ interdependent_state.v1 = 1 - assert interdependent_state.get_delta() == { + assert interdependent_state.get_delta(check=True) == { interdependent_state.get_full_name(): {"v1": 1, "v1x2": 2, "v1x2x2": 4}, } @@ -810,12 +812,14 @@ def test_dirty_computed_var_from_backend_var(interdependent_state): interdependent_state: A state with varying Var dependencies. """ interdependent_state._v2 = 2 - assert interdependent_state.get_delta() == { + assert interdependent_state.get_delta(check=True) == { interdependent_state.get_full_name(): {"v2x2": 4}, } def test_child_state(): + """Test that the child state computed vars can reference parent state vars.""" + class MainState(State): v: int = 2 @@ -829,3 +833,24 @@ def test_child_state(): assert ms.v == 2 assert cs.v == 2 assert cs.rendered_var == 2 + + +def test_conditional_computed_vars(): + """Test that computed vars can have conditionals.""" + + class MainState(State): + flag: bool = False + t1: str = "a" + t2: str = "b" + + @ComputedVar + def rendered_var(self) -> str: + if self.flag: + return self.t1 + return self.t2 + + ms = MainState() + # Initially there are no dirty computed vars. + assert ms._dirty_computed_vars(from_vars={"flag"}) == {"rendered_var"} + assert ms._dirty_computed_vars(from_vars={"t2"}) == {"rendered_var"} + assert ms._dirty_computed_vars(from_vars={"t1"}) == {"rendered_var"}