diff --git a/poetry.lock b/poetry.lock index ff9323ad4..71be30d76 100644 --- a/poetry.lock +++ b/poetry.lock @@ -521,6 +521,21 @@ files = [ {file = "darglint-1.8.1.tar.gz", hash = "sha256:080d5106df149b199822e7ee7deb9c012b49891538f14a11be681044f0bb20da"}, ] +[[package]] +name = "dill" +version = "0.3.9" +description = "serialize all of Python" +optional = false +python-versions = ">=3.8" +files = [ + {file = "dill-0.3.9-py3-none-any.whl", hash = "sha256:468dff3b89520b474c0397703366b7b95eebe6303f108adf9b19da1f702be87a"}, + {file = "dill-0.3.9.tar.gz", hash = "sha256:81aa267dddf68cbfe8029c42ca9ec6a4ab3b22371d1c450abc54422577b4512c"}, +] + +[package.extras] +graph = ["objgraph (>=1.7.2)"] +profile = ["gprof2dot (>=2022.7.29)"] + [[package]] name = "distlib" version = "0.3.9" @@ -1333,8 +1348,8 @@ files = [ [package.dependencies] numpy = [ - {version = ">=1.26.0", markers = "python_version >= \"3.12\""}, {version = ">=1.23.2", markers = "python_version == \"3.11\""}, + {version = ">=1.26.0", markers = "python_version >= \"3.12\""}, {version = ">=1.22.4", markers = "python_version < \"3.11\""}, ] python-dateutil = ">=2.8.2" @@ -1652,8 +1667,8 @@ files = [ annotated-types = ">=0.6.0" pydantic-core = "2.23.4" typing-extensions = [ - {version = ">=4.12.2", markers = "python_version >= \"3.13\""}, {version = ">=4.6.1", markers = "python_version < \"3.13\""}, + {version = ">=4.12.2", markers = "python_version >= \"3.13\""}, ] [package.extras] @@ -3033,4 +3048,4 @@ type = ["pytest-mypy"] [metadata] lock-version = "2.0" python-versions = "^3.9" -content-hash = "8090ccaeca173bd8612e17a0b8d157d7492618e49450abd1c8373e2976349db0" +content-hash = "e03374b85bf10f0a7bb857969b2d6714f25affa63e14a48a88be9fa154b24326" diff --git a/pyproject.toml b/pyproject.toml index 93f3c5d50..2635e1156 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -65,6 +65,7 @@ pytest = ">=7.1.2,<9.0" pytest-mock = ">=3.10.0,<4.0" pyright = ">=1.1.229,<1.1.335" darglint = ">=1.8.1,<2.0" +dill = ">=0.3.8" toml = ">=0.10.2,<1.0" pytest-asyncio = ">=0.24.0" pytest-cov = ">=4.0.0,<6.0" diff --git a/reflex/state.py b/reflex/state.py index e3e189b22..6e229b97d 100644 --- a/reflex/state.py +++ b/reflex/state.py @@ -2063,12 +2063,24 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow): """ try: return pickle.dumps((self._to_schema(), self)) - except pickle.PicklingError: - console.warn( + except (pickle.PicklingError, AttributeError) as og_pickle_error: + error = ( f"Failed to serialize state {self.get_full_name()} due to unpicklable object. " - "This state will not be persisted." + "This state will not be persisted. " ) - return b"" + try: + import dill + + return dill.dumps((self._to_schema(), self)) + except ImportError: + error += ( + f"Pickle error: {og_pickle_error}. " + "Consider `pip install 'dill>=0.3.8'` for more exotic serialization support." + ) + except (pickle.PicklingError, TypeError, ValueError) as ex: + error += f"Dill was also unable to pickle the state: {ex}" + console.warn(error) + return b"" @classmethod def _deserialize( diff --git a/tests/units/test_state.py b/tests/units/test_state.py index 544ddc606..89dd1fd3d 100644 --- a/tests/units/test_state.py +++ b/tests/units/test_state.py @@ -3364,3 +3364,35 @@ async def test_deserialize_gc_state_disk(token): assert s.num == 43 c = await root.get_state(Child) assert c.foo == "bar" + + +class Obj(Base): + """A object containing a callable for testing fallback pickle.""" + + _f: Callable + + +def test_fallback_pickle(): + """Test that state serialization will fall back to dill.""" + + class DillState(BaseState): + _o: Optional[Obj] = None + _f: Optional[Callable] = None + _g: Any = None + + state = DillState(_reflex_internal_init=True) # type: ignore + state._o = Obj(_f=lambda: 42) + state._f = lambda: 420 + + pk = state._serialize() + + unpickled_state = BaseState._deserialize(pk) + assert unpickled_state._f() == 420 + assert unpickled_state._o._f() == 42 + + # Some object, like generator, are still unpicklable with dill. + state._g = (i for i in range(10)) + pk = state._serialize() + assert len(pk) == 0 + with pytest.raises(EOFError): + BaseState._deserialize(pk)