diff --git a/pynecone/state.py b/pynecone/state.py index e5c452e05..0417b7380 100644 --- a/pynecone/state.py +++ b/pynecone/state.py @@ -5,7 +5,18 @@ import asyncio import functools import traceback from abc import ABC -from typing import Any, Callable, ClassVar, Dict, List, Optional, Sequence, Set, Type +from typing import ( + Any, + Callable, + ClassVar, + Dict, + List, + Optional, + Sequence, + Set, + Type, + Union, +) import cloudpickle from redis import Redis @@ -13,7 +24,7 @@ from redis import Redis from pynecone import constants, utils from pynecone.base import Base from pynecone.event import Event, EventHandler, window_alert -from pynecone.var import BaseVar, ComputedVar, PCList, Var +from pynecone.var import BaseVar, ComputedVar, PCDict, PCList, Var Delta = Dict[str, Any] @@ -79,7 +90,7 @@ class State(Base, ABC): value, self._reassign_field, field.name ) - if utils._issubclass(field.type_, List): + if utils._issubclass(field.type_, Union[List, Dict]): setattr(self, field.name, value_in_pc_data) self.clean() @@ -727,5 +738,7 @@ def _convert_mutable_datatypes( field_value[key] = _convert_mutable_datatypes( value, reassign_field, field_name ) - + field_value = PCDict( + field_value, reassign_field=reassign_field, field_name=field_name + ) return field_value diff --git a/pynecone/var.py b/pynecone/var.py index 27ba9e991..6a0d31283 100644 --- a/pynecone/var.py +++ b/pynecone/var.py @@ -773,72 +773,155 @@ class PCList(list): super().__init__(original_list) - def append(self, *args, **kargs): + def append(self, *args, **kwargs): """Append. Args: args: The args passed. - kargs: The kwargs passed. + kwargs: The kwargs passed. """ - super().append(*args, **kargs) + super().append(*args, **kwargs) self._reassign_field() - def __setitem__(self, *args, **kargs): + def __setitem__(self, *args, **kwargs): """Set item. Args: args: The args passed. - kargs: The kwargs passed. + kwargs: The kwargs passed. """ - super().__setitem__(*args, **kargs) + super().__setitem__(*args, **kwargs) self._reassign_field() - def __delitem__(self, *args, **kargs): + def __delitem__(self, *args, **kwargs): """Delete item. Args: args: The args passed. - kargs: The kwargs passed. + kwargs: The kwargs passed. """ - super().__delitem__(*args, **kargs) + super().__delitem__(*args, **kwargs) self._reassign_field() - def clear(self, *args, **kargs): + def clear(self, *args, **kwargs): """Remove all item from the list. Args: args: The args passed. - kargs: The kwargs passed. + kwargs: The kwargs passed. """ - super().clear(*args, **kargs) + super().clear(*args, **kwargs) self._reassign_field() - def extend(self, *args, **kargs): + def extend(self, *args, **kwargs): """Add all item of a list to the end of the list. Args: args: The args passed. - kargs: The kwargs passed. + kwargs: The kwargs passed. """ - super().extend(*args, **kargs) + super().extend(*args, **kwargs) self._reassign_field() if hasattr(self, "_reassign_field") else None - def pop(self, *args, **kargs): + def pop(self, *args, **kwargs): """Remove an element. Args: args: The args passed. - kargs: The kwargs passed. + kwargs: The kwargs passed. """ - super().pop(*args, **kargs) + super().pop(*args, **kwargs) self._reassign_field() - def remove(self, *args, **kargs): + def remove(self, *args, **kwargs): """Remove an element. Args: args: The args passed. - kargs: The kwargs passed. + kwargs: The kwargs passed. """ - super().remove(*args, **kargs) + super().remove(*args, **kwargs) + self._reassign_field() + + +class PCDict(dict): + """A custom dict that pynecone can detect its mutation.""" + + def __init__( + self, + original_dict: Dict, + reassign_field: Callable = lambda _field_name: None, + field_name: str = "", + ): + """Initialize PCDict. + + Args: + original_dict: The original dict + reassign_field: + The method in the parent state to reassign the field. + Default to be a no-op function + field_name: the name of field in the parent state + """ + super().__init__(original_dict) + self._reassign_field = lambda: reassign_field(field_name) + + def clear(self): + """Remove all item from the list.""" + super().clear() + + self._reassign_field() + + def setdefault(self, *args, **kwargs): + """set default. + + Args: + args: The args passed. + kwargs: The kwargs passed. + """ + super().setdefault(*args, **kwargs) + self._reassign_field() + + def popitem(self): + """Pop last item.""" + super().popitem() + self._reassign_field() + + def pop(self, k, d=None): + """Remove an element. + + Args: + k: The args passed. + d: The kwargs passed. + """ + super().pop(k, d) + self._reassign_field() + + def update(self, *args, **kwargs): + """update dict. + + Args: + args: The args passed. + kwargs: The kwargs passed. + """ + super().update(*args, **kwargs) + self._reassign_field() + + def __setitem__(self, *args, **kwargs): + """set item. + + Args: + args: The args passed. + kwargs: The kwargs passed. + """ + super().__setitem__(*args, **kwargs) + self._reassign_field() if hasattr(self, "_reassign_field") else None + + def __delitem__(self, *args, **kwargs): + """delete item. + + Args: + args: The args passed. + kwargs: The kwargs passed. + """ + super().__delitem__(*args, **kwargs) self._reassign_field() diff --git a/tests/conftest.py b/tests/conftest.py index 8e15c0a74..08a3585b3 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -4,6 +4,8 @@ from typing import Generator import pytest +from pynecone.state import State + @pytest.fixture(scope="function") def windows_platform() -> Generator: @@ -13,3 +15,122 @@ def windows_platform() -> Generator: whether system is windows. """ yield platform.system() == "Windows" + + +@pytest.fixture +def list_mutation_state(): + """A fixture to create a state with list mutation features. + + Returns: + A state with list mutation features. + """ + + class TestState(State): + """The test state.""" + + # plain list + plain_friends = ["Tommy"] + + def make_friend(self): + self.plain_friends.append("another-fd") + + def change_first_friend(self): + self.plain_friends[0] = "Jenny" + + def unfriend_all_friends(self): + self.plain_friends.clear() + + def unfriend_first_friend(self): + del self.plain_friends[0] + + def remove_last_friend(self): + self.plain_friends.pop() + + def make_friends_with_colleagues(self): + colleagues = ["Peter", "Jimmy"] + self.plain_friends.extend(colleagues) + + def remove_tommy(self): + self.plain_friends.remove("Tommy") + + # list in dict + friends_in_dict = {"Tommy": ["Jenny"]} + + def remove_jenny_from_tommy(self): + self.friends_in_dict["Tommy"].remove("Jenny") + + def add_jimmy_to_tommy_friends(self): + self.friends_in_dict["Tommy"].append("Jimmy") + + def tommy_has_no_fds(self): + self.friends_in_dict["Tommy"].clear() + + # nested list + friends_in_nested_list = [["Tommy"], ["Jenny"]] + + def remove_first_group(self): + self.friends_in_nested_list.pop(0) + + def remove_first_person_from_first_group(self): + self.friends_in_nested_list[0].pop(0) + + def add_jimmy_to_second_group(self): + self.friends_in_nested_list[1].append("Jimmy") + + return TestState() + + +@pytest.fixture +def dict_mutation_state(): + """A fixture to create a state with dict mutation features. + + Returns: + A state with dict mutation features. + """ + + class TestState(State): + """The test state.""" + + # plain dict + details = {"name": "Tommy"} + + def add_age(self): + self.details.update({"age": 20}) # type: ignore + + def change_name(self): + self.details["name"] = "Jenny" + + def remove_last_detail(self): + self.details.popitem() + + def clear_details(self): + self.details.clear() + + def remove_name(self): + del self.details["name"] + + def pop_out_age(self): + self.details.pop("age") + + # dict in list + address = [{"home": "home address"}, {"work": "work address"}] + + def remove_home_address(self): + self.address[0].pop("home") + + def add_street_to_home_address(self): + self.address[0]["street"] = "street address" + + # nested dict + friend_in_nested_dict = {"name": "Nikhil", "friend": {"name": "Alek"}} + + def change_friend_name(self): + self.friend_in_nested_dict["friend"]["name"] = "Tommy" + + def remove_friend(self): + self.friend_in_nested_dict.pop("friend") + + def add_friend_age(self): + self.friend_in_nested_dict["friend"]["age"] = 30 + + return TestState() diff --git a/tests/test_app.py b/tests/test_app.py index 169cead84..eb2a71dc9 100644 --- a/tests/test_app.py +++ b/tests/test_app.py @@ -164,69 +164,6 @@ def test_set_and_get_state(TestState: Type[State]): assert state2.var == 2 # type: ignore -@pytest.fixture -def list_mutation_state(): - """A fixture to create a state with list mutation features. - - Returns: - A state with list mutation features. - """ - - class TestState(State): - """The test state.""" - - # plain list - plain_friends = ["Tommy"] - - def make_friend(self): - self.plain_friends.append("another-fd") - - def change_first_friend(self): - self.plain_friends[0] = "Jenny" - - def unfriend_all_friends(self): - self.plain_friends.clear() - - def unfriend_first_friend(self): - del self.plain_friends[0] - - def remove_last_friend(self): - self.plain_friends.pop() - - def make_friends_with_colleagues(self): - colleagues = ["Peter", "Jimmy"] - self.plain_friends.extend(colleagues) - - def remove_tommy(self): - self.plain_friends.remove("Tommy") - - # list in dict - friends_in_dict = {"Tommy": ["Jenny"]} - - def remove_jenny_from_tommy(self): - self.friends_in_dict["Tommy"].remove("Jenny") - - def add_jimmy_to_tommy_friends(self): - self.friends_in_dict["Tommy"].append("Jimmy") - - def tommy_has_no_fds(self): - self.friends_in_dict["Tommy"].clear() - - # nested list - friends_in_nested_list = [["Tommy"], ["Jenny"]] - - def remove_first_group(self): - self.friends_in_nested_list.pop(0) - - def remove_first_person_from_first_group(self): - self.friends_in_nested_list[0].pop(0) - - def add_jimmy_to_second_group(self): - self.friends_in_nested_list[1].append("Jimmy") - - return TestState() - - @pytest.mark.asyncio @pytest.mark.parametrize( "event_tuples", @@ -343,3 +280,130 @@ async def test_list_mutation_detection__plain_list( ) assert result.delta == expected_delta + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + "event_tuples", + [ + pytest.param( + [ + ( + "test_state.add_age", + {"test_state": {"details": {"name": "Tommy", "age": 20}}}, + ), + ( + "test_state.change_name", + {"test_state": {"details": {"name": "Jenny", "age": 20}}}, + ), + ( + "test_state.remove_last_detail", + {"test_state": {"details": {"name": "Jenny"}}}, + ), + ], + id="update then __setitem__", + ), + pytest.param( + [ + ( + "test_state.clear_details", + {"test_state": {"details": {}}}, + ), + ( + "test_state.add_age", + {"test_state": {"details": {"age": 20}}}, + ), + ], + id="delitem then update", + ), + pytest.param( + [ + ( + "test_state.add_age", + {"test_state": {"details": {"name": "Tommy", "age": 20}}}, + ), + ( + "test_state.remove_name", + {"test_state": {"details": {"age": 20}}}, + ), + ( + "test_state.pop_out_age", + {"test_state": {"details": {}}}, + ), + ], + id="add, remove, pop", + ), + pytest.param( + [ + ( + "test_state.remove_home_address", + {"test_state": {"address": [{}, {"work": "work address"}]}}, + ), + ( + "test_state.add_street_to_home_address", + { + "test_state": { + "address": [ + {"street": "street address"}, + {"work": "work address"}, + ] + } + }, + ), + ], + id="dict in list", + ), + pytest.param( + [ + ( + "test_state.change_friend_name", + { + "test_state": { + "friend_in_nested_dict": { + "name": "Nikhil", + "friend": {"name": "Tommy"}, + } + } + }, + ), + ( + "test_state.add_friend_age", + { + "test_state": { + "friend_in_nested_dict": { + "name": "Nikhil", + "friend": {"name": "Tommy", "age": 30}, + } + } + }, + ), + ( + "test_state.remove_friend", + {"test_state": {"friend_in_nested_dict": {"name": "Nikhil"}}}, + ), + ], + id="nested dict", + ), + ], +) +async def test_dict_mutation_detection__plain_list( + event_tuples: List[Tuple[str, List[str]]], dict_mutation_state: State +): + """Test dict mutation detection + when reassignment is not explicitly included in the logic. + + Args: + event_tuples: From parametrization. + dict_mutation_state: A state with dict mutation features. + """ + for event_name, expected_delta in event_tuples: + result = await dict_mutation_state.process( + Event( + token="fake-token", + name=event_name, + router_data={"pathname": "/", "query": {}}, + payload={}, + ) + ) + + assert result.delta == expected_delta diff --git a/tests/test_var.py b/tests/test_var.py index 925907c48..80bea3bde 100644 --- a/tests/test_var.py +++ b/tests/test_var.py @@ -4,7 +4,7 @@ import cloudpickle import pytest from pynecone.base import Base -from pynecone.var import BaseVar, PCList, Var +from pynecone.var import BaseVar, PCDict, PCList, Var test_vars = [ BaseVar(name="prop1", type_=int), @@ -218,3 +218,13 @@ def test_pickleable_pc_list(): pickled_list = cloudpickle.dumps(pc_list) assert cloudpickle.loads(pickled_list) == pc_list + + +def test_pickleable_pc_dict(): + """Test that PCDict is pickleable.""" + pc_dict = PCDict( + original_dict={1: 2, 3: 4}, reassign_field=lambda x: x, field_name="random" + ) + + pickled_dict = cloudpickle.dumps(pc_dict) + assert cloudpickle.loads(pickled_dict) == pc_dict