MutableProxy wraps values yielded by __iter__ (#1876)
This commit is contained in:
parent
bd0cd18796
commit
5ca7f29853
@ -1647,6 +1647,13 @@ class MutableProxy(wrapt.ObjectProxy):
|
||||
"update",
|
||||
]
|
||||
)
|
||||
# Methods on wrapped objects might return mutable objects that should be tracked.
|
||||
__wrap_mutable_attrs__ = set(
|
||||
[
|
||||
"get",
|
||||
"setdefault",
|
||||
]
|
||||
)
|
||||
|
||||
__mutable_types__ = (list, dict, set, Base)
|
||||
|
||||
@ -1663,7 +1670,13 @@ class MutableProxy(wrapt.ObjectProxy):
|
||||
self._self_state = state
|
||||
self._self_field_name = field_name
|
||||
|
||||
def _mark_dirty(self, wrapped=None, instance=None, args=tuple(), kwargs=None):
|
||||
def _mark_dirty(
|
||||
self,
|
||||
wrapped=None,
|
||||
instance=None,
|
||||
args=tuple(),
|
||||
kwargs=None,
|
||||
) -> Any:
|
||||
"""Mark the state as dirty, then call a wrapped function.
|
||||
|
||||
Intended for use with `FunctionWrapper` from the `wrapt` library.
|
||||
@ -1673,11 +1686,47 @@ class MutableProxy(wrapt.ObjectProxy):
|
||||
instance: The instance of the wrapped function.
|
||||
args: The args for the wrapped function.
|
||||
kwargs: The kwargs for the wrapped function.
|
||||
|
||||
Returns:
|
||||
The result of the wrapped function.
|
||||
"""
|
||||
self._self_state.dirty_vars.add(self._self_field_name)
|
||||
self._self_state._mark_dirty()
|
||||
if wrapped is not None:
|
||||
wrapped(*args, **(kwargs or {}))
|
||||
return wrapped(*args, **(kwargs or {}))
|
||||
|
||||
def _wrap_recursive(self, value: Any) -> Any:
|
||||
"""Wrap a value recursively if it is mutable.
|
||||
|
||||
Args:
|
||||
value: The value to wrap.
|
||||
|
||||
Returns:
|
||||
The wrapped value.
|
||||
"""
|
||||
if isinstance(value, self.__mutable_types__):
|
||||
return type(self)(
|
||||
wrapped=value,
|
||||
state=self._self_state,
|
||||
field_name=self._self_field_name,
|
||||
)
|
||||
return value
|
||||
|
||||
def _wrap_recursive_decorator(self, wrapped, instance, args, kwargs) -> Any:
|
||||
"""Wrap a function that returns a possibly mutable value.
|
||||
|
||||
Intended for use with `FunctionWrapper` from the `wrapt` library.
|
||||
|
||||
Args:
|
||||
wrapped: The wrapped function.
|
||||
instance: The instance of the wrapped function.
|
||||
args: The args for the wrapped function.
|
||||
kwargs: The kwargs for the wrapped function.
|
||||
|
||||
Returns:
|
||||
The result of the wrapped function (possibly wrapped in a MutableProxy).
|
||||
"""
|
||||
return self._wrap_recursive(wrapped(*args, **kwargs))
|
||||
|
||||
def __getattribute__(self, __name: str) -> Any:
|
||||
"""Get the attribute on the proxied object and return a proxy if mutable.
|
||||
@ -1690,24 +1739,26 @@ class MutableProxy(wrapt.ObjectProxy):
|
||||
"""
|
||||
value = super().__getattribute__(__name)
|
||||
|
||||
if callable(value) and __name in super().__getattribute__(
|
||||
"__mark_dirty_attrs__"
|
||||
):
|
||||
# Wrap special callables, like "append", which should mark state dirty.
|
||||
return wrapt.FunctionWrapper(
|
||||
value,
|
||||
super().__getattribute__("_mark_dirty"),
|
||||
)
|
||||
if callable(value):
|
||||
if __name in super().__getattribute__("__mark_dirty_attrs__"):
|
||||
# Wrap special callables, like "append", which should mark state dirty.
|
||||
value = wrapt.FunctionWrapper(
|
||||
value,
|
||||
super().__getattribute__("_mark_dirty"),
|
||||
)
|
||||
|
||||
if __name in super().__getattribute__("__wrap_mutable_attrs__"):
|
||||
# Wrap methods that may return mutable objects tied to the state.
|
||||
value = wrapt.FunctionWrapper(
|
||||
value,
|
||||
super().__getattribute__("_wrap_recursive_decorator"),
|
||||
)
|
||||
|
||||
if isinstance(
|
||||
value, super().__getattribute__("__mutable_types__")
|
||||
) and __name not in ("__wrapped__", "_self_state"):
|
||||
# Recursively wrap mutable attribute values retrieved through this proxy.
|
||||
return type(self)(
|
||||
wrapped=value,
|
||||
state=self._self_state,
|
||||
field_name=self._self_field_name,
|
||||
)
|
||||
return self._wrap_recursive(value)
|
||||
|
||||
return value
|
||||
|
||||
@ -1721,14 +1772,18 @@ class MutableProxy(wrapt.ObjectProxy):
|
||||
The item value.
|
||||
"""
|
||||
value = super().__getitem__(key)
|
||||
if isinstance(value, self.__mutable_types__):
|
||||
# Recursively wrap mutable items retrieved through this proxy.
|
||||
return self._wrap_recursive(value)
|
||||
|
||||
def __iter__(self) -> Any:
|
||||
"""Iterate over the proxied object and return a proxy if mutable.
|
||||
|
||||
Yields:
|
||||
Each item value (possibly wrapped in MutableProxy).
|
||||
"""
|
||||
for value in super().__iter__():
|
||||
# Recursively wrap mutable items retrieved through this proxy.
|
||||
return type(self)(
|
||||
wrapped=value,
|
||||
state=self._self_state,
|
||||
field_name=self._self_field_name,
|
||||
)
|
||||
return value
|
||||
yield self._wrap_recursive(value)
|
||||
|
||||
def __delattr__(self, name):
|
||||
"""Delete the attribute on the proxied object and mark state dirty.
|
||||
|
@ -1858,6 +1858,15 @@ def test_mutable_list(mutable_state):
|
||||
assert_array_dirty()
|
||||
assert isinstance(mutable_state.array[0], MutableProxy)
|
||||
|
||||
# Test proxy returned from __iter__
|
||||
mutable_state.array = [{}]
|
||||
assert_array_dirty()
|
||||
assert isinstance(mutable_state.array[0], MutableProxy)
|
||||
for item in mutable_state.array:
|
||||
assert isinstance(item, MutableProxy)
|
||||
item["foo"] = "bar"
|
||||
assert_array_dirty()
|
||||
|
||||
|
||||
def test_mutable_dict(mutable_state):
|
||||
"""Test that mutable dicts are tracked correctly.
|
||||
@ -1875,9 +1884,13 @@ def test_mutable_dict(mutable_state):
|
||||
# Test all dict operations
|
||||
mutable_state.hashmap.update({"new_key": 43})
|
||||
assert_hashmap_dirty()
|
||||
mutable_state.hashmap.setdefault("another_key", 66)
|
||||
assert mutable_state.hashmap.setdefault("another_key", 66) == "another_value"
|
||||
assert_hashmap_dirty()
|
||||
mutable_state.hashmap.pop("new_key")
|
||||
assert mutable_state.hashmap.setdefault("setdefault_key", 67) == 67
|
||||
assert_hashmap_dirty()
|
||||
assert mutable_state.hashmap.setdefault("setdefault_key", 68) == 67
|
||||
assert_hashmap_dirty()
|
||||
assert mutable_state.hashmap.pop("new_key") == 43
|
||||
assert_hashmap_dirty()
|
||||
mutable_state.hashmap.popitem()
|
||||
assert_hashmap_dirty()
|
||||
@ -1905,6 +1918,31 @@ def test_mutable_dict(mutable_state):
|
||||
mutable_state.hashmap["dict"]["dict"]["key"] = 43
|
||||
assert_hashmap_dirty()
|
||||
|
||||
# Test proxy returned from `setdefault` and `get`
|
||||
mutable_value = mutable_state.hashmap.setdefault("setdefault_mutable_key", [])
|
||||
assert_hashmap_dirty()
|
||||
assert mutable_value == []
|
||||
assert isinstance(mutable_value, MutableProxy)
|
||||
mutable_value.append("foo")
|
||||
assert_hashmap_dirty()
|
||||
mutable_value_other_ref = mutable_state.hashmap.get("setdefault_mutable_key")
|
||||
assert isinstance(mutable_value_other_ref, MutableProxy)
|
||||
assert mutable_value is not mutable_value_other_ref
|
||||
assert mutable_value == mutable_value_other_ref
|
||||
assert not mutable_state.dirty_vars
|
||||
mutable_value_other_ref.append("bar")
|
||||
assert_hashmap_dirty()
|
||||
|
||||
# `pop` should NOT return a proxy, because the returned value is no longer in the dict
|
||||
mutable_value_third_ref = mutable_state.hashmap.pop("setdefault_mutable_key")
|
||||
assert not isinstance(mutable_value_third_ref, MutableProxy)
|
||||
assert_hashmap_dirty()
|
||||
mutable_value_third_ref.append("baz")
|
||||
assert not mutable_state.dirty_vars
|
||||
# Unfortunately previous refs still will mark the state dirty... nothing doing about that
|
||||
assert mutable_value.pop()
|
||||
assert_hashmap_dirty()
|
||||
|
||||
|
||||
def test_mutable_set(mutable_state):
|
||||
"""Test that mutable sets are tracked correctly.
|
||||
|
Loading…
Reference in New Issue
Block a user