List and Dict mutation on setattr (#1428)

This commit is contained in:
Elijah Ahianyo 2023-07-27 19:45:57 +00:00 committed by GitHub
parent 43220438b6
commit 3fa33bd644
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 78 additions and 10 deletions

View File

@ -595,6 +595,10 @@ class State(Base, ABC, extra=pydantic.Extra.allow):
self.mark_dirty() self.mark_dirty()
return return
# Make sure lists and dicts are converted to ReflexList and ReflexDict.
if name in self.vars and types._isinstance(value, Union[List, Dict]):
value = _convert_mutable_datatypes(value, self._reassign_field, name)
# Set the attribute. # Set the attribute.
super().__setattr__(name, value) super().__setattr__(name, value)
@ -990,21 +994,22 @@ def _convert_mutable_datatypes(
The converted field_value The converted field_value
""" """
if isinstance(field_value, list): if isinstance(field_value, list):
for index in range(len(field_value)): field_value = [
field_value[index] = _convert_mutable_datatypes( _convert_mutable_datatypes(value, reassign_field, field_name)
field_value[index], reassign_field, field_name for value in field_value
) ]
field_value = ReflexList( field_value = ReflexList(
field_value, reassign_field=reassign_field, field_name=field_name field_value, reassign_field=reassign_field, field_name=field_name
) )
if isinstance(field_value, dict): if isinstance(field_value, dict):
for key, value in field_value.items(): field_value = {
field_value[key] = _convert_mutable_datatypes( key: _convert_mutable_datatypes(value, reassign_field, field_name)
value, reassign_field, field_name for key, value in field_value.items()
) }
field_value = ReflexDict( field_value = ReflexDict(
field_value, reassign_field=reassign_field, field_name=field_name field_value, reassign_field=reassign_field, field_name=field_name
) )
return field_value return field_value

View File

@ -3,7 +3,7 @@ import contextlib
import os import os
import platform import platform
from pathlib import Path from pathlib import Path
from typing import Dict, Generator, List from typing import Dict, Generator, List, Union
import pytest import pytest
@ -538,3 +538,36 @@ def tmp_working_dir(tmp_path):
working_dir.mkdir() working_dir.mkdir()
with chdir(working_dir): with chdir(working_dir):
yield working_dir yield working_dir
@pytest.fixture
def mutable_state():
"""Create a Test state containing mutable types.
Returns:
A state object.
"""
class MutableTestState(rx.State):
"""A test state."""
array: List[Union[str, List, Dict[str, str]]] = [
"value",
[1, 2, 3],
{"key": "value"},
]
hashmap: Dict[str, Union[List, str, Dict[str, str]]] = {
"key": ["list", "of", "values"],
"another_key": "another_value",
"third_key": {"key": "value"},
}
def reassign_mutables(self):
self.array = ["modified_value", [1, 2, 3], {"mod_key": "mod_value"}]
self.hashmap = {
"mod_key": ["list", "of", "values"],
"mod_another_key": "another_value",
"mod_third_key": {"key": "value"},
}
return MutableTestState()

View File

@ -10,7 +10,7 @@ from reflex.constants import IS_HYDRATED, RouteVar
from reflex.event import Event, EventHandler from reflex.event import Event, EventHandler
from reflex.state import State from reflex.state import State
from reflex.utils import format from reflex.utils import format
from reflex.vars import BaseVar, ComputedVar from reflex.vars import BaseVar, ComputedVar, ReflexDict, ReflexList
class Object(Base): class Object(Base):
@ -1140,3 +1140,33 @@ def test_backend_method():
bms = BackendMethodState() bms = BackendMethodState()
bms.handler() bms.handler()
assert bms._be_method() assert bms._be_method()
def test_setattr_of_mutable_types(mutable_state):
"""Test that mutable types are converted to corresponding Reflex wrappers.
Args:
mutable_state: A test state.
"""
array = mutable_state.array
hashmap = mutable_state.hashmap
assert isinstance(array, ReflexList)
assert isinstance(array[1], ReflexList)
assert isinstance(array[2], ReflexDict)
assert isinstance(hashmap, ReflexDict)
assert isinstance(hashmap["key"], ReflexList)
assert isinstance(hashmap["third_key"], ReflexDict)
mutable_state.reassign_mutables()
array = mutable_state.array
hashmap = mutable_state.hashmap
assert isinstance(array, ReflexList)
assert isinstance(array[1], ReflexList)
assert isinstance(array[2], ReflexDict)
assert isinstance(hashmap, ReflexDict)
assert isinstance(hashmap["mod_key"], ReflexList)
assert isinstance(hashmap["mod_third_key"], ReflexDict)