Add MutableProxy._is_mutable_value
to avoid duplicate logic
This commit is contained in:
parent
6595fcf2ca
commit
e2c6f6983b
@ -1244,7 +1244,7 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
|
|||||||
if parent_state is not None:
|
if parent_state is not None:
|
||||||
return getattr(parent_state, name)
|
return getattr(parent_state, name)
|
||||||
|
|
||||||
if isinstance(value, MutableProxy.__mutable_types__) and (
|
if MutableProxy._is_mutable_type(value) and (
|
||||||
name in super().__getattribute__("base_vars") or name in backend_vars
|
name in super().__getattribute__("base_vars") or name in backend_vars
|
||||||
):
|
):
|
||||||
# track changes in mutable containers (list, dict, set, etc)
|
# track changes in mutable containers (list, dict, set, etc)
|
||||||
@ -3522,6 +3522,7 @@ class MutableProxy(wrapt.ObjectProxy):
|
|||||||
pydantic.BaseModel.__dict__
|
pydantic.BaseModel.__dict__
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# These types will be wrapped in MutableProxy
|
||||||
__mutable_types__ = (
|
__mutable_types__ = (
|
||||||
list,
|
list,
|
||||||
dict,
|
dict,
|
||||||
@ -3570,6 +3571,18 @@ class MutableProxy(wrapt.ObjectProxy):
|
|||||||
if wrapped is not None:
|
if wrapped is not None:
|
||||||
return wrapped(*args, **(kwargs or {}))
|
return wrapped(*args, **(kwargs or {}))
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _is_mutable_type(cls, value: Any) -> bool:
|
||||||
|
"""Check if a value is of a mutable type and should be wrapped.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
value: The value to check.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Whether the value is of a mutable type.
|
||||||
|
"""
|
||||||
|
return isinstance(value, cls.__mutable_types__)
|
||||||
|
|
||||||
def _wrap_recursive(self, value: Any) -> Any:
|
def _wrap_recursive(self, value: Any) -> Any:
|
||||||
"""Wrap a value recursively if it is mutable.
|
"""Wrap a value recursively if it is mutable.
|
||||||
|
|
||||||
@ -3580,9 +3593,7 @@ class MutableProxy(wrapt.ObjectProxy):
|
|||||||
The wrapped value.
|
The wrapped value.
|
||||||
"""
|
"""
|
||||||
# Recursively wrap mutable types, but do not re-wrap MutableProxy instances.
|
# Recursively wrap mutable types, but do not re-wrap MutableProxy instances.
|
||||||
if isinstance(value, self.__mutable_types__) and not isinstance(
|
if self._is_mutable_type(value) and not isinstance(value, MutableProxy):
|
||||||
value, MutableProxy
|
|
||||||
):
|
|
||||||
return type(self)(
|
return type(self)(
|
||||||
wrapped=value,
|
wrapped=value,
|
||||||
state=self._self_state,
|
state=self._self_state,
|
||||||
@ -3640,7 +3651,7 @@ class MutableProxy(wrapt.ObjectProxy):
|
|||||||
self._wrap_recursive_decorator,
|
self._wrap_recursive_decorator,
|
||||||
)
|
)
|
||||||
|
|
||||||
if isinstance(value, self.__mutable_types__) and __name not in (
|
if self._is_mutable_type(value) and __name not in (
|
||||||
"__wrapped__",
|
"__wrapped__",
|
||||||
"_self_state",
|
"_self_state",
|
||||||
):
|
):
|
||||||
|
@ -283,13 +283,14 @@ def serialize_base_model_v1(model: BaseModelV1) -> dict:
|
|||||||
|
|
||||||
|
|
||||||
if BaseModelV2 is not BaseModelV1:
|
if BaseModelV2 is not BaseModelV1:
|
||||||
|
|
||||||
@serializer(to=dict)
|
@serializer(to=dict)
|
||||||
def serialize_base_model_v2(model: BaseModelV2) -> dict:
|
def serialize_base_model_v2(model: BaseModelV2) -> dict:
|
||||||
"""Serialize a pydantic v2 BaseModel instance.
|
"""Serialize a pydantic v2 BaseModel instance.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
model: The BaseModel to serialize.
|
model: The BaseModel to serialize.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
The serialized BaseModel.
|
The serialized BaseModel.
|
||||||
"""
|
"""
|
||||||
|
@ -3427,20 +3427,36 @@ class ModelV2(BaseModelV2):
|
|||||||
foo: str = "bar"
|
foo: str = "bar"
|
||||||
|
|
||||||
|
|
||||||
|
@dataclasses.dataclass
|
||||||
|
class ModelDC:
|
||||||
|
"""A dataclass."""
|
||||||
|
|
||||||
|
foo: str = "bar"
|
||||||
|
|
||||||
|
|
||||||
class PydanticState(rx.State):
|
class PydanticState(rx.State):
|
||||||
"""A state with pydantic BaseModel vars."""
|
"""A state with pydantic BaseModel vars."""
|
||||||
|
|
||||||
v1: ModelV1 = ModelV1()
|
v1: ModelV1 = ModelV1()
|
||||||
v2: ModelV2 = ModelV2()
|
v2: ModelV2 = ModelV2()
|
||||||
|
dc: ModelDC = ModelDC()
|
||||||
|
|
||||||
|
|
||||||
def test_pydantic_base_models():
|
def test_mutable_models():
|
||||||
"""Test that pydantic BaseModel v1 and v2 can be used as state vars with dep tracking."""
|
"""Test that dataclass and pydantic BaseModel v1 and v2 use dep tracking."""
|
||||||
state = PydanticState()
|
state = PydanticState()
|
||||||
assert isinstance(state.v1, MutableProxy)
|
assert isinstance(state.v1, MutableProxy)
|
||||||
state.v1.foo = "baz"
|
state.v1.foo = "baz"
|
||||||
assert "v1" in state.dirty_vars
|
assert state.dirty_vars == {"v1"}
|
||||||
|
state.dirty_vars.clear()
|
||||||
|
|
||||||
assert isinstance(state.v2, MutableProxy)
|
assert isinstance(state.v2, MutableProxy)
|
||||||
state.v2.foo = "baz"
|
state.v2.foo = "baz"
|
||||||
assert "v2" in state.dirty_vars
|
assert state.dirty_vars == {"v2"}
|
||||||
|
state.dirty_vars.clear()
|
||||||
|
|
||||||
|
# Not yet supported ENG-4083
|
||||||
|
# assert isinstance(state.dc, MutableProxy)
|
||||||
|
# state.dc.foo = "baz"
|
||||||
|
# assert state.dirty_vars == {"dc"}
|
||||||
|
# state.dirty_vars.clear()
|
||||||
|
Loading…
Reference in New Issue
Block a user