diff --git a/reflex/state.py b/reflex/state.py index 120a57881..e72da636a 100644 --- a/reflex/state.py +++ b/reflex/state.py @@ -1647,6 +1647,13 @@ class MutableProxy(wrapt.ObjectProxy): "update", ] ) + # Methods on wrapped objects might return mutable objects that should be tracked. + __wrap_mutable_attrs__ = set( + [ + "get", + "setdefault", + ] + ) __mutable_types__ = (list, dict, set, Base) @@ -1663,7 +1670,13 @@ class MutableProxy(wrapt.ObjectProxy): self._self_state = state self._self_field_name = field_name - def _mark_dirty(self, wrapped=None, instance=None, args=tuple(), kwargs=None): + def _mark_dirty( + self, + wrapped=None, + instance=None, + args=tuple(), + kwargs=None, + ) -> Any: """Mark the state as dirty, then call a wrapped function. Intended for use with `FunctionWrapper` from the `wrapt` library. @@ -1673,11 +1686,47 @@ class MutableProxy(wrapt.ObjectProxy): instance: The instance of the wrapped function. args: The args for the wrapped function. kwargs: The kwargs for the wrapped function. + + Returns: + The result of the wrapped function. """ self._self_state.dirty_vars.add(self._self_field_name) self._self_state._mark_dirty() if wrapped is not None: - wrapped(*args, **(kwargs or {})) + return wrapped(*args, **(kwargs or {})) + + def _wrap_recursive(self, value: Any) -> Any: + """Wrap a value recursively if it is mutable. + + Args: + value: The value to wrap. + + Returns: + The wrapped value. + """ + if isinstance(value, self.__mutable_types__): + return type(self)( + wrapped=value, + state=self._self_state, + field_name=self._self_field_name, + ) + return value + + def _wrap_recursive_decorator(self, wrapped, instance, args, kwargs) -> Any: + """Wrap a function that returns a possibly mutable value. + + Intended for use with `FunctionWrapper` from the `wrapt` library. + + Args: + wrapped: The wrapped function. + instance: The instance of the wrapped function. + args: The args for the wrapped function. + kwargs: The kwargs for the wrapped function. + + Returns: + The result of the wrapped function (possibly wrapped in a MutableProxy). + """ + return self._wrap_recursive(wrapped(*args, **kwargs)) def __getattribute__(self, __name: str) -> Any: """Get the attribute on the proxied object and return a proxy if mutable. @@ -1690,24 +1739,26 @@ class MutableProxy(wrapt.ObjectProxy): """ value = super().__getattribute__(__name) - if callable(value) and __name in super().__getattribute__( - "__mark_dirty_attrs__" - ): - # Wrap special callables, like "append", which should mark state dirty. - return wrapt.FunctionWrapper( - value, - super().__getattribute__("_mark_dirty"), - ) + if callable(value): + if __name in super().__getattribute__("__mark_dirty_attrs__"): + # Wrap special callables, like "append", which should mark state dirty. + value = wrapt.FunctionWrapper( + value, + super().__getattribute__("_mark_dirty"), + ) + + if __name in super().__getattribute__("__wrap_mutable_attrs__"): + # Wrap methods that may return mutable objects tied to the state. + value = wrapt.FunctionWrapper( + value, + super().__getattribute__("_wrap_recursive_decorator"), + ) if isinstance( value, super().__getattribute__("__mutable_types__") ) and __name not in ("__wrapped__", "_self_state"): # Recursively wrap mutable attribute values retrieved through this proxy. - return type(self)( - wrapped=value, - state=self._self_state, - field_name=self._self_field_name, - ) + return self._wrap_recursive(value) return value @@ -1721,14 +1772,18 @@ class MutableProxy(wrapt.ObjectProxy): The item value. """ value = super().__getitem__(key) - if isinstance(value, self.__mutable_types__): + # Recursively wrap mutable items retrieved through this proxy. + return self._wrap_recursive(value) + + def __iter__(self) -> Any: + """Iterate over the proxied object and return a proxy if mutable. + + Yields: + Each item value (possibly wrapped in MutableProxy). + """ + for value in super().__iter__(): # Recursively wrap mutable items retrieved through this proxy. - return type(self)( - wrapped=value, - state=self._self_state, - field_name=self._self_field_name, - ) - return value + yield self._wrap_recursive(value) def __delattr__(self, name): """Delete the attribute on the proxied object and mark state dirty. diff --git a/tests/test_state.py b/tests/test_state.py index af327f37f..0b91b5ba7 100644 --- a/tests/test_state.py +++ b/tests/test_state.py @@ -1858,6 +1858,15 @@ def test_mutable_list(mutable_state): assert_array_dirty() assert isinstance(mutable_state.array[0], MutableProxy) + # Test proxy returned from __iter__ + mutable_state.array = [{}] + assert_array_dirty() + assert isinstance(mutable_state.array[0], MutableProxy) + for item in mutable_state.array: + assert isinstance(item, MutableProxy) + item["foo"] = "bar" + assert_array_dirty() + def test_mutable_dict(mutable_state): """Test that mutable dicts are tracked correctly. @@ -1875,9 +1884,13 @@ def test_mutable_dict(mutable_state): # Test all dict operations mutable_state.hashmap.update({"new_key": 43}) assert_hashmap_dirty() - mutable_state.hashmap.setdefault("another_key", 66) + assert mutable_state.hashmap.setdefault("another_key", 66) == "another_value" assert_hashmap_dirty() - mutable_state.hashmap.pop("new_key") + assert mutable_state.hashmap.setdefault("setdefault_key", 67) == 67 + assert_hashmap_dirty() + assert mutable_state.hashmap.setdefault("setdefault_key", 68) == 67 + assert_hashmap_dirty() + assert mutable_state.hashmap.pop("new_key") == 43 assert_hashmap_dirty() mutable_state.hashmap.popitem() assert_hashmap_dirty() @@ -1905,6 +1918,31 @@ def test_mutable_dict(mutable_state): mutable_state.hashmap["dict"]["dict"]["key"] = 43 assert_hashmap_dirty() + # Test proxy returned from `setdefault` and `get` + mutable_value = mutable_state.hashmap.setdefault("setdefault_mutable_key", []) + assert_hashmap_dirty() + assert mutable_value == [] + assert isinstance(mutable_value, MutableProxy) + mutable_value.append("foo") + assert_hashmap_dirty() + mutable_value_other_ref = mutable_state.hashmap.get("setdefault_mutable_key") + assert isinstance(mutable_value_other_ref, MutableProxy) + assert mutable_value is not mutable_value_other_ref + assert mutable_value == mutable_value_other_ref + assert not mutable_state.dirty_vars + mutable_value_other_ref.append("bar") + assert_hashmap_dirty() + + # `pop` should NOT return a proxy, because the returned value is no longer in the dict + mutable_value_third_ref = mutable_state.hashmap.pop("setdefault_mutable_key") + assert not isinstance(mutable_value_third_ref, MutableProxy) + assert_hashmap_dirty() + mutable_value_third_ref.append("baz") + assert not mutable_state.dirty_vars + # Unfortunately previous refs still will mark the state dirty... nothing doing about that + assert mutable_value.pop() + assert_hashmap_dirty() + def test_mutable_set(mutable_state): """Test that mutable sets are tracked correctly.