State.reset uses deepcopy on defaults (#1889)
This commit is contained in:
parent
5ca7f29853
commit
4f6b3c049b
@ -679,7 +679,7 @@ class State(Base, ABC, extra=pydantic.Extra.allow):
|
||||
# Reset the base vars.
|
||||
fields = self.get_fields()
|
||||
for prop_name in self.base_vars:
|
||||
setattr(self, prop_name, fields[prop_name].default)
|
||||
setattr(self, prop_name, copy.deepcopy(fields[prop_name].default))
|
||||
|
||||
# Recursively reset the substates.
|
||||
for substate in self.substates.values():
|
||||
@ -696,7 +696,7 @@ class State(Base, ABC, extra=pydantic.Extra.allow):
|
||||
isinstance(field.type_, type)
|
||||
and issubclass(field.type_, ClientStorageBase)
|
||||
):
|
||||
setattr(self, prop_name, field.default)
|
||||
setattr(self, prop_name, copy.deepcopy(field.default))
|
||||
|
||||
# Recursively reset the substate client storage.
|
||||
for substate in self.substates.values():
|
||||
|
@ -2089,3 +2089,33 @@ def test_mutable_copy_vars(mutable_state, copy_func):
|
||||
def test_duplicate_substate_class(duplicate_substate):
|
||||
with pytest.raises(ValueError):
|
||||
duplicate_substate()
|
||||
|
||||
|
||||
def test_reset_with_mutables():
|
||||
"""Calling reset should always reset fields to a copy of the defaults."""
|
||||
default = [[0, 0], [0, 1], [1, 1]]
|
||||
copied_default = copy.deepcopy(default)
|
||||
|
||||
class MutableResetState(State):
|
||||
items: List[List[int]] = default
|
||||
|
||||
instance = MutableResetState()
|
||||
assert instance.items.__wrapped__ is not default # type: ignore
|
||||
assert instance.items == default == copied_default
|
||||
instance.items.append([3, 3])
|
||||
assert instance.items != default
|
||||
assert instance.items != copied_default
|
||||
|
||||
instance.reset()
|
||||
assert instance.items.__wrapped__ is not default # type: ignore
|
||||
assert instance.items == default == copied_default
|
||||
instance.items.append([3, 3])
|
||||
assert instance.items != default
|
||||
assert instance.items != copied_default
|
||||
|
||||
instance.reset()
|
||||
assert instance.items.__wrapped__ is not default # type: ignore
|
||||
assert instance.items == default == copied_default
|
||||
instance.items.append([3, 3])
|
||||
assert instance.items != default
|
||||
assert instance.items != copied_default
|
||||
|
Loading…
Reference in New Issue
Block a user