[ENG-4083] Track internal changes in dataclass instances (#4558)

* [ENG-4083] Track internal changes in dataclass instances

Create a dynamic subclass of MutableProxy with `__dataclass_fields__` set
according to the dataclass being wrapped.

* support dataclasses.asdict on MutableProxy instances
This commit is contained in:
Masen Furer 2025-01-03 15:49:28 -08:00 committed by GitHub
parent 4b89b8260b
commit 41cb2d8cff
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 98 additions and 14 deletions

View File

@ -3649,6 +3649,9 @@ def get_state_manager() -> StateManager:
class MutableProxy(wrapt.ObjectProxy):
"""A proxy for a mutable object that tracks changes."""
# Hint for finding the base class of the proxy.
__base_proxy__ = "MutableProxy"
# Methods on wrapped objects which should mark the state as dirty.
__mark_dirty_attrs__ = {
"add",
@ -3691,6 +3694,39 @@ class MutableProxy(wrapt.ObjectProxy):
BaseModelV1,
)
# Dynamically generated classes for tracking dataclass mutations.
__dataclass_proxies__: Dict[type, type] = {}
def __new__(cls, wrapped: Any, *args, **kwargs) -> MutableProxy:
"""Create a proxy instance for a mutable object that tracks changes.
Args:
wrapped: The object to proxy.
*args: Other args passed to MutableProxy (ignored).
**kwargs: Other kwargs passed to MutableProxy (ignored).
Returns:
The proxy instance.
"""
if dataclasses.is_dataclass(wrapped):
wrapped_cls = type(wrapped)
wrapper_cls_name = wrapped_cls.__name__ + cls.__name__
# Find the associated class
if wrapper_cls_name not in cls.__dataclass_proxies__:
# Create a new class that has the __dataclass_fields__ defined
cls.__dataclass_proxies__[wrapper_cls_name] = type(
wrapper_cls_name,
(cls,),
{
dataclasses._FIELDS: getattr( # pyright: ignore [reportGeneralTypeIssues]
wrapped_cls,
dataclasses._FIELDS, # pyright: ignore [reportGeneralTypeIssues]
),
},
)
cls = cls.__dataclass_proxies__[wrapper_cls_name]
return super().__new__(cls)
def __init__(self, wrapped: Any, state: BaseState, field_name: str):
"""Create a proxy for a mutable object that tracks changes.
@ -3747,7 +3783,27 @@ class MutableProxy(wrapt.ObjectProxy):
Returns:
Whether the value is of a mutable type.
"""
return isinstance(value, cls.__mutable_types__)
return isinstance(value, cls.__mutable_types__) or (
dataclasses.is_dataclass(value) and not isinstance(value, Var)
)
@staticmethod
def _is_called_from_dataclasses_internal() -> bool:
"""Check if the current function is called from dataclasses helper.
Returns:
Whether the current function is called from dataclasses internal code.
"""
# Walk up the stack a bit to see if we are called from dataclasses
# internal code, for example `asdict` or `astuple`.
frame = inspect.currentframe()
for _ in range(5):
# Why not `inspect.stack()` -- this is much faster!
if not (frame := frame and frame.f_back):
break
if inspect.getfile(frame) == dataclasses.__file__:
return True
return False
def _wrap_recursive(self, value: Any) -> Any:
"""Wrap a value recursively if it is mutable.
@ -3758,9 +3814,13 @@ class MutableProxy(wrapt.ObjectProxy):
Returns:
The wrapped value.
"""
# When called from dataclasses internal code, return the unwrapped value
if self._is_called_from_dataclasses_internal():
return value
# Recursively wrap mutable types, but do not re-wrap MutableProxy instances.
if self._is_mutable_type(value) and not isinstance(value, MutableProxy):
return type(self)(
base_cls = globals()[self.__base_proxy__]
return base_cls(
wrapped=value,
state=self._self_state,
field_name=self._self_field_name,
@ -3968,6 +4028,9 @@ class ImmutableMutableProxy(MutableProxy):
to modify the wrapped object when the StateProxy is immutable.
"""
# Ensure that recursively wrapped proxies use ImmutableMutableProxy as base.
__base_proxy__ = "ImmutableMutableProxy"
def _mark_dirty(
self,
wrapped=None,

View File

@ -1936,6 +1936,14 @@ def mock_app(mock_app_simple: rx.App, state_manager: StateManager) -> rx.App:
return mock_app_simple
@dataclasses.dataclass
class ModelDC:
"""A dataclass."""
foo: str = "bar"
ls: list[dict] = dataclasses.field(default_factory=list)
@pytest.mark.asyncio
async def test_state_proxy(grandchild_state: GrandchildState, mock_app: rx.App):
"""Test that the state proxy works.
@ -2038,6 +2046,7 @@ class BackgroundTaskState(BaseState):
order: List[str] = []
dict_list: Dict[str, List[int]] = {"foo": [1, 2, 3]}
dc: ModelDC = ModelDC()
def __init__(self, **kwargs): # noqa: D107
super().__init__(**kwargs)
@ -2063,10 +2072,18 @@ class BackgroundTaskState(BaseState):
with pytest.raises(ImmutableStateError):
self.order.append("bad idea")
with pytest.raises(ImmutableStateError):
# Cannot manipulate dataclass attributes.
self.dc.foo = "baz"
with pytest.raises(ImmutableStateError):
# Even nested access to mutables raises an exception.
self.dict_list["foo"].append(42)
with pytest.raises(ImmutableStateError):
# Cannot modify dataclass list attribute.
self.dc.ls.append({"foo": "bar"})
with pytest.raises(ImmutableStateError):
# Direct calling another handler that modifies state raises an exception.
self.other()
@ -3582,13 +3599,6 @@ class ModelV2(BaseModelV2):
foo: str = "bar"
@dataclasses.dataclass
class ModelDC:
"""A dataclass."""
foo: str = "bar"
class PydanticState(rx.State):
"""A state with pydantic BaseModel vars."""
@ -3610,11 +3620,22 @@ def test_mutable_models():
assert state.dirty_vars == {"v2"}
state.dirty_vars.clear()
# Not yet supported ENG-4083
# assert isinstance(state.dc, MutableProxy) #noqa: ERA001
# state.dc.foo = "baz" #noqa: ERA001
# assert state.dirty_vars == {"dc"} #noqa: ERA001
# state.dirty_vars.clear() #noqa: ERA001
assert isinstance(state.dc, MutableProxy)
state.dc.foo = "baz"
assert state.dirty_vars == {"dc"}
state.dirty_vars.clear()
assert state.dirty_vars == set()
state.dc.ls.append({"hi": "reflex"})
assert state.dirty_vars == {"dc"}
state.dirty_vars.clear()
assert state.dirty_vars == set()
assert dataclasses.asdict(state.dc) == {"foo": "baz", "ls": [{"hi": "reflex"}]}
assert dataclasses.astuple(state.dc) == ("baz", [{"hi": "reflex"}])
# creating a new instance shouldn't mark the state dirty
assert dataclasses.replace(state.dc, foo="quuc") == ModelDC(
foo="quuc", ls=[{"hi": "reflex"}]
)
assert state.dirty_vars == set()
def test_get_value():