[ENG-4013] Catch more exceptions for dill pickle fallback (#4270)

Additionally catch TypeError, IndexError, and ValueError which may be thrown
when attempting to pickle unpicklable objects.
This commit is contained in:
Masen Furer 2024-10-30 16:50:19 -07:00 committed by GitHub
parent c288741cab
commit 2ab662b757
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 23 additions and 7 deletions

View File

@ -105,6 +105,15 @@ var = computed_var
# If the state is this large, it's considered a performance issue.
TOO_LARGE_SERIALIZED_STATE = 100 * 1024 # 100kb
# Errors caught during pickling of state
HANDLED_PICKLE_ERRORS = (
pickle.PicklingError,
AttributeError,
IndexError,
TypeError,
ValueError,
)
def _no_chain_background_task(
state_cls: Type["BaseState"], name: str, fn: Callable
@ -2076,7 +2085,7 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
"""
try:
return pickle.dumps((self._to_schema(), self))
except (pickle.PicklingError, AttributeError) as og_pickle_error:
except HANDLED_PICKLE_ERRORS as og_pickle_error:
error = (
f"Failed to serialize state {self.get_full_name()} due to unpicklable object. "
"This state will not be persisted. "
@ -2090,7 +2099,7 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
f"Pickle error: {og_pickle_error}. "
"Consider `pip install 'dill>=0.3.8'` for more exotic serialization support."
)
except (pickle.PicklingError, TypeError, ValueError) as ex:
except HANDLED_PICKLE_ERRORS as ex:
error += f"Dill was also unable to pickle the state: {ex}"
console.warn(error)
return b""

View File

@ -8,6 +8,7 @@ import functools
import json
import os
import sys
import threading
from textwrap import dedent
from typing import Any, AsyncGenerator, Callable, Dict, List, Optional, Union
from unittest.mock import AsyncMock, Mock
@ -3390,9 +3391,15 @@ def test_fallback_pickle():
assert unpickled_state._f() == 420
assert unpickled_state._o._f() == 42
# Threading locks are unpicklable normally, and raise TypeError instead of PicklingError.
state2 = DillState(_reflex_internal_init=True) # type: ignore
state2._g = threading.Lock()
pk2 = state2._serialize()
unpickled_state2 = BaseState._deserialize(pk2)
assert isinstance(unpickled_state2._g, type(threading.Lock()))
# Some object, like generator, are still unpicklable with dill.
state._g = (i for i in range(10))
pk = state._serialize()
assert len(pk) == 0
with pytest.raises(EOFError):
BaseState._deserialize(pk)
state3 = DillState(_reflex_internal_init=True) # type: ignore
state3._g = (i for i in range(10))
pk3 = state3._serialize()
assert len(pk3) == 0