Merge 86342c79f3
into 8b2c7291d3
This commit is contained in:
commit
f90fb78b97
155
benchmarks/benchmark_pickle.py
Normal file
155
benchmarks/benchmark_pickle.py
Normal file
@ -0,0 +1,155 @@
|
||||
"""Benchmarks for pickling and unpickling states."""
|
||||
|
||||
import logging
|
||||
import pickle
|
||||
import time
|
||||
import uuid
|
||||
from typing import Tuple
|
||||
|
||||
import pytest
|
||||
from pytest_benchmark.fixture import BenchmarkFixture
|
||||
from redis import Redis
|
||||
|
||||
from reflex.state import State
|
||||
from reflex.utils.prerequisites import get_redis_sync
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
SLOW_REDIS_MAP: dict[bytes, bytes] = {}
|
||||
|
||||
|
||||
class SlowRedis:
|
||||
"""Simulate a slow Redis client which uses a global dict and sleeps based on size."""
|
||||
|
||||
def __init__(self):
|
||||
"""Initialize the slow Redis client."""
|
||||
pass
|
||||
|
||||
def set(self, key: bytes, value: bytes) -> None:
|
||||
"""Set a key-value pair in the slow Redis client.
|
||||
|
||||
Args:
|
||||
key: The key.
|
||||
value: The value.
|
||||
"""
|
||||
SLOW_REDIS_MAP[key] = value
|
||||
size = len(value)
|
||||
sleep_time = (size / 1e6) + 0.05
|
||||
time.sleep(sleep_time)
|
||||
|
||||
def get(self, key: bytes) -> bytes:
|
||||
"""Get a value from the slow Redis client.
|
||||
|
||||
Args:
|
||||
key: The key.
|
||||
|
||||
Returns:
|
||||
The value.
|
||||
"""
|
||||
value = SLOW_REDIS_MAP[key]
|
||||
size = len(value)
|
||||
sleep_time = (size / 1e6) + 0.05
|
||||
time.sleep(sleep_time)
|
||||
return value
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"protocol",
|
||||
argvalues=[
|
||||
pickle.DEFAULT_PROTOCOL,
|
||||
pickle.HIGHEST_PROTOCOL,
|
||||
],
|
||||
ids=[
|
||||
"pickle_default",
|
||||
"pickle_highest",
|
||||
],
|
||||
)
|
||||
@pytest.mark.parametrize(
|
||||
"redis",
|
||||
[
|
||||
Redis,
|
||||
SlowRedis,
|
||||
None,
|
||||
],
|
||||
ids=[
|
||||
"redis",
|
||||
"slow_redis",
|
||||
"no_redis",
|
||||
],
|
||||
)
|
||||
@pytest.mark.parametrize(
|
||||
"should_compress", [True, False], ids=["compress", "no_compress"]
|
||||
)
|
||||
@pytest.mark.benchmark(disable_gc=True)
|
||||
def test_pickle(
|
||||
request: pytest.FixtureRequest,
|
||||
benchmark: BenchmarkFixture,
|
||||
big_state: State,
|
||||
big_state_size: Tuple[int, str],
|
||||
protocol: int,
|
||||
redis: Redis | SlowRedis | None,
|
||||
should_compress: bool,
|
||||
) -> None:
|
||||
"""Benchmark pickling a big state.
|
||||
|
||||
Args:
|
||||
request: The pytest fixture request object.
|
||||
benchmark: The benchmark fixture.
|
||||
big_state: The big state fixture.
|
||||
big_state_size: The big state size fixture.
|
||||
protocol: The pickle protocol.
|
||||
redis: Whether to use Redis.
|
||||
should_compress: Whether to compress the pickled state.
|
||||
"""
|
||||
if should_compress:
|
||||
try:
|
||||
from blosc2 import compress, decompress
|
||||
except ImportError:
|
||||
pytest.skip("Blosc is not available.")
|
||||
|
||||
def dump(obj: State) -> bytes:
|
||||
return compress(pickle.dumps(obj, protocol=protocol)) # pyright: ignore[reportReturnType]
|
||||
|
||||
def load(data: bytes) -> State:
|
||||
return pickle.loads(decompress(data)) # pyright: ignore[reportAny,reportArgumentType]
|
||||
|
||||
else:
|
||||
|
||||
def dump(obj: State) -> bytes:
|
||||
return pickle.dumps(obj, protocol=protocol)
|
||||
|
||||
def load(data: bytes) -> State:
|
||||
return pickle.loads(data)
|
||||
|
||||
if redis:
|
||||
if redis == Redis:
|
||||
redis_client = get_redis_sync()
|
||||
if redis_client is None:
|
||||
pytest.skip("Redis is not available.")
|
||||
else:
|
||||
redis_client = SlowRedis()
|
||||
|
||||
key = str(uuid.uuid4()).encode()
|
||||
|
||||
def run(obj: State) -> None:
|
||||
_ = redis_client.set(key, dump(obj))
|
||||
_ = load(redis_client.get(key)) # pyright: ignore[reportArgumentType]
|
||||
|
||||
else:
|
||||
|
||||
def run(obj: State) -> None:
|
||||
_ = load(dump(obj))
|
||||
|
||||
# calculate size before benchmark to not affect it
|
||||
out = dump(big_state)
|
||||
size = len(out)
|
||||
log.warning(f"{protocol=}, {redis=}, {should_compress=}, {size=}")
|
||||
|
||||
benchmark.extra_info["size"] = size
|
||||
benchmark.extra_info["redis"] = redis
|
||||
benchmark.extra_info["pickle_protocol"] = protocol
|
||||
redis_group = redis.__name__ if redis else "no_redis" # pyright: ignore[reportUnknownMemberType,reportAttributeAccessIssue,reportUnknownVariableType]
|
||||
benchmark.group = f"{redis_group}_{big_state_size[1]}"
|
||||
|
||||
_ = benchmark(run, big_state)
|
@ -1,7 +1,11 @@
|
||||
"""Shared conftest for all benchmark tests."""
|
||||
|
||||
from typing import Tuple
|
||||
|
||||
import pandas as pd
|
||||
import pytest
|
||||
|
||||
from reflex.state import State
|
||||
from reflex.testing import AppHarness, AppHarnessProd
|
||||
|
||||
|
||||
@ -18,3 +22,43 @@ def app_harness_env(request):
|
||||
The AppHarness class to use for the test.
|
||||
"""
|
||||
return request.param
|
||||
|
||||
|
||||
@pytest.fixture(params=[(10, "SmallState"), (2000, "BigState")], ids=["small", "big"])
|
||||
def big_state_size(request: pytest.FixtureRequest) -> int:
|
||||
"""The size of the DataFrame.
|
||||
|
||||
Args:
|
||||
request: The pytest fixture request object.
|
||||
|
||||
Returns:
|
||||
The size of the BigState
|
||||
"""
|
||||
return request.param
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def big_state(big_state_size: Tuple[int, str]) -> State:
|
||||
"""A big state with a dictionary and a list of DataFrames.
|
||||
|
||||
Args:
|
||||
big_state_size: The size of the big state.
|
||||
|
||||
Returns:
|
||||
A big state instance.
|
||||
"""
|
||||
size, _ = big_state_size
|
||||
|
||||
class BigState(State):
|
||||
"""A big state."""
|
||||
|
||||
d: dict[str, int]
|
||||
d_repeated: dict[str, int]
|
||||
df: list[pd.DataFrame]
|
||||
|
||||
d = {str(i): i for i in range(size)}
|
||||
d_repeated = {str(i): i for i in range(size)}
|
||||
df = [pd.DataFrame({"a": [i]}) for i in range(size)]
|
||||
|
||||
state = BigState(d=d, df=df, d_repeated=d_repeated)
|
||||
return state
|
||||
|
478
poetry.lock
generated
478
poetry.lock
generated
File diff suppressed because it is too large
Load Diff
@ -72,6 +72,7 @@ pytest-benchmark = ">=4.0.0,<6.0"
|
||||
playwright = ">=1.46.0"
|
||||
pytest-playwright = ">=0.5.1"
|
||||
pytest-codspeed = "^3.1.2"
|
||||
blosc2 = { version = ">=2.7.1", python = ">=3.11" }
|
||||
|
||||
[tool.poetry.scripts]
|
||||
reflex = "reflex.reflex:cli"
|
||||
@ -87,8 +88,37 @@ reportIncompatibleMethodOverride = false
|
||||
target-version = "py310"
|
||||
output-format = "concise"
|
||||
lint.isort.split-on-trailing-comma = false
|
||||
lint.select = ["ANN001","B", "C4", "D", "E", "ERA", "F", "FURB", "I", "N", "PERF", "PGH", "PTH", "RUF", "SIM", "T", "TRY", "W"]
|
||||
lint.ignore = ["B008", "D205", "E501", "F403", "SIM115", "RUF006", "RUF008", "RUF012", "TRY0"]
|
||||
lint.select = [
|
||||
"ANN001",
|
||||
"B",
|
||||
"C4",
|
||||
"D",
|
||||
"E",
|
||||
"ERA",
|
||||
"F",
|
||||
"FURB",
|
||||
"I",
|
||||
"N",
|
||||
"PERF",
|
||||
"PGH",
|
||||
"PTH",
|
||||
"RUF",
|
||||
"SIM",
|
||||
"T",
|
||||
"TRY",
|
||||
"W",
|
||||
]
|
||||
lint.ignore = [
|
||||
"B008",
|
||||
"D205",
|
||||
"E501",
|
||||
"F403",
|
||||
"SIM115",
|
||||
"RUF006",
|
||||
"RUF008",
|
||||
"RUF012",
|
||||
"TRY0",
|
||||
]
|
||||
lint.pydocstyle.convention = "google"
|
||||
|
||||
[tool.ruff.lint.per-file-ignores]
|
||||
|
@ -571,6 +571,12 @@ class EnvironmentVariables:
|
||||
# Whether to use the turbopack bundler.
|
||||
REFLEX_USE_TURBOPACK: EnvVar[bool] = env_var(True)
|
||||
|
||||
# Whether to compress the reflex state.
|
||||
REFLEX_COMPRESS_STATE: EnvVar[bool] = env_var(False)
|
||||
|
||||
# Threshold for the reflex state compression in bytes.
|
||||
REFLEX_COMPRESS_THRESHOLD: EnvVar[int] = env_var(1024)
|
||||
|
||||
|
||||
environment = EnvironmentVariables()
|
||||
|
||||
|
@ -16,6 +16,7 @@ import typing
|
||||
import uuid
|
||||
from abc import ABC, abstractmethod
|
||||
from hashlib import md5
|
||||
from io import BytesIO
|
||||
from pathlib import Path
|
||||
from types import FunctionType, MethodType
|
||||
from typing import (
|
||||
@ -145,6 +146,10 @@ HANDLED_PICKLE_ERRORS = (
|
||||
ValueError,
|
||||
)
|
||||
|
||||
STATE_NOT_COMPRESSED = b"\x01"
|
||||
STATE_COMPRESSED = b"\x02"
|
||||
STATE_CHUNK_SIZE = 1024
|
||||
|
||||
# For BaseState.get_var_value
|
||||
VAR_TYPE = TypeVar("VAR_TYPE")
|
||||
|
||||
@ -2200,12 +2205,23 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
|
||||
error += f"Dill was also unable to pickle the state: {ex}"
|
||||
console.warn(error)
|
||||
|
||||
size = len(payload)
|
||||
if environment.REFLEX_PERF_MODE.get() != PerformanceMode.OFF:
|
||||
self._check_state_size(len(payload))
|
||||
self._check_state_size(size)
|
||||
|
||||
if not payload:
|
||||
raise StateSerializationError(error)
|
||||
|
||||
if environment.REFLEX_COMPRESS_STATE.get():
|
||||
if size > environment.REFLEX_COMPRESS_THRESHOLD.get():
|
||||
from blosc2 import compress
|
||||
|
||||
payload = compress(payload, _ignore_multiple_size=True)
|
||||
prefix = STATE_COMPRESSED
|
||||
else:
|
||||
prefix = STATE_NOT_COMPRESSED
|
||||
payload = prefix + payload # pyright: ignore[reportOperatorIssue,reportUnknownVariableType]
|
||||
|
||||
return payload
|
||||
|
||||
@classmethod
|
||||
@ -2228,14 +2244,40 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
|
||||
StateSchemaMismatchError: If the state schema does not match the expected schema.
|
||||
"""
|
||||
if data is not None and fp is None:
|
||||
(substate_schema, state) = pickle.loads(data)
|
||||
if environment.REFLEX_COMPRESS_STATE.get():
|
||||
# get first byte to determine if compressed
|
||||
is_compressed = data[:1] == STATE_COMPRESSED
|
||||
# remove compression byte
|
||||
data = data[1:]
|
||||
if is_compressed:
|
||||
from blosc2 import decompress
|
||||
|
||||
data = decompress(data) # pyright: ignore[reportAssignmentType]
|
||||
data = pickle.loads(data) # pyright: ignore[reportArgumentType]
|
||||
elif fp is not None and data is None:
|
||||
(substate_schema, state) = pickle.load(fp)
|
||||
if environment.REFLEX_COMPRESS_STATE.get():
|
||||
# read first byte to determine if compressed
|
||||
is_compressed = fp.read(1) == STATE_COMPRESSED
|
||||
if is_compressed:
|
||||
from blosc2 import SChunk
|
||||
|
||||
schunk = SChunk(chunksize=STATE_CHUNK_SIZE)
|
||||
|
||||
while chunk := fp.read(STATE_CHUNK_SIZE):
|
||||
schunk.append_data(chunk)
|
||||
|
||||
fp = BytesIO()
|
||||
|
||||
for chunk_index in range(schunk.nchunks):
|
||||
fp.write(schunk.decompress_chunk(chunk_index)) # pyright: ignore[reportArgumentType,reportUnusedCallResult]
|
||||
|
||||
data = pickle.load(fp)
|
||||
else:
|
||||
raise ValueError("Only one of `data` or `fp` must be provided")
|
||||
if substate_schema != state._to_schema():
|
||||
substate_schema, state = data # pyright: ignore[reportUnknownVariableType,reportGeneralTypeIssues]
|
||||
if substate_schema != state._to_schema(): # pyright: ignore[reportAttributeAccessIssue,reportUnknownMemberType]
|
||||
raise StateSchemaMismatchError()
|
||||
return state
|
||||
return state # pyright: ignore[reportUnknownVariableType,reportReturnType]
|
||||
|
||||
|
||||
T_STATE = TypeVar("T_STATE", bound=BaseState)
|
||||
|
Loading…
Reference in New Issue
Block a user