From 72d0e5f2307690bd5c4020baaed2ab1863adfb29 Mon Sep 17 00:00:00 2001 From: Masen Furer Date: Thu, 19 Dec 2024 02:12:29 -0800 Subject: [PATCH] [ENG-4083] Track internal changes in dataclass instances Create a dynamic subclass of MutableProxy with `__dataclass_fields__` set according to the dataclass being wrapped. --- reflex/state.py | 46 +++++++++++++++++++++++++++++++++++++-- tests/units/test_state.py | 33 ++++++++++++++++++---------- 2 files changed, 65 insertions(+), 14 deletions(-) diff --git a/reflex/state.py b/reflex/state.py index e7e6bcf32..fb1a481f9 100644 --- a/reflex/state.py +++ b/reflex/state.py @@ -3645,6 +3645,9 @@ def get_state_manager() -> StateManager: class MutableProxy(wrapt.ObjectProxy): """A proxy for a mutable object that tracks changes.""" + # Hint for finding the base class of the proxy. + __base_proxy__ = "MutableProxy" + # Methods on wrapped objects which should mark the state as dirty. __mark_dirty_attrs__ = { "add", @@ -3687,6 +3690,39 @@ class MutableProxy(wrapt.ObjectProxy): BaseModelV1, ) + # Dynamically generated classes for tracking dataclass mutations. + __dataclass_proxies__: Dict[type, type] = {} + + def __new__(cls, wrapped: Any, *args, **kwargs) -> MutableProxy: + """Create a proxy instance for a mutable object that tracks changes. + + Args: + wrapped: The object to proxy. + *args: Other args passed to MutableProxy (ignored). + **kwargs: Other kwargs passed to MutableProxy (ignored). + + Returns: + The proxy instance. + """ + if dataclasses.is_dataclass(wrapped): + wrapped_cls = type(wrapped) + wrapper_cls_name = wrapped_cls.__name__ + cls.__name__ + # Find the associated class + if wrapper_cls_name not in cls.__dataclass_proxies__: + # Create a new class that has the __dataclass_fields__ defined + cls.__dataclass_proxies__[wrapper_cls_name] = type( + wrapper_cls_name, + (cls,), + { + dataclasses._FIELDS: getattr( # pyright: ignore [reportGeneralTypeIssues] + wrapped_cls, + dataclasses._FIELDS, # pyright: ignore [reportGeneralTypeIssues] + ), + }, + ) + cls = cls.__dataclass_proxies__[wrapper_cls_name] + return super().__new__(cls) + def __init__(self, wrapped: Any, state: BaseState, field_name: str): """Create a proxy for a mutable object that tracks changes. @@ -3743,7 +3779,9 @@ class MutableProxy(wrapt.ObjectProxy): Returns: Whether the value is of a mutable type. """ - return isinstance(value, cls.__mutable_types__) + return isinstance(value, cls.__mutable_types__) or ( + dataclasses.is_dataclass(value) and not isinstance(value, Var) + ) def _wrap_recursive(self, value: Any) -> Any: """Wrap a value recursively if it is mutable. @@ -3756,7 +3794,8 @@ class MutableProxy(wrapt.ObjectProxy): """ # Recursively wrap mutable types, but do not re-wrap MutableProxy instances. if self._is_mutable_type(value) and not isinstance(value, MutableProxy): - return type(self)( + base_cls = globals()[self.__base_proxy__] + return base_cls( wrapped=value, state=self._self_state, field_name=self._self_field_name, @@ -3964,6 +4003,9 @@ class ImmutableMutableProxy(MutableProxy): to modify the wrapped object when the StateProxy is immutable. """ + # Ensure that recursively wrapped proxies use ImmutableMutableProxy as base. + __base_proxy__ = "ImmutableMutableProxy" + def _mark_dirty( self, wrapped=None, diff --git a/tests/units/test_state.py b/tests/units/test_state.py index 912d72f4f..4b9b2d6cf 100644 --- a/tests/units/test_state.py +++ b/tests/units/test_state.py @@ -1936,6 +1936,14 @@ def mock_app(mock_app_simple: rx.App, state_manager: StateManager) -> rx.App: return mock_app_simple +@dataclasses.dataclass +class ModelDC: + """A dataclass.""" + + foo: str = "bar" + ls: list[dict] = dataclasses.field(default_factory=list) + + @pytest.mark.asyncio async def test_state_proxy(grandchild_state: GrandchildState, mock_app: rx.App): """Test that the state proxy works. @@ -2038,6 +2046,7 @@ class BackgroundTaskState(BaseState): order: List[str] = [] dict_list: Dict[str, List[int]] = {"foo": [1, 2, 3]} + dc: ModelDC = ModelDC() def __init__(self, **kwargs): # noqa: D107 super().__init__(**kwargs) @@ -2063,10 +2072,18 @@ class BackgroundTaskState(BaseState): with pytest.raises(ImmutableStateError): self.order.append("bad idea") + with pytest.raises(ImmutableStateError): + # Cannot manipulate dataclass attributes. + self.dc.foo = "baz" + with pytest.raises(ImmutableStateError): # Even nested access to mutables raises an exception. self.dict_list["foo"].append(42) + with pytest.raises(ImmutableStateError): + # Cannot modify dataclass list attribute. + self.dc.ls.append({"foo": "bar"}) + with pytest.raises(ImmutableStateError): # Direct calling another handler that modifies state raises an exception. self.other() @@ -3582,13 +3599,6 @@ class ModelV2(BaseModelV2): foo: str = "bar" -@dataclasses.dataclass -class ModelDC: - """A dataclass.""" - - foo: str = "bar" - - class PydanticState(rx.State): """A state with pydantic BaseModel vars.""" @@ -3610,11 +3620,10 @@ def test_mutable_models(): assert state.dirty_vars == {"v2"} state.dirty_vars.clear() - # Not yet supported ENG-4083 - # assert isinstance(state.dc, MutableProxy) #noqa: ERA001 - # state.dc.foo = "baz" #noqa: ERA001 - # assert state.dirty_vars == {"dc"} #noqa: ERA001 - # state.dirty_vars.clear() #noqa: ERA001 + assert isinstance(state.dc, MutableProxy) + state.dc.foo = "baz" + assert state.dirty_vars == {"dc"} + state.dirty_vars.clear() def test_get_value():