[REF-1356] Track changes applied to Base
subclass via helper method. (#2242)
This commit is contained in:
parent
5995b32f5f
commit
b04e3a6ce9
@ -2867,6 +2867,11 @@ class MutableProxy(wrapt.ObjectProxy):
|
|||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# These internal attributes on rx.Base should NOT be wrapped in a MutableProxy.
|
||||||
|
__never_wrap_base_attrs__ = set(Base.__dict__) - {"set"} | set(
|
||||||
|
pydantic.BaseModel.__dict__
|
||||||
|
)
|
||||||
|
|
||||||
__mutable_types__ = (list, dict, set, Base)
|
__mutable_types__ = (list, dict, set, Base)
|
||||||
|
|
||||||
def __init__(self, wrapped: Any, state: BaseState, field_name: str):
|
def __init__(self, wrapped: Any, state: BaseState, field_name: str):
|
||||||
@ -2916,7 +2921,10 @@ class MutableProxy(wrapt.ObjectProxy):
|
|||||||
Returns:
|
Returns:
|
||||||
The wrapped value.
|
The wrapped value.
|
||||||
"""
|
"""
|
||||||
if isinstance(value, self.__mutable_types__):
|
# Recursively wrap mutable types, but do not re-wrap MutableProxy instances.
|
||||||
|
if isinstance(value, self.__mutable_types__) and not isinstance(
|
||||||
|
value, MutableProxy
|
||||||
|
):
|
||||||
return type(self)(
|
return type(self)(
|
||||||
wrapped=value,
|
wrapped=value,
|
||||||
state=self._self_state,
|
state=self._self_state,
|
||||||
@ -2963,6 +2971,17 @@ class MutableProxy(wrapt.ObjectProxy):
|
|||||||
self._wrap_recursive_decorator,
|
self._wrap_recursive_decorator,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if (
|
||||||
|
isinstance(self.__wrapped__, Base)
|
||||||
|
and __name not in self.__never_wrap_base_attrs__
|
||||||
|
and hasattr(value, "__func__")
|
||||||
|
):
|
||||||
|
# Wrap methods called on Base subclasses, which might do _anything_
|
||||||
|
return wrapt.FunctionWrapper(
|
||||||
|
functools.partial(value.__func__, self),
|
||||||
|
self._wrap_recursive_decorator,
|
||||||
|
)
|
||||||
|
|
||||||
if isinstance(value, self.__mutable_types__) and __name not in (
|
if isinstance(value, self.__mutable_types__) and __name not in (
|
||||||
"__wrapped__",
|
"__wrapped__",
|
||||||
"_self_state",
|
"_self_state",
|
||||||
|
@ -2392,6 +2392,22 @@ class Custom1(Base):
|
|||||||
|
|
||||||
foo: str
|
foo: str
|
||||||
|
|
||||||
|
def set_foo(self, val: str):
|
||||||
|
"""Set the attribute foo.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
val: The value to set.
|
||||||
|
"""
|
||||||
|
self.foo = val
|
||||||
|
|
||||||
|
def double_foo(self) -> str:
|
||||||
|
"""Concantenate foo with foo.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
foo + foo
|
||||||
|
"""
|
||||||
|
return self.foo + self.foo
|
||||||
|
|
||||||
|
|
||||||
class Custom2(Base):
|
class Custom2(Base):
|
||||||
"""A custom class with a Custom1 field."""
|
"""A custom class with a Custom1 field."""
|
||||||
@ -2399,6 +2415,14 @@ class Custom2(Base):
|
|||||||
c1: Optional[Custom1] = None
|
c1: Optional[Custom1] = None
|
||||||
c1r: Custom1
|
c1r: Custom1
|
||||||
|
|
||||||
|
def set_c1r_foo(self, val: str):
|
||||||
|
"""Set the foo attribute of the c1 field.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
val: The value to set.
|
||||||
|
"""
|
||||||
|
self.c1r.set_foo(val)
|
||||||
|
|
||||||
|
|
||||||
class Custom3(Base):
|
class Custom3(Base):
|
||||||
"""A custom class with a Custom2 field."""
|
"""A custom class with a Custom2 field."""
|
||||||
@ -2436,6 +2460,47 @@ def test_state_union_optional():
|
|||||||
assert types.is_union(UnionState.int_float._var_type) # type: ignore
|
assert types.is_union(UnionState.int_float._var_type) # type: ignore
|
||||||
|
|
||||||
|
|
||||||
|
def test_set_base_field_via_setter():
|
||||||
|
"""When calling a setter on a Base instance, also track changes."""
|
||||||
|
|
||||||
|
class BaseFieldSetterState(BaseState):
|
||||||
|
c1: Custom1 = Custom1(foo="")
|
||||||
|
c2: Custom2 = Custom2(c1r=Custom1(foo=""))
|
||||||
|
|
||||||
|
bfss = BaseFieldSetterState()
|
||||||
|
assert "c1" not in bfss.dirty_vars
|
||||||
|
|
||||||
|
# Non-mutating function, not dirty
|
||||||
|
bfss.c1.double_foo()
|
||||||
|
assert "c1" not in bfss.dirty_vars
|
||||||
|
|
||||||
|
# Mutating function, dirty
|
||||||
|
bfss.c1.set_foo("bar")
|
||||||
|
assert "c1" in bfss.dirty_vars
|
||||||
|
bfss.dirty_vars.clear()
|
||||||
|
assert "c1" not in bfss.dirty_vars
|
||||||
|
|
||||||
|
# Mutating function from Base, dirty
|
||||||
|
bfss.c1.set(foo="bar")
|
||||||
|
assert "c1" in bfss.dirty_vars
|
||||||
|
bfss.dirty_vars.clear()
|
||||||
|
assert "c1" not in bfss.dirty_vars
|
||||||
|
|
||||||
|
# Assert identity of MutableProxy
|
||||||
|
mp = bfss.c1
|
||||||
|
assert isinstance(mp, MutableProxy)
|
||||||
|
mp2 = mp.set()
|
||||||
|
assert mp is mp2
|
||||||
|
mp3 = bfss.c1.set()
|
||||||
|
assert mp is not mp3
|
||||||
|
# Since none of these set calls had values, the state should not be dirty
|
||||||
|
assert not bfss.dirty_vars
|
||||||
|
|
||||||
|
# Chained Mutating function, dirty
|
||||||
|
bfss.c2.set_c1r_foo("baz")
|
||||||
|
assert "c2" in bfss.dirty_vars
|
||||||
|
|
||||||
|
|
||||||
def exp_is_hydrated(state: State, is_hydrated: bool = True) -> Dict[str, Any]:
|
def exp_is_hydrated(state: State, is_hydrated: bool = True) -> Dict[str, Any]:
|
||||||
"""Expected IS_HYDRATED delta that would be emitted by HydrateMiddleware.
|
"""Expected IS_HYDRATED delta that would be emitted by HydrateMiddleware.
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user