From b04e3a6ce9e9b319772c134234c77957f686fc7f Mon Sep 17 00:00:00 2001 From: Masen Furer Date: Fri, 31 May 2024 14:58:58 -0700 Subject: [PATCH] [REF-1356] Track changes applied to `Base` subclass via helper method. (#2242) --- reflex/state.py | 21 ++++++++++++++- tests/test_state.py | 65 +++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 85 insertions(+), 1 deletion(-) diff --git a/reflex/state.py b/reflex/state.py index 775ebd4e1..56b28f9e8 100644 --- a/reflex/state.py +++ b/reflex/state.py @@ -2867,6 +2867,11 @@ class MutableProxy(wrapt.ObjectProxy): ] ) + # These internal attributes on rx.Base should NOT be wrapped in a MutableProxy. + __never_wrap_base_attrs__ = set(Base.__dict__) - {"set"} | set( + pydantic.BaseModel.__dict__ + ) + __mutable_types__ = (list, dict, set, Base) def __init__(self, wrapped: Any, state: BaseState, field_name: str): @@ -2916,7 +2921,10 @@ class MutableProxy(wrapt.ObjectProxy): Returns: The wrapped value. """ - if isinstance(value, self.__mutable_types__): + # Recursively wrap mutable types, but do not re-wrap MutableProxy instances. + if isinstance(value, self.__mutable_types__) and not isinstance( + value, MutableProxy + ): return type(self)( wrapped=value, state=self._self_state, @@ -2963,6 +2971,17 @@ class MutableProxy(wrapt.ObjectProxy): self._wrap_recursive_decorator, ) + if ( + isinstance(self.__wrapped__, Base) + and __name not in self.__never_wrap_base_attrs__ + and hasattr(value, "__func__") + ): + # Wrap methods called on Base subclasses, which might do _anything_ + return wrapt.FunctionWrapper( + functools.partial(value.__func__, self), + self._wrap_recursive_decorator, + ) + if isinstance(value, self.__mutable_types__) and __name not in ( "__wrapped__", "_self_state", diff --git a/tests/test_state.py b/tests/test_state.py index e6e16ddef..6fcd1a67f 100644 --- a/tests/test_state.py +++ b/tests/test_state.py @@ -2392,6 +2392,22 @@ class Custom1(Base): foo: str + def set_foo(self, val: str): + """Set the attribute foo. + + Args: + val: The value to set. + """ + self.foo = val + + def double_foo(self) -> str: + """Concantenate foo with foo. + + Returns: + foo + foo + """ + return self.foo + self.foo + class Custom2(Base): """A custom class with a Custom1 field.""" @@ -2399,6 +2415,14 @@ class Custom2(Base): c1: Optional[Custom1] = None c1r: Custom1 + def set_c1r_foo(self, val: str): + """Set the foo attribute of the c1 field. + + Args: + val: The value to set. + """ + self.c1r.set_foo(val) + class Custom3(Base): """A custom class with a Custom2 field.""" @@ -2436,6 +2460,47 @@ def test_state_union_optional(): assert types.is_union(UnionState.int_float._var_type) # type: ignore +def test_set_base_field_via_setter(): + """When calling a setter on a Base instance, also track changes.""" + + class BaseFieldSetterState(BaseState): + c1: Custom1 = Custom1(foo="") + c2: Custom2 = Custom2(c1r=Custom1(foo="")) + + bfss = BaseFieldSetterState() + assert "c1" not in bfss.dirty_vars + + # Non-mutating function, not dirty + bfss.c1.double_foo() + assert "c1" not in bfss.dirty_vars + + # Mutating function, dirty + bfss.c1.set_foo("bar") + assert "c1" in bfss.dirty_vars + bfss.dirty_vars.clear() + assert "c1" not in bfss.dirty_vars + + # Mutating function from Base, dirty + bfss.c1.set(foo="bar") + assert "c1" in bfss.dirty_vars + bfss.dirty_vars.clear() + assert "c1" not in bfss.dirty_vars + + # Assert identity of MutableProxy + mp = bfss.c1 + assert isinstance(mp, MutableProxy) + mp2 = mp.set() + assert mp is mp2 + mp3 = bfss.c1.set() + assert mp is not mp3 + # Since none of these set calls had values, the state should not be dirty + assert not bfss.dirty_vars + + # Chained Mutating function, dirty + bfss.c2.set_c1r_foo("baz") + assert "c2" in bfss.dirty_vars + + def exp_is_hydrated(state: State, is_hydrated: bool = True) -> Dict[str, Any]: """Expected IS_HYDRATED delta that would be emitted by HydrateMiddleware.