diff --git a/reflex/state.py b/reflex/state.py index 1b5eda865..dfe9803c7 100644 --- a/reflex/state.py +++ b/reflex/state.py @@ -1313,3 +1313,22 @@ class MutableProxy(wrapt.ObjectProxy): super().__setattr__(name, value) return self._mark_dirty(super().__setattr__, args=(name, value)) + + def __copy__(self) -> Any: + """Return a copy of the proxy. + + Returns: + A copy of the wrapped object, unconnected to the proxy. + """ + return copy.copy(self.__wrapped__) + + def __deepcopy__(self, memo=None) -> Any: + """Return a deepcopy of the proxy. + + Args: + memo: The memo dict to use for the deepcopy. + + Returns: + A deepcopy of the wrapped object, unconnected to the proxy. + """ + return copy.deepcopy(self.__wrapped__, memo=memo) diff --git a/tests/test_state.py b/tests/test_state.py index 763d36a6d..cf59e5eb9 100644 --- a/tests/test_state.py +++ b/tests/test_state.py @@ -1,5 +1,6 @@ from __future__ import annotations +import copy import datetime import functools import sys @@ -1585,6 +1586,56 @@ def test_mutable_backend(mutable_state): assert_custom_dirty() +@pytest.mark.parametrize( + ("copy_func",), + [ + (copy.copy,), + (copy.deepcopy,), + ], +) +def test_mutable_copy(mutable_state, copy_func): + """Test that mutable types are copied correctly. + + Args: + mutable_state: A test state. + copy_func: A copy function. + """ + ms_copy = copy_func(mutable_state) + assert ms_copy is not mutable_state + for attr in ("array", "hashmap", "test_set", "custom"): + assert getattr(ms_copy, attr) == getattr(mutable_state, attr) + assert getattr(ms_copy, attr) is not getattr(mutable_state, attr) + ms_copy.custom.array.append(42) + assert "custom" in ms_copy.dirty_vars + if copy_func is copy.copy: + assert "custom" in mutable_state.dirty_vars + else: + assert not mutable_state.dirty_vars + + +@pytest.mark.parametrize( + ("copy_func",), + [ + (copy.copy,), + (copy.deepcopy,), + ], +) +def test_mutable_copy_vars(mutable_state, copy_func): + """Test that mutable types are copied correctly. + + Args: + mutable_state: A test state. + copy_func: A copy function. + """ + for attr in ("array", "hashmap", "test_set", "custom"): + var_orig = getattr(mutable_state, attr) + var_copy = copy_func(var_orig) + assert var_orig is not var_copy + assert var_orig == var_copy + # copied vars should never be proxies, as they by definition are no longer attached to the state. + assert not isinstance(var_copy, MutableProxy) + + def test_duplicate_substate_class(duplicate_substate): with pytest.raises(ValueError): duplicate_substate()