State.reset uses deepcopy on defaults (#1889)

This commit is contained in:
Masen Furer 2023-09-29 16:33:16 -07:00 committed by GitHub
parent 5ca7f29853
commit 4f6b3c049b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 32 additions and 2 deletions

View File

@ -679,7 +679,7 @@ class State(Base, ABC, extra=pydantic.Extra.allow):
# Reset the base vars.
fields = self.get_fields()
for prop_name in self.base_vars:
setattr(self, prop_name, fields[prop_name].default)
setattr(self, prop_name, copy.deepcopy(fields[prop_name].default))
# Recursively reset the substates.
for substate in self.substates.values():
@ -696,7 +696,7 @@ class State(Base, ABC, extra=pydantic.Extra.allow):
isinstance(field.type_, type)
and issubclass(field.type_, ClientStorageBase)
):
setattr(self, prop_name, field.default)
setattr(self, prop_name, copy.deepcopy(field.default))
# Recursively reset the substate client storage.
for substate in self.substates.values():

View File

@ -2089,3 +2089,33 @@ def test_mutable_copy_vars(mutable_state, copy_func):
def test_duplicate_substate_class(duplicate_substate):
with pytest.raises(ValueError):
duplicate_substate()
def test_reset_with_mutables():
"""Calling reset should always reset fields to a copy of the defaults."""
default = [[0, 0], [0, 1], [1, 1]]
copied_default = copy.deepcopy(default)
class MutableResetState(State):
items: List[List[int]] = default
instance = MutableResetState()
assert instance.items.__wrapped__ is not default # type: ignore
assert instance.items == default == copied_default
instance.items.append([3, 3])
assert instance.items != default
assert instance.items != copied_default
instance.reset()
assert instance.items.__wrapped__ is not default # type: ignore
assert instance.items == default == copied_default
instance.items.append([3, 3])
assert instance.items != default
assert instance.items != copied_default
instance.reset()
assert instance.items.__wrapped__ is not default # type: ignore
assert instance.items == default == copied_default
instance.items.append([3, 3])
assert instance.items != default
assert instance.items != copied_default