This commit is contained in:
benedikt-bartscher 2025-02-08 20:01:04 +00:00 committed by GitHub
commit f90fb78b97
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 509 additions and 260 deletions

View File

@ -0,0 +1,155 @@
"""Benchmarks for pickling and unpickling states."""
import logging
import pickle
import time
import uuid
from typing import Tuple
import pytest
from pytest_benchmark.fixture import BenchmarkFixture
from redis import Redis
from reflex.state import State
from reflex.utils.prerequisites import get_redis_sync
log = logging.getLogger(__name__)
SLOW_REDIS_MAP: dict[bytes, bytes] = {}
class SlowRedis:
"""Simulate a slow Redis client which uses a global dict and sleeps based on size."""
def __init__(self):
"""Initialize the slow Redis client."""
pass
def set(self, key: bytes, value: bytes) -> None:
"""Set a key-value pair in the slow Redis client.
Args:
key: The key.
value: The value.
"""
SLOW_REDIS_MAP[key] = value
size = len(value)
sleep_time = (size / 1e6) + 0.05
time.sleep(sleep_time)
def get(self, key: bytes) -> bytes:
"""Get a value from the slow Redis client.
Args:
key: The key.
Returns:
The value.
"""
value = SLOW_REDIS_MAP[key]
size = len(value)
sleep_time = (size / 1e6) + 0.05
time.sleep(sleep_time)
return value
@pytest.mark.parametrize(
"protocol",
argvalues=[
pickle.DEFAULT_PROTOCOL,
pickle.HIGHEST_PROTOCOL,
],
ids=[
"pickle_default",
"pickle_highest",
],
)
@pytest.mark.parametrize(
"redis",
[
Redis,
SlowRedis,
None,
],
ids=[
"redis",
"slow_redis",
"no_redis",
],
)
@pytest.mark.parametrize(
"should_compress", [True, False], ids=["compress", "no_compress"]
)
@pytest.mark.benchmark(disable_gc=True)
def test_pickle(
request: pytest.FixtureRequest,
benchmark: BenchmarkFixture,
big_state: State,
big_state_size: Tuple[int, str],
protocol: int,
redis: Redis | SlowRedis | None,
should_compress: bool,
) -> None:
"""Benchmark pickling a big state.
Args:
request: The pytest fixture request object.
benchmark: The benchmark fixture.
big_state: The big state fixture.
big_state_size: The big state size fixture.
protocol: The pickle protocol.
redis: Whether to use Redis.
should_compress: Whether to compress the pickled state.
"""
if should_compress:
try:
from blosc2 import compress, decompress
except ImportError:
pytest.skip("Blosc is not available.")
def dump(obj: State) -> bytes:
return compress(pickle.dumps(obj, protocol=protocol)) # pyright: ignore[reportReturnType]
def load(data: bytes) -> State:
return pickle.loads(decompress(data)) # pyright: ignore[reportAny,reportArgumentType]
else:
def dump(obj: State) -> bytes:
return pickle.dumps(obj, protocol=protocol)
def load(data: bytes) -> State:
return pickle.loads(data)
if redis:
if redis == Redis:
redis_client = get_redis_sync()
if redis_client is None:
pytest.skip("Redis is not available.")
else:
redis_client = SlowRedis()
key = str(uuid.uuid4()).encode()
def run(obj: State) -> None:
_ = redis_client.set(key, dump(obj))
_ = load(redis_client.get(key)) # pyright: ignore[reportArgumentType]
else:
def run(obj: State) -> None:
_ = load(dump(obj))
# calculate size before benchmark to not affect it
out = dump(big_state)
size = len(out)
log.warning(f"{protocol=}, {redis=}, {should_compress=}, {size=}")
benchmark.extra_info["size"] = size
benchmark.extra_info["redis"] = redis
benchmark.extra_info["pickle_protocol"] = protocol
redis_group = redis.__name__ if redis else "no_redis" # pyright: ignore[reportUnknownMemberType,reportAttributeAccessIssue,reportUnknownVariableType]
benchmark.group = f"{redis_group}_{big_state_size[1]}"
_ = benchmark(run, big_state)

View File

@ -1,7 +1,11 @@
"""Shared conftest for all benchmark tests.""" """Shared conftest for all benchmark tests."""
from typing import Tuple
import pandas as pd
import pytest import pytest
from reflex.state import State
from reflex.testing import AppHarness, AppHarnessProd from reflex.testing import AppHarness, AppHarnessProd
@ -18,3 +22,43 @@ def app_harness_env(request):
The AppHarness class to use for the test. The AppHarness class to use for the test.
""" """
return request.param return request.param
@pytest.fixture(params=[(10, "SmallState"), (2000, "BigState")], ids=["small", "big"])
def big_state_size(request: pytest.FixtureRequest) -> int:
"""The size of the DataFrame.
Args:
request: The pytest fixture request object.
Returns:
The size of the BigState
"""
return request.param
@pytest.fixture
def big_state(big_state_size: Tuple[int, str]) -> State:
"""A big state with a dictionary and a list of DataFrames.
Args:
big_state_size: The size of the big state.
Returns:
A big state instance.
"""
size, _ = big_state_size
class BigState(State):
"""A big state."""
d: dict[str, int]
d_repeated: dict[str, int]
df: list[pd.DataFrame]
d = {str(i): i for i in range(size)}
d_repeated = {str(i): i for i in range(size)}
df = [pd.DataFrame({"a": [i]}) for i in range(size)]
state = BigState(d=d, df=df, d_repeated=d_repeated)
return state

478
poetry.lock generated

File diff suppressed because it is too large Load Diff

View File

@ -72,6 +72,7 @@ pytest-benchmark = ">=4.0.0,<6.0"
playwright = ">=1.46.0" playwright = ">=1.46.0"
pytest-playwright = ">=0.5.1" pytest-playwright = ">=0.5.1"
pytest-codspeed = "^3.1.2" pytest-codspeed = "^3.1.2"
blosc2 = { version = ">=2.7.1", python = ">=3.11" }
[tool.poetry.scripts] [tool.poetry.scripts]
reflex = "reflex.reflex:cli" reflex = "reflex.reflex:cli"
@ -87,8 +88,37 @@ reportIncompatibleMethodOverride = false
target-version = "py310" target-version = "py310"
output-format = "concise" output-format = "concise"
lint.isort.split-on-trailing-comma = false lint.isort.split-on-trailing-comma = false
lint.select = ["ANN001","B", "C4", "D", "E", "ERA", "F", "FURB", "I", "N", "PERF", "PGH", "PTH", "RUF", "SIM", "T", "TRY", "W"] lint.select = [
lint.ignore = ["B008", "D205", "E501", "F403", "SIM115", "RUF006", "RUF008", "RUF012", "TRY0"] "ANN001",
"B",
"C4",
"D",
"E",
"ERA",
"F",
"FURB",
"I",
"N",
"PERF",
"PGH",
"PTH",
"RUF",
"SIM",
"T",
"TRY",
"W",
]
lint.ignore = [
"B008",
"D205",
"E501",
"F403",
"SIM115",
"RUF006",
"RUF008",
"RUF012",
"TRY0",
]
lint.pydocstyle.convention = "google" lint.pydocstyle.convention = "google"
[tool.ruff.lint.per-file-ignores] [tool.ruff.lint.per-file-ignores]

View File

@ -571,6 +571,12 @@ class EnvironmentVariables:
# Whether to use the turbopack bundler. # Whether to use the turbopack bundler.
REFLEX_USE_TURBOPACK: EnvVar[bool] = env_var(True) REFLEX_USE_TURBOPACK: EnvVar[bool] = env_var(True)
# Whether to compress the reflex state.
REFLEX_COMPRESS_STATE: EnvVar[bool] = env_var(False)
# Threshold for the reflex state compression in bytes.
REFLEX_COMPRESS_THRESHOLD: EnvVar[int] = env_var(1024)
environment = EnvironmentVariables() environment = EnvironmentVariables()

View File

@ -16,6 +16,7 @@ import typing
import uuid import uuid
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from hashlib import md5 from hashlib import md5
from io import BytesIO
from pathlib import Path from pathlib import Path
from types import FunctionType, MethodType from types import FunctionType, MethodType
from typing import ( from typing import (
@ -145,6 +146,10 @@ HANDLED_PICKLE_ERRORS = (
ValueError, ValueError,
) )
STATE_NOT_COMPRESSED = b"\x01"
STATE_COMPRESSED = b"\x02"
STATE_CHUNK_SIZE = 1024
# For BaseState.get_var_value # For BaseState.get_var_value
VAR_TYPE = TypeVar("VAR_TYPE") VAR_TYPE = TypeVar("VAR_TYPE")
@ -2200,12 +2205,23 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
error += f"Dill was also unable to pickle the state: {ex}" error += f"Dill was also unable to pickle the state: {ex}"
console.warn(error) console.warn(error)
size = len(payload)
if environment.REFLEX_PERF_MODE.get() != PerformanceMode.OFF: if environment.REFLEX_PERF_MODE.get() != PerformanceMode.OFF:
self._check_state_size(len(payload)) self._check_state_size(size)
if not payload: if not payload:
raise StateSerializationError(error) raise StateSerializationError(error)
if environment.REFLEX_COMPRESS_STATE.get():
if size > environment.REFLEX_COMPRESS_THRESHOLD.get():
from blosc2 import compress
payload = compress(payload, _ignore_multiple_size=True)
prefix = STATE_COMPRESSED
else:
prefix = STATE_NOT_COMPRESSED
payload = prefix + payload # pyright: ignore[reportOperatorIssue,reportUnknownVariableType]
return payload return payload
@classmethod @classmethod
@ -2228,14 +2244,40 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
StateSchemaMismatchError: If the state schema does not match the expected schema. StateSchemaMismatchError: If the state schema does not match the expected schema.
""" """
if data is not None and fp is None: if data is not None and fp is None:
(substate_schema, state) = pickle.loads(data) if environment.REFLEX_COMPRESS_STATE.get():
# get first byte to determine if compressed
is_compressed = data[:1] == STATE_COMPRESSED
# remove compression byte
data = data[1:]
if is_compressed:
from blosc2 import decompress
data = decompress(data) # pyright: ignore[reportAssignmentType]
data = pickle.loads(data) # pyright: ignore[reportArgumentType]
elif fp is not None and data is None: elif fp is not None and data is None:
(substate_schema, state) = pickle.load(fp) if environment.REFLEX_COMPRESS_STATE.get():
# read first byte to determine if compressed
is_compressed = fp.read(1) == STATE_COMPRESSED
if is_compressed:
from blosc2 import SChunk
schunk = SChunk(chunksize=STATE_CHUNK_SIZE)
while chunk := fp.read(STATE_CHUNK_SIZE):
schunk.append_data(chunk)
fp = BytesIO()
for chunk_index in range(schunk.nchunks):
fp.write(schunk.decompress_chunk(chunk_index)) # pyright: ignore[reportArgumentType,reportUnusedCallResult]
data = pickle.load(fp)
else: else:
raise ValueError("Only one of `data` or `fp` must be provided") raise ValueError("Only one of `data` or `fp` must be provided")
if substate_schema != state._to_schema(): substate_schema, state = data # pyright: ignore[reportUnknownVariableType,reportGeneralTypeIssues]
if substate_schema != state._to_schema(): # pyright: ignore[reportAttributeAccessIssue,reportUnknownMemberType]
raise StateSchemaMismatchError() raise StateSchemaMismatchError()
return state return state # pyright: ignore[reportUnknownVariableType,reportReturnType]
T_STATE = TypeVar("T_STATE", bound=BaseState) T_STATE = TypeVar("T_STATE", bound=BaseState)