Added support for RelexSet wrapper (#1535)

This commit is contained in:
Smit Parmar 2023-08-10 22:29:03 +05:30 committed by GitHub
parent cd47815a4d
commit ef78465f16
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 121 additions and 6 deletions

View File

@ -34,7 +34,7 @@ from reflex import constants
from reflex.base import Base
from reflex.event import Event, EventHandler, EventSpec, fix_events, window_alert
from reflex.utils import format, prerequisites, types
from reflex.vars import BaseVar, ComputedVar, ReflexDict, ReflexList, Var
from reflex.vars import BaseVar, ComputedVar, ReflexDict, ReflexList, ReflexSet, Var
Delta = Dict[str, Any]
@ -601,8 +601,8 @@ class State(Base, ABC, extra=pydantic.Extra.allow):
self.mark_dirty()
return
# Make sure lists and dicts are converted to ReflexList and ReflexDict.
if name in self.vars and types._isinstance(value, Union[List, Dict]):
# Make sure lists and dicts are converted to ReflexList, ReflexDict and ReflexSet.
if name in self.vars and types._isinstance(value, Union[List, Dict, Set]):
value = _convert_mutable_datatypes(value, self._reassign_field, name)
# Set the attribute.
@ -985,7 +985,7 @@ def _convert_mutable_datatypes(
) -> Any:
"""Recursively convert mutable data to the Rx data types.
Note: right now only list & dict would be handled recursively.
Note: right now only list, dict and set would be handled recursively.
Args:
field_value: The target field_value.
@ -1015,4 +1015,14 @@ def _convert_mutable_datatypes(
field_value, reassign_field=reassign_field, field_name=field_name
)
if isinstance(field_value, set):
field_value = [
_convert_mutable_datatypes(value, reassign_field, field_name)
for value in field_value
]
field_value = ReflexSet(
field_value, reassign_field=reassign_field, field_name=field_name
)
return field_value

View File

@ -1130,6 +1130,91 @@ class ReflexDict(dict):
self._reassign_field()
class ReflexSet(set):
"""A custom set that reflex can detect its mutation."""
def __init__(
self,
original_set: Set,
reassign_field: Callable = lambda _field_name: None,
field_name: str = "",
):
"""Initialize ReflexSet.
Args:
original_set (Set): The original set
reassign_field (Callable):
The method in the parent state to reassign the field.
Default to be a no-op function
field_name (str): the name of field in the parent state
"""
self._reassign_field = lambda: reassign_field(field_name)
super().__init__(original_set)
def add(self, *args, **kwargs):
"""Add an element to set.
Args:
args: The args passed.
kwargs: The kwargs passed.
"""
super().add(*args, **kwargs)
self._reassign_field()
def remove(self, *args, **kwargs):
"""Remove an element.
Raise key error if element not found.
Args:
args: The args passed.
kwargs: The kwargs passed.
"""
super().remove(*args, **kwargs)
self._reassign_field()
def discard(self, *args, **kwargs):
"""Remove an element.
Does not raise key error if element not found.
Args:
args: The args passed.
kwargs: The kwargs passed.
"""
super().discard(*args, **kwargs)
self._reassign_field()
def pop(self, *args, **kwargs):
"""Remove an element.
Args:
args: The args passed.
kwargs: The kwargs passed.
"""
super().pop(*args, **kwargs)
self._reassign_field()
def clear(self, *args, **kwargs):
"""Remove all elements from the set.
Args:
args: The args passed.
kwargs: The kwargs passed.
"""
super().clear(*args, **kwargs)
self._reassign_field()
def update(self, *args, **kwargs):
"""Adds elements from an iterable to the set.
Args:
args: The args passed.
kwargs: The kwargs passed.
"""
super().update(*args, **kwargs)
self._reassign_field()
class ImportVar(Base):
"""An import var."""

View File

@ -3,7 +3,7 @@ import contextlib
import os
import platform
from pathlib import Path
from typing import Dict, Generator, List, Union
from typing import Dict, Generator, List, Set, Union
import pytest
@ -561,6 +561,7 @@ def mutable_state():
"another_key": "another_value",
"third_key": {"key": "value"},
}
test_set: Set[Union[str, int]] = {1, 2, 3, 4, "five"}
def reassign_mutables(self):
self.array = ["modified_value", [1, 2, 3], {"mod_key": "mod_value"}]
@ -569,5 +570,6 @@ def mutable_state():
"mod_another_key": "another_value",
"mod_third_key": {"key": "value"},
}
self.test_set = {1, 2, 3, 4, "five"}
return MutableTestState()

View File

@ -10,7 +10,7 @@ from reflex.constants import IS_HYDRATED, RouteVar
from reflex.event import Event, EventHandler
from reflex.state import State
from reflex.utils import format
from reflex.vars import BaseVar, ComputedVar, ReflexDict, ReflexList
from reflex.vars import BaseVar, ComputedVar, ReflexDict, ReflexList, ReflexSet
class Object(Base):
@ -1164,6 +1164,7 @@ def test_setattr_of_mutable_types(mutable_state):
"""
array = mutable_state.array
hashmap = mutable_state.hashmap
test_set = mutable_state.test_set
assert isinstance(array, ReflexList)
assert isinstance(array[1], ReflexList)
@ -1173,10 +1174,14 @@ def test_setattr_of_mutable_types(mutable_state):
assert isinstance(hashmap["key"], ReflexList)
assert isinstance(hashmap["third_key"], ReflexDict)
assert isinstance(test_set, set)
mutable_state.reassign_mutables()
array = mutable_state.array
hashmap = mutable_state.hashmap
test_set = mutable_state.test_set
assert isinstance(array, ReflexList)
assert isinstance(array[1], ReflexList)
assert isinstance(array[2], ReflexDict)
@ -1184,3 +1189,5 @@ def test_setattr_of_mutable_types(mutable_state):
assert isinstance(hashmap, ReflexDict)
assert isinstance(hashmap["mod_key"], ReflexList)
assert isinstance(hashmap["mod_third_key"], ReflexDict)
assert isinstance(test_set, ReflexSet)

View File

@ -13,6 +13,7 @@ from reflex.vars import (
ImportVar,
ReflexDict,
ReflexList,
ReflexSet,
Var,
get_local_storage,
)
@ -532,6 +533,16 @@ def test_pickleable_rx_dict():
assert cloudpickle.loads(pickled_dict) == rx_dict
def test_pickleable_rx_set():
"""Test that ReflexSet is pickleable."""
rx_set = ReflexSet(
original_set={1, 2, 3}, reassign_field=lambda x: x, field_name="random"
)
pickled_set = cloudpickle.dumps(rx_set)
assert cloudpickle.loads(pickled_set) == rx_set
@pytest.mark.parametrize(
"import_var,expected",
zip(