diff --git a/reflex/state.py b/reflex/state.py index 66b1e3cab..a87b9c3e7 100644 --- a/reflex/state.py +++ b/reflex/state.py @@ -39,6 +39,8 @@ from typing import ( get_type_hints, ) +from pydantic import BaseModel as BaseModelV2 +from pydantic.v1 import BaseModel as BaseModelV1 from sqlalchemy.orm import DeclarativeBase from typing_extensions import Self @@ -3520,7 +3522,15 @@ class MutableProxy(wrapt.ObjectProxy): pydantic.BaseModel.__dict__ ) - __mutable_types__ = (list, dict, set, Base, DeclarativeBase) + __mutable_types__ = ( + list, + dict, + set, + Base, + DeclarativeBase, + BaseModelV2, + BaseModelV1, + ) def __init__(self, wrapped: Any, state: BaseState, field_name: str): """Create a proxy for a mutable object that tracks changes. diff --git a/reflex/utils/serializers.py b/reflex/utils/serializers.py index b87909aec..480eadf35 100644 --- a/reflex/utils/serializers.py +++ b/reflex/utils/serializers.py @@ -24,6 +24,9 @@ from typing import ( overload, ) +from pydantic import BaseModel as BaseModelV2 +from pydantic.v1 import BaseModel as BaseModelV1 + from reflex.base import Base from reflex.constants.colors import Color, format_color from reflex.utils import types @@ -266,6 +269,32 @@ def serialize_base(value: Base) -> dict: return {k: v for k, v in value.dict().items() if not callable(v)} +@serializer(to=dict) +def serialize_base_model_v1(model: BaseModelV1) -> dict: + """Serialize a pydantic v1 BaseModel instance. + + Args: + model: The BaseModel to serialize. + + Returns: + The serialized BaseModel. + """ + return model.dict() + + +@serializer(to=dict) +def serialize_base_model_v2(model: BaseModelV2) -> dict: + """Serialize a pydantic v2 BaseModel instance. + + Args: + model: The BaseModel to serialize. + + Returns: + The serialized BaseModel. + """ + return model.model_dump() + + @serializer def serialize_set(value: Set) -> list: """Serialize a set to a JSON serializable list. diff --git a/tests/units/test_state.py b/tests/units/test_state.py index 2ce0b7bd5..6bb130822 100644 --- a/tests/units/test_state.py +++ b/tests/units/test_state.py @@ -16,6 +16,8 @@ from unittest.mock import AsyncMock, Mock import pytest import pytest_asyncio from plotly.graph_objects import Figure +from pydantic import BaseModel as BaseModelV2 +from pydantic.v1 import BaseModel as BaseModelV1 import reflex as rx import reflex.config @@ -3411,3 +3413,34 @@ def test_typed_state() -> None: field: rx.Field[str] = rx.field("") _ = TypedState(field="str") + + +class ModelV1(BaseModelV1): + """A pydantic BaseModel v1.""" + + foo: str = "bar" + + +class ModelV2(BaseModelV2): + """A pydantic BaseModel v2.""" + + foo: str = "bar" + + +class PydanticState(rx.State): + """A state with pydantic BaseModel vars.""" + + v1: ModelV1 = ModelV1() + v2: ModelV2 = ModelV2() + + +def test_pydantic_base_models(): + """Test that pydantic BaseModel v1 and v2 can be used as state vars with dep tracking.""" + state = PydanticState() + assert isinstance(state.v1, MutableProxy) + state.v1.foo = "baz" + assert "v1" in state.dirty_vars + + assert isinstance(state.v2, MutableProxy) + state.v2.foo = "baz" + assert "v2" in state.dirty_vars