From 52d1a2c5ec0536635a98e3771ce8d8cc0555601d Mon Sep 17 00:00:00 2001 From: Masen Furer Date: Thu, 3 Oct 2024 13:53:58 -0700 Subject: [PATCH] Use regular `pickle` module from stdlib --- reflex/state.py | 59 ++++++++++++++++++++++++++++++++------ reflex/utils/exceptions.py | 4 +++ 2 files changed, 55 insertions(+), 8 deletions(-) diff --git a/reflex/state.py b/reflex/state.py index b1988e38a..0d2056cd0 100644 --- a/reflex/state.py +++ b/reflex/state.py @@ -9,6 +9,7 @@ import dataclasses import functools import inspect import os +import pickle import uuid from abc import ABC, abstractmethod from collections import defaultdict @@ -19,6 +20,7 @@ from typing import ( TYPE_CHECKING, Any, AsyncIterator, + BinaryIO, Callable, ClassVar, Dict, @@ -76,6 +78,7 @@ from reflex.utils.exceptions import ( ImmutableStateError, LockExpiredError, SetUndefinedStateVarError, + StateSchemaMismatchError, ) from reflex.utils.exec import is_testing_env from reflex.utils.serializers import serializer @@ -1914,7 +1917,7 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow): def __getstate__(self): """Get the state for redis serialization. - This method is called by cloudpickle to serialize the object. + This method is called by pickle to serialize the object. It explicitly removes parent_state and substates because those are serialized separately by the StateManagerRedis to allow for better horizontal scaling as state size increases. @@ -1930,6 +1933,43 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow): state["__dict__"].pop("_was_touched", None) return state + def _serialize(self) -> bytes: + """Serialize the state for redis. + + Returns: + The serialized state. + """ + return pickle.dumps((state_to_schema(self), self)) + + @classmethod + def _deserialize( + cls, data: bytes | None = None, fp: BinaryIO | None = None + ) -> BaseState: + """Deserialize the state from redis/disk. + + data and fp are mutually exclusive, but one must be provided. + + Args: + data: The serialized state data. + fp: The file pointer to the serialized state data. + + Returns: + The deserialized state. + + Raises: + ValueError: If both data and fp are provided, or neither are provided. + StateSchemaMismatchError: If the state schema does not match the expected schema. + """ + if data is not None and fp is None: + (substate_schema, state) = pickle.loads(data) + elif fp is not None and data is None: + (substate_schema, state) = pickle.load(fp) + else: + raise ValueError("Only one of `data` or `fp` must be provided") + if substate_schema != state_to_schema(state): + raise StateSchemaMismatchError() + return state + class State(BaseState): """The app Base State.""" @@ -2086,7 +2126,11 @@ class ComponentState(State, mixin=True): """ cls._per_component_state_instance_count += 1 state_cls_name = f"{cls.__name__}_n{cls._per_component_state_instance_count}" - component_state = type(state_cls_name, (cls, State), {}, mixin=False) + component_state = type( + state_cls_name, (cls, State), {"__module__": __name__}, mixin=False + ) + # Save a reference to the dynamic state for pickle/unpickle. + globals()[state_cls_name] = component_state component = component_state.get_component(*children, **props) component.State = component_state return component @@ -2552,7 +2596,7 @@ def is_serializable(value: Any) -> bool: Whether the value is serializable. """ try: - return bool(dill.dumps(value)) + return bool(pickle.dumps(value)) except Exception: return False @@ -2688,8 +2732,7 @@ class StateManagerDisk(StateManager): if token_path.exists(): try: with token_path.open(mode="rb") as file: - (substate_schema, substate) = dill.load(file) - if substate_schema == state_to_schema(substate): + substate = BaseState._deserialize(fp=file) await self.populate_substates(client_token, substate, root_state) return substate except Exception: @@ -2747,7 +2790,7 @@ class StateManagerDisk(StateManager): self.states[substate_token] = substate - state_dilled = dill.dumps((state_to_schema(substate), substate)) + state_dilled = substate._serialize() if not self.states_directory.exists(): self.states_directory.mkdir(parents=True, exist_ok=True) self.token_path(substate_token).write_bytes(state_dilled) @@ -2948,7 +2991,7 @@ class StateManagerRedis(StateManager): if redis_state is not None: # Deserialize the substate. - state = dill.loads(redis_state) + state = BaseState._deserialize(data=redis_state) # Populate parent state if missing and requested. if parent_state is None: @@ -3060,7 +3103,7 @@ class StateManagerRedis(StateManager): ) # Persist only the given state (parents or substates are excluded by BaseState.__getstate__). if state._get_was_touched(): - pickle_state = dill.dumps(state, byref=True) + pickle_state = state._serialize() self._warn_if_too_large(state, len(pickle_state)) await self.redis.set( _substate_key(client_token, state), diff --git a/reflex/utils/exceptions.py b/reflex/utils/exceptions.py index 0383f7ba6..8bce605b5 100644 --- a/reflex/utils/exceptions.py +++ b/reflex/utils/exceptions.py @@ -123,3 +123,7 @@ class DynamicComponentMissingLibrary(ReflexError, ValueError): class SetUndefinedStateVarError(ReflexError, AttributeError): """Raised when setting the value of a var without first declaring it.""" + + +class StateSchemaMismatchError(ReflexError, TypeError): + """Raised when the serialized schema of a state class does not match the current schema."""