diff --git a/reflex/state.py b/reflex/state.py index 7bdbcdc2b..cc9dda05b 100644 --- a/reflex/state.py +++ b/reflex/state.py @@ -105,6 +105,15 @@ var = computed_var # If the state is this large, it's considered a performance issue. TOO_LARGE_SERIALIZED_STATE = 100 * 1024 # 100kb +# Errors caught during pickling of state +HANDLED_PICKLE_ERRORS = ( + pickle.PicklingError, + AttributeError, + IndexError, + TypeError, + ValueError, +) + def _no_chain_background_task( state_cls: Type["BaseState"], name: str, fn: Callable @@ -2076,7 +2085,7 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow): """ try: return pickle.dumps((self._to_schema(), self)) - except (pickle.PicklingError, AttributeError) as og_pickle_error: + except HANDLED_PICKLE_ERRORS as og_pickle_error: error = ( f"Failed to serialize state {self.get_full_name()} due to unpicklable object. " "This state will not be persisted. " @@ -2090,7 +2099,7 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow): f"Pickle error: {og_pickle_error}. " "Consider `pip install 'dill>=0.3.8'` for more exotic serialization support." ) - except (pickle.PicklingError, TypeError, ValueError) as ex: + except HANDLED_PICKLE_ERRORS as ex: error += f"Dill was also unable to pickle the state: {ex}" console.warn(error) return b"" diff --git a/tests/units/test_state.py b/tests/units/test_state.py index 8397954cf..fe2f652ac 100644 --- a/tests/units/test_state.py +++ b/tests/units/test_state.py @@ -8,6 +8,7 @@ import functools import json import os import sys +import threading from textwrap import dedent from typing import Any, AsyncGenerator, Callable, Dict, List, Optional, Union from unittest.mock import AsyncMock, Mock @@ -3390,9 +3391,15 @@ def test_fallback_pickle(): assert unpickled_state._f() == 420 assert unpickled_state._o._f() == 42 + # Threading locks are unpicklable normally, and raise TypeError instead of PicklingError. + state2 = DillState(_reflex_internal_init=True) # type: ignore + state2._g = threading.Lock() + pk2 = state2._serialize() + unpickled_state2 = BaseState._deserialize(pk2) + assert isinstance(unpickled_state2._g, type(threading.Lock())) + # Some object, like generator, are still unpicklable with dill. - state._g = (i for i in range(10)) - pk = state._serialize() - assert len(pk) == 0 - with pytest.raises(EOFError): - BaseState._deserialize(pk) + state3 = DillState(_reflex_internal_init=True) # type: ignore + state3._g = (i for i in range(10)) + pk3 = state3._serialize() + assert len(pk3) == 0