reflex/benchmarks/benchmark_pickle.py
2025-01-30 15:30:34 +01:00

156 lines
4.0 KiB
Python

"""Benchmarks for pickling and unpickling states."""
import logging
import pickle
import time
import uuid
from typing import Tuple
import pytest
from pytest_benchmark.fixture import BenchmarkFixture
from redis import Redis
from reflex.state import State
from reflex.utils.prerequisites import get_redis_sync
log = logging.getLogger(__name__)
SLOW_REDIS_MAP: dict[bytes, bytes] = {}
class SlowRedis:
"""Simulate a slow Redis client which uses a global dict and sleeps based on size."""
def __init__(self):
"""Initialize the slow Redis client."""
pass
def set(self, key: bytes, value: bytes) -> None:
"""Set a key-value pair in the slow Redis client.
Args:
key: The key.
value: The value.
"""
SLOW_REDIS_MAP[key] = value
size = len(value)
sleep_time = (size / 1e6) + 0.05
time.sleep(sleep_time)
def get(self, key: bytes) -> bytes:
"""Get a value from the slow Redis client.
Args:
key: The key.
Returns:
The value.
"""
value = SLOW_REDIS_MAP[key]
size = len(value)
sleep_time = (size / 1e6) + 0.05
time.sleep(sleep_time)
return value
@pytest.mark.parametrize(
"protocol",
argvalues=[
pickle.DEFAULT_PROTOCOL,
pickle.HIGHEST_PROTOCOL,
],
ids=[
"pickle_default",
"pickle_highest",
],
)
@pytest.mark.parametrize(
"redis",
[
Redis,
SlowRedis,
None,
],
ids=[
"redis",
"slow_redis",
"no_redis",
],
)
@pytest.mark.parametrize(
"should_compress", [True, False], ids=["compress", "no_compress"]
)
@pytest.mark.benchmark(disable_gc=True)
def test_pickle(
request: pytest.FixtureRequest,
benchmark: BenchmarkFixture,
big_state: State,
big_state_size: Tuple[int, str],
protocol: int,
redis: Redis | SlowRedis | None,
should_compress: bool,
) -> None:
"""Benchmark pickling a big state.
Args:
request: The pytest fixture request object.
benchmark: The benchmark fixture.
big_state: The big state fixture.
big_state_size: The big state size fixture.
protocol: The pickle protocol.
redis: Whether to use Redis.
should_compress: Whether to compress the pickled state.
"""
if should_compress:
try:
from blosc2 import compress, decompress
except ImportError:
pytest.skip("Blosc is not available.")
def dump(obj: State) -> bytes:
return compress(pickle.dumps(obj, protocol=protocol)) # pyright: ignore[reportReturnType]
def load(data: bytes) -> State:
return pickle.loads(decompress(data)) # pyright: ignore[reportAny,reportArgumentType]
else:
def dump(obj: State) -> bytes:
return pickle.dumps(obj, protocol=protocol)
def load(data: bytes) -> State:
return pickle.loads(data)
if redis:
if redis == Redis:
redis_client = get_redis_sync()
if redis_client is None:
pytest.skip("Redis is not available.")
else:
redis_client = SlowRedis()
key = str(uuid.uuid4()).encode()
def run(obj: State) -> None:
_ = redis_client.set(key, dump(obj))
_ = load(redis_client.get(key)) # pyright: ignore[reportArgumentType]
else:
def run(obj: State) -> None:
_ = load(dump(obj))
# calculate size before benchmark to not affect it
out = dump(big_state)
size = len(out)
log.warning(f"{protocol=}, {redis=}, {should_compress=}, {size=}")
benchmark.extra_info["size"] = size
benchmark.extra_info["redis"] = redis
benchmark.extra_info["pickle_protocol"] = protocol
redis_group = redis.__name__ if redis else "no_redis" # pyright: ignore[reportUnknownMemberType,reportAttributeAccessIssue,reportUnknownVariableType]
benchmark.group = f"{redis_group}_{big_state_size[1]}"
_ = benchmark(run, big_state)