diff --git a/reflex/state.py b/reflex/state.py index a87b9c3e7..ebf958734 100644 --- a/reflex/state.py +++ b/reflex/state.py @@ -1244,7 +1244,7 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow): if parent_state is not None: return getattr(parent_state, name) - if isinstance(value, MutableProxy.__mutable_types__) and ( + if MutableProxy._is_mutable_type(value) and ( name in super().__getattribute__("base_vars") or name in backend_vars ): # track changes in mutable containers (list, dict, set, etc) @@ -3522,6 +3522,7 @@ class MutableProxy(wrapt.ObjectProxy): pydantic.BaseModel.__dict__ ) + # These types will be wrapped in MutableProxy __mutable_types__ = ( list, dict, @@ -3570,6 +3571,18 @@ class MutableProxy(wrapt.ObjectProxy): if wrapped is not None: return wrapped(*args, **(kwargs or {})) + @classmethod + def _is_mutable_type(cls, value: Any) -> bool: + """Check if a value is of a mutable type and should be wrapped. + + Args: + value: The value to check. + + Returns: + Whether the value is of a mutable type. + """ + return isinstance(value, cls.__mutable_types__) + def _wrap_recursive(self, value: Any) -> Any: """Wrap a value recursively if it is mutable. @@ -3580,9 +3593,7 @@ class MutableProxy(wrapt.ObjectProxy): The wrapped value. """ # Recursively wrap mutable types, but do not re-wrap MutableProxy instances. - if isinstance(value, self.__mutable_types__) and not isinstance( - value, MutableProxy - ): + if self._is_mutable_type(value) and not isinstance(value, MutableProxy): return type(self)( wrapped=value, state=self._self_state, @@ -3640,7 +3651,7 @@ class MutableProxy(wrapt.ObjectProxy): self._wrap_recursive_decorator, ) - if isinstance(value, self.__mutable_types__) and __name not in ( + if self._is_mutable_type(value) and __name not in ( "__wrapped__", "_self_state", ): diff --git a/reflex/utils/serializers.py b/reflex/utils/serializers.py index 9a0613f5a..18a7eb671 100644 --- a/reflex/utils/serializers.py +++ b/reflex/utils/serializers.py @@ -283,13 +283,14 @@ def serialize_base_model_v1(model: BaseModelV1) -> dict: if BaseModelV2 is not BaseModelV1: + @serializer(to=dict) def serialize_base_model_v2(model: BaseModelV2) -> dict: """Serialize a pydantic v2 BaseModel instance. - + Args: model: The BaseModel to serialize. - + Returns: The serialized BaseModel. """ diff --git a/tests/units/test_state.py b/tests/units/test_state.py index 6bb130822..43a90135b 100644 --- a/tests/units/test_state.py +++ b/tests/units/test_state.py @@ -3427,20 +3427,36 @@ class ModelV2(BaseModelV2): foo: str = "bar" +@dataclasses.dataclass +class ModelDC: + """A dataclass.""" + + foo: str = "bar" + + class PydanticState(rx.State): """A state with pydantic BaseModel vars.""" v1: ModelV1 = ModelV1() v2: ModelV2 = ModelV2() + dc: ModelDC = ModelDC() -def test_pydantic_base_models(): - """Test that pydantic BaseModel v1 and v2 can be used as state vars with dep tracking.""" +def test_mutable_models(): + """Test that dataclass and pydantic BaseModel v1 and v2 use dep tracking.""" state = PydanticState() assert isinstance(state.v1, MutableProxy) state.v1.foo = "baz" - assert "v1" in state.dirty_vars + assert state.dirty_vars == {"v1"} + state.dirty_vars.clear() assert isinstance(state.v2, MutableProxy) state.v2.foo = "baz" - assert "v2" in state.dirty_vars + assert state.dirty_vars == {"v2"} + state.dirty_vars.clear() + + # Not yet supported ENG-4083 + # assert isinstance(state.dc, MutableProxy) + # state.dc.foo = "baz" + # assert state.dirty_vars == {"dc"} + # state.dirty_vars.clear()