Add MutableProxy._is_mutable_value to avoid duplicate logic

This commit is contained in:
Masen Furer 2024-11-08 16:42:57 -08:00
parent 6595fcf2ca
commit e2c6f6983b
No known key found for this signature in database
GPG Key ID: B0008AD22B3B3A95
3 changed files with 39 additions and 11 deletions

View File

@ -1244,7 +1244,7 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
if parent_state is not None: if parent_state is not None:
return getattr(parent_state, name) return getattr(parent_state, name)
if isinstance(value, MutableProxy.__mutable_types__) and ( if MutableProxy._is_mutable_type(value) and (
name in super().__getattribute__("base_vars") or name in backend_vars name in super().__getattribute__("base_vars") or name in backend_vars
): ):
# track changes in mutable containers (list, dict, set, etc) # track changes in mutable containers (list, dict, set, etc)
@ -3522,6 +3522,7 @@ class MutableProxy(wrapt.ObjectProxy):
pydantic.BaseModel.__dict__ pydantic.BaseModel.__dict__
) )
# These types will be wrapped in MutableProxy
__mutable_types__ = ( __mutable_types__ = (
list, list,
dict, dict,
@ -3570,6 +3571,18 @@ class MutableProxy(wrapt.ObjectProxy):
if wrapped is not None: if wrapped is not None:
return wrapped(*args, **(kwargs or {})) return wrapped(*args, **(kwargs or {}))
@classmethod
def _is_mutable_type(cls, value: Any) -> bool:
"""Check if a value is of a mutable type and should be wrapped.
Args:
value: The value to check.
Returns:
Whether the value is of a mutable type.
"""
return isinstance(value, cls.__mutable_types__)
def _wrap_recursive(self, value: Any) -> Any: def _wrap_recursive(self, value: Any) -> Any:
"""Wrap a value recursively if it is mutable. """Wrap a value recursively if it is mutable.
@ -3580,9 +3593,7 @@ class MutableProxy(wrapt.ObjectProxy):
The wrapped value. The wrapped value.
""" """
# Recursively wrap mutable types, but do not re-wrap MutableProxy instances. # Recursively wrap mutable types, but do not re-wrap MutableProxy instances.
if isinstance(value, self.__mutable_types__) and not isinstance( if self._is_mutable_type(value) and not isinstance(value, MutableProxy):
value, MutableProxy
):
return type(self)( return type(self)(
wrapped=value, wrapped=value,
state=self._self_state, state=self._self_state,
@ -3640,7 +3651,7 @@ class MutableProxy(wrapt.ObjectProxy):
self._wrap_recursive_decorator, self._wrap_recursive_decorator,
) )
if isinstance(value, self.__mutable_types__) and __name not in ( if self._is_mutable_type(value) and __name not in (
"__wrapped__", "__wrapped__",
"_self_state", "_self_state",
): ):

View File

@ -283,13 +283,14 @@ def serialize_base_model_v1(model: BaseModelV1) -> dict:
if BaseModelV2 is not BaseModelV1: if BaseModelV2 is not BaseModelV1:
@serializer(to=dict) @serializer(to=dict)
def serialize_base_model_v2(model: BaseModelV2) -> dict: def serialize_base_model_v2(model: BaseModelV2) -> dict:
"""Serialize a pydantic v2 BaseModel instance. """Serialize a pydantic v2 BaseModel instance.
Args: Args:
model: The BaseModel to serialize. model: The BaseModel to serialize.
Returns: Returns:
The serialized BaseModel. The serialized BaseModel.
""" """

View File

@ -3427,20 +3427,36 @@ class ModelV2(BaseModelV2):
foo: str = "bar" foo: str = "bar"
@dataclasses.dataclass
class ModelDC:
"""A dataclass."""
foo: str = "bar"
class PydanticState(rx.State): class PydanticState(rx.State):
"""A state with pydantic BaseModel vars.""" """A state with pydantic BaseModel vars."""
v1: ModelV1 = ModelV1() v1: ModelV1 = ModelV1()
v2: ModelV2 = ModelV2() v2: ModelV2 = ModelV2()
dc: ModelDC = ModelDC()
def test_pydantic_base_models(): def test_mutable_models():
"""Test that pydantic BaseModel v1 and v2 can be used as state vars with dep tracking.""" """Test that dataclass and pydantic BaseModel v1 and v2 use dep tracking."""
state = PydanticState() state = PydanticState()
assert isinstance(state.v1, MutableProxy) assert isinstance(state.v1, MutableProxy)
state.v1.foo = "baz" state.v1.foo = "baz"
assert "v1" in state.dirty_vars assert state.dirty_vars == {"v1"}
state.dirty_vars.clear()
assert isinstance(state.v2, MutableProxy) assert isinstance(state.v2, MutableProxy)
state.v2.foo = "baz" state.v2.foo = "baz"
assert "v2" in state.dirty_vars assert state.dirty_vars == {"v2"}
state.dirty_vars.clear()
# Not yet supported ENG-4083
# assert isinstance(state.dc, MutableProxy)
# state.dc.foo = "baz"
# assert state.dirty_vars == {"dc"}
# state.dirty_vars.clear()