Merge 86342c79f3
into 8b2c7291d3
This commit is contained in:
commit
f90fb78b97
155
benchmarks/benchmark_pickle.py
Normal file
155
benchmarks/benchmark_pickle.py
Normal file
@ -0,0 +1,155 @@
|
|||||||
|
"""Benchmarks for pickling and unpickling states."""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import pickle
|
||||||
|
import time
|
||||||
|
import uuid
|
||||||
|
from typing import Tuple
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from pytest_benchmark.fixture import BenchmarkFixture
|
||||||
|
from redis import Redis
|
||||||
|
|
||||||
|
from reflex.state import State
|
||||||
|
from reflex.utils.prerequisites import get_redis_sync
|
||||||
|
|
||||||
|
log = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
SLOW_REDIS_MAP: dict[bytes, bytes] = {}
|
||||||
|
|
||||||
|
|
||||||
|
class SlowRedis:
|
||||||
|
"""Simulate a slow Redis client which uses a global dict and sleeps based on size."""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
"""Initialize the slow Redis client."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
def set(self, key: bytes, value: bytes) -> None:
|
||||||
|
"""Set a key-value pair in the slow Redis client.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
key: The key.
|
||||||
|
value: The value.
|
||||||
|
"""
|
||||||
|
SLOW_REDIS_MAP[key] = value
|
||||||
|
size = len(value)
|
||||||
|
sleep_time = (size / 1e6) + 0.05
|
||||||
|
time.sleep(sleep_time)
|
||||||
|
|
||||||
|
def get(self, key: bytes) -> bytes:
|
||||||
|
"""Get a value from the slow Redis client.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
key: The key.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The value.
|
||||||
|
"""
|
||||||
|
value = SLOW_REDIS_MAP[key]
|
||||||
|
size = len(value)
|
||||||
|
sleep_time = (size / 1e6) + 0.05
|
||||||
|
time.sleep(sleep_time)
|
||||||
|
return value
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"protocol",
|
||||||
|
argvalues=[
|
||||||
|
pickle.DEFAULT_PROTOCOL,
|
||||||
|
pickle.HIGHEST_PROTOCOL,
|
||||||
|
],
|
||||||
|
ids=[
|
||||||
|
"pickle_default",
|
||||||
|
"pickle_highest",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"redis",
|
||||||
|
[
|
||||||
|
Redis,
|
||||||
|
SlowRedis,
|
||||||
|
None,
|
||||||
|
],
|
||||||
|
ids=[
|
||||||
|
"redis",
|
||||||
|
"slow_redis",
|
||||||
|
"no_redis",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"should_compress", [True, False], ids=["compress", "no_compress"]
|
||||||
|
)
|
||||||
|
@pytest.mark.benchmark(disable_gc=True)
|
||||||
|
def test_pickle(
|
||||||
|
request: pytest.FixtureRequest,
|
||||||
|
benchmark: BenchmarkFixture,
|
||||||
|
big_state: State,
|
||||||
|
big_state_size: Tuple[int, str],
|
||||||
|
protocol: int,
|
||||||
|
redis: Redis | SlowRedis | None,
|
||||||
|
should_compress: bool,
|
||||||
|
) -> None:
|
||||||
|
"""Benchmark pickling a big state.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
request: The pytest fixture request object.
|
||||||
|
benchmark: The benchmark fixture.
|
||||||
|
big_state: The big state fixture.
|
||||||
|
big_state_size: The big state size fixture.
|
||||||
|
protocol: The pickle protocol.
|
||||||
|
redis: Whether to use Redis.
|
||||||
|
should_compress: Whether to compress the pickled state.
|
||||||
|
"""
|
||||||
|
if should_compress:
|
||||||
|
try:
|
||||||
|
from blosc2 import compress, decompress
|
||||||
|
except ImportError:
|
||||||
|
pytest.skip("Blosc is not available.")
|
||||||
|
|
||||||
|
def dump(obj: State) -> bytes:
|
||||||
|
return compress(pickle.dumps(obj, protocol=protocol)) # pyright: ignore[reportReturnType]
|
||||||
|
|
||||||
|
def load(data: bytes) -> State:
|
||||||
|
return pickle.loads(decompress(data)) # pyright: ignore[reportAny,reportArgumentType]
|
||||||
|
|
||||||
|
else:
|
||||||
|
|
||||||
|
def dump(obj: State) -> bytes:
|
||||||
|
return pickle.dumps(obj, protocol=protocol)
|
||||||
|
|
||||||
|
def load(data: bytes) -> State:
|
||||||
|
return pickle.loads(data)
|
||||||
|
|
||||||
|
if redis:
|
||||||
|
if redis == Redis:
|
||||||
|
redis_client = get_redis_sync()
|
||||||
|
if redis_client is None:
|
||||||
|
pytest.skip("Redis is not available.")
|
||||||
|
else:
|
||||||
|
redis_client = SlowRedis()
|
||||||
|
|
||||||
|
key = str(uuid.uuid4()).encode()
|
||||||
|
|
||||||
|
def run(obj: State) -> None:
|
||||||
|
_ = redis_client.set(key, dump(obj))
|
||||||
|
_ = load(redis_client.get(key)) # pyright: ignore[reportArgumentType]
|
||||||
|
|
||||||
|
else:
|
||||||
|
|
||||||
|
def run(obj: State) -> None:
|
||||||
|
_ = load(dump(obj))
|
||||||
|
|
||||||
|
# calculate size before benchmark to not affect it
|
||||||
|
out = dump(big_state)
|
||||||
|
size = len(out)
|
||||||
|
log.warning(f"{protocol=}, {redis=}, {should_compress=}, {size=}")
|
||||||
|
|
||||||
|
benchmark.extra_info["size"] = size
|
||||||
|
benchmark.extra_info["redis"] = redis
|
||||||
|
benchmark.extra_info["pickle_protocol"] = protocol
|
||||||
|
redis_group = redis.__name__ if redis else "no_redis" # pyright: ignore[reportUnknownMemberType,reportAttributeAccessIssue,reportUnknownVariableType]
|
||||||
|
benchmark.group = f"{redis_group}_{big_state_size[1]}"
|
||||||
|
|
||||||
|
_ = benchmark(run, big_state)
|
@ -1,7 +1,11 @@
|
|||||||
"""Shared conftest for all benchmark tests."""
|
"""Shared conftest for all benchmark tests."""
|
||||||
|
|
||||||
|
from typing import Tuple
|
||||||
|
|
||||||
|
import pandas as pd
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
|
from reflex.state import State
|
||||||
from reflex.testing import AppHarness, AppHarnessProd
|
from reflex.testing import AppHarness, AppHarnessProd
|
||||||
|
|
||||||
|
|
||||||
@ -18,3 +22,43 @@ def app_harness_env(request):
|
|||||||
The AppHarness class to use for the test.
|
The AppHarness class to use for the test.
|
||||||
"""
|
"""
|
||||||
return request.param
|
return request.param
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(params=[(10, "SmallState"), (2000, "BigState")], ids=["small", "big"])
|
||||||
|
def big_state_size(request: pytest.FixtureRequest) -> int:
|
||||||
|
"""The size of the DataFrame.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
request: The pytest fixture request object.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The size of the BigState
|
||||||
|
"""
|
||||||
|
return request.param
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def big_state(big_state_size: Tuple[int, str]) -> State:
|
||||||
|
"""A big state with a dictionary and a list of DataFrames.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
big_state_size: The size of the big state.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A big state instance.
|
||||||
|
"""
|
||||||
|
size, _ = big_state_size
|
||||||
|
|
||||||
|
class BigState(State):
|
||||||
|
"""A big state."""
|
||||||
|
|
||||||
|
d: dict[str, int]
|
||||||
|
d_repeated: dict[str, int]
|
||||||
|
df: list[pd.DataFrame]
|
||||||
|
|
||||||
|
d = {str(i): i for i in range(size)}
|
||||||
|
d_repeated = {str(i): i for i in range(size)}
|
||||||
|
df = [pd.DataFrame({"a": [i]}) for i in range(size)]
|
||||||
|
|
||||||
|
state = BigState(d=d, df=df, d_repeated=d_repeated)
|
||||||
|
return state
|
||||||
|
478
poetry.lock
generated
478
poetry.lock
generated
File diff suppressed because it is too large
Load Diff
@ -72,6 +72,7 @@ pytest-benchmark = ">=4.0.0,<6.0"
|
|||||||
playwright = ">=1.46.0"
|
playwright = ">=1.46.0"
|
||||||
pytest-playwright = ">=0.5.1"
|
pytest-playwright = ">=0.5.1"
|
||||||
pytest-codspeed = "^3.1.2"
|
pytest-codspeed = "^3.1.2"
|
||||||
|
blosc2 = { version = ">=2.7.1", python = ">=3.11" }
|
||||||
|
|
||||||
[tool.poetry.scripts]
|
[tool.poetry.scripts]
|
||||||
reflex = "reflex.reflex:cli"
|
reflex = "reflex.reflex:cli"
|
||||||
@ -87,8 +88,37 @@ reportIncompatibleMethodOverride = false
|
|||||||
target-version = "py310"
|
target-version = "py310"
|
||||||
output-format = "concise"
|
output-format = "concise"
|
||||||
lint.isort.split-on-trailing-comma = false
|
lint.isort.split-on-trailing-comma = false
|
||||||
lint.select = ["ANN001","B", "C4", "D", "E", "ERA", "F", "FURB", "I", "N", "PERF", "PGH", "PTH", "RUF", "SIM", "T", "TRY", "W"]
|
lint.select = [
|
||||||
lint.ignore = ["B008", "D205", "E501", "F403", "SIM115", "RUF006", "RUF008", "RUF012", "TRY0"]
|
"ANN001",
|
||||||
|
"B",
|
||||||
|
"C4",
|
||||||
|
"D",
|
||||||
|
"E",
|
||||||
|
"ERA",
|
||||||
|
"F",
|
||||||
|
"FURB",
|
||||||
|
"I",
|
||||||
|
"N",
|
||||||
|
"PERF",
|
||||||
|
"PGH",
|
||||||
|
"PTH",
|
||||||
|
"RUF",
|
||||||
|
"SIM",
|
||||||
|
"T",
|
||||||
|
"TRY",
|
||||||
|
"W",
|
||||||
|
]
|
||||||
|
lint.ignore = [
|
||||||
|
"B008",
|
||||||
|
"D205",
|
||||||
|
"E501",
|
||||||
|
"F403",
|
||||||
|
"SIM115",
|
||||||
|
"RUF006",
|
||||||
|
"RUF008",
|
||||||
|
"RUF012",
|
||||||
|
"TRY0",
|
||||||
|
]
|
||||||
lint.pydocstyle.convention = "google"
|
lint.pydocstyle.convention = "google"
|
||||||
|
|
||||||
[tool.ruff.lint.per-file-ignores]
|
[tool.ruff.lint.per-file-ignores]
|
||||||
|
@ -571,6 +571,12 @@ class EnvironmentVariables:
|
|||||||
# Whether to use the turbopack bundler.
|
# Whether to use the turbopack bundler.
|
||||||
REFLEX_USE_TURBOPACK: EnvVar[bool] = env_var(True)
|
REFLEX_USE_TURBOPACK: EnvVar[bool] = env_var(True)
|
||||||
|
|
||||||
|
# Whether to compress the reflex state.
|
||||||
|
REFLEX_COMPRESS_STATE: EnvVar[bool] = env_var(False)
|
||||||
|
|
||||||
|
# Threshold for the reflex state compression in bytes.
|
||||||
|
REFLEX_COMPRESS_THRESHOLD: EnvVar[int] = env_var(1024)
|
||||||
|
|
||||||
|
|
||||||
environment = EnvironmentVariables()
|
environment = EnvironmentVariables()
|
||||||
|
|
||||||
|
@ -16,6 +16,7 @@ import typing
|
|||||||
import uuid
|
import uuid
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from hashlib import md5
|
from hashlib import md5
|
||||||
|
from io import BytesIO
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from types import FunctionType, MethodType
|
from types import FunctionType, MethodType
|
||||||
from typing import (
|
from typing import (
|
||||||
@ -145,6 +146,10 @@ HANDLED_PICKLE_ERRORS = (
|
|||||||
ValueError,
|
ValueError,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
STATE_NOT_COMPRESSED = b"\x01"
|
||||||
|
STATE_COMPRESSED = b"\x02"
|
||||||
|
STATE_CHUNK_SIZE = 1024
|
||||||
|
|
||||||
# For BaseState.get_var_value
|
# For BaseState.get_var_value
|
||||||
VAR_TYPE = TypeVar("VAR_TYPE")
|
VAR_TYPE = TypeVar("VAR_TYPE")
|
||||||
|
|
||||||
@ -2200,12 +2205,23 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
|
|||||||
error += f"Dill was also unable to pickle the state: {ex}"
|
error += f"Dill was also unable to pickle the state: {ex}"
|
||||||
console.warn(error)
|
console.warn(error)
|
||||||
|
|
||||||
|
size = len(payload)
|
||||||
if environment.REFLEX_PERF_MODE.get() != PerformanceMode.OFF:
|
if environment.REFLEX_PERF_MODE.get() != PerformanceMode.OFF:
|
||||||
self._check_state_size(len(payload))
|
self._check_state_size(size)
|
||||||
|
|
||||||
if not payload:
|
if not payload:
|
||||||
raise StateSerializationError(error)
|
raise StateSerializationError(error)
|
||||||
|
|
||||||
|
if environment.REFLEX_COMPRESS_STATE.get():
|
||||||
|
if size > environment.REFLEX_COMPRESS_THRESHOLD.get():
|
||||||
|
from blosc2 import compress
|
||||||
|
|
||||||
|
payload = compress(payload, _ignore_multiple_size=True)
|
||||||
|
prefix = STATE_COMPRESSED
|
||||||
|
else:
|
||||||
|
prefix = STATE_NOT_COMPRESSED
|
||||||
|
payload = prefix + payload # pyright: ignore[reportOperatorIssue,reportUnknownVariableType]
|
||||||
|
|
||||||
return payload
|
return payload
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@ -2228,14 +2244,40 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
|
|||||||
StateSchemaMismatchError: If the state schema does not match the expected schema.
|
StateSchemaMismatchError: If the state schema does not match the expected schema.
|
||||||
"""
|
"""
|
||||||
if data is not None and fp is None:
|
if data is not None and fp is None:
|
||||||
(substate_schema, state) = pickle.loads(data)
|
if environment.REFLEX_COMPRESS_STATE.get():
|
||||||
|
# get first byte to determine if compressed
|
||||||
|
is_compressed = data[:1] == STATE_COMPRESSED
|
||||||
|
# remove compression byte
|
||||||
|
data = data[1:]
|
||||||
|
if is_compressed:
|
||||||
|
from blosc2 import decompress
|
||||||
|
|
||||||
|
data = decompress(data) # pyright: ignore[reportAssignmentType]
|
||||||
|
data = pickle.loads(data) # pyright: ignore[reportArgumentType]
|
||||||
elif fp is not None and data is None:
|
elif fp is not None and data is None:
|
||||||
(substate_schema, state) = pickle.load(fp)
|
if environment.REFLEX_COMPRESS_STATE.get():
|
||||||
|
# read first byte to determine if compressed
|
||||||
|
is_compressed = fp.read(1) == STATE_COMPRESSED
|
||||||
|
if is_compressed:
|
||||||
|
from blosc2 import SChunk
|
||||||
|
|
||||||
|
schunk = SChunk(chunksize=STATE_CHUNK_SIZE)
|
||||||
|
|
||||||
|
while chunk := fp.read(STATE_CHUNK_SIZE):
|
||||||
|
schunk.append_data(chunk)
|
||||||
|
|
||||||
|
fp = BytesIO()
|
||||||
|
|
||||||
|
for chunk_index in range(schunk.nchunks):
|
||||||
|
fp.write(schunk.decompress_chunk(chunk_index)) # pyright: ignore[reportArgumentType,reportUnusedCallResult]
|
||||||
|
|
||||||
|
data = pickle.load(fp)
|
||||||
else:
|
else:
|
||||||
raise ValueError("Only one of `data` or `fp` must be provided")
|
raise ValueError("Only one of `data` or `fp` must be provided")
|
||||||
if substate_schema != state._to_schema():
|
substate_schema, state = data # pyright: ignore[reportUnknownVariableType,reportGeneralTypeIssues]
|
||||||
|
if substate_schema != state._to_schema(): # pyright: ignore[reportAttributeAccessIssue,reportUnknownMemberType]
|
||||||
raise StateSchemaMismatchError()
|
raise StateSchemaMismatchError()
|
||||||
return state
|
return state # pyright: ignore[reportUnknownVariableType,reportReturnType]
|
||||||
|
|
||||||
|
|
||||||
T_STATE = TypeVar("T_STATE", bound=BaseState)
|
T_STATE = TypeVar("T_STATE", bound=BaseState)
|
||||||
|
Loading…
Reference in New Issue
Block a user