Added support for RelexSet
wrapper (#1535)
This commit is contained in:
parent
cd47815a4d
commit
ef78465f16
@ -34,7 +34,7 @@ from reflex import constants
|
||||
from reflex.base import Base
|
||||
from reflex.event import Event, EventHandler, EventSpec, fix_events, window_alert
|
||||
from reflex.utils import format, prerequisites, types
|
||||
from reflex.vars import BaseVar, ComputedVar, ReflexDict, ReflexList, Var
|
||||
from reflex.vars import BaseVar, ComputedVar, ReflexDict, ReflexList, ReflexSet, Var
|
||||
|
||||
Delta = Dict[str, Any]
|
||||
|
||||
@ -601,8 +601,8 @@ class State(Base, ABC, extra=pydantic.Extra.allow):
|
||||
self.mark_dirty()
|
||||
return
|
||||
|
||||
# Make sure lists and dicts are converted to ReflexList and ReflexDict.
|
||||
if name in self.vars and types._isinstance(value, Union[List, Dict]):
|
||||
# Make sure lists and dicts are converted to ReflexList, ReflexDict and ReflexSet.
|
||||
if name in self.vars and types._isinstance(value, Union[List, Dict, Set]):
|
||||
value = _convert_mutable_datatypes(value, self._reassign_field, name)
|
||||
|
||||
# Set the attribute.
|
||||
@ -985,7 +985,7 @@ def _convert_mutable_datatypes(
|
||||
) -> Any:
|
||||
"""Recursively convert mutable data to the Rx data types.
|
||||
|
||||
Note: right now only list & dict would be handled recursively.
|
||||
Note: right now only list, dict and set would be handled recursively.
|
||||
|
||||
Args:
|
||||
field_value: The target field_value.
|
||||
@ -1015,4 +1015,14 @@ def _convert_mutable_datatypes(
|
||||
field_value, reassign_field=reassign_field, field_name=field_name
|
||||
)
|
||||
|
||||
if isinstance(field_value, set):
|
||||
field_value = [
|
||||
_convert_mutable_datatypes(value, reassign_field, field_name)
|
||||
for value in field_value
|
||||
]
|
||||
|
||||
field_value = ReflexSet(
|
||||
field_value, reassign_field=reassign_field, field_name=field_name
|
||||
)
|
||||
|
||||
return field_value
|
||||
|
@ -1130,6 +1130,91 @@ class ReflexDict(dict):
|
||||
self._reassign_field()
|
||||
|
||||
|
||||
class ReflexSet(set):
|
||||
"""A custom set that reflex can detect its mutation."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
original_set: Set,
|
||||
reassign_field: Callable = lambda _field_name: None,
|
||||
field_name: str = "",
|
||||
):
|
||||
"""Initialize ReflexSet.
|
||||
|
||||
Args:
|
||||
original_set (Set): The original set
|
||||
reassign_field (Callable):
|
||||
The method in the parent state to reassign the field.
|
||||
Default to be a no-op function
|
||||
field_name (str): the name of field in the parent state
|
||||
"""
|
||||
self._reassign_field = lambda: reassign_field(field_name)
|
||||
|
||||
super().__init__(original_set)
|
||||
|
||||
def add(self, *args, **kwargs):
|
||||
"""Add an element to set.
|
||||
|
||||
Args:
|
||||
args: The args passed.
|
||||
kwargs: The kwargs passed.
|
||||
"""
|
||||
super().add(*args, **kwargs)
|
||||
self._reassign_field()
|
||||
|
||||
def remove(self, *args, **kwargs):
|
||||
"""Remove an element.
|
||||
Raise key error if element not found.
|
||||
|
||||
Args:
|
||||
args: The args passed.
|
||||
kwargs: The kwargs passed.
|
||||
"""
|
||||
super().remove(*args, **kwargs)
|
||||
self._reassign_field()
|
||||
|
||||
def discard(self, *args, **kwargs):
|
||||
"""Remove an element.
|
||||
Does not raise key error if element not found.
|
||||
|
||||
Args:
|
||||
args: The args passed.
|
||||
kwargs: The kwargs passed.
|
||||
"""
|
||||
super().discard(*args, **kwargs)
|
||||
self._reassign_field()
|
||||
|
||||
def pop(self, *args, **kwargs):
|
||||
"""Remove an element.
|
||||
|
||||
Args:
|
||||
args: The args passed.
|
||||
kwargs: The kwargs passed.
|
||||
"""
|
||||
super().pop(*args, **kwargs)
|
||||
self._reassign_field()
|
||||
|
||||
def clear(self, *args, **kwargs):
|
||||
"""Remove all elements from the set.
|
||||
|
||||
Args:
|
||||
args: The args passed.
|
||||
kwargs: The kwargs passed.
|
||||
"""
|
||||
super().clear(*args, **kwargs)
|
||||
self._reassign_field()
|
||||
|
||||
def update(self, *args, **kwargs):
|
||||
"""Adds elements from an iterable to the set.
|
||||
|
||||
Args:
|
||||
args: The args passed.
|
||||
kwargs: The kwargs passed.
|
||||
"""
|
||||
super().update(*args, **kwargs)
|
||||
self._reassign_field()
|
||||
|
||||
|
||||
class ImportVar(Base):
|
||||
"""An import var."""
|
||||
|
||||
|
@ -3,7 +3,7 @@ import contextlib
|
||||
import os
|
||||
import platform
|
||||
from pathlib import Path
|
||||
from typing import Dict, Generator, List, Union
|
||||
from typing import Dict, Generator, List, Set, Union
|
||||
|
||||
import pytest
|
||||
|
||||
@ -561,6 +561,7 @@ def mutable_state():
|
||||
"another_key": "another_value",
|
||||
"third_key": {"key": "value"},
|
||||
}
|
||||
test_set: Set[Union[str, int]] = {1, 2, 3, 4, "five"}
|
||||
|
||||
def reassign_mutables(self):
|
||||
self.array = ["modified_value", [1, 2, 3], {"mod_key": "mod_value"}]
|
||||
@ -569,5 +570,6 @@ def mutable_state():
|
||||
"mod_another_key": "another_value",
|
||||
"mod_third_key": {"key": "value"},
|
||||
}
|
||||
self.test_set = {1, 2, 3, 4, "five"}
|
||||
|
||||
return MutableTestState()
|
||||
|
@ -10,7 +10,7 @@ from reflex.constants import IS_HYDRATED, RouteVar
|
||||
from reflex.event import Event, EventHandler
|
||||
from reflex.state import State
|
||||
from reflex.utils import format
|
||||
from reflex.vars import BaseVar, ComputedVar, ReflexDict, ReflexList
|
||||
from reflex.vars import BaseVar, ComputedVar, ReflexDict, ReflexList, ReflexSet
|
||||
|
||||
|
||||
class Object(Base):
|
||||
@ -1164,6 +1164,7 @@ def test_setattr_of_mutable_types(mutable_state):
|
||||
"""
|
||||
array = mutable_state.array
|
||||
hashmap = mutable_state.hashmap
|
||||
test_set = mutable_state.test_set
|
||||
|
||||
assert isinstance(array, ReflexList)
|
||||
assert isinstance(array[1], ReflexList)
|
||||
@ -1173,10 +1174,14 @@ def test_setattr_of_mutable_types(mutable_state):
|
||||
assert isinstance(hashmap["key"], ReflexList)
|
||||
assert isinstance(hashmap["third_key"], ReflexDict)
|
||||
|
||||
assert isinstance(test_set, set)
|
||||
|
||||
mutable_state.reassign_mutables()
|
||||
|
||||
array = mutable_state.array
|
||||
hashmap = mutable_state.hashmap
|
||||
test_set = mutable_state.test_set
|
||||
|
||||
assert isinstance(array, ReflexList)
|
||||
assert isinstance(array[1], ReflexList)
|
||||
assert isinstance(array[2], ReflexDict)
|
||||
@ -1184,3 +1189,5 @@ def test_setattr_of_mutable_types(mutable_state):
|
||||
assert isinstance(hashmap, ReflexDict)
|
||||
assert isinstance(hashmap["mod_key"], ReflexList)
|
||||
assert isinstance(hashmap["mod_third_key"], ReflexDict)
|
||||
|
||||
assert isinstance(test_set, ReflexSet)
|
||||
|
@ -13,6 +13,7 @@ from reflex.vars import (
|
||||
ImportVar,
|
||||
ReflexDict,
|
||||
ReflexList,
|
||||
ReflexSet,
|
||||
Var,
|
||||
get_local_storage,
|
||||
)
|
||||
@ -532,6 +533,16 @@ def test_pickleable_rx_dict():
|
||||
assert cloudpickle.loads(pickled_dict) == rx_dict
|
||||
|
||||
|
||||
def test_pickleable_rx_set():
|
||||
"""Test that ReflexSet is pickleable."""
|
||||
rx_set = ReflexSet(
|
||||
original_set={1, 2, 3}, reassign_field=lambda x: x, field_name="random"
|
||||
)
|
||||
|
||||
pickled_set = cloudpickle.dumps(rx_set)
|
||||
assert cloudpickle.loads(pickled_set) == rx_set
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"import_var,expected",
|
||||
zip(
|
||||
|
Loading…
Reference in New Issue
Block a user