[ENG-3953] Support pydantic BaseModel (v1 and v2) as state var

Provide serializers and mutable proxy tracking for pydantic models directly.
This commit is contained in:
Masen Furer 2024-11-08 09:56:43 -08:00
parent cd59ab5406
commit 91437317d3
No known key found for this signature in database
GPG Key ID: B0008AD22B3B3A95
3 changed files with 73 additions and 1 deletions

View File

@ -39,6 +39,8 @@ from typing import (
get_type_hints,
)
from pydantic import BaseModel as BaseModelV2
from pydantic.v1 import BaseModel as BaseModelV1
from sqlalchemy.orm import DeclarativeBase
from typing_extensions import Self
@ -3520,7 +3522,15 @@ class MutableProxy(wrapt.ObjectProxy):
pydantic.BaseModel.__dict__
)
__mutable_types__ = (list, dict, set, Base, DeclarativeBase)
__mutable_types__ = (
list,
dict,
set,
Base,
DeclarativeBase,
BaseModelV2,
BaseModelV1,
)
def __init__(self, wrapped: Any, state: BaseState, field_name: str):
"""Create a proxy for a mutable object that tracks changes.

View File

@ -24,6 +24,9 @@ from typing import (
overload,
)
from pydantic import BaseModel as BaseModelV2
from pydantic.v1 import BaseModel as BaseModelV1
from reflex.base import Base
from reflex.constants.colors import Color, format_color
from reflex.utils import types
@ -266,6 +269,32 @@ def serialize_base(value: Base) -> dict:
return {k: v for k, v in value.dict().items() if not callable(v)}
@serializer(to=dict)
def serialize_base_model_v1(model: BaseModelV1) -> dict:
"""Serialize a pydantic v1 BaseModel instance.
Args:
model: The BaseModel to serialize.
Returns:
The serialized BaseModel.
"""
return model.dict()
@serializer(to=dict)
def serialize_base_model_v2(model: BaseModelV2) -> dict:
"""Serialize a pydantic v2 BaseModel instance.
Args:
model: The BaseModel to serialize.
Returns:
The serialized BaseModel.
"""
return model.model_dump()
@serializer
def serialize_set(value: Set) -> list:
"""Serialize a set to a JSON serializable list.

View File

@ -16,6 +16,8 @@ from unittest.mock import AsyncMock, Mock
import pytest
import pytest_asyncio
from plotly.graph_objects import Figure
from pydantic import BaseModel as BaseModelV2
from pydantic.v1 import BaseModel as BaseModelV1
import reflex as rx
import reflex.config
@ -3411,3 +3413,34 @@ def test_typed_state() -> None:
field: rx.Field[str] = rx.field("")
_ = TypedState(field="str")
class ModelV1(BaseModelV1):
"""A pydantic BaseModel v1."""
foo: str = "bar"
class ModelV2(BaseModelV2):
"""A pydantic BaseModel v2."""
foo: str = "bar"
class PydanticState(rx.State):
"""A state with pydantic BaseModel vars."""
v1: ModelV1 = ModelV1()
v2: ModelV2 = ModelV2()
def test_pydantic_base_models():
"""Test that pydantic BaseModel v1 and v2 can be used as state vars with dep tracking."""
state = PydanticState()
assert isinstance(state.v1, MutableProxy)
state.v1.foo = "baz"
assert "v1" in state.dirty_vars
assert isinstance(state.v2, MutableProxy)
state.v2.foo = "baz"
assert "v2" in state.dirty_vars