[ENG-3953] Support pydantic BaseModel (v1 and v2) as state var
Provide serializers and mutable proxy tracking for pydantic models directly.
This commit is contained in:
parent
cd59ab5406
commit
91437317d3
@ -39,6 +39,8 @@ from typing import (
|
||||
get_type_hints,
|
||||
)
|
||||
|
||||
from pydantic import BaseModel as BaseModelV2
|
||||
from pydantic.v1 import BaseModel as BaseModelV1
|
||||
from sqlalchemy.orm import DeclarativeBase
|
||||
from typing_extensions import Self
|
||||
|
||||
@ -3520,7 +3522,15 @@ class MutableProxy(wrapt.ObjectProxy):
|
||||
pydantic.BaseModel.__dict__
|
||||
)
|
||||
|
||||
__mutable_types__ = (list, dict, set, Base, DeclarativeBase)
|
||||
__mutable_types__ = (
|
||||
list,
|
||||
dict,
|
||||
set,
|
||||
Base,
|
||||
DeclarativeBase,
|
||||
BaseModelV2,
|
||||
BaseModelV1,
|
||||
)
|
||||
|
||||
def __init__(self, wrapped: Any, state: BaseState, field_name: str):
|
||||
"""Create a proxy for a mutable object that tracks changes.
|
||||
|
@ -24,6 +24,9 @@ from typing import (
|
||||
overload,
|
||||
)
|
||||
|
||||
from pydantic import BaseModel as BaseModelV2
|
||||
from pydantic.v1 import BaseModel as BaseModelV1
|
||||
|
||||
from reflex.base import Base
|
||||
from reflex.constants.colors import Color, format_color
|
||||
from reflex.utils import types
|
||||
@ -266,6 +269,32 @@ def serialize_base(value: Base) -> dict:
|
||||
return {k: v for k, v in value.dict().items() if not callable(v)}
|
||||
|
||||
|
||||
@serializer(to=dict)
|
||||
def serialize_base_model_v1(model: BaseModelV1) -> dict:
|
||||
"""Serialize a pydantic v1 BaseModel instance.
|
||||
|
||||
Args:
|
||||
model: The BaseModel to serialize.
|
||||
|
||||
Returns:
|
||||
The serialized BaseModel.
|
||||
"""
|
||||
return model.dict()
|
||||
|
||||
|
||||
@serializer(to=dict)
|
||||
def serialize_base_model_v2(model: BaseModelV2) -> dict:
|
||||
"""Serialize a pydantic v2 BaseModel instance.
|
||||
|
||||
Args:
|
||||
model: The BaseModel to serialize.
|
||||
|
||||
Returns:
|
||||
The serialized BaseModel.
|
||||
"""
|
||||
return model.model_dump()
|
||||
|
||||
|
||||
@serializer
|
||||
def serialize_set(value: Set) -> list:
|
||||
"""Serialize a set to a JSON serializable list.
|
||||
|
@ -16,6 +16,8 @@ from unittest.mock import AsyncMock, Mock
|
||||
import pytest
|
||||
import pytest_asyncio
|
||||
from plotly.graph_objects import Figure
|
||||
from pydantic import BaseModel as BaseModelV2
|
||||
from pydantic.v1 import BaseModel as BaseModelV1
|
||||
|
||||
import reflex as rx
|
||||
import reflex.config
|
||||
@ -3411,3 +3413,34 @@ def test_typed_state() -> None:
|
||||
field: rx.Field[str] = rx.field("")
|
||||
|
||||
_ = TypedState(field="str")
|
||||
|
||||
|
||||
class ModelV1(BaseModelV1):
|
||||
"""A pydantic BaseModel v1."""
|
||||
|
||||
foo: str = "bar"
|
||||
|
||||
|
||||
class ModelV2(BaseModelV2):
|
||||
"""A pydantic BaseModel v2."""
|
||||
|
||||
foo: str = "bar"
|
||||
|
||||
|
||||
class PydanticState(rx.State):
|
||||
"""A state with pydantic BaseModel vars."""
|
||||
|
||||
v1: ModelV1 = ModelV1()
|
||||
v2: ModelV2 = ModelV2()
|
||||
|
||||
|
||||
def test_pydantic_base_models():
|
||||
"""Test that pydantic BaseModel v1 and v2 can be used as state vars with dep tracking."""
|
||||
state = PydanticState()
|
||||
assert isinstance(state.v1, MutableProxy)
|
||||
state.v1.foo = "baz"
|
||||
assert "v1" in state.dirty_vars
|
||||
|
||||
assert isinstance(state.v2, MutableProxy)
|
||||
state.v2.foo = "baz"
|
||||
assert "v2" in state.dirty_vars
|
||||
|
Loading…
Reference in New Issue
Block a user