diff --git a/reflex/state.py b/reflex/state.py index df7400114..571a6931d 100644 --- a/reflex/state.py +++ b/reflex/state.py @@ -34,7 +34,7 @@ from reflex import constants from reflex.base import Base from reflex.event import Event, EventHandler, EventSpec, fix_events, window_alert from reflex.utils import format, prerequisites, types -from reflex.vars import BaseVar, ComputedVar, ReflexDict, ReflexList, Var +from reflex.vars import BaseVar, ComputedVar, ReflexDict, ReflexList, ReflexSet, Var Delta = Dict[str, Any] @@ -601,8 +601,8 @@ class State(Base, ABC, extra=pydantic.Extra.allow): self.mark_dirty() return - # Make sure lists and dicts are converted to ReflexList and ReflexDict. - if name in self.vars and types._isinstance(value, Union[List, Dict]): + # Make sure lists and dicts are converted to ReflexList, ReflexDict and ReflexSet. + if name in self.vars and types._isinstance(value, Union[List, Dict, Set]): value = _convert_mutable_datatypes(value, self._reassign_field, name) # Set the attribute. @@ -985,7 +985,7 @@ def _convert_mutable_datatypes( ) -> Any: """Recursively convert mutable data to the Rx data types. - Note: right now only list & dict would be handled recursively. + Note: right now only list, dict and set would be handled recursively. Args: field_value: The target field_value. @@ -1015,4 +1015,14 @@ def _convert_mutable_datatypes( field_value, reassign_field=reassign_field, field_name=field_name ) + if isinstance(field_value, set): + field_value = [ + _convert_mutable_datatypes(value, reassign_field, field_name) + for value in field_value + ] + + field_value = ReflexSet( + field_value, reassign_field=reassign_field, field_name=field_name + ) + return field_value diff --git a/reflex/vars.py b/reflex/vars.py index 5898fca1a..a4d57b479 100644 --- a/reflex/vars.py +++ b/reflex/vars.py @@ -1130,6 +1130,91 @@ class ReflexDict(dict): self._reassign_field() +class ReflexSet(set): + """A custom set that reflex can detect its mutation.""" + + def __init__( + self, + original_set: Set, + reassign_field: Callable = lambda _field_name: None, + field_name: str = "", + ): + """Initialize ReflexSet. + + Args: + original_set (Set): The original set + reassign_field (Callable): + The method in the parent state to reassign the field. + Default to be a no-op function + field_name (str): the name of field in the parent state + """ + self._reassign_field = lambda: reassign_field(field_name) + + super().__init__(original_set) + + def add(self, *args, **kwargs): + """Add an element to set. + + Args: + args: The args passed. + kwargs: The kwargs passed. + """ + super().add(*args, **kwargs) + self._reassign_field() + + def remove(self, *args, **kwargs): + """Remove an element. + Raise key error if element not found. + + Args: + args: The args passed. + kwargs: The kwargs passed. + """ + super().remove(*args, **kwargs) + self._reassign_field() + + def discard(self, *args, **kwargs): + """Remove an element. + Does not raise key error if element not found. + + Args: + args: The args passed. + kwargs: The kwargs passed. + """ + super().discard(*args, **kwargs) + self._reassign_field() + + def pop(self, *args, **kwargs): + """Remove an element. + + Args: + args: The args passed. + kwargs: The kwargs passed. + """ + super().pop(*args, **kwargs) + self._reassign_field() + + def clear(self, *args, **kwargs): + """Remove all elements from the set. + + Args: + args: The args passed. + kwargs: The kwargs passed. + """ + super().clear(*args, **kwargs) + self._reassign_field() + + def update(self, *args, **kwargs): + """Adds elements from an iterable to the set. + + Args: + args: The args passed. + kwargs: The kwargs passed. + """ + super().update(*args, **kwargs) + self._reassign_field() + + class ImportVar(Base): """An import var.""" diff --git a/tests/conftest.py b/tests/conftest.py index e0c45efe3..7faa0c689 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -3,7 +3,7 @@ import contextlib import os import platform from pathlib import Path -from typing import Dict, Generator, List, Union +from typing import Dict, Generator, List, Set, Union import pytest @@ -561,6 +561,7 @@ def mutable_state(): "another_key": "another_value", "third_key": {"key": "value"}, } + test_set: Set[Union[str, int]] = {1, 2, 3, 4, "five"} def reassign_mutables(self): self.array = ["modified_value", [1, 2, 3], {"mod_key": "mod_value"}] @@ -569,5 +570,6 @@ def mutable_state(): "mod_another_key": "another_value", "mod_third_key": {"key": "value"}, } + self.test_set = {1, 2, 3, 4, "five"} return MutableTestState() diff --git a/tests/test_state.py b/tests/test_state.py index c4dab434f..352ba6156 100644 --- a/tests/test_state.py +++ b/tests/test_state.py @@ -10,7 +10,7 @@ from reflex.constants import IS_HYDRATED, RouteVar from reflex.event import Event, EventHandler from reflex.state import State from reflex.utils import format -from reflex.vars import BaseVar, ComputedVar, ReflexDict, ReflexList +from reflex.vars import BaseVar, ComputedVar, ReflexDict, ReflexList, ReflexSet class Object(Base): @@ -1164,6 +1164,7 @@ def test_setattr_of_mutable_types(mutable_state): """ array = mutable_state.array hashmap = mutable_state.hashmap + test_set = mutable_state.test_set assert isinstance(array, ReflexList) assert isinstance(array[1], ReflexList) @@ -1173,10 +1174,14 @@ def test_setattr_of_mutable_types(mutable_state): assert isinstance(hashmap["key"], ReflexList) assert isinstance(hashmap["third_key"], ReflexDict) + assert isinstance(test_set, set) + mutable_state.reassign_mutables() array = mutable_state.array hashmap = mutable_state.hashmap + test_set = mutable_state.test_set + assert isinstance(array, ReflexList) assert isinstance(array[1], ReflexList) assert isinstance(array[2], ReflexDict) @@ -1184,3 +1189,5 @@ def test_setattr_of_mutable_types(mutable_state): assert isinstance(hashmap, ReflexDict) assert isinstance(hashmap["mod_key"], ReflexList) assert isinstance(hashmap["mod_third_key"], ReflexDict) + + assert isinstance(test_set, ReflexSet) diff --git a/tests/test_var.py b/tests/test_var.py index 9dd0784ef..beb661f3c 100644 --- a/tests/test_var.py +++ b/tests/test_var.py @@ -13,6 +13,7 @@ from reflex.vars import ( ImportVar, ReflexDict, ReflexList, + ReflexSet, Var, get_local_storage, ) @@ -532,6 +533,16 @@ def test_pickleable_rx_dict(): assert cloudpickle.loads(pickled_dict) == rx_dict +def test_pickleable_rx_set(): + """Test that ReflexSet is pickleable.""" + rx_set = ReflexSet( + original_set={1, 2, 3}, reassign_field=lambda x: x, field_name="random" + ) + + pickled_set = cloudpickle.dumps(rx_set) + assert cloudpickle.loads(pickled_set) == rx_set + + @pytest.mark.parametrize( "import_var,expected", zip(