Use regular pickle module from stdlib

This commit is contained in:
Masen Furer 2024-10-03 13:53:58 -07:00
parent fafdeb892e
commit 52d1a2c5ec
No known key found for this signature in database
GPG Key ID: B0008AD22B3B3A95
2 changed files with 55 additions and 8 deletions

View File

@ -9,6 +9,7 @@ import dataclasses
import functools
import inspect
import os
import pickle
import uuid
from abc import ABC, abstractmethod
from collections import defaultdict
@ -19,6 +20,7 @@ from typing import (
TYPE_CHECKING,
Any,
AsyncIterator,
BinaryIO,
Callable,
ClassVar,
Dict,
@ -76,6 +78,7 @@ from reflex.utils.exceptions import (
ImmutableStateError,
LockExpiredError,
SetUndefinedStateVarError,
StateSchemaMismatchError,
)
from reflex.utils.exec import is_testing_env
from reflex.utils.serializers import serializer
@ -1914,7 +1917,7 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
def __getstate__(self):
"""Get the state for redis serialization.
This method is called by cloudpickle to serialize the object.
This method is called by pickle to serialize the object.
It explicitly removes parent_state and substates because those are serialized separately
by the StateManagerRedis to allow for better horizontal scaling as state size increases.
@ -1930,6 +1933,43 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
state["__dict__"].pop("_was_touched", None)
return state
def _serialize(self) -> bytes:
"""Serialize the state for redis.
Returns:
The serialized state.
"""
return pickle.dumps((state_to_schema(self), self))
@classmethod
def _deserialize(
cls, data: bytes | None = None, fp: BinaryIO | None = None
) -> BaseState:
"""Deserialize the state from redis/disk.
data and fp are mutually exclusive, but one must be provided.
Args:
data: The serialized state data.
fp: The file pointer to the serialized state data.
Returns:
The deserialized state.
Raises:
ValueError: If both data and fp are provided, or neither are provided.
StateSchemaMismatchError: If the state schema does not match the expected schema.
"""
if data is not None and fp is None:
(substate_schema, state) = pickle.loads(data)
elif fp is not None and data is None:
(substate_schema, state) = pickle.load(fp)
else:
raise ValueError("Only one of `data` or `fp` must be provided")
if substate_schema != state_to_schema(state):
raise StateSchemaMismatchError()
return state
class State(BaseState):
"""The app Base State."""
@ -2086,7 +2126,11 @@ class ComponentState(State, mixin=True):
"""
cls._per_component_state_instance_count += 1
state_cls_name = f"{cls.__name__}_n{cls._per_component_state_instance_count}"
component_state = type(state_cls_name, (cls, State), {}, mixin=False)
component_state = type(
state_cls_name, (cls, State), {"__module__": __name__}, mixin=False
)
# Save a reference to the dynamic state for pickle/unpickle.
globals()[state_cls_name] = component_state
component = component_state.get_component(*children, **props)
component.State = component_state
return component
@ -2552,7 +2596,7 @@ def is_serializable(value: Any) -> bool:
Whether the value is serializable.
"""
try:
return bool(dill.dumps(value))
return bool(pickle.dumps(value))
except Exception:
return False
@ -2688,8 +2732,7 @@ class StateManagerDisk(StateManager):
if token_path.exists():
try:
with token_path.open(mode="rb") as file:
(substate_schema, substate) = dill.load(file)
if substate_schema == state_to_schema(substate):
substate = BaseState._deserialize(fp=file)
await self.populate_substates(client_token, substate, root_state)
return substate
except Exception:
@ -2747,7 +2790,7 @@ class StateManagerDisk(StateManager):
self.states[substate_token] = substate
state_dilled = dill.dumps((state_to_schema(substate), substate))
state_dilled = substate._serialize()
if not self.states_directory.exists():
self.states_directory.mkdir(parents=True, exist_ok=True)
self.token_path(substate_token).write_bytes(state_dilled)
@ -2948,7 +2991,7 @@ class StateManagerRedis(StateManager):
if redis_state is not None:
# Deserialize the substate.
state = dill.loads(redis_state)
state = BaseState._deserialize(data=redis_state)
# Populate parent state if missing and requested.
if parent_state is None:
@ -3060,7 +3103,7 @@ class StateManagerRedis(StateManager):
)
# Persist only the given state (parents or substates are excluded by BaseState.__getstate__).
if state._get_was_touched():
pickle_state = dill.dumps(state, byref=True)
pickle_state = state._serialize()
self._warn_if_too_large(state, len(pickle_state))
await self.redis.set(
_substate_key(client_token, state),

View File

@ -123,3 +123,7 @@ class DynamicComponentMissingLibrary(ReflexError, ValueError):
class SetUndefinedStateVarError(ReflexError, AttributeError):
"""Raised when setting the value of a var without first declaring it."""
class StateSchemaMismatchError(ReflexError, TypeError):
"""Raised when the serialized schema of a state class does not match the current schema."""