Add pc.list for mutation detection (#339)

This commit is contained in:
Tommy Dew 2023-01-29 02:50:52 +08:00 committed by GitHub
parent 5aae6a122d
commit d5a76f103a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 346 additions and 2 deletions

View File

@ -13,7 +13,7 @@ from redis import Redis
from pynecone import constants, utils
from pynecone.base import Base
from pynecone.event import Event, EventHandler, window_alert
from pynecone.var import BaseVar, ComputedVar, Var
from pynecone.var import BaseVar, ComputedVar, PCList, Var
Delta = Dict[str, Any]
@ -61,6 +61,40 @@ class State(Base, ABC):
for substate in self.get_substates():
self.substates[substate.get_name()] = substate().set(parent_state=self)
self._init_mutable_fields()
def _init_mutable_fields(self):
"""Initialize mutable fields.
So that mutation to them can be detected by the app:
* list
"""
for field in self.base_vars.values():
value = getattr(self, field.name)
value_in_pc_data = _convert_mutable_datatypes(
value, self._reassign_field, field.name
)
if utils._issubclass(field.type_, List):
setattr(self, field.name, value_in_pc_data)
self.clean()
def _reassign_field(self, field_name: str):
"""Reassign the given field.
Primarily for mutation in fields of mutable data types.
Args:
field_name (str): The name of the field we want to reassign
"""
setattr(
self,
field_name,
getattr(self, field_name),
)
def __repr__(self) -> str:
"""Get the string representation of the state.
@ -578,3 +612,38 @@ class StateManager(Base):
if self.redis is None:
return
self.redis.set(token, pickle.dumps(state), ex=self.token_expiration)
def _convert_mutable_datatypes(
field_value: Any, reassign_field: Callable, field_name: str
) -> Any:
"""Recursively convert mutable data to the Pc data types.
Note: right now only list & dict would be handled recursively.
Args:
field_value: The target field_value.
reassign_field:
The function to reassign the field in the parent state.
field_name: the name of the field in the parent state
Returns:
The converted field_value
"""
if isinstance(field_value, list):
for index in range(len(field_value)):
field_value[index] = _convert_mutable_datatypes(
field_value[index], reassign_field, field_name
)
field_value = PCList(
field_value, reassign_field=reassign_field, field_name=field_name
)
elif isinstance(field_value, dict):
for key, value in field_value.items():
field_value[key] = _convert_mutable_datatypes(
value, reassign_field, field_name
)
return field_value

View File

@ -748,3 +748,96 @@ class ComputedVar(property, Var):
if "return" in self.fget.__annotations__:
return self.fget.__annotations__["return"]
return Any
class PCList(list):
"""A custom list that pynecone can detect its mutation."""
def __init__(
self,
original_list: List,
reassign_field: Callable = lambda _field_name: None,
field_name: str = "",
):
"""Initialize PCList.
Args:
original_list (List): The original list
reassign_field (Callable):
The method in the parent state to reassign the field.
Default to be a no-op function
field_name (str): the name of field in the parent state
"""
self._reassign_field = lambda: reassign_field(field_name)
super().__init__(original_list)
def append(self, *args, **kargs):
"""Append.
Args:
args: The args passed.
kargs: The kwargs passed.
"""
super().append(*args, **kargs)
self._reassign_field()
def __setitem__(self, *args, **kargs):
"""Set item.
Args:
args: The args passed.
kargs: The kwargs passed.
"""
super().__setitem__(*args, **kargs)
self._reassign_field()
def __delitem__(self, *args, **kargs):
"""Delete item.
Args:
args: The args passed.
kargs: The kwargs passed.
"""
super().__delitem__(*args, **kargs)
self._reassign_field()
def clear(self, *args, **kargs):
"""Remove all item from the list.
Args:
args: The args passed.
kargs: The kwargs passed.
"""
super().clear(*args, **kargs)
self._reassign_field()
def extend(self, *args, **kargs):
"""Add all item of a list to the end of the list.
Args:
args: The args passed.
kargs: The kwargs passed.
"""
super().extend(*args, **kargs)
self._reassign_field()
def pop(self, *args, **kargs):
"""Remove an element.
Args:
args: The args passed.
kargs: The kwargs passed.
"""
super().pop(*args, **kargs)
self._reassign_field()
def remove(self, *args, **kargs):
"""Remove an element.
Args:
args: The args passed.
kargs: The kwargs passed.
"""
super().remove(*args, **kargs)
self._reassign_field()

View File

@ -1,9 +1,10 @@
from typing import Type
from typing import List, Tuple, Type
import pytest
from pynecone.app import App, DefaultState
from pynecone.components import Box
from pynecone.event import Event
from pynecone.middleware import HydrateMiddleware
from pynecone.state import State
from pynecone.style import Style
@ -156,3 +157,184 @@ def test_set_and_get_state(TestState: Type[State]):
state2 = app.state_manager.get_state(token2)
assert state1.var == 1 # type: ignore
assert state2.var == 2 # type: ignore
@pytest.fixture
def list_mutation_state():
"""A fixture to create a state with list mutation features.
Returns:
A state with list mutation features.
"""
class TestState(State):
"""The test state."""
# plain list
plain_friends = ["Tommy"]
def make_friend(self):
self.plain_friends.append("another-fd")
def change_first_friend(self):
self.plain_friends[0] = "Jenny"
def unfriend_all_friends(self):
self.plain_friends.clear()
def unfriend_first_friend(self):
del self.plain_friends[0]
def remove_last_friend(self):
self.plain_friends.pop()
def make_friends_with_colleagues(self):
colleagues = ["Peter", "Jimmy"]
self.plain_friends.extend(colleagues)
def remove_tommy(self):
self.plain_friends.remove("Tommy")
# list in dict
friends_in_dict = {"Tommy": ["Jenny"]}
def remove_jenny_from_tommy(self):
self.friends_in_dict["Tommy"].remove("Jenny")
def add_jimmy_to_tommy_friends(self):
self.friends_in_dict["Tommy"].append("Jimmy")
def tommy_has_no_fds(self):
self.friends_in_dict["Tommy"].clear()
# nested list
friends_in_nested_list = [["Tommy"], ["Jenny"]]
def remove_first_group(self):
self.friends_in_nested_list.pop(0)
def remove_first_person_from_first_group(self):
self.friends_in_nested_list[0].pop(0)
def add_jimmy_to_second_group(self):
self.friends_in_nested_list[1].append("Jimmy")
return TestState()
@pytest.mark.asyncio
@pytest.mark.parametrize(
"event_tuples",
[
pytest.param(
[
(
"test_state.make_friend",
{"test_state": {"plain_friends": ["Tommy", "another-fd"]}},
),
(
"test_state.change_first_friend",
{"test_state": {"plain_friends": ["Jenny", "another-fd"]}},
),
],
id="append then __setitem__",
),
pytest.param(
[
(
"test_state.unfriend_first_friend",
{"test_state": {"plain_friends": []}},
),
(
"test_state.make_friend",
{"test_state": {"plain_friends": ["another-fd"]}},
),
],
id="delitem then append",
),
pytest.param(
[
(
"test_state.make_friends_with_colleagues",
{"test_state": {"plain_friends": ["Tommy", "Peter", "Jimmy"]}},
),
(
"test_state.remove_tommy",
{"test_state": {"plain_friends": ["Peter", "Jimmy"]}},
),
(
"test_state.remove_last_friend",
{"test_state": {"plain_friends": ["Peter"]}},
),
(
"test_state.unfriend_all_friends",
{"test_state": {"plain_friends": []}},
),
],
id="extend, remove, pop, clear",
),
pytest.param(
[
(
"test_state.add_jimmy_to_second_group",
{
"test_state": {
"friends_in_nested_list": [["Tommy"], ["Jenny", "Jimmy"]]
}
},
),
(
"test_state.remove_first_person_from_first_group",
{
"test_state": {
"friends_in_nested_list": [[], ["Jenny", "Jimmy"]]
}
},
),
(
"test_state.remove_first_group",
{"test_state": {"friends_in_nested_list": [["Jenny", "Jimmy"]]}},
),
],
id="nested list",
),
pytest.param(
[
(
"test_state.add_jimmy_to_tommy_friends",
{"test_state": {"friends_in_dict": {"Tommy": ["Jenny", "Jimmy"]}}},
),
(
"test_state.remove_jenny_from_tommy",
{"test_state": {"friends_in_dict": {"Tommy": ["Jimmy"]}}},
),
(
"test_state.tommy_has_no_fds",
{"test_state": {"friends_in_dict": {"Tommy": []}}},
),
],
id="list in dict",
),
],
)
async def test_list_mutation_detection__plain_list(
event_tuples: List[Tuple[str, List[str]]], list_mutation_state: State
):
"""Test list mutation detection
when reassignment is not explicitly included in the logic.
Args:
event_tuples: From parametrization.
list_mutation_state: A state with list mutation features.
"""
for event_name, expected_delta in event_tuples:
result = await list_mutation_state.process(
Event(
token="fake-token",
name=event_name,
router_data={"pathname": "/", "query": {}},
payload={},
)
)
assert result.delta == expected_delta