implement state compression
This commit is contained in:
parent
c9656eef6e
commit
ec1a2d3a97
@ -567,6 +567,12 @@ class EnvironmentVariables:
|
|||||||
# The maximum size of the reflex state in kilobytes.
|
# The maximum size of the reflex state in kilobytes.
|
||||||
REFLEX_STATE_SIZE_LIMIT: EnvVar[int] = env_var(1000)
|
REFLEX_STATE_SIZE_LIMIT: EnvVar[int] = env_var(1000)
|
||||||
|
|
||||||
|
# Whether to compress the reflex state.
|
||||||
|
REFLEX_COMPRESS_STATE: EnvVar[bool] = env_var(False)
|
||||||
|
|
||||||
|
# Threshold for the reflex state compression in bytes.
|
||||||
|
REFLEX_COMPRESS_THRESHOLD: EnvVar[int] = env_var(1024)
|
||||||
|
|
||||||
|
|
||||||
environment = EnvironmentVariables()
|
environment = EnvironmentVariables()
|
||||||
|
|
||||||
|
@ -17,6 +17,7 @@ import uuid
|
|||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from hashlib import md5
|
from hashlib import md5
|
||||||
|
from io import BytesIO
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from types import FunctionType, MethodType
|
from types import FunctionType, MethodType
|
||||||
from typing import (
|
from typing import (
|
||||||
@ -143,6 +144,10 @@ HANDLED_PICKLE_ERRORS = (
|
|||||||
ValueError,
|
ValueError,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
STATE_NOT_COMPRESSED = b"\x01"
|
||||||
|
STATE_COMPRESSED = b"\x02"
|
||||||
|
STATE_CHUNK_SIZE = 1024
|
||||||
|
|
||||||
|
|
||||||
def _no_chain_background_task(
|
def _no_chain_background_task(
|
||||||
state_cls: Type["BaseState"], name: str, fn: Callable
|
state_cls: Type["BaseState"], name: str, fn: Callable
|
||||||
@ -2218,6 +2223,8 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
|
|||||||
error = ""
|
error = ""
|
||||||
try:
|
try:
|
||||||
payload = pickle.dumps((self._to_schema(), self))
|
payload = pickle.dumps((self._to_schema(), self))
|
||||||
|
if environment.REFLEX_PERF_MODE.get() != PerformanceMode.OFF:
|
||||||
|
self._check_state_size(len(payload))
|
||||||
except HANDLED_PICKLE_ERRORS as og_pickle_error:
|
except HANDLED_PICKLE_ERRORS as og_pickle_error:
|
||||||
error = (
|
error = (
|
||||||
f"Failed to serialize state {self.get_full_name()} due to unpicklable object. "
|
f"Failed to serialize state {self.get_full_name()} due to unpicklable object. "
|
||||||
@ -2236,12 +2243,23 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
|
|||||||
error += f"Dill was also unable to pickle the state: {ex}"
|
error += f"Dill was also unable to pickle the state: {ex}"
|
||||||
console.warn(error)
|
console.warn(error)
|
||||||
|
|
||||||
|
size = len(payload)
|
||||||
if environment.REFLEX_PERF_MODE.get() != PerformanceMode.OFF:
|
if environment.REFLEX_PERF_MODE.get() != PerformanceMode.OFF:
|
||||||
self._check_state_size(len(payload))
|
self._check_state_size(size)
|
||||||
|
|
||||||
if not payload:
|
if not payload:
|
||||||
raise StateSerializationError(error)
|
raise StateSerializationError(error)
|
||||||
|
|
||||||
|
if environment.REFLEX_COMPRESS_STATE.get():
|
||||||
|
if size > environment.REFLEX_COMPRESS_THRESHOLD.get():
|
||||||
|
from blosc2 import compress
|
||||||
|
|
||||||
|
payload = compress(payload)
|
||||||
|
prefix = STATE_COMPRESSED
|
||||||
|
else:
|
||||||
|
prefix = STATE_NOT_COMPRESSED
|
||||||
|
payload = prefix + payload
|
||||||
|
|
||||||
return payload
|
return payload
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@ -2264,14 +2282,40 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
|
|||||||
StateSchemaMismatchError: If the state schema does not match the expected schema.
|
StateSchemaMismatchError: If the state schema does not match the expected schema.
|
||||||
"""
|
"""
|
||||||
if data is not None and fp is None:
|
if data is not None and fp is None:
|
||||||
(substate_schema, state) = pickle.loads(data)
|
if environment.REFLEX_COMPRESS_STATE.get():
|
||||||
|
# get first byte to determine if compressed
|
||||||
|
is_compressed = data[:1] == STATE_COMPRESSED
|
||||||
|
# remove compression byte
|
||||||
|
data = data[1:]
|
||||||
|
if is_compressed:
|
||||||
|
from blosc2 import decompress
|
||||||
|
|
||||||
|
data = decompress(data)
|
||||||
|
data = pickle.loads(data) # type: ignore
|
||||||
elif fp is not None and data is None:
|
elif fp is not None and data is None:
|
||||||
(substate_schema, state) = pickle.load(fp)
|
if environment.REFLEX_COMPRESS_STATE.get():
|
||||||
|
# read first byte to determine if compressed
|
||||||
|
is_compressed = fp.read(1) == STATE_COMPRESSED
|
||||||
|
if is_compressed:
|
||||||
|
from blosc2 import SChunk
|
||||||
|
|
||||||
|
schunk = SChunk(chunksize=STATE_CHUNK_SIZE)
|
||||||
|
|
||||||
|
while chunk := fp.read(STATE_CHUNK_SIZE):
|
||||||
|
schunk.append_data(chunk)
|
||||||
|
|
||||||
|
fp = BytesIO()
|
||||||
|
|
||||||
|
for chunk_index in range(schunk.nchunks):
|
||||||
|
fp.write(schunk.decompress_chunk(chunk_index))
|
||||||
|
|
||||||
|
data = pickle.load(fp)
|
||||||
else:
|
else:
|
||||||
raise ValueError("Only one of `data` or `fp` must be provided")
|
raise ValueError("Only one of `data` or `fp` must be provided")
|
||||||
if substate_schema != state._to_schema():
|
substate_schema, state = data # type: ignore
|
||||||
|
if substate_schema != state._to_schema(): # type: ignore
|
||||||
raise StateSchemaMismatchError()
|
raise StateSchemaMismatchError()
|
||||||
return state
|
return state # type: ignore
|
||||||
|
|
||||||
|
|
||||||
class State(BaseState):
|
class State(BaseState):
|
||||||
|
Loading…
Reference in New Issue
Block a user