[REF-1356] Track changes applied to Base subclass via helper method. (#2242)

This commit is contained in:
Masen Furer 2024-05-31 14:58:58 -07:00 committed by GitHub
parent 5995b32f5f
commit b04e3a6ce9
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 85 additions and 1 deletions

View File

@ -2867,6 +2867,11 @@ class MutableProxy(wrapt.ObjectProxy):
]
)
# These internal attributes on rx.Base should NOT be wrapped in a MutableProxy.
__never_wrap_base_attrs__ = set(Base.__dict__) - {"set"} | set(
pydantic.BaseModel.__dict__
)
__mutable_types__ = (list, dict, set, Base)
def __init__(self, wrapped: Any, state: BaseState, field_name: str):
@ -2916,7 +2921,10 @@ class MutableProxy(wrapt.ObjectProxy):
Returns:
The wrapped value.
"""
if isinstance(value, self.__mutable_types__):
# Recursively wrap mutable types, but do not re-wrap MutableProxy instances.
if isinstance(value, self.__mutable_types__) and not isinstance(
value, MutableProxy
):
return type(self)(
wrapped=value,
state=self._self_state,
@ -2963,6 +2971,17 @@ class MutableProxy(wrapt.ObjectProxy):
self._wrap_recursive_decorator,
)
if (
isinstance(self.__wrapped__, Base)
and __name not in self.__never_wrap_base_attrs__
and hasattr(value, "__func__")
):
# Wrap methods called on Base subclasses, which might do _anything_
return wrapt.FunctionWrapper(
functools.partial(value.__func__, self),
self._wrap_recursive_decorator,
)
if isinstance(value, self.__mutable_types__) and __name not in (
"__wrapped__",
"_self_state",

View File

@ -2392,6 +2392,22 @@ class Custom1(Base):
foo: str
def set_foo(self, val: str):
"""Set the attribute foo.
Args:
val: The value to set.
"""
self.foo = val
def double_foo(self) -> str:
"""Concantenate foo with foo.
Returns:
foo + foo
"""
return self.foo + self.foo
class Custom2(Base):
"""A custom class with a Custom1 field."""
@ -2399,6 +2415,14 @@ class Custom2(Base):
c1: Optional[Custom1] = None
c1r: Custom1
def set_c1r_foo(self, val: str):
"""Set the foo attribute of the c1 field.
Args:
val: The value to set.
"""
self.c1r.set_foo(val)
class Custom3(Base):
"""A custom class with a Custom2 field."""
@ -2436,6 +2460,47 @@ def test_state_union_optional():
assert types.is_union(UnionState.int_float._var_type) # type: ignore
def test_set_base_field_via_setter():
"""When calling a setter on a Base instance, also track changes."""
class BaseFieldSetterState(BaseState):
c1: Custom1 = Custom1(foo="")
c2: Custom2 = Custom2(c1r=Custom1(foo=""))
bfss = BaseFieldSetterState()
assert "c1" not in bfss.dirty_vars
# Non-mutating function, not dirty
bfss.c1.double_foo()
assert "c1" not in bfss.dirty_vars
# Mutating function, dirty
bfss.c1.set_foo("bar")
assert "c1" in bfss.dirty_vars
bfss.dirty_vars.clear()
assert "c1" not in bfss.dirty_vars
# Mutating function from Base, dirty
bfss.c1.set(foo="bar")
assert "c1" in bfss.dirty_vars
bfss.dirty_vars.clear()
assert "c1" not in bfss.dirty_vars
# Assert identity of MutableProxy
mp = bfss.c1
assert isinstance(mp, MutableProxy)
mp2 = mp.set()
assert mp is mp2
mp3 = bfss.c1.set()
assert mp is not mp3
# Since none of these set calls had values, the state should not be dirty
assert not bfss.dirty_vars
# Chained Mutating function, dirty
bfss.c2.set_c1r_foo("baz")
assert "c2" in bfss.dirty_vars
def exp_is_hydrated(state: State, is_hydrated: bool = True) -> Dict[str, Any]:
"""Expected IS_HYDRATED delta that would be emitted by HydrateMiddleware.