MutableProxy wraps values yielded by __iter__ (#1876)

This commit is contained in:
Masen Furer 2023-09-28 17:34:46 -07:00 committed by GitHub
parent bd0cd18796
commit 5ca7f29853
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 117 additions and 24 deletions

View File

@ -1647,6 +1647,13 @@ class MutableProxy(wrapt.ObjectProxy):
"update",
]
)
# Methods on wrapped objects might return mutable objects that should be tracked.
__wrap_mutable_attrs__ = set(
[
"get",
"setdefault",
]
)
__mutable_types__ = (list, dict, set, Base)
@ -1663,7 +1670,13 @@ class MutableProxy(wrapt.ObjectProxy):
self._self_state = state
self._self_field_name = field_name
def _mark_dirty(self, wrapped=None, instance=None, args=tuple(), kwargs=None):
def _mark_dirty(
self,
wrapped=None,
instance=None,
args=tuple(),
kwargs=None,
) -> Any:
"""Mark the state as dirty, then call a wrapped function.
Intended for use with `FunctionWrapper` from the `wrapt` library.
@ -1673,11 +1686,47 @@ class MutableProxy(wrapt.ObjectProxy):
instance: The instance of the wrapped function.
args: The args for the wrapped function.
kwargs: The kwargs for the wrapped function.
Returns:
The result of the wrapped function.
"""
self._self_state.dirty_vars.add(self._self_field_name)
self._self_state._mark_dirty()
if wrapped is not None:
wrapped(*args, **(kwargs or {}))
return wrapped(*args, **(kwargs or {}))
def _wrap_recursive(self, value: Any) -> Any:
"""Wrap a value recursively if it is mutable.
Args:
value: The value to wrap.
Returns:
The wrapped value.
"""
if isinstance(value, self.__mutable_types__):
return type(self)(
wrapped=value,
state=self._self_state,
field_name=self._self_field_name,
)
return value
def _wrap_recursive_decorator(self, wrapped, instance, args, kwargs) -> Any:
"""Wrap a function that returns a possibly mutable value.
Intended for use with `FunctionWrapper` from the `wrapt` library.
Args:
wrapped: The wrapped function.
instance: The instance of the wrapped function.
args: The args for the wrapped function.
kwargs: The kwargs for the wrapped function.
Returns:
The result of the wrapped function (possibly wrapped in a MutableProxy).
"""
return self._wrap_recursive(wrapped(*args, **kwargs))
def __getattribute__(self, __name: str) -> Any:
"""Get the attribute on the proxied object and return a proxy if mutable.
@ -1690,24 +1739,26 @@ class MutableProxy(wrapt.ObjectProxy):
"""
value = super().__getattribute__(__name)
if callable(value) and __name in super().__getattribute__(
"__mark_dirty_attrs__"
):
# Wrap special callables, like "append", which should mark state dirty.
return wrapt.FunctionWrapper(
value,
super().__getattribute__("_mark_dirty"),
)
if callable(value):
if __name in super().__getattribute__("__mark_dirty_attrs__"):
# Wrap special callables, like "append", which should mark state dirty.
value = wrapt.FunctionWrapper(
value,
super().__getattribute__("_mark_dirty"),
)
if __name in super().__getattribute__("__wrap_mutable_attrs__"):
# Wrap methods that may return mutable objects tied to the state.
value = wrapt.FunctionWrapper(
value,
super().__getattribute__("_wrap_recursive_decorator"),
)
if isinstance(
value, super().__getattribute__("__mutable_types__")
) and __name not in ("__wrapped__", "_self_state"):
# Recursively wrap mutable attribute values retrieved through this proxy.
return type(self)(
wrapped=value,
state=self._self_state,
field_name=self._self_field_name,
)
return self._wrap_recursive(value)
return value
@ -1721,14 +1772,18 @@ class MutableProxy(wrapt.ObjectProxy):
The item value.
"""
value = super().__getitem__(key)
if isinstance(value, self.__mutable_types__):
# Recursively wrap mutable items retrieved through this proxy.
return self._wrap_recursive(value)
def __iter__(self) -> Any:
"""Iterate over the proxied object and return a proxy if mutable.
Yields:
Each item value (possibly wrapped in MutableProxy).
"""
for value in super().__iter__():
# Recursively wrap mutable items retrieved through this proxy.
return type(self)(
wrapped=value,
state=self._self_state,
field_name=self._self_field_name,
)
return value
yield self._wrap_recursive(value)
def __delattr__(self, name):
"""Delete the attribute on the proxied object and mark state dirty.

View File

@ -1858,6 +1858,15 @@ def test_mutable_list(mutable_state):
assert_array_dirty()
assert isinstance(mutable_state.array[0], MutableProxy)
# Test proxy returned from __iter__
mutable_state.array = [{}]
assert_array_dirty()
assert isinstance(mutable_state.array[0], MutableProxy)
for item in mutable_state.array:
assert isinstance(item, MutableProxy)
item["foo"] = "bar"
assert_array_dirty()
def test_mutable_dict(mutable_state):
"""Test that mutable dicts are tracked correctly.
@ -1875,9 +1884,13 @@ def test_mutable_dict(mutable_state):
# Test all dict operations
mutable_state.hashmap.update({"new_key": 43})
assert_hashmap_dirty()
mutable_state.hashmap.setdefault("another_key", 66)
assert mutable_state.hashmap.setdefault("another_key", 66) == "another_value"
assert_hashmap_dirty()
mutable_state.hashmap.pop("new_key")
assert mutable_state.hashmap.setdefault("setdefault_key", 67) == 67
assert_hashmap_dirty()
assert mutable_state.hashmap.setdefault("setdefault_key", 68) == 67
assert_hashmap_dirty()
assert mutable_state.hashmap.pop("new_key") == 43
assert_hashmap_dirty()
mutable_state.hashmap.popitem()
assert_hashmap_dirty()
@ -1905,6 +1918,31 @@ def test_mutable_dict(mutable_state):
mutable_state.hashmap["dict"]["dict"]["key"] = 43
assert_hashmap_dirty()
# Test proxy returned from `setdefault` and `get`
mutable_value = mutable_state.hashmap.setdefault("setdefault_mutable_key", [])
assert_hashmap_dirty()
assert mutable_value == []
assert isinstance(mutable_value, MutableProxy)
mutable_value.append("foo")
assert_hashmap_dirty()
mutable_value_other_ref = mutable_state.hashmap.get("setdefault_mutable_key")
assert isinstance(mutable_value_other_ref, MutableProxy)
assert mutable_value is not mutable_value_other_ref
assert mutable_value == mutable_value_other_ref
assert not mutable_state.dirty_vars
mutable_value_other_ref.append("bar")
assert_hashmap_dirty()
# `pop` should NOT return a proxy, because the returned value is no longer in the dict
mutable_value_third_ref = mutable_state.hashmap.pop("setdefault_mutable_key")
assert not isinstance(mutable_value_third_ref, MutableProxy)
assert_hashmap_dirty()
mutable_value_third_ref.append("baz")
assert not mutable_state.dirty_vars
# Unfortunately previous refs still will mark the state dirty... nothing doing about that
assert mutable_value.pop()
assert_hashmap_dirty()
def test_mutable_set(mutable_state):
"""Test that mutable sets are tracked correctly.