[ENG-4013] Catch more exceptions for dill pickle fallback (#4270)

Additionally catch TypeError, IndexError, and ValueError which may be thrown
when attempting to pickle unpicklable objects.
This commit is contained in:
Masen Furer 2024-10-30 16:50:19 -07:00 committed by GitHub
parent c288741cab
commit 2ab662b757
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 23 additions and 7 deletions

View File

@ -105,6 +105,15 @@ var = computed_var
# If the state is this large, it's considered a performance issue. # If the state is this large, it's considered a performance issue.
TOO_LARGE_SERIALIZED_STATE = 100 * 1024 # 100kb TOO_LARGE_SERIALIZED_STATE = 100 * 1024 # 100kb
# Errors caught during pickling of state
HANDLED_PICKLE_ERRORS = (
pickle.PicklingError,
AttributeError,
IndexError,
TypeError,
ValueError,
)
def _no_chain_background_task( def _no_chain_background_task(
state_cls: Type["BaseState"], name: str, fn: Callable state_cls: Type["BaseState"], name: str, fn: Callable
@ -2076,7 +2085,7 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
""" """
try: try:
return pickle.dumps((self._to_schema(), self)) return pickle.dumps((self._to_schema(), self))
except (pickle.PicklingError, AttributeError) as og_pickle_error: except HANDLED_PICKLE_ERRORS as og_pickle_error:
error = ( error = (
f"Failed to serialize state {self.get_full_name()} due to unpicklable object. " f"Failed to serialize state {self.get_full_name()} due to unpicklable object. "
"This state will not be persisted. " "This state will not be persisted. "
@ -2090,7 +2099,7 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
f"Pickle error: {og_pickle_error}. " f"Pickle error: {og_pickle_error}. "
"Consider `pip install 'dill>=0.3.8'` for more exotic serialization support." "Consider `pip install 'dill>=0.3.8'` for more exotic serialization support."
) )
except (pickle.PicklingError, TypeError, ValueError) as ex: except HANDLED_PICKLE_ERRORS as ex:
error += f"Dill was also unable to pickle the state: {ex}" error += f"Dill was also unable to pickle the state: {ex}"
console.warn(error) console.warn(error)
return b"" return b""

View File

@ -8,6 +8,7 @@ import functools
import json import json
import os import os
import sys import sys
import threading
from textwrap import dedent from textwrap import dedent
from typing import Any, AsyncGenerator, Callable, Dict, List, Optional, Union from typing import Any, AsyncGenerator, Callable, Dict, List, Optional, Union
from unittest.mock import AsyncMock, Mock from unittest.mock import AsyncMock, Mock
@ -3390,9 +3391,15 @@ def test_fallback_pickle():
assert unpickled_state._f() == 420 assert unpickled_state._f() == 420
assert unpickled_state._o._f() == 42 assert unpickled_state._o._f() == 42
# Threading locks are unpicklable normally, and raise TypeError instead of PicklingError.
state2 = DillState(_reflex_internal_init=True) # type: ignore
state2._g = threading.Lock()
pk2 = state2._serialize()
unpickled_state2 = BaseState._deserialize(pk2)
assert isinstance(unpickled_state2._g, type(threading.Lock()))
# Some object, like generator, are still unpicklable with dill. # Some object, like generator, are still unpicklable with dill.
state._g = (i for i in range(10)) state3 = DillState(_reflex_internal_init=True) # type: ignore
pk = state._serialize() state3._g = (i for i in range(10))
assert len(pk) == 0 pk3 = state3._serialize()
with pytest.raises(EOFError): assert len(pk3) == 0
BaseState._deserialize(pk)