MutableProxy wraps values yielded by __iter__ (#1876)
This commit is contained in:
parent
bd0cd18796
commit
5ca7f29853
@ -1647,6 +1647,13 @@ class MutableProxy(wrapt.ObjectProxy):
|
|||||||
"update",
|
"update",
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
# Methods on wrapped objects might return mutable objects that should be tracked.
|
||||||
|
__wrap_mutable_attrs__ = set(
|
||||||
|
[
|
||||||
|
"get",
|
||||||
|
"setdefault",
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
__mutable_types__ = (list, dict, set, Base)
|
__mutable_types__ = (list, dict, set, Base)
|
||||||
|
|
||||||
@ -1663,7 +1670,13 @@ class MutableProxy(wrapt.ObjectProxy):
|
|||||||
self._self_state = state
|
self._self_state = state
|
||||||
self._self_field_name = field_name
|
self._self_field_name = field_name
|
||||||
|
|
||||||
def _mark_dirty(self, wrapped=None, instance=None, args=tuple(), kwargs=None):
|
def _mark_dirty(
|
||||||
|
self,
|
||||||
|
wrapped=None,
|
||||||
|
instance=None,
|
||||||
|
args=tuple(),
|
||||||
|
kwargs=None,
|
||||||
|
) -> Any:
|
||||||
"""Mark the state as dirty, then call a wrapped function.
|
"""Mark the state as dirty, then call a wrapped function.
|
||||||
|
|
||||||
Intended for use with `FunctionWrapper` from the `wrapt` library.
|
Intended for use with `FunctionWrapper` from the `wrapt` library.
|
||||||
@ -1673,11 +1686,47 @@ class MutableProxy(wrapt.ObjectProxy):
|
|||||||
instance: The instance of the wrapped function.
|
instance: The instance of the wrapped function.
|
||||||
args: The args for the wrapped function.
|
args: The args for the wrapped function.
|
||||||
kwargs: The kwargs for the wrapped function.
|
kwargs: The kwargs for the wrapped function.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The result of the wrapped function.
|
||||||
"""
|
"""
|
||||||
self._self_state.dirty_vars.add(self._self_field_name)
|
self._self_state.dirty_vars.add(self._self_field_name)
|
||||||
self._self_state._mark_dirty()
|
self._self_state._mark_dirty()
|
||||||
if wrapped is not None:
|
if wrapped is not None:
|
||||||
wrapped(*args, **(kwargs or {}))
|
return wrapped(*args, **(kwargs or {}))
|
||||||
|
|
||||||
|
def _wrap_recursive(self, value: Any) -> Any:
|
||||||
|
"""Wrap a value recursively if it is mutable.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
value: The value to wrap.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The wrapped value.
|
||||||
|
"""
|
||||||
|
if isinstance(value, self.__mutable_types__):
|
||||||
|
return type(self)(
|
||||||
|
wrapped=value,
|
||||||
|
state=self._self_state,
|
||||||
|
field_name=self._self_field_name,
|
||||||
|
)
|
||||||
|
return value
|
||||||
|
|
||||||
|
def _wrap_recursive_decorator(self, wrapped, instance, args, kwargs) -> Any:
|
||||||
|
"""Wrap a function that returns a possibly mutable value.
|
||||||
|
|
||||||
|
Intended for use with `FunctionWrapper` from the `wrapt` library.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
wrapped: The wrapped function.
|
||||||
|
instance: The instance of the wrapped function.
|
||||||
|
args: The args for the wrapped function.
|
||||||
|
kwargs: The kwargs for the wrapped function.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The result of the wrapped function (possibly wrapped in a MutableProxy).
|
||||||
|
"""
|
||||||
|
return self._wrap_recursive(wrapped(*args, **kwargs))
|
||||||
|
|
||||||
def __getattribute__(self, __name: str) -> Any:
|
def __getattribute__(self, __name: str) -> Any:
|
||||||
"""Get the attribute on the proxied object and return a proxy if mutable.
|
"""Get the attribute on the proxied object and return a proxy if mutable.
|
||||||
@ -1690,24 +1739,26 @@ class MutableProxy(wrapt.ObjectProxy):
|
|||||||
"""
|
"""
|
||||||
value = super().__getattribute__(__name)
|
value = super().__getattribute__(__name)
|
||||||
|
|
||||||
if callable(value) and __name in super().__getattribute__(
|
if callable(value):
|
||||||
"__mark_dirty_attrs__"
|
if __name in super().__getattribute__("__mark_dirty_attrs__"):
|
||||||
):
|
# Wrap special callables, like "append", which should mark state dirty.
|
||||||
# Wrap special callables, like "append", which should mark state dirty.
|
value = wrapt.FunctionWrapper(
|
||||||
return wrapt.FunctionWrapper(
|
value,
|
||||||
value,
|
super().__getattribute__("_mark_dirty"),
|
||||||
super().__getattribute__("_mark_dirty"),
|
)
|
||||||
)
|
|
||||||
|
if __name in super().__getattribute__("__wrap_mutable_attrs__"):
|
||||||
|
# Wrap methods that may return mutable objects tied to the state.
|
||||||
|
value = wrapt.FunctionWrapper(
|
||||||
|
value,
|
||||||
|
super().__getattribute__("_wrap_recursive_decorator"),
|
||||||
|
)
|
||||||
|
|
||||||
if isinstance(
|
if isinstance(
|
||||||
value, super().__getattribute__("__mutable_types__")
|
value, super().__getattribute__("__mutable_types__")
|
||||||
) and __name not in ("__wrapped__", "_self_state"):
|
) and __name not in ("__wrapped__", "_self_state"):
|
||||||
# Recursively wrap mutable attribute values retrieved through this proxy.
|
# Recursively wrap mutable attribute values retrieved through this proxy.
|
||||||
return type(self)(
|
return self._wrap_recursive(value)
|
||||||
wrapped=value,
|
|
||||||
state=self._self_state,
|
|
||||||
field_name=self._self_field_name,
|
|
||||||
)
|
|
||||||
|
|
||||||
return value
|
return value
|
||||||
|
|
||||||
@ -1721,14 +1772,18 @@ class MutableProxy(wrapt.ObjectProxy):
|
|||||||
The item value.
|
The item value.
|
||||||
"""
|
"""
|
||||||
value = super().__getitem__(key)
|
value = super().__getitem__(key)
|
||||||
if isinstance(value, self.__mutable_types__):
|
# Recursively wrap mutable items retrieved through this proxy.
|
||||||
|
return self._wrap_recursive(value)
|
||||||
|
|
||||||
|
def __iter__(self) -> Any:
|
||||||
|
"""Iterate over the proxied object and return a proxy if mutable.
|
||||||
|
|
||||||
|
Yields:
|
||||||
|
Each item value (possibly wrapped in MutableProxy).
|
||||||
|
"""
|
||||||
|
for value in super().__iter__():
|
||||||
# Recursively wrap mutable items retrieved through this proxy.
|
# Recursively wrap mutable items retrieved through this proxy.
|
||||||
return type(self)(
|
yield self._wrap_recursive(value)
|
||||||
wrapped=value,
|
|
||||||
state=self._self_state,
|
|
||||||
field_name=self._self_field_name,
|
|
||||||
)
|
|
||||||
return value
|
|
||||||
|
|
||||||
def __delattr__(self, name):
|
def __delattr__(self, name):
|
||||||
"""Delete the attribute on the proxied object and mark state dirty.
|
"""Delete the attribute on the proxied object and mark state dirty.
|
||||||
|
@ -1858,6 +1858,15 @@ def test_mutable_list(mutable_state):
|
|||||||
assert_array_dirty()
|
assert_array_dirty()
|
||||||
assert isinstance(mutable_state.array[0], MutableProxy)
|
assert isinstance(mutable_state.array[0], MutableProxy)
|
||||||
|
|
||||||
|
# Test proxy returned from __iter__
|
||||||
|
mutable_state.array = [{}]
|
||||||
|
assert_array_dirty()
|
||||||
|
assert isinstance(mutable_state.array[0], MutableProxy)
|
||||||
|
for item in mutable_state.array:
|
||||||
|
assert isinstance(item, MutableProxy)
|
||||||
|
item["foo"] = "bar"
|
||||||
|
assert_array_dirty()
|
||||||
|
|
||||||
|
|
||||||
def test_mutable_dict(mutable_state):
|
def test_mutable_dict(mutable_state):
|
||||||
"""Test that mutable dicts are tracked correctly.
|
"""Test that mutable dicts are tracked correctly.
|
||||||
@ -1875,9 +1884,13 @@ def test_mutable_dict(mutable_state):
|
|||||||
# Test all dict operations
|
# Test all dict operations
|
||||||
mutable_state.hashmap.update({"new_key": 43})
|
mutable_state.hashmap.update({"new_key": 43})
|
||||||
assert_hashmap_dirty()
|
assert_hashmap_dirty()
|
||||||
mutable_state.hashmap.setdefault("another_key", 66)
|
assert mutable_state.hashmap.setdefault("another_key", 66) == "another_value"
|
||||||
assert_hashmap_dirty()
|
assert_hashmap_dirty()
|
||||||
mutable_state.hashmap.pop("new_key")
|
assert mutable_state.hashmap.setdefault("setdefault_key", 67) == 67
|
||||||
|
assert_hashmap_dirty()
|
||||||
|
assert mutable_state.hashmap.setdefault("setdefault_key", 68) == 67
|
||||||
|
assert_hashmap_dirty()
|
||||||
|
assert mutable_state.hashmap.pop("new_key") == 43
|
||||||
assert_hashmap_dirty()
|
assert_hashmap_dirty()
|
||||||
mutable_state.hashmap.popitem()
|
mutable_state.hashmap.popitem()
|
||||||
assert_hashmap_dirty()
|
assert_hashmap_dirty()
|
||||||
@ -1905,6 +1918,31 @@ def test_mutable_dict(mutable_state):
|
|||||||
mutable_state.hashmap["dict"]["dict"]["key"] = 43
|
mutable_state.hashmap["dict"]["dict"]["key"] = 43
|
||||||
assert_hashmap_dirty()
|
assert_hashmap_dirty()
|
||||||
|
|
||||||
|
# Test proxy returned from `setdefault` and `get`
|
||||||
|
mutable_value = mutable_state.hashmap.setdefault("setdefault_mutable_key", [])
|
||||||
|
assert_hashmap_dirty()
|
||||||
|
assert mutable_value == []
|
||||||
|
assert isinstance(mutable_value, MutableProxy)
|
||||||
|
mutable_value.append("foo")
|
||||||
|
assert_hashmap_dirty()
|
||||||
|
mutable_value_other_ref = mutable_state.hashmap.get("setdefault_mutable_key")
|
||||||
|
assert isinstance(mutable_value_other_ref, MutableProxy)
|
||||||
|
assert mutable_value is not mutable_value_other_ref
|
||||||
|
assert mutable_value == mutable_value_other_ref
|
||||||
|
assert not mutable_state.dirty_vars
|
||||||
|
mutable_value_other_ref.append("bar")
|
||||||
|
assert_hashmap_dirty()
|
||||||
|
|
||||||
|
# `pop` should NOT return a proxy, because the returned value is no longer in the dict
|
||||||
|
mutable_value_third_ref = mutable_state.hashmap.pop("setdefault_mutable_key")
|
||||||
|
assert not isinstance(mutable_value_third_ref, MutableProxy)
|
||||||
|
assert_hashmap_dirty()
|
||||||
|
mutable_value_third_ref.append("baz")
|
||||||
|
assert not mutable_state.dirty_vars
|
||||||
|
# Unfortunately previous refs still will mark the state dirty... nothing doing about that
|
||||||
|
assert mutable_value.pop()
|
||||||
|
assert_hashmap_dirty()
|
||||||
|
|
||||||
|
|
||||||
def test_mutable_set(mutable_state):
|
def test_mutable_set(mutable_state):
|
||||||
"""Test that mutable sets are tracked correctly.
|
"""Test that mutable sets are tracked correctly.
|
||||||
|
Loading…
Reference in New Issue
Block a user