Use regular pickle module from stdlib

This commit is contained in:
Masen Furer 2024-10-03 13:53:58 -07:00
parent fafdeb892e
commit 52d1a2c5ec
No known key found for this signature in database
GPG Key ID: B0008AD22B3B3A95
2 changed files with 55 additions and 8 deletions

View File

@ -9,6 +9,7 @@ import dataclasses
import functools import functools
import inspect import inspect
import os import os
import pickle
import uuid import uuid
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from collections import defaultdict from collections import defaultdict
@ -19,6 +20,7 @@ from typing import (
TYPE_CHECKING, TYPE_CHECKING,
Any, Any,
AsyncIterator, AsyncIterator,
BinaryIO,
Callable, Callable,
ClassVar, ClassVar,
Dict, Dict,
@ -76,6 +78,7 @@ from reflex.utils.exceptions import (
ImmutableStateError, ImmutableStateError,
LockExpiredError, LockExpiredError,
SetUndefinedStateVarError, SetUndefinedStateVarError,
StateSchemaMismatchError,
) )
from reflex.utils.exec import is_testing_env from reflex.utils.exec import is_testing_env
from reflex.utils.serializers import serializer from reflex.utils.serializers import serializer
@ -1914,7 +1917,7 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
def __getstate__(self): def __getstate__(self):
"""Get the state for redis serialization. """Get the state for redis serialization.
This method is called by cloudpickle to serialize the object. This method is called by pickle to serialize the object.
It explicitly removes parent_state and substates because those are serialized separately It explicitly removes parent_state and substates because those are serialized separately
by the StateManagerRedis to allow for better horizontal scaling as state size increases. by the StateManagerRedis to allow for better horizontal scaling as state size increases.
@ -1930,6 +1933,43 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
state["__dict__"].pop("_was_touched", None) state["__dict__"].pop("_was_touched", None)
return state return state
def _serialize(self) -> bytes:
"""Serialize the state for redis.
Returns:
The serialized state.
"""
return pickle.dumps((state_to_schema(self), self))
@classmethod
def _deserialize(
cls, data: bytes | None = None, fp: BinaryIO | None = None
) -> BaseState:
"""Deserialize the state from redis/disk.
data and fp are mutually exclusive, but one must be provided.
Args:
data: The serialized state data.
fp: The file pointer to the serialized state data.
Returns:
The deserialized state.
Raises:
ValueError: If both data and fp are provided, or neither are provided.
StateSchemaMismatchError: If the state schema does not match the expected schema.
"""
if data is not None and fp is None:
(substate_schema, state) = pickle.loads(data)
elif fp is not None and data is None:
(substate_schema, state) = pickle.load(fp)
else:
raise ValueError("Only one of `data` or `fp` must be provided")
if substate_schema != state_to_schema(state):
raise StateSchemaMismatchError()
return state
class State(BaseState): class State(BaseState):
"""The app Base State.""" """The app Base State."""
@ -2086,7 +2126,11 @@ class ComponentState(State, mixin=True):
""" """
cls._per_component_state_instance_count += 1 cls._per_component_state_instance_count += 1
state_cls_name = f"{cls.__name__}_n{cls._per_component_state_instance_count}" state_cls_name = f"{cls.__name__}_n{cls._per_component_state_instance_count}"
component_state = type(state_cls_name, (cls, State), {}, mixin=False) component_state = type(
state_cls_name, (cls, State), {"__module__": __name__}, mixin=False
)
# Save a reference to the dynamic state for pickle/unpickle.
globals()[state_cls_name] = component_state
component = component_state.get_component(*children, **props) component = component_state.get_component(*children, **props)
component.State = component_state component.State = component_state
return component return component
@ -2552,7 +2596,7 @@ def is_serializable(value: Any) -> bool:
Whether the value is serializable. Whether the value is serializable.
""" """
try: try:
return bool(dill.dumps(value)) return bool(pickle.dumps(value))
except Exception: except Exception:
return False return False
@ -2688,8 +2732,7 @@ class StateManagerDisk(StateManager):
if token_path.exists(): if token_path.exists():
try: try:
with token_path.open(mode="rb") as file: with token_path.open(mode="rb") as file:
(substate_schema, substate) = dill.load(file) substate = BaseState._deserialize(fp=file)
if substate_schema == state_to_schema(substate):
await self.populate_substates(client_token, substate, root_state) await self.populate_substates(client_token, substate, root_state)
return substate return substate
except Exception: except Exception:
@ -2747,7 +2790,7 @@ class StateManagerDisk(StateManager):
self.states[substate_token] = substate self.states[substate_token] = substate
state_dilled = dill.dumps((state_to_schema(substate), substate)) state_dilled = substate._serialize()
if not self.states_directory.exists(): if not self.states_directory.exists():
self.states_directory.mkdir(parents=True, exist_ok=True) self.states_directory.mkdir(parents=True, exist_ok=True)
self.token_path(substate_token).write_bytes(state_dilled) self.token_path(substate_token).write_bytes(state_dilled)
@ -2948,7 +2991,7 @@ class StateManagerRedis(StateManager):
if redis_state is not None: if redis_state is not None:
# Deserialize the substate. # Deserialize the substate.
state = dill.loads(redis_state) state = BaseState._deserialize(data=redis_state)
# Populate parent state if missing and requested. # Populate parent state if missing and requested.
if parent_state is None: if parent_state is None:
@ -3060,7 +3103,7 @@ class StateManagerRedis(StateManager):
) )
# Persist only the given state (parents or substates are excluded by BaseState.__getstate__). # Persist only the given state (parents or substates are excluded by BaseState.__getstate__).
if state._get_was_touched(): if state._get_was_touched():
pickle_state = dill.dumps(state, byref=True) pickle_state = state._serialize()
self._warn_if_too_large(state, len(pickle_state)) self._warn_if_too_large(state, len(pickle_state))
await self.redis.set( await self.redis.set(
_substate_key(client_token, state), _substate_key(client_token, state),

View File

@ -123,3 +123,7 @@ class DynamicComponentMissingLibrary(ReflexError, ValueError):
class SetUndefinedStateVarError(ReflexError, AttributeError): class SetUndefinedStateVarError(ReflexError, AttributeError):
"""Raised when setting the value of a var without first declaring it.""" """Raised when setting the value of a var without first declaring it."""
class StateSchemaMismatchError(ReflexError, TypeError):
"""Raised when the serialized schema of a state class does not match the current schema."""