[ENG-3953] Support pydantic BaseModel (v1 and v2) as state var (#4338)

* [ENG-3953] Support pydantic BaseModel (v1 and v2) as state var

Provide serializers and mutable proxy tracking for pydantic models directly.

* conditionally define v2 serializer

Co-authored-by: Khaleel Al-Adhami <khaleel.aladhami@gmail.com>

* Add `MutableProxy._is_mutable_value` to avoid duplicate logic

* Conditionally import BaseModel to handle older pydantic v1 versions

* pre-commit fu

---------

Co-authored-by: Khaleel Al-Adhami <khaleel.aladhami@gmail.com>
This commit is contained in:
Masen Furer 2024-11-21 17:01:14 -08:00 committed by GitHub
parent 5702a18502
commit a6b324bd3e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 128 additions and 6 deletions

View File

@ -62,6 +62,13 @@ try:
except ModuleNotFoundError:
import pydantic
from pydantic import BaseModel as BaseModelV2
try:
from pydantic.v1 import BaseModel as BaseModelV1
except ModuleNotFoundError:
BaseModelV1 = BaseModelV2
import wrapt
from redis.asyncio import Redis
from redis.exceptions import ResponseError
@ -1250,7 +1257,7 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
if parent_state is not None:
return getattr(parent_state, name)
if isinstance(value, MutableProxy.__mutable_types__) and (
if MutableProxy._is_mutable_type(value) and (
name in super().__getattribute__("base_vars") or name in backend_vars
):
# track changes in mutable containers (list, dict, set, etc)
@ -3558,7 +3565,16 @@ class MutableProxy(wrapt.ObjectProxy):
pydantic.BaseModel.__dict__
)
__mutable_types__ = (list, dict, set, Base, DeclarativeBase)
# These types will be wrapped in MutableProxy
__mutable_types__ = (
list,
dict,
set,
Base,
DeclarativeBase,
BaseModelV2,
BaseModelV1,
)
def __init__(self, wrapped: Any, state: BaseState, field_name: str):
"""Create a proxy for a mutable object that tracks changes.
@ -3598,6 +3614,18 @@ class MutableProxy(wrapt.ObjectProxy):
if wrapped is not None:
return wrapped(*args, **(kwargs or {}))
@classmethod
def _is_mutable_type(cls, value: Any) -> bool:
"""Check if a value is of a mutable type and should be wrapped.
Args:
value: The value to check.
Returns:
Whether the value is of a mutable type.
"""
return isinstance(value, cls.__mutable_types__)
def _wrap_recursive(self, value: Any) -> Any:
"""Wrap a value recursively if it is mutable.
@ -3608,9 +3636,7 @@ class MutableProxy(wrapt.ObjectProxy):
The wrapped value.
"""
# Recursively wrap mutable types, but do not re-wrap MutableProxy instances.
if isinstance(value, self.__mutable_types__) and not isinstance(
value, MutableProxy
):
if self._is_mutable_type(value) and not isinstance(value, MutableProxy):
return type(self)(
wrapped=value,
state=self._self_state,
@ -3668,7 +3694,7 @@ class MutableProxy(wrapt.ObjectProxy):
self._wrap_recursive_decorator,
)
if isinstance(value, self.__mutable_types__) and __name not in (
if self._is_mutable_type(value) and __name not in (
"__wrapped__",
"_self_state",
):

View File

@ -270,6 +270,53 @@ def serialize_base(value: Base) -> dict:
}
try:
from pydantic.v1 import BaseModel as BaseModelV1
@serializer(to=dict)
def serialize_base_model_v1(model: BaseModelV1) -> dict:
"""Serialize a pydantic v1 BaseModel instance.
Args:
model: The BaseModel to serialize.
Returns:
The serialized BaseModel.
"""
return model.dict()
from pydantic import BaseModel as BaseModelV2
if BaseModelV1 is not BaseModelV2:
@serializer(to=dict)
def serialize_base_model_v2(model: BaseModelV2) -> dict:
"""Serialize a pydantic v2 BaseModel instance.
Args:
model: The BaseModel to serialize.
Returns:
The serialized BaseModel.
"""
return model.model_dump()
except ImportError:
# Older pydantic v1 import
from pydantic import BaseModel as BaseModelV1
@serializer(to=dict)
def serialize_base_model_v1(model: BaseModelV1) -> dict:
"""Serialize a pydantic v1 BaseModel instance.
Args:
model: The BaseModel to serialize.
Returns:
The serialized BaseModel.
"""
return model.dict()
@serializer
def serialize_set(value: Set) -> list:
"""Serialize a set to a JSON serializable list.

View File

@ -16,6 +16,8 @@ from unittest.mock import AsyncMock, Mock
import pytest
import pytest_asyncio
from plotly.graph_objects import Figure
from pydantic import BaseModel as BaseModelV2
from pydantic.v1 import BaseModel as BaseModelV1
import reflex as rx
import reflex.config
@ -3413,6 +3415,53 @@ def test_typed_state() -> None:
_ = TypedState(field="str")
class ModelV1(BaseModelV1):
"""A pydantic BaseModel v1."""
foo: str = "bar"
class ModelV2(BaseModelV2):
"""A pydantic BaseModel v2."""
foo: str = "bar"
@dataclasses.dataclass
class ModelDC:
"""A dataclass."""
foo: str = "bar"
class PydanticState(rx.State):
"""A state with pydantic BaseModel vars."""
v1: ModelV1 = ModelV1()
v2: ModelV2 = ModelV2()
dc: ModelDC = ModelDC()
def test_mutable_models():
"""Test that dataclass and pydantic BaseModel v1 and v2 use dep tracking."""
state = PydanticState()
assert isinstance(state.v1, MutableProxy)
state.v1.foo = "baz"
assert state.dirty_vars == {"v1"}
state.dirty_vars.clear()
assert isinstance(state.v2, MutableProxy)
state.v2.foo = "baz"
assert state.dirty_vars == {"v2"}
state.dirty_vars.clear()
# Not yet supported ENG-4083
# assert isinstance(state.dc, MutableProxy)
# state.dc.foo = "baz"
# assert state.dirty_vars == {"dc"}
# state.dirty_vars.clear()
def test_get_value():
class GetValueState(rx.State):
foo: str = "FOO"