implement state compression

This commit is contained in:
Benedikt Bartscher 2024-11-25 00:07:25 +01:00
parent c9656eef6e
commit ec1a2d3a97
No known key found for this signature in database
2 changed files with 55 additions and 5 deletions

View File

@ -567,6 +567,12 @@ class EnvironmentVariables:
# The maximum size of the reflex state in kilobytes.
REFLEX_STATE_SIZE_LIMIT: EnvVar[int] = env_var(1000)
# Whether to compress the reflex state.
REFLEX_COMPRESS_STATE: EnvVar[bool] = env_var(False)
# Threshold for the reflex state compression in bytes.
REFLEX_COMPRESS_THRESHOLD: EnvVar[int] = env_var(1024)
environment = EnvironmentVariables()

View File

@ -17,6 +17,7 @@ import uuid
from abc import ABC, abstractmethod
from collections import defaultdict
from hashlib import md5
from io import BytesIO
from pathlib import Path
from types import FunctionType, MethodType
from typing import (
@ -143,6 +144,10 @@ HANDLED_PICKLE_ERRORS = (
ValueError,
)
STATE_NOT_COMPRESSED = b"\x01"
STATE_COMPRESSED = b"\x02"
STATE_CHUNK_SIZE = 1024
def _no_chain_background_task(
state_cls: Type["BaseState"], name: str, fn: Callable
@ -2218,6 +2223,8 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
error = ""
try:
payload = pickle.dumps((self._to_schema(), self))
if environment.REFLEX_PERF_MODE.get() != PerformanceMode.OFF:
self._check_state_size(len(payload))
except HANDLED_PICKLE_ERRORS as og_pickle_error:
error = (
f"Failed to serialize state {self.get_full_name()} due to unpicklable object. "
@ -2236,12 +2243,23 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
error += f"Dill was also unable to pickle the state: {ex}"
console.warn(error)
size = len(payload)
if environment.REFLEX_PERF_MODE.get() != PerformanceMode.OFF:
self._check_state_size(len(payload))
self._check_state_size(size)
if not payload:
raise StateSerializationError(error)
if environment.REFLEX_COMPRESS_STATE.get():
if size > environment.REFLEX_COMPRESS_THRESHOLD.get():
from blosc2 import compress
payload = compress(payload)
prefix = STATE_COMPRESSED
else:
prefix = STATE_NOT_COMPRESSED
payload = prefix + payload
return payload
@classmethod
@ -2264,14 +2282,40 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
StateSchemaMismatchError: If the state schema does not match the expected schema.
"""
if data is not None and fp is None:
(substate_schema, state) = pickle.loads(data)
if environment.REFLEX_COMPRESS_STATE.get():
# get first byte to determine if compressed
is_compressed = data[:1] == STATE_COMPRESSED
# remove compression byte
data = data[1:]
if is_compressed:
from blosc2 import decompress
data = decompress(data)
data = pickle.loads(data) # type: ignore
elif fp is not None and data is None:
(substate_schema, state) = pickle.load(fp)
if environment.REFLEX_COMPRESS_STATE.get():
# read first byte to determine if compressed
is_compressed = fp.read(1) == STATE_COMPRESSED
if is_compressed:
from blosc2 import SChunk
schunk = SChunk(chunksize=STATE_CHUNK_SIZE)
while chunk := fp.read(STATE_CHUNK_SIZE):
schunk.append_data(chunk)
fp = BytesIO()
for chunk_index in range(schunk.nchunks):
fp.write(schunk.decompress_chunk(chunk_index))
data = pickle.load(fp)
else:
raise ValueError("Only one of `data` or `fp` must be provided")
if substate_schema != state._to_schema():
substate_schema, state = data # type: ignore
if substate_schema != state._to_schema(): # type: ignore
raise StateSchemaMismatchError()
return state
return state # type: ignore
class State(BaseState):