Add MutableProxy._is_mutable_value to avoid duplicate logic

This commit is contained in:
Masen Furer 2024-11-08 16:42:57 -08:00
parent 6595fcf2ca
commit e2c6f6983b
No known key found for this signature in database
GPG Key ID: B0008AD22B3B3A95
3 changed files with 39 additions and 11 deletions

View File

@ -1244,7 +1244,7 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
if parent_state is not None:
return getattr(parent_state, name)
if isinstance(value, MutableProxy.__mutable_types__) and (
if MutableProxy._is_mutable_type(value) and (
name in super().__getattribute__("base_vars") or name in backend_vars
):
# track changes in mutable containers (list, dict, set, etc)
@ -3522,6 +3522,7 @@ class MutableProxy(wrapt.ObjectProxy):
pydantic.BaseModel.__dict__
)
# These types will be wrapped in MutableProxy
__mutable_types__ = (
list,
dict,
@ -3570,6 +3571,18 @@ class MutableProxy(wrapt.ObjectProxy):
if wrapped is not None:
return wrapped(*args, **(kwargs or {}))
@classmethod
def _is_mutable_type(cls, value: Any) -> bool:
"""Check if a value is of a mutable type and should be wrapped.
Args:
value: The value to check.
Returns:
Whether the value is of a mutable type.
"""
return isinstance(value, cls.__mutable_types__)
def _wrap_recursive(self, value: Any) -> Any:
"""Wrap a value recursively if it is mutable.
@ -3580,9 +3593,7 @@ class MutableProxy(wrapt.ObjectProxy):
The wrapped value.
"""
# Recursively wrap mutable types, but do not re-wrap MutableProxy instances.
if isinstance(value, self.__mutable_types__) and not isinstance(
value, MutableProxy
):
if self._is_mutable_type(value) and not isinstance(value, MutableProxy):
return type(self)(
wrapped=value,
state=self._self_state,
@ -3640,7 +3651,7 @@ class MutableProxy(wrapt.ObjectProxy):
self._wrap_recursive_decorator,
)
if isinstance(value, self.__mutable_types__) and __name not in (
if self._is_mutable_type(value) and __name not in (
"__wrapped__",
"_self_state",
):

View File

@ -283,6 +283,7 @@ def serialize_base_model_v1(model: BaseModelV1) -> dict:
if BaseModelV2 is not BaseModelV1:
@serializer(to=dict)
def serialize_base_model_v2(model: BaseModelV2) -> dict:
"""Serialize a pydantic v2 BaseModel instance.

View File

@ -3427,20 +3427,36 @@ class ModelV2(BaseModelV2):
foo: str = "bar"
@dataclasses.dataclass
class ModelDC:
"""A dataclass."""
foo: str = "bar"
class PydanticState(rx.State):
"""A state with pydantic BaseModel vars."""
v1: ModelV1 = ModelV1()
v2: ModelV2 = ModelV2()
dc: ModelDC = ModelDC()
def test_pydantic_base_models():
"""Test that pydantic BaseModel v1 and v2 can be used as state vars with dep tracking."""
def test_mutable_models():
"""Test that dataclass and pydantic BaseModel v1 and v2 use dep tracking."""
state = PydanticState()
assert isinstance(state.v1, MutableProxy)
state.v1.foo = "baz"
assert "v1" in state.dirty_vars
assert state.dirty_vars == {"v1"}
state.dirty_vars.clear()
assert isinstance(state.v2, MutableProxy)
state.v2.foo = "baz"
assert "v2" in state.dirty_vars
assert state.dirty_vars == {"v2"}
state.dirty_vars.clear()
# Not yet supported ENG-4083
# assert isinstance(state.dc, MutableProxy)
# state.dc.foo = "baz"
# assert state.dirty_vars == {"dc"}
# state.dirty_vars.clear()