From d5a76f103ab98daeeb0504ee4477a4811a5d9e5b Mon Sep 17 00:00:00 2001 From: Tommy Dew Date: Sun, 29 Jan 2023 02:50:52 +0800 Subject: [PATCH] Add `pc.list` for mutation detection (#339) --- pynecone/state.py | 71 +++++++++++++++++- pynecone/var.py | 93 +++++++++++++++++++++++ tests/test_app.py | 184 +++++++++++++++++++++++++++++++++++++++++++++- 3 files changed, 346 insertions(+), 2 deletions(-) diff --git a/pynecone/state.py b/pynecone/state.py index 095700128..a5901e9f1 100644 --- a/pynecone/state.py +++ b/pynecone/state.py @@ -13,7 +13,7 @@ from redis import Redis from pynecone import constants, utils from pynecone.base import Base from pynecone.event import Event, EventHandler, window_alert -from pynecone.var import BaseVar, ComputedVar, Var +from pynecone.var import BaseVar, ComputedVar, PCList, Var Delta = Dict[str, Any] @@ -61,6 +61,40 @@ class State(Base, ABC): for substate in self.get_substates(): self.substates[substate.get_name()] = substate().set(parent_state=self) + self._init_mutable_fields() + + def _init_mutable_fields(self): + """Initialize mutable fields. + + So that mutation to them can be detected by the app: + * list + """ + for field in self.base_vars.values(): + value = getattr(self, field.name) + + value_in_pc_data = _convert_mutable_datatypes( + value, self._reassign_field, field.name + ) + + if utils._issubclass(field.type_, List): + setattr(self, field.name, value_in_pc_data) + + self.clean() + + def _reassign_field(self, field_name: str): + """Reassign the given field. + + Primarily for mutation in fields of mutable data types. + + Args: + field_name (str): The name of the field we want to reassign + """ + setattr( + self, + field_name, + getattr(self, field_name), + ) + def __repr__(self) -> str: """Get the string representation of the state. @@ -578,3 +612,38 @@ class StateManager(Base): if self.redis is None: return self.redis.set(token, pickle.dumps(state), ex=self.token_expiration) + + +def _convert_mutable_datatypes( + field_value: Any, reassign_field: Callable, field_name: str +) -> Any: + """Recursively convert mutable data to the Pc data types. + + Note: right now only list & dict would be handled recursively. + + Args: + field_value: The target field_value. + reassign_field: + The function to reassign the field in the parent state. + field_name: the name of the field in the parent state + + Returns: + The converted field_value + """ + if isinstance(field_value, list): + for index in range(len(field_value)): + field_value[index] = _convert_mutable_datatypes( + field_value[index], reassign_field, field_name + ) + + field_value = PCList( + field_value, reassign_field=reassign_field, field_name=field_name + ) + + elif isinstance(field_value, dict): + for key, value in field_value.items(): + field_value[key] = _convert_mutable_datatypes( + value, reassign_field, field_name + ) + + return field_value diff --git a/pynecone/var.py b/pynecone/var.py index 21816bd66..dc1a2ffa2 100644 --- a/pynecone/var.py +++ b/pynecone/var.py @@ -748,3 +748,96 @@ class ComputedVar(property, Var): if "return" in self.fget.__annotations__: return self.fget.__annotations__["return"] return Any + + +class PCList(list): + """A custom list that pynecone can detect its mutation.""" + + def __init__( + self, + original_list: List, + reassign_field: Callable = lambda _field_name: None, + field_name: str = "", + ): + """Initialize PCList. + + Args: + original_list (List): The original list + reassign_field (Callable): + The method in the parent state to reassign the field. + Default to be a no-op function + field_name (str): the name of field in the parent state + """ + self._reassign_field = lambda: reassign_field(field_name) + + super().__init__(original_list) + + def append(self, *args, **kargs): + """Append. + + Args: + args: The args passed. + kargs: The kwargs passed. + """ + super().append(*args, **kargs) + self._reassign_field() + + def __setitem__(self, *args, **kargs): + """Set item. + + Args: + args: The args passed. + kargs: The kwargs passed. + """ + super().__setitem__(*args, **kargs) + self._reassign_field() + + def __delitem__(self, *args, **kargs): + """Delete item. + + Args: + args: The args passed. + kargs: The kwargs passed. + """ + super().__delitem__(*args, **kargs) + self._reassign_field() + + def clear(self, *args, **kargs): + """Remove all item from the list. + + Args: + args: The args passed. + kargs: The kwargs passed. + """ + super().clear(*args, **kargs) + self._reassign_field() + + def extend(self, *args, **kargs): + """Add all item of a list to the end of the list. + + Args: + args: The args passed. + kargs: The kwargs passed. + """ + super().extend(*args, **kargs) + self._reassign_field() + + def pop(self, *args, **kargs): + """Remove an element. + + Args: + args: The args passed. + kargs: The kwargs passed. + """ + super().pop(*args, **kargs) + self._reassign_field() + + def remove(self, *args, **kargs): + """Remove an element. + + Args: + args: The args passed. + kargs: The kwargs passed. + """ + super().remove(*args, **kargs) + self._reassign_field() diff --git a/tests/test_app.py b/tests/test_app.py index 4f1dc09af..e5adc841f 100644 --- a/tests/test_app.py +++ b/tests/test_app.py @@ -1,9 +1,10 @@ -from typing import Type +from typing import List, Tuple, Type import pytest from pynecone.app import App, DefaultState from pynecone.components import Box +from pynecone.event import Event from pynecone.middleware import HydrateMiddleware from pynecone.state import State from pynecone.style import Style @@ -156,3 +157,184 @@ def test_set_and_get_state(TestState: Type[State]): state2 = app.state_manager.get_state(token2) assert state1.var == 1 # type: ignore assert state2.var == 2 # type: ignore + + +@pytest.fixture +def list_mutation_state(): + """A fixture to create a state with list mutation features. + + Returns: + A state with list mutation features. + """ + + class TestState(State): + """The test state.""" + + # plain list + plain_friends = ["Tommy"] + + def make_friend(self): + self.plain_friends.append("another-fd") + + def change_first_friend(self): + self.plain_friends[0] = "Jenny" + + def unfriend_all_friends(self): + self.plain_friends.clear() + + def unfriend_first_friend(self): + del self.plain_friends[0] + + def remove_last_friend(self): + self.plain_friends.pop() + + def make_friends_with_colleagues(self): + colleagues = ["Peter", "Jimmy"] + self.plain_friends.extend(colleagues) + + def remove_tommy(self): + self.plain_friends.remove("Tommy") + + # list in dict + friends_in_dict = {"Tommy": ["Jenny"]} + + def remove_jenny_from_tommy(self): + self.friends_in_dict["Tommy"].remove("Jenny") + + def add_jimmy_to_tommy_friends(self): + self.friends_in_dict["Tommy"].append("Jimmy") + + def tommy_has_no_fds(self): + self.friends_in_dict["Tommy"].clear() + + # nested list + friends_in_nested_list = [["Tommy"], ["Jenny"]] + + def remove_first_group(self): + self.friends_in_nested_list.pop(0) + + def remove_first_person_from_first_group(self): + self.friends_in_nested_list[0].pop(0) + + def add_jimmy_to_second_group(self): + self.friends_in_nested_list[1].append("Jimmy") + + return TestState() + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + "event_tuples", + [ + pytest.param( + [ + ( + "test_state.make_friend", + {"test_state": {"plain_friends": ["Tommy", "another-fd"]}}, + ), + ( + "test_state.change_first_friend", + {"test_state": {"plain_friends": ["Jenny", "another-fd"]}}, + ), + ], + id="append then __setitem__", + ), + pytest.param( + [ + ( + "test_state.unfriend_first_friend", + {"test_state": {"plain_friends": []}}, + ), + ( + "test_state.make_friend", + {"test_state": {"plain_friends": ["another-fd"]}}, + ), + ], + id="delitem then append", + ), + pytest.param( + [ + ( + "test_state.make_friends_with_colleagues", + {"test_state": {"plain_friends": ["Tommy", "Peter", "Jimmy"]}}, + ), + ( + "test_state.remove_tommy", + {"test_state": {"plain_friends": ["Peter", "Jimmy"]}}, + ), + ( + "test_state.remove_last_friend", + {"test_state": {"plain_friends": ["Peter"]}}, + ), + ( + "test_state.unfriend_all_friends", + {"test_state": {"plain_friends": []}}, + ), + ], + id="extend, remove, pop, clear", + ), + pytest.param( + [ + ( + "test_state.add_jimmy_to_second_group", + { + "test_state": { + "friends_in_nested_list": [["Tommy"], ["Jenny", "Jimmy"]] + } + }, + ), + ( + "test_state.remove_first_person_from_first_group", + { + "test_state": { + "friends_in_nested_list": [[], ["Jenny", "Jimmy"]] + } + }, + ), + ( + "test_state.remove_first_group", + {"test_state": {"friends_in_nested_list": [["Jenny", "Jimmy"]]}}, + ), + ], + id="nested list", + ), + pytest.param( + [ + ( + "test_state.add_jimmy_to_tommy_friends", + {"test_state": {"friends_in_dict": {"Tommy": ["Jenny", "Jimmy"]}}}, + ), + ( + "test_state.remove_jenny_from_tommy", + {"test_state": {"friends_in_dict": {"Tommy": ["Jimmy"]}}}, + ), + ( + "test_state.tommy_has_no_fds", + {"test_state": {"friends_in_dict": {"Tommy": []}}}, + ), + ], + id="list in dict", + ), + ], +) +async def test_list_mutation_detection__plain_list( + event_tuples: List[Tuple[str, List[str]]], list_mutation_state: State +): + """Test list mutation detection + when reassignment is not explicitly included in the logic. + + Args: + event_tuples: From parametrization. + list_mutation_state: A state with list mutation features. + """ + for event_name, expected_delta in event_tuples: + result = await list_mutation_state.process( + Event( + token="fake-token", + name=event_name, + router_data={"pathname": "/", "query": {}}, + payload={}, + ) + ) + + assert result.delta == expected_delta