From 41cb2d8cffa54ce396cddbe1525fbc55e143d344 Mon Sep 17 00:00:00 2001 From: Masen Furer Date: Fri, 3 Jan 2025 15:49:28 -0800 Subject: [PATCH] [ENG-4083] Track internal changes in dataclass instances (#4558) * [ENG-4083] Track internal changes in dataclass instances Create a dynamic subclass of MutableProxy with `__dataclass_fields__` set according to the dataclass being wrapped. * support dataclasses.asdict on MutableProxy instances --- reflex/state.py | 67 +++++++++++++++++++++++++++++++++++++-- tests/units/test_state.py | 45 +++++++++++++++++++------- 2 files changed, 98 insertions(+), 14 deletions(-) diff --git a/reflex/state.py b/reflex/state.py index c30a4038d..101155a97 100644 --- a/reflex/state.py +++ b/reflex/state.py @@ -3649,6 +3649,9 @@ def get_state_manager() -> StateManager: class MutableProxy(wrapt.ObjectProxy): """A proxy for a mutable object that tracks changes.""" + # Hint for finding the base class of the proxy. + __base_proxy__ = "MutableProxy" + # Methods on wrapped objects which should mark the state as dirty. __mark_dirty_attrs__ = { "add", @@ -3691,6 +3694,39 @@ class MutableProxy(wrapt.ObjectProxy): BaseModelV1, ) + # Dynamically generated classes for tracking dataclass mutations. + __dataclass_proxies__: Dict[type, type] = {} + + def __new__(cls, wrapped: Any, *args, **kwargs) -> MutableProxy: + """Create a proxy instance for a mutable object that tracks changes. + + Args: + wrapped: The object to proxy. + *args: Other args passed to MutableProxy (ignored). + **kwargs: Other kwargs passed to MutableProxy (ignored). + + Returns: + The proxy instance. + """ + if dataclasses.is_dataclass(wrapped): + wrapped_cls = type(wrapped) + wrapper_cls_name = wrapped_cls.__name__ + cls.__name__ + # Find the associated class + if wrapper_cls_name not in cls.__dataclass_proxies__: + # Create a new class that has the __dataclass_fields__ defined + cls.__dataclass_proxies__[wrapper_cls_name] = type( + wrapper_cls_name, + (cls,), + { + dataclasses._FIELDS: getattr( # pyright: ignore [reportGeneralTypeIssues] + wrapped_cls, + dataclasses._FIELDS, # pyright: ignore [reportGeneralTypeIssues] + ), + }, + ) + cls = cls.__dataclass_proxies__[wrapper_cls_name] + return super().__new__(cls) + def __init__(self, wrapped: Any, state: BaseState, field_name: str): """Create a proxy for a mutable object that tracks changes. @@ -3747,7 +3783,27 @@ class MutableProxy(wrapt.ObjectProxy): Returns: Whether the value is of a mutable type. """ - return isinstance(value, cls.__mutable_types__) + return isinstance(value, cls.__mutable_types__) or ( + dataclasses.is_dataclass(value) and not isinstance(value, Var) + ) + + @staticmethod + def _is_called_from_dataclasses_internal() -> bool: + """Check if the current function is called from dataclasses helper. + + Returns: + Whether the current function is called from dataclasses internal code. + """ + # Walk up the stack a bit to see if we are called from dataclasses + # internal code, for example `asdict` or `astuple`. + frame = inspect.currentframe() + for _ in range(5): + # Why not `inspect.stack()` -- this is much faster! + if not (frame := frame and frame.f_back): + break + if inspect.getfile(frame) == dataclasses.__file__: + return True + return False def _wrap_recursive(self, value: Any) -> Any: """Wrap a value recursively if it is mutable. @@ -3758,9 +3814,13 @@ class MutableProxy(wrapt.ObjectProxy): Returns: The wrapped value. """ + # When called from dataclasses internal code, return the unwrapped value + if self._is_called_from_dataclasses_internal(): + return value # Recursively wrap mutable types, but do not re-wrap MutableProxy instances. if self._is_mutable_type(value) and not isinstance(value, MutableProxy): - return type(self)( + base_cls = globals()[self.__base_proxy__] + return base_cls( wrapped=value, state=self._self_state, field_name=self._self_field_name, @@ -3968,6 +4028,9 @@ class ImmutableMutableProxy(MutableProxy): to modify the wrapped object when the StateProxy is immutable. """ + # Ensure that recursively wrapped proxies use ImmutableMutableProxy as base. + __base_proxy__ = "ImmutableMutableProxy" + def _mark_dirty( self, wrapped=None, diff --git a/tests/units/test_state.py b/tests/units/test_state.py index 39484752c..d6c48bd2b 100644 --- a/tests/units/test_state.py +++ b/tests/units/test_state.py @@ -1936,6 +1936,14 @@ def mock_app(mock_app_simple: rx.App, state_manager: StateManager) -> rx.App: return mock_app_simple +@dataclasses.dataclass +class ModelDC: + """A dataclass.""" + + foo: str = "bar" + ls: list[dict] = dataclasses.field(default_factory=list) + + @pytest.mark.asyncio async def test_state_proxy(grandchild_state: GrandchildState, mock_app: rx.App): """Test that the state proxy works. @@ -2038,6 +2046,7 @@ class BackgroundTaskState(BaseState): order: List[str] = [] dict_list: Dict[str, List[int]] = {"foo": [1, 2, 3]} + dc: ModelDC = ModelDC() def __init__(self, **kwargs): # noqa: D107 super().__init__(**kwargs) @@ -2063,10 +2072,18 @@ class BackgroundTaskState(BaseState): with pytest.raises(ImmutableStateError): self.order.append("bad idea") + with pytest.raises(ImmutableStateError): + # Cannot manipulate dataclass attributes. + self.dc.foo = "baz" + with pytest.raises(ImmutableStateError): # Even nested access to mutables raises an exception. self.dict_list["foo"].append(42) + with pytest.raises(ImmutableStateError): + # Cannot modify dataclass list attribute. + self.dc.ls.append({"foo": "bar"}) + with pytest.raises(ImmutableStateError): # Direct calling another handler that modifies state raises an exception. self.other() @@ -3582,13 +3599,6 @@ class ModelV2(BaseModelV2): foo: str = "bar" -@dataclasses.dataclass -class ModelDC: - """A dataclass.""" - - foo: str = "bar" - - class PydanticState(rx.State): """A state with pydantic BaseModel vars.""" @@ -3610,11 +3620,22 @@ def test_mutable_models(): assert state.dirty_vars == {"v2"} state.dirty_vars.clear() - # Not yet supported ENG-4083 - # assert isinstance(state.dc, MutableProxy) #noqa: ERA001 - # state.dc.foo = "baz" #noqa: ERA001 - # assert state.dirty_vars == {"dc"} #noqa: ERA001 - # state.dirty_vars.clear() #noqa: ERA001 + assert isinstance(state.dc, MutableProxy) + state.dc.foo = "baz" + assert state.dirty_vars == {"dc"} + state.dirty_vars.clear() + assert state.dirty_vars == set() + state.dc.ls.append({"hi": "reflex"}) + assert state.dirty_vars == {"dc"} + state.dirty_vars.clear() + assert state.dirty_vars == set() + assert dataclasses.asdict(state.dc) == {"foo": "baz", "ls": [{"hi": "reflex"}]} + assert dataclasses.astuple(state.dc) == ("baz", [{"hi": "reflex"}]) + # creating a new instance shouldn't mark the state dirty + assert dataclasses.replace(state.dc, foo="quuc") == ModelDC( + foo="quuc", ls=[{"hi": "reflex"}] + ) + assert state.dirty_vars == set() def test_get_value():