diff --git a/reflex/config.py b/reflex/config.py index 0579b019f..fa54fc182 100644 --- a/reflex/config.py +++ b/reflex/config.py @@ -567,6 +567,12 @@ class EnvironmentVariables: # The maximum size of the reflex state in kilobytes. REFLEX_STATE_SIZE_LIMIT: EnvVar[int] = env_var(1000) + # Whether to compress the reflex state. + REFLEX_COMPRESS_STATE: EnvVar[bool] = env_var(False) + + # Threshold for the reflex state compression in bytes. + REFLEX_COMPRESS_THRESHOLD: EnvVar[int] = env_var(1024) + environment = EnvironmentVariables() diff --git a/reflex/state.py b/reflex/state.py index 1cd3e2c3e..d0e76370d 100644 --- a/reflex/state.py +++ b/reflex/state.py @@ -17,6 +17,7 @@ import uuid from abc import ABC, abstractmethod from collections import defaultdict from hashlib import md5 +from io import BytesIO from pathlib import Path from types import FunctionType, MethodType from typing import ( @@ -143,6 +144,10 @@ HANDLED_PICKLE_ERRORS = ( ValueError, ) +STATE_NOT_COMPRESSED = b"\x01" +STATE_COMPRESSED = b"\x02" +STATE_CHUNK_SIZE = 1024 + def _no_chain_background_task( state_cls: Type["BaseState"], name: str, fn: Callable @@ -2218,6 +2223,8 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow): error = "" try: payload = pickle.dumps((self._to_schema(), self)) + if environment.REFLEX_PERF_MODE.get() != PerformanceMode.OFF: + self._check_state_size(len(payload)) except HANDLED_PICKLE_ERRORS as og_pickle_error: error = ( f"Failed to serialize state {self.get_full_name()} due to unpicklable object. " @@ -2236,12 +2243,23 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow): error += f"Dill was also unable to pickle the state: {ex}" console.warn(error) + size = len(payload) if environment.REFLEX_PERF_MODE.get() != PerformanceMode.OFF: - self._check_state_size(len(payload)) + self._check_state_size(size) if not payload: raise StateSerializationError(error) + if environment.REFLEX_COMPRESS_STATE.get(): + if size > environment.REFLEX_COMPRESS_THRESHOLD.get(): + from blosc2 import compress + + payload = compress(payload) + prefix = STATE_COMPRESSED + else: + prefix = STATE_NOT_COMPRESSED + payload = prefix + payload + return payload @classmethod @@ -2264,14 +2282,40 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow): StateSchemaMismatchError: If the state schema does not match the expected schema. """ if data is not None and fp is None: - (substate_schema, state) = pickle.loads(data) + if environment.REFLEX_COMPRESS_STATE.get(): + # get first byte to determine if compressed + is_compressed = data[:1] == STATE_COMPRESSED + # remove compression byte + data = data[1:] + if is_compressed: + from blosc2 import decompress + + data = decompress(data) + data = pickle.loads(data) # type: ignore elif fp is not None and data is None: - (substate_schema, state) = pickle.load(fp) + if environment.REFLEX_COMPRESS_STATE.get(): + # read first byte to determine if compressed + is_compressed = fp.read(1) == STATE_COMPRESSED + if is_compressed: + from blosc2 import SChunk + + schunk = SChunk(chunksize=STATE_CHUNK_SIZE) + + while chunk := fp.read(STATE_CHUNK_SIZE): + schunk.append_data(chunk) + + fp = BytesIO() + + for chunk_index in range(schunk.nchunks): + fp.write(schunk.decompress_chunk(chunk_index)) + + data = pickle.load(fp) else: raise ValueError("Only one of `data` or `fp` must be provided") - if substate_schema != state._to_schema(): + substate_schema, state = data # type: ignore + if substate_schema != state._to_schema(): # type: ignore raise StateSchemaMismatchError() - return state + return state # type: ignore class State(BaseState):