state: implement __copy__ and __deepcopy__ for MutableProxy (#1845)

This commit is contained in:
Masen Furer 2023-09-20 16:46:49 -07:00 committed by GitHub
parent 3113aecb30
commit 1bfb579b20
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 70 additions and 0 deletions

View File

@ -1313,3 +1313,22 @@ class MutableProxy(wrapt.ObjectProxy):
super().__setattr__(name, value)
return
self._mark_dirty(super().__setattr__, args=(name, value))
def __copy__(self) -> Any:
"""Return a copy of the proxy.
Returns:
A copy of the wrapped object, unconnected to the proxy.
"""
return copy.copy(self.__wrapped__)
def __deepcopy__(self, memo=None) -> Any:
"""Return a deepcopy of the proxy.
Args:
memo: The memo dict to use for the deepcopy.
Returns:
A deepcopy of the wrapped object, unconnected to the proxy.
"""
return copy.deepcopy(self.__wrapped__, memo=memo)

View File

@ -1,5 +1,6 @@
from __future__ import annotations
import copy
import datetime
import functools
import sys
@ -1585,6 +1586,56 @@ def test_mutable_backend(mutable_state):
assert_custom_dirty()
@pytest.mark.parametrize(
("copy_func",),
[
(copy.copy,),
(copy.deepcopy,),
],
)
def test_mutable_copy(mutable_state, copy_func):
"""Test that mutable types are copied correctly.
Args:
mutable_state: A test state.
copy_func: A copy function.
"""
ms_copy = copy_func(mutable_state)
assert ms_copy is not mutable_state
for attr in ("array", "hashmap", "test_set", "custom"):
assert getattr(ms_copy, attr) == getattr(mutable_state, attr)
assert getattr(ms_copy, attr) is not getattr(mutable_state, attr)
ms_copy.custom.array.append(42)
assert "custom" in ms_copy.dirty_vars
if copy_func is copy.copy:
assert "custom" in mutable_state.dirty_vars
else:
assert not mutable_state.dirty_vars
@pytest.mark.parametrize(
("copy_func",),
[
(copy.copy,),
(copy.deepcopy,),
],
)
def test_mutable_copy_vars(mutable_state, copy_func):
"""Test that mutable types are copied correctly.
Args:
mutable_state: A test state.
copy_func: A copy function.
"""
for attr in ("array", "hashmap", "test_set", "custom"):
var_orig = getattr(mutable_state, attr)
var_copy = copy_func(var_orig)
assert var_orig is not var_copy
assert var_orig == var_copy
# copied vars should never be proxies, as they by definition are no longer attached to the state.
assert not isinstance(var_copy, MutableProxy)
def test_duplicate_substate_class(duplicate_substate):
with pytest.raises(ValueError):
duplicate_substate()