diff --git a/reflex/state.py b/reflex/state.py index 37613fae8..48a77639c 100644 --- a/reflex/state.py +++ b/reflex/state.py @@ -595,6 +595,10 @@ class State(Base, ABC, extra=pydantic.Extra.allow): self.mark_dirty() return + # Make sure lists and dicts are converted to ReflexList and ReflexDict. + if name in self.vars and types._isinstance(value, Union[List, Dict]): + value = _convert_mutable_datatypes(value, self._reassign_field, name) + # Set the attribute. super().__setattr__(name, value) @@ -990,21 +994,22 @@ def _convert_mutable_datatypes( The converted field_value """ if isinstance(field_value, list): - for index in range(len(field_value)): - field_value[index] = _convert_mutable_datatypes( - field_value[index], reassign_field, field_name - ) + field_value = [ + _convert_mutable_datatypes(value, reassign_field, field_name) + for value in field_value + ] field_value = ReflexList( field_value, reassign_field=reassign_field, field_name=field_name ) if isinstance(field_value, dict): - for key, value in field_value.items(): - field_value[key] = _convert_mutable_datatypes( - value, reassign_field, field_name - ) + field_value = { + key: _convert_mutable_datatypes(value, reassign_field, field_name) + for key, value in field_value.items() + } field_value = ReflexDict( field_value, reassign_field=reassign_field, field_name=field_name ) + return field_value diff --git a/tests/conftest.py b/tests/conftest.py index 807860ac6..e0c45efe3 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -3,7 +3,7 @@ import contextlib import os import platform from pathlib import Path -from typing import Dict, Generator, List +from typing import Dict, Generator, List, Union import pytest @@ -538,3 +538,36 @@ def tmp_working_dir(tmp_path): working_dir.mkdir() with chdir(working_dir): yield working_dir + + +@pytest.fixture +def mutable_state(): + """Create a Test state containing mutable types. + + Returns: + A state object. + """ + + class MutableTestState(rx.State): + """A test state.""" + + array: List[Union[str, List, Dict[str, str]]] = [ + "value", + [1, 2, 3], + {"key": "value"}, + ] + hashmap: Dict[str, Union[List, str, Dict[str, str]]] = { + "key": ["list", "of", "values"], + "another_key": "another_value", + "third_key": {"key": "value"}, + } + + def reassign_mutables(self): + self.array = ["modified_value", [1, 2, 3], {"mod_key": "mod_value"}] + self.hashmap = { + "mod_key": ["list", "of", "values"], + "mod_another_key": "another_value", + "mod_third_key": {"key": "value"}, + } + + return MutableTestState() diff --git a/tests/test_state.py b/tests/test_state.py index e6332dfcd..12da3635f 100644 --- a/tests/test_state.py +++ b/tests/test_state.py @@ -10,7 +10,7 @@ from reflex.constants import IS_HYDRATED, RouteVar from reflex.event import Event, EventHandler from reflex.state import State from reflex.utils import format -from reflex.vars import BaseVar, ComputedVar +from reflex.vars import BaseVar, ComputedVar, ReflexDict, ReflexList class Object(Base): @@ -1140,3 +1140,33 @@ def test_backend_method(): bms = BackendMethodState() bms.handler() assert bms._be_method() + + +def test_setattr_of_mutable_types(mutable_state): + """Test that mutable types are converted to corresponding Reflex wrappers. + + Args: + mutable_state: A test state. + """ + array = mutable_state.array + hashmap = mutable_state.hashmap + + assert isinstance(array, ReflexList) + assert isinstance(array[1], ReflexList) + assert isinstance(array[2], ReflexDict) + + assert isinstance(hashmap, ReflexDict) + assert isinstance(hashmap["key"], ReflexList) + assert isinstance(hashmap["third_key"], ReflexDict) + + mutable_state.reassign_mutables() + + array = mutable_state.array + hashmap = mutable_state.hashmap + assert isinstance(array, ReflexList) + assert isinstance(array[1], ReflexList) + assert isinstance(array[2], ReflexDict) + + assert isinstance(hashmap, ReflexDict) + assert isinstance(hashmap["mod_key"], ReflexList) + assert isinstance(hashmap["mod_third_key"], ReflexDict)