diff --git a/reflex/state.py b/reflex/state.py index 442fa57b2..349dc59e9 100644 --- a/reflex/state.py +++ b/reflex/state.py @@ -62,6 +62,13 @@ try: except ModuleNotFoundError: import pydantic +from pydantic import BaseModel as BaseModelV2 + +try: + from pydantic.v1 import BaseModel as BaseModelV1 +except ModuleNotFoundError: + BaseModelV1 = BaseModelV2 + import wrapt from redis.asyncio import Redis from redis.exceptions import ResponseError @@ -1250,7 +1257,7 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow): if parent_state is not None: return getattr(parent_state, name) - if isinstance(value, MutableProxy.__mutable_types__) and ( + if MutableProxy._is_mutable_type(value) and ( name in super().__getattribute__("base_vars") or name in backend_vars ): # track changes in mutable containers (list, dict, set, etc) @@ -3558,7 +3565,16 @@ class MutableProxy(wrapt.ObjectProxy): pydantic.BaseModel.__dict__ ) - __mutable_types__ = (list, dict, set, Base, DeclarativeBase) + # These types will be wrapped in MutableProxy + __mutable_types__ = ( + list, + dict, + set, + Base, + DeclarativeBase, + BaseModelV2, + BaseModelV1, + ) def __init__(self, wrapped: Any, state: BaseState, field_name: str): """Create a proxy for a mutable object that tracks changes. @@ -3598,6 +3614,18 @@ class MutableProxy(wrapt.ObjectProxy): if wrapped is not None: return wrapped(*args, **(kwargs or {})) + @classmethod + def _is_mutable_type(cls, value: Any) -> bool: + """Check if a value is of a mutable type and should be wrapped. + + Args: + value: The value to check. + + Returns: + Whether the value is of a mutable type. + """ + return isinstance(value, cls.__mutable_types__) + def _wrap_recursive(self, value: Any) -> Any: """Wrap a value recursively if it is mutable. @@ -3608,9 +3636,7 @@ class MutableProxy(wrapt.ObjectProxy): The wrapped value. """ # Recursively wrap mutable types, but do not re-wrap MutableProxy instances. - if isinstance(value, self.__mutable_types__) and not isinstance( - value, MutableProxy - ): + if self._is_mutable_type(value) and not isinstance(value, MutableProxy): return type(self)( wrapped=value, state=self._self_state, @@ -3668,7 +3694,7 @@ class MutableProxy(wrapt.ObjectProxy): self._wrap_recursive_decorator, ) - if isinstance(value, self.__mutable_types__) and __name not in ( + if self._is_mutable_type(value) and __name not in ( "__wrapped__", "_self_state", ): diff --git a/reflex/utils/serializers.py b/reflex/utils/serializers.py index b0ad935c8..4bb8dea92 100644 --- a/reflex/utils/serializers.py +++ b/reflex/utils/serializers.py @@ -270,6 +270,53 @@ def serialize_base(value: Base) -> dict: } +try: + from pydantic.v1 import BaseModel as BaseModelV1 + + @serializer(to=dict) + def serialize_base_model_v1(model: BaseModelV1) -> dict: + """Serialize a pydantic v1 BaseModel instance. + + Args: + model: The BaseModel to serialize. + + Returns: + The serialized BaseModel. + """ + return model.dict() + + from pydantic import BaseModel as BaseModelV2 + + if BaseModelV1 is not BaseModelV2: + + @serializer(to=dict) + def serialize_base_model_v2(model: BaseModelV2) -> dict: + """Serialize a pydantic v2 BaseModel instance. + + Args: + model: The BaseModel to serialize. + + Returns: + The serialized BaseModel. + """ + return model.model_dump() +except ImportError: + # Older pydantic v1 import + from pydantic import BaseModel as BaseModelV1 + + @serializer(to=dict) + def serialize_base_model_v1(model: BaseModelV1) -> dict: + """Serialize a pydantic v1 BaseModel instance. + + Args: + model: The BaseModel to serialize. + + Returns: + The serialized BaseModel. + """ + return model.dict() + + @serializer def serialize_set(value: Set) -> list: """Serialize a set to a JSON serializable list. diff --git a/tests/units/test_state.py b/tests/units/test_state.py index 7cebaff8e..c8a52e6c0 100644 --- a/tests/units/test_state.py +++ b/tests/units/test_state.py @@ -16,6 +16,8 @@ from unittest.mock import AsyncMock, Mock import pytest import pytest_asyncio from plotly.graph_objects import Figure +from pydantic import BaseModel as BaseModelV2 +from pydantic.v1 import BaseModel as BaseModelV1 import reflex as rx import reflex.config @@ -3413,6 +3415,53 @@ def test_typed_state() -> None: _ = TypedState(field="str") +class ModelV1(BaseModelV1): + """A pydantic BaseModel v1.""" + + foo: str = "bar" + + +class ModelV2(BaseModelV2): + """A pydantic BaseModel v2.""" + + foo: str = "bar" + + +@dataclasses.dataclass +class ModelDC: + """A dataclass.""" + + foo: str = "bar" + + +class PydanticState(rx.State): + """A state with pydantic BaseModel vars.""" + + v1: ModelV1 = ModelV1() + v2: ModelV2 = ModelV2() + dc: ModelDC = ModelDC() + + +def test_mutable_models(): + """Test that dataclass and pydantic BaseModel v1 and v2 use dep tracking.""" + state = PydanticState() + assert isinstance(state.v1, MutableProxy) + state.v1.foo = "baz" + assert state.dirty_vars == {"v1"} + state.dirty_vars.clear() + + assert isinstance(state.v2, MutableProxy) + state.v2.foo = "baz" + assert state.dirty_vars == {"v2"} + state.dirty_vars.clear() + + # Not yet supported ENG-4083 + # assert isinstance(state.dc, MutableProxy) + # state.dc.foo = "baz" + # assert state.dirty_vars == {"dc"} + # state.dirty_vars.clear() + + def test_get_value(): class GetValueState(rx.State): foo: str = "FOO"