This commit is contained in:
benedikt-bartscher 2025-02-08 20:01:04 +00:00 committed by GitHub
commit f90fb78b97
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 509 additions and 260 deletions

View File

@ -0,0 +1,155 @@
"""Benchmarks for pickling and unpickling states."""
import logging
import pickle
import time
import uuid
from typing import Tuple
import pytest
from pytest_benchmark.fixture import BenchmarkFixture
from redis import Redis
from reflex.state import State
from reflex.utils.prerequisites import get_redis_sync
log = logging.getLogger(__name__)
SLOW_REDIS_MAP: dict[bytes, bytes] = {}
class SlowRedis:
"""Simulate a slow Redis client which uses a global dict and sleeps based on size."""
def __init__(self):
"""Initialize the slow Redis client."""
pass
def set(self, key: bytes, value: bytes) -> None:
"""Set a key-value pair in the slow Redis client.
Args:
key: The key.
value: The value.
"""
SLOW_REDIS_MAP[key] = value
size = len(value)
sleep_time = (size / 1e6) + 0.05
time.sleep(sleep_time)
def get(self, key: bytes) -> bytes:
"""Get a value from the slow Redis client.
Args:
key: The key.
Returns:
The value.
"""
value = SLOW_REDIS_MAP[key]
size = len(value)
sleep_time = (size / 1e6) + 0.05
time.sleep(sleep_time)
return value
@pytest.mark.parametrize(
"protocol",
argvalues=[
pickle.DEFAULT_PROTOCOL,
pickle.HIGHEST_PROTOCOL,
],
ids=[
"pickle_default",
"pickle_highest",
],
)
@pytest.mark.parametrize(
"redis",
[
Redis,
SlowRedis,
None,
],
ids=[
"redis",
"slow_redis",
"no_redis",
],
)
@pytest.mark.parametrize(
"should_compress", [True, False], ids=["compress", "no_compress"]
)
@pytest.mark.benchmark(disable_gc=True)
def test_pickle(
request: pytest.FixtureRequest,
benchmark: BenchmarkFixture,
big_state: State,
big_state_size: Tuple[int, str],
protocol: int,
redis: Redis | SlowRedis | None,
should_compress: bool,
) -> None:
"""Benchmark pickling a big state.
Args:
request: The pytest fixture request object.
benchmark: The benchmark fixture.
big_state: The big state fixture.
big_state_size: The big state size fixture.
protocol: The pickle protocol.
redis: Whether to use Redis.
should_compress: Whether to compress the pickled state.
"""
if should_compress:
try:
from blosc2 import compress, decompress
except ImportError:
pytest.skip("Blosc is not available.")
def dump(obj: State) -> bytes:
return compress(pickle.dumps(obj, protocol=protocol)) # pyright: ignore[reportReturnType]
def load(data: bytes) -> State:
return pickle.loads(decompress(data)) # pyright: ignore[reportAny,reportArgumentType]
else:
def dump(obj: State) -> bytes:
return pickle.dumps(obj, protocol=protocol)
def load(data: bytes) -> State:
return pickle.loads(data)
if redis:
if redis == Redis:
redis_client = get_redis_sync()
if redis_client is None:
pytest.skip("Redis is not available.")
else:
redis_client = SlowRedis()
key = str(uuid.uuid4()).encode()
def run(obj: State) -> None:
_ = redis_client.set(key, dump(obj))
_ = load(redis_client.get(key)) # pyright: ignore[reportArgumentType]
else:
def run(obj: State) -> None:
_ = load(dump(obj))
# calculate size before benchmark to not affect it
out = dump(big_state)
size = len(out)
log.warning(f"{protocol=}, {redis=}, {should_compress=}, {size=}")
benchmark.extra_info["size"] = size
benchmark.extra_info["redis"] = redis
benchmark.extra_info["pickle_protocol"] = protocol
redis_group = redis.__name__ if redis else "no_redis" # pyright: ignore[reportUnknownMemberType,reportAttributeAccessIssue,reportUnknownVariableType]
benchmark.group = f"{redis_group}_{big_state_size[1]}"
_ = benchmark(run, big_state)

View File

@ -1,7 +1,11 @@
"""Shared conftest for all benchmark tests."""
from typing import Tuple
import pandas as pd
import pytest
from reflex.state import State
from reflex.testing import AppHarness, AppHarnessProd
@ -18,3 +22,43 @@ def app_harness_env(request):
The AppHarness class to use for the test.
"""
return request.param
@pytest.fixture(params=[(10, "SmallState"), (2000, "BigState")], ids=["small", "big"])
def big_state_size(request: pytest.FixtureRequest) -> int:
"""The size of the DataFrame.
Args:
request: The pytest fixture request object.
Returns:
The size of the BigState
"""
return request.param
@pytest.fixture
def big_state(big_state_size: Tuple[int, str]) -> State:
"""A big state with a dictionary and a list of DataFrames.
Args:
big_state_size: The size of the big state.
Returns:
A big state instance.
"""
size, _ = big_state_size
class BigState(State):
"""A big state."""
d: dict[str, int]
d_repeated: dict[str, int]
df: list[pd.DataFrame]
d = {str(i): i for i in range(size)}
d_repeated = {str(i): i for i in range(size)}
df = [pd.DataFrame({"a": [i]}) for i in range(size)]
state = BigState(d=d, df=df, d_repeated=d_repeated)
return state

478
poetry.lock generated

File diff suppressed because it is too large Load Diff

View File

@ -72,6 +72,7 @@ pytest-benchmark = ">=4.0.0,<6.0"
playwright = ">=1.46.0"
pytest-playwright = ">=0.5.1"
pytest-codspeed = "^3.1.2"
blosc2 = { version = ">=2.7.1", python = ">=3.11" }
[tool.poetry.scripts]
reflex = "reflex.reflex:cli"
@ -87,8 +88,37 @@ reportIncompatibleMethodOverride = false
target-version = "py310"
output-format = "concise"
lint.isort.split-on-trailing-comma = false
lint.select = ["ANN001","B", "C4", "D", "E", "ERA", "F", "FURB", "I", "N", "PERF", "PGH", "PTH", "RUF", "SIM", "T", "TRY", "W"]
lint.ignore = ["B008", "D205", "E501", "F403", "SIM115", "RUF006", "RUF008", "RUF012", "TRY0"]
lint.select = [
"ANN001",
"B",
"C4",
"D",
"E",
"ERA",
"F",
"FURB",
"I",
"N",
"PERF",
"PGH",
"PTH",
"RUF",
"SIM",
"T",
"TRY",
"W",
]
lint.ignore = [
"B008",
"D205",
"E501",
"F403",
"SIM115",
"RUF006",
"RUF008",
"RUF012",
"TRY0",
]
lint.pydocstyle.convention = "google"
[tool.ruff.lint.per-file-ignores]

View File

@ -571,6 +571,12 @@ class EnvironmentVariables:
# Whether to use the turbopack bundler.
REFLEX_USE_TURBOPACK: EnvVar[bool] = env_var(True)
# Whether to compress the reflex state.
REFLEX_COMPRESS_STATE: EnvVar[bool] = env_var(False)
# Threshold for the reflex state compression in bytes.
REFLEX_COMPRESS_THRESHOLD: EnvVar[int] = env_var(1024)
environment = EnvironmentVariables()

View File

@ -16,6 +16,7 @@ import typing
import uuid
from abc import ABC, abstractmethod
from hashlib import md5
from io import BytesIO
from pathlib import Path
from types import FunctionType, MethodType
from typing import (
@ -145,6 +146,10 @@ HANDLED_PICKLE_ERRORS = (
ValueError,
)
STATE_NOT_COMPRESSED = b"\x01"
STATE_COMPRESSED = b"\x02"
STATE_CHUNK_SIZE = 1024
# For BaseState.get_var_value
VAR_TYPE = TypeVar("VAR_TYPE")
@ -2200,12 +2205,23 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
error += f"Dill was also unable to pickle the state: {ex}"
console.warn(error)
size = len(payload)
if environment.REFLEX_PERF_MODE.get() != PerformanceMode.OFF:
self._check_state_size(len(payload))
self._check_state_size(size)
if not payload:
raise StateSerializationError(error)
if environment.REFLEX_COMPRESS_STATE.get():
if size > environment.REFLEX_COMPRESS_THRESHOLD.get():
from blosc2 import compress
payload = compress(payload, _ignore_multiple_size=True)
prefix = STATE_COMPRESSED
else:
prefix = STATE_NOT_COMPRESSED
payload = prefix + payload # pyright: ignore[reportOperatorIssue,reportUnknownVariableType]
return payload
@classmethod
@ -2228,14 +2244,40 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
StateSchemaMismatchError: If the state schema does not match the expected schema.
"""
if data is not None and fp is None:
(substate_schema, state) = pickle.loads(data)
if environment.REFLEX_COMPRESS_STATE.get():
# get first byte to determine if compressed
is_compressed = data[:1] == STATE_COMPRESSED
# remove compression byte
data = data[1:]
if is_compressed:
from blosc2 import decompress
data = decompress(data) # pyright: ignore[reportAssignmentType]
data = pickle.loads(data) # pyright: ignore[reportArgumentType]
elif fp is not None and data is None:
(substate_schema, state) = pickle.load(fp)
if environment.REFLEX_COMPRESS_STATE.get():
# read first byte to determine if compressed
is_compressed = fp.read(1) == STATE_COMPRESSED
if is_compressed:
from blosc2 import SChunk
schunk = SChunk(chunksize=STATE_CHUNK_SIZE)
while chunk := fp.read(STATE_CHUNK_SIZE):
schunk.append_data(chunk)
fp = BytesIO()
for chunk_index in range(schunk.nchunks):
fp.write(schunk.decompress_chunk(chunk_index)) # pyright: ignore[reportArgumentType,reportUnusedCallResult]
data = pickle.load(fp)
else:
raise ValueError("Only one of `data` or `fp` must be provided")
if substate_schema != state._to_schema():
substate_schema, state = data # pyright: ignore[reportUnknownVariableType,reportGeneralTypeIssues]
if substate_schema != state._to_schema(): # pyright: ignore[reportAttributeAccessIssue,reportUnknownMemberType]
raise StateSchemaMismatchError()
return state
return state # pyright: ignore[reportUnknownVariableType,reportReturnType]
T_STATE = TypeVar("T_STATE", bound=BaseState)