support dataclasses.asdict on MutableProxy instances

This commit is contained in:
Masen Furer 2024-12-20 16:04:47 -08:00
parent 72d0e5f230
commit d2eb751ab5
No known key found for this signature in database
GPG Key ID: B0008AD22B3B3A95
2 changed files with 33 additions and 0 deletions

View File

@ -3783,6 +3783,24 @@ class MutableProxy(wrapt.ObjectProxy):
dataclasses.is_dataclass(value) and not isinstance(value, Var) dataclasses.is_dataclass(value) and not isinstance(value, Var)
) )
@staticmethod
def _is_called_from_dataclasses_internal() -> bool:
"""Check if the current function is called from dataclasses helper.
Returns:
Whether the current function is called from dataclasses internal code.
"""
# Walk up the stack a bit to see if we are called from dataclasses
# internal code, for example `asdict` or `astuple`.
frame = inspect.currentframe()
for _ in range(5):
# Why not `inspect.stack()` -- this is much faster!
if not (frame := frame and frame.f_back):
break
if inspect.getfile(frame) == dataclasses.__file__:
return True
return False
def _wrap_recursive(self, value: Any) -> Any: def _wrap_recursive(self, value: Any) -> Any:
"""Wrap a value recursively if it is mutable. """Wrap a value recursively if it is mutable.
@ -3792,6 +3810,9 @@ class MutableProxy(wrapt.ObjectProxy):
Returns: Returns:
The wrapped value. The wrapped value.
""" """
# When called from dataclasses internal code, return the unwrapped value
if self._is_called_from_dataclasses_internal():
return value
# Recursively wrap mutable types, but do not re-wrap MutableProxy instances. # Recursively wrap mutable types, but do not re-wrap MutableProxy instances.
if self._is_mutable_type(value) and not isinstance(value, MutableProxy): if self._is_mutable_type(value) and not isinstance(value, MutableProxy):
base_cls = globals()[self.__base_proxy__] base_cls = globals()[self.__base_proxy__]

View File

@ -3624,6 +3624,18 @@ def test_mutable_models():
state.dc.foo = "baz" state.dc.foo = "baz"
assert state.dirty_vars == {"dc"} assert state.dirty_vars == {"dc"}
state.dirty_vars.clear() state.dirty_vars.clear()
assert state.dirty_vars == set()
state.dc.ls.append({"hi": "reflex"})
assert state.dirty_vars == {"dc"}
state.dirty_vars.clear()
assert state.dirty_vars == set()
assert dataclasses.asdict(state.dc) == {"foo": "baz", "ls": [{"hi": "reflex"}]}
assert dataclasses.astuple(state.dc) == ("baz", [{"hi": "reflex"}])
# creating a new instance shouldn't mark the state dirty
assert dataclasses.replace(state.dc, foo="quuc") == ModelDC(
foo="quuc", ls=[{"hi": "reflex"}]
)
assert state.dirty_vars == set()
def test_get_value(): def test_get_value():