implement state compression

This commit is contained in:
Benedikt Bartscher 2024-11-25 00:07:25 +01:00
parent c9656eef6e
commit ec1a2d3a97
No known key found for this signature in database
2 changed files with 55 additions and 5 deletions

View File

@ -567,6 +567,12 @@ class EnvironmentVariables:
# The maximum size of the reflex state in kilobytes. # The maximum size of the reflex state in kilobytes.
REFLEX_STATE_SIZE_LIMIT: EnvVar[int] = env_var(1000) REFLEX_STATE_SIZE_LIMIT: EnvVar[int] = env_var(1000)
# Whether to compress the reflex state.
REFLEX_COMPRESS_STATE: EnvVar[bool] = env_var(False)
# Threshold for the reflex state compression in bytes.
REFLEX_COMPRESS_THRESHOLD: EnvVar[int] = env_var(1024)
environment = EnvironmentVariables() environment = EnvironmentVariables()

View File

@ -17,6 +17,7 @@ import uuid
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from collections import defaultdict from collections import defaultdict
from hashlib import md5 from hashlib import md5
from io import BytesIO
from pathlib import Path from pathlib import Path
from types import FunctionType, MethodType from types import FunctionType, MethodType
from typing import ( from typing import (
@ -143,6 +144,10 @@ HANDLED_PICKLE_ERRORS = (
ValueError, ValueError,
) )
STATE_NOT_COMPRESSED = b"\x01"
STATE_COMPRESSED = b"\x02"
STATE_CHUNK_SIZE = 1024
def _no_chain_background_task( def _no_chain_background_task(
state_cls: Type["BaseState"], name: str, fn: Callable state_cls: Type["BaseState"], name: str, fn: Callable
@ -2218,6 +2223,8 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
error = "" error = ""
try: try:
payload = pickle.dumps((self._to_schema(), self)) payload = pickle.dumps((self._to_schema(), self))
if environment.REFLEX_PERF_MODE.get() != PerformanceMode.OFF:
self._check_state_size(len(payload))
except HANDLED_PICKLE_ERRORS as og_pickle_error: except HANDLED_PICKLE_ERRORS as og_pickle_error:
error = ( error = (
f"Failed to serialize state {self.get_full_name()} due to unpicklable object. " f"Failed to serialize state {self.get_full_name()} due to unpicklable object. "
@ -2236,12 +2243,23 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
error += f"Dill was also unable to pickle the state: {ex}" error += f"Dill was also unable to pickle the state: {ex}"
console.warn(error) console.warn(error)
size = len(payload)
if environment.REFLEX_PERF_MODE.get() != PerformanceMode.OFF: if environment.REFLEX_PERF_MODE.get() != PerformanceMode.OFF:
self._check_state_size(len(payload)) self._check_state_size(size)
if not payload: if not payload:
raise StateSerializationError(error) raise StateSerializationError(error)
if environment.REFLEX_COMPRESS_STATE.get():
if size > environment.REFLEX_COMPRESS_THRESHOLD.get():
from blosc2 import compress
payload = compress(payload)
prefix = STATE_COMPRESSED
else:
prefix = STATE_NOT_COMPRESSED
payload = prefix + payload
return payload return payload
@classmethod @classmethod
@ -2264,14 +2282,40 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
StateSchemaMismatchError: If the state schema does not match the expected schema. StateSchemaMismatchError: If the state schema does not match the expected schema.
""" """
if data is not None and fp is None: if data is not None and fp is None:
(substate_schema, state) = pickle.loads(data) if environment.REFLEX_COMPRESS_STATE.get():
# get first byte to determine if compressed
is_compressed = data[:1] == STATE_COMPRESSED
# remove compression byte
data = data[1:]
if is_compressed:
from blosc2 import decompress
data = decompress(data)
data = pickle.loads(data) # type: ignore
elif fp is not None and data is None: elif fp is not None and data is None:
(substate_schema, state) = pickle.load(fp) if environment.REFLEX_COMPRESS_STATE.get():
# read first byte to determine if compressed
is_compressed = fp.read(1) == STATE_COMPRESSED
if is_compressed:
from blosc2 import SChunk
schunk = SChunk(chunksize=STATE_CHUNK_SIZE)
while chunk := fp.read(STATE_CHUNK_SIZE):
schunk.append_data(chunk)
fp = BytesIO()
for chunk_index in range(schunk.nchunks):
fp.write(schunk.decompress_chunk(chunk_index))
data = pickle.load(fp)
else: else:
raise ValueError("Only one of `data` or `fp` must be provided") raise ValueError("Only one of `data` or `fp` must be provided")
if substate_schema != state._to_schema(): substate_schema, state = data # type: ignore
if substate_schema != state._to_schema(): # type: ignore
raise StateSchemaMismatchError() raise StateSchemaMismatchError()
return state return state # type: ignore
class State(BaseState): class State(BaseState):