[ENG-3953] Support pydantic BaseModel (v1 and v2) as state var (#4338)
* [ENG-3953] Support pydantic BaseModel (v1 and v2) as state var Provide serializers and mutable proxy tracking for pydantic models directly. * conditionally define v2 serializer Co-authored-by: Khaleel Al-Adhami <khaleel.aladhami@gmail.com> * Add `MutableProxy._is_mutable_value` to avoid duplicate logic * Conditionally import BaseModel to handle older pydantic v1 versions * pre-commit fu --------- Co-authored-by: Khaleel Al-Adhami <khaleel.aladhami@gmail.com>
This commit is contained in:
parent
5702a18502
commit
a6b324bd3e
@ -62,6 +62,13 @@ try:
|
||||
except ModuleNotFoundError:
|
||||
import pydantic
|
||||
|
||||
from pydantic import BaseModel as BaseModelV2
|
||||
|
||||
try:
|
||||
from pydantic.v1 import BaseModel as BaseModelV1
|
||||
except ModuleNotFoundError:
|
||||
BaseModelV1 = BaseModelV2
|
||||
|
||||
import wrapt
|
||||
from redis.asyncio import Redis
|
||||
from redis.exceptions import ResponseError
|
||||
@ -1250,7 +1257,7 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
|
||||
if parent_state is not None:
|
||||
return getattr(parent_state, name)
|
||||
|
||||
if isinstance(value, MutableProxy.__mutable_types__) and (
|
||||
if MutableProxy._is_mutable_type(value) and (
|
||||
name in super().__getattribute__("base_vars") or name in backend_vars
|
||||
):
|
||||
# track changes in mutable containers (list, dict, set, etc)
|
||||
@ -3558,7 +3565,16 @@ class MutableProxy(wrapt.ObjectProxy):
|
||||
pydantic.BaseModel.__dict__
|
||||
)
|
||||
|
||||
__mutable_types__ = (list, dict, set, Base, DeclarativeBase)
|
||||
# These types will be wrapped in MutableProxy
|
||||
__mutable_types__ = (
|
||||
list,
|
||||
dict,
|
||||
set,
|
||||
Base,
|
||||
DeclarativeBase,
|
||||
BaseModelV2,
|
||||
BaseModelV1,
|
||||
)
|
||||
|
||||
def __init__(self, wrapped: Any, state: BaseState, field_name: str):
|
||||
"""Create a proxy for a mutable object that tracks changes.
|
||||
@ -3598,6 +3614,18 @@ class MutableProxy(wrapt.ObjectProxy):
|
||||
if wrapped is not None:
|
||||
return wrapped(*args, **(kwargs or {}))
|
||||
|
||||
@classmethod
|
||||
def _is_mutable_type(cls, value: Any) -> bool:
|
||||
"""Check if a value is of a mutable type and should be wrapped.
|
||||
|
||||
Args:
|
||||
value: The value to check.
|
||||
|
||||
Returns:
|
||||
Whether the value is of a mutable type.
|
||||
"""
|
||||
return isinstance(value, cls.__mutable_types__)
|
||||
|
||||
def _wrap_recursive(self, value: Any) -> Any:
|
||||
"""Wrap a value recursively if it is mutable.
|
||||
|
||||
@ -3608,9 +3636,7 @@ class MutableProxy(wrapt.ObjectProxy):
|
||||
The wrapped value.
|
||||
"""
|
||||
# Recursively wrap mutable types, but do not re-wrap MutableProxy instances.
|
||||
if isinstance(value, self.__mutable_types__) and not isinstance(
|
||||
value, MutableProxy
|
||||
):
|
||||
if self._is_mutable_type(value) and not isinstance(value, MutableProxy):
|
||||
return type(self)(
|
||||
wrapped=value,
|
||||
state=self._self_state,
|
||||
@ -3668,7 +3694,7 @@ class MutableProxy(wrapt.ObjectProxy):
|
||||
self._wrap_recursive_decorator,
|
||||
)
|
||||
|
||||
if isinstance(value, self.__mutable_types__) and __name not in (
|
||||
if self._is_mutable_type(value) and __name not in (
|
||||
"__wrapped__",
|
||||
"_self_state",
|
||||
):
|
||||
|
@ -270,6 +270,53 @@ def serialize_base(value: Base) -> dict:
|
||||
}
|
||||
|
||||
|
||||
try:
|
||||
from pydantic.v1 import BaseModel as BaseModelV1
|
||||
|
||||
@serializer(to=dict)
|
||||
def serialize_base_model_v1(model: BaseModelV1) -> dict:
|
||||
"""Serialize a pydantic v1 BaseModel instance.
|
||||
|
||||
Args:
|
||||
model: The BaseModel to serialize.
|
||||
|
||||
Returns:
|
||||
The serialized BaseModel.
|
||||
"""
|
||||
return model.dict()
|
||||
|
||||
from pydantic import BaseModel as BaseModelV2
|
||||
|
||||
if BaseModelV1 is not BaseModelV2:
|
||||
|
||||
@serializer(to=dict)
|
||||
def serialize_base_model_v2(model: BaseModelV2) -> dict:
|
||||
"""Serialize a pydantic v2 BaseModel instance.
|
||||
|
||||
Args:
|
||||
model: The BaseModel to serialize.
|
||||
|
||||
Returns:
|
||||
The serialized BaseModel.
|
||||
"""
|
||||
return model.model_dump()
|
||||
except ImportError:
|
||||
# Older pydantic v1 import
|
||||
from pydantic import BaseModel as BaseModelV1
|
||||
|
||||
@serializer(to=dict)
|
||||
def serialize_base_model_v1(model: BaseModelV1) -> dict:
|
||||
"""Serialize a pydantic v1 BaseModel instance.
|
||||
|
||||
Args:
|
||||
model: The BaseModel to serialize.
|
||||
|
||||
Returns:
|
||||
The serialized BaseModel.
|
||||
"""
|
||||
return model.dict()
|
||||
|
||||
|
||||
@serializer
|
||||
def serialize_set(value: Set) -> list:
|
||||
"""Serialize a set to a JSON serializable list.
|
||||
|
@ -16,6 +16,8 @@ from unittest.mock import AsyncMock, Mock
|
||||
import pytest
|
||||
import pytest_asyncio
|
||||
from plotly.graph_objects import Figure
|
||||
from pydantic import BaseModel as BaseModelV2
|
||||
from pydantic.v1 import BaseModel as BaseModelV1
|
||||
|
||||
import reflex as rx
|
||||
import reflex.config
|
||||
@ -3413,6 +3415,53 @@ def test_typed_state() -> None:
|
||||
_ = TypedState(field="str")
|
||||
|
||||
|
||||
class ModelV1(BaseModelV1):
|
||||
"""A pydantic BaseModel v1."""
|
||||
|
||||
foo: str = "bar"
|
||||
|
||||
|
||||
class ModelV2(BaseModelV2):
|
||||
"""A pydantic BaseModel v2."""
|
||||
|
||||
foo: str = "bar"
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class ModelDC:
|
||||
"""A dataclass."""
|
||||
|
||||
foo: str = "bar"
|
||||
|
||||
|
||||
class PydanticState(rx.State):
|
||||
"""A state with pydantic BaseModel vars."""
|
||||
|
||||
v1: ModelV1 = ModelV1()
|
||||
v2: ModelV2 = ModelV2()
|
||||
dc: ModelDC = ModelDC()
|
||||
|
||||
|
||||
def test_mutable_models():
|
||||
"""Test that dataclass and pydantic BaseModel v1 and v2 use dep tracking."""
|
||||
state = PydanticState()
|
||||
assert isinstance(state.v1, MutableProxy)
|
||||
state.v1.foo = "baz"
|
||||
assert state.dirty_vars == {"v1"}
|
||||
state.dirty_vars.clear()
|
||||
|
||||
assert isinstance(state.v2, MutableProxy)
|
||||
state.v2.foo = "baz"
|
||||
assert state.dirty_vars == {"v2"}
|
||||
state.dirty_vars.clear()
|
||||
|
||||
# Not yet supported ENG-4083
|
||||
# assert isinstance(state.dc, MutableProxy)
|
||||
# state.dc.foo = "baz"
|
||||
# assert state.dirty_vars == {"dc"}
|
||||
# state.dirty_vars.clear()
|
||||
|
||||
|
||||
def test_get_value():
|
||||
class GetValueState(rx.State):
|
||||
foo: str = "FOO"
|
||||
|
Loading…
Reference in New Issue
Block a user