[ENG-4100]Throw warnings when Redis lock is held for more than the allowed threshold (#4522)
* Throw warnings when Redis lock is held for more than the allowed threshold * initial tests * fix tests and address comments * fix tests fr, and use pydantic validators * darglint fix * increase lock expiration in tests to 2500 * remove print statement --------- Co-authored-by: Khaleel Al-Adhami <khaleel.aladhami@gmail.com>
This commit is contained in:
parent
2d9849e00a
commit
c387f517b6
@ -684,6 +684,9 @@ class Config(Base):
|
|||||||
# Maximum expiration lock time for redis state manager
|
# Maximum expiration lock time for redis state manager
|
||||||
redis_lock_expiration: int = constants.Expiration.LOCK
|
redis_lock_expiration: int = constants.Expiration.LOCK
|
||||||
|
|
||||||
|
# Maximum lock time before warning for redis state manager.
|
||||||
|
redis_lock_warning_threshold: int = constants.Expiration.LOCK_WARNING_THRESHOLD
|
||||||
|
|
||||||
# Token expiration time for redis state manager
|
# Token expiration time for redis state manager
|
||||||
redis_token_expiration: int = constants.Expiration.TOKEN
|
redis_token_expiration: int = constants.Expiration.TOKEN
|
||||||
|
|
||||||
|
@ -29,6 +29,8 @@ class Expiration(SimpleNamespace):
|
|||||||
LOCK = 10000
|
LOCK = 10000
|
||||||
# The PING timeout
|
# The PING timeout
|
||||||
PING = 120
|
PING = 120
|
||||||
|
# The maximum time in milliseconds to hold a lock before throwing a warning.
|
||||||
|
LOCK_WARNING_THRESHOLD = 1000
|
||||||
|
|
||||||
|
|
||||||
class GitIgnore(SimpleNamespace):
|
class GitIgnore(SimpleNamespace):
|
||||||
|
@ -71,6 +71,11 @@ try:
|
|||||||
except ModuleNotFoundError:
|
except ModuleNotFoundError:
|
||||||
BaseModelV1 = BaseModelV2
|
BaseModelV1 = BaseModelV2
|
||||||
|
|
||||||
|
try:
|
||||||
|
from pydantic.v1 import validator
|
||||||
|
except ModuleNotFoundError:
|
||||||
|
from pydantic import validator
|
||||||
|
|
||||||
import wrapt
|
import wrapt
|
||||||
from redis.asyncio import Redis
|
from redis.asyncio import Redis
|
||||||
from redis.exceptions import ResponseError
|
from redis.exceptions import ResponseError
|
||||||
@ -94,6 +99,7 @@ from reflex.utils.exceptions import (
|
|||||||
DynamicRouteArgShadowsStateVar,
|
DynamicRouteArgShadowsStateVar,
|
||||||
EventHandlerShadowsBuiltInStateMethod,
|
EventHandlerShadowsBuiltInStateMethod,
|
||||||
ImmutableStateError,
|
ImmutableStateError,
|
||||||
|
InvalidLockWarningThresholdError,
|
||||||
InvalidStateManagerMode,
|
InvalidStateManagerMode,
|
||||||
LockExpiredError,
|
LockExpiredError,
|
||||||
ReflexRuntimeError,
|
ReflexRuntimeError,
|
||||||
@ -2834,6 +2840,7 @@ class StateManager(Base, ABC):
|
|||||||
redis=redis,
|
redis=redis,
|
||||||
token_expiration=config.redis_token_expiration,
|
token_expiration=config.redis_token_expiration,
|
||||||
lock_expiration=config.redis_lock_expiration,
|
lock_expiration=config.redis_lock_expiration,
|
||||||
|
lock_warning_threshold=config.redis_lock_warning_threshold,
|
||||||
)
|
)
|
||||||
raise InvalidStateManagerMode(
|
raise InvalidStateManagerMode(
|
||||||
f"Expected one of: DISK, MEMORY, REDIS, got {config.state_manager_mode}"
|
f"Expected one of: DISK, MEMORY, REDIS, got {config.state_manager_mode}"
|
||||||
@ -3203,6 +3210,15 @@ def _default_lock_expiration() -> int:
|
|||||||
return get_config().redis_lock_expiration
|
return get_config().redis_lock_expiration
|
||||||
|
|
||||||
|
|
||||||
|
def _default_lock_warning_threshold() -> int:
|
||||||
|
"""Get the default lock warning threshold.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The default lock warning threshold.
|
||||||
|
"""
|
||||||
|
return get_config().redis_lock_warning_threshold
|
||||||
|
|
||||||
|
|
||||||
class StateManagerRedis(StateManager):
|
class StateManagerRedis(StateManager):
|
||||||
"""A state manager that stores states in redis."""
|
"""A state manager that stores states in redis."""
|
||||||
|
|
||||||
@ -3215,6 +3231,11 @@ class StateManagerRedis(StateManager):
|
|||||||
# The maximum time to hold a lock (ms).
|
# The maximum time to hold a lock (ms).
|
||||||
lock_expiration: int = pydantic.Field(default_factory=_default_lock_expiration)
|
lock_expiration: int = pydantic.Field(default_factory=_default_lock_expiration)
|
||||||
|
|
||||||
|
# The maximum time to hold a lock (ms) before warning.
|
||||||
|
lock_warning_threshold: int = pydantic.Field(
|
||||||
|
default_factory=_default_lock_warning_threshold
|
||||||
|
)
|
||||||
|
|
||||||
# The keyspace subscription string when redis is waiting for lock to be released
|
# The keyspace subscription string when redis is waiting for lock to be released
|
||||||
_redis_notify_keyspace_events: str = (
|
_redis_notify_keyspace_events: str = (
|
||||||
"K" # Enable keyspace notifications (target a particular key)
|
"K" # Enable keyspace notifications (target a particular key)
|
||||||
@ -3402,6 +3423,17 @@ class StateManagerRedis(StateManager):
|
|||||||
f"`app.state_manager.lock_expiration` (currently {self.lock_expiration}) "
|
f"`app.state_manager.lock_expiration` (currently {self.lock_expiration}) "
|
||||||
"or use `@rx.event(background=True)` decorator for long-running tasks."
|
"or use `@rx.event(background=True)` decorator for long-running tasks."
|
||||||
)
|
)
|
||||||
|
elif lock_id is not None:
|
||||||
|
time_taken = self.lock_expiration / 1000 - (
|
||||||
|
await self.redis.ttl(self._lock_key(token))
|
||||||
|
)
|
||||||
|
if time_taken > self.lock_warning_threshold / 1000:
|
||||||
|
console.warn(
|
||||||
|
f"Lock for token {token} was held too long {time_taken=}s, "
|
||||||
|
f"use `@rx.event(background=True)` decorator for long-running tasks.",
|
||||||
|
dedupe=True,
|
||||||
|
)
|
||||||
|
|
||||||
client_token, substate_name = _split_substate_key(token)
|
client_token, substate_name = _split_substate_key(token)
|
||||||
# If the substate name on the token doesn't match the instance name, it cannot have a parent.
|
# If the substate name on the token doesn't match the instance name, it cannot have a parent.
|
||||||
if state.parent_state is not None and state.get_full_name() != substate_name:
|
if state.parent_state is not None and state.get_full_name() != substate_name:
|
||||||
@ -3451,6 +3483,27 @@ class StateManagerRedis(StateManager):
|
|||||||
yield state
|
yield state
|
||||||
await self.set_state(token, state, lock_id)
|
await self.set_state(token, state, lock_id)
|
||||||
|
|
||||||
|
@validator("lock_warning_threshold")
|
||||||
|
@classmethod
|
||||||
|
def validate_lock_warning_threshold(cls, lock_warning_threshold: int, values):
|
||||||
|
"""Validate the lock warning threshold.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
lock_warning_threshold: The lock warning threshold.
|
||||||
|
values: The validated attributes.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The lock warning threshold.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
InvalidLockWarningThresholdError: If the lock warning threshold is invalid.
|
||||||
|
"""
|
||||||
|
if lock_warning_threshold >= (lock_expiration := values["lock_expiration"]):
|
||||||
|
raise InvalidLockWarningThresholdError(
|
||||||
|
f"The lock warning threshold({lock_warning_threshold}) must be less than the lock expiration time({lock_expiration})."
|
||||||
|
)
|
||||||
|
return lock_warning_threshold
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _lock_key(token: str) -> bytes:
|
def _lock_key(token: str) -> bytes:
|
||||||
"""Get the redis key for a token's lock.
|
"""Get the redis key for a token's lock.
|
||||||
|
@ -20,6 +20,24 @@ _EMITTED_DEPRECATION_WARNINGS = set()
|
|||||||
# Info messages which have been printed.
|
# Info messages which have been printed.
|
||||||
_EMITTED_INFO = set()
|
_EMITTED_INFO = set()
|
||||||
|
|
||||||
|
# Warnings which have been printed.
|
||||||
|
_EMIITED_WARNINGS = set()
|
||||||
|
|
||||||
|
# Errors which have been printed.
|
||||||
|
_EMITTED_ERRORS = set()
|
||||||
|
|
||||||
|
# Success messages which have been printed.
|
||||||
|
_EMITTED_SUCCESS = set()
|
||||||
|
|
||||||
|
# Debug messages which have been printed.
|
||||||
|
_EMITTED_DEBUG = set()
|
||||||
|
|
||||||
|
# Logs which have been printed.
|
||||||
|
_EMITTED_LOGS = set()
|
||||||
|
|
||||||
|
# Prints which have been printed.
|
||||||
|
_EMITTED_PRINTS = set()
|
||||||
|
|
||||||
|
|
||||||
def set_log_level(log_level: LogLevel):
|
def set_log_level(log_level: LogLevel):
|
||||||
"""Set the log level.
|
"""Set the log level.
|
||||||
@ -55,25 +73,37 @@ def is_debug() -> bool:
|
|||||||
return _LOG_LEVEL <= LogLevel.DEBUG
|
return _LOG_LEVEL <= LogLevel.DEBUG
|
||||||
|
|
||||||
|
|
||||||
def print(msg: str, **kwargs):
|
def print(msg: str, dedupe: bool = False, **kwargs):
|
||||||
"""Print a message.
|
"""Print a message.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
msg: The message to print.
|
msg: The message to print.
|
||||||
|
dedupe: If True, suppress multiple console logs of print message.
|
||||||
kwargs: Keyword arguments to pass to the print function.
|
kwargs: Keyword arguments to pass to the print function.
|
||||||
"""
|
"""
|
||||||
|
if dedupe:
|
||||||
|
if msg in _EMITTED_PRINTS:
|
||||||
|
return
|
||||||
|
else:
|
||||||
|
_EMITTED_PRINTS.add(msg)
|
||||||
_console.print(msg, **kwargs)
|
_console.print(msg, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
def debug(msg: str, **kwargs):
|
def debug(msg: str, dedupe: bool = False, **kwargs):
|
||||||
"""Print a debug message.
|
"""Print a debug message.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
msg: The debug message.
|
msg: The debug message.
|
||||||
|
dedupe: If True, suppress multiple console logs of debug message.
|
||||||
kwargs: Keyword arguments to pass to the print function.
|
kwargs: Keyword arguments to pass to the print function.
|
||||||
"""
|
"""
|
||||||
if is_debug():
|
if is_debug():
|
||||||
msg_ = f"[purple]Debug: {msg}[/purple]"
|
msg_ = f"[purple]Debug: {msg}[/purple]"
|
||||||
|
if dedupe:
|
||||||
|
if msg_ in _EMITTED_DEBUG:
|
||||||
|
return
|
||||||
|
else:
|
||||||
|
_EMITTED_DEBUG.add(msg_)
|
||||||
if progress := kwargs.pop("progress", None):
|
if progress := kwargs.pop("progress", None):
|
||||||
progress.console.print(msg_, **kwargs)
|
progress.console.print(msg_, **kwargs)
|
||||||
else:
|
else:
|
||||||
@ -97,25 +127,37 @@ def info(msg: str, dedupe: bool = False, **kwargs):
|
|||||||
print(f"[cyan]Info: {msg}[/cyan]", **kwargs)
|
print(f"[cyan]Info: {msg}[/cyan]", **kwargs)
|
||||||
|
|
||||||
|
|
||||||
def success(msg: str, **kwargs):
|
def success(msg: str, dedupe: bool = False, **kwargs):
|
||||||
"""Print a success message.
|
"""Print a success message.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
msg: The success message.
|
msg: The success message.
|
||||||
|
dedupe: If True, suppress multiple console logs of success message.
|
||||||
kwargs: Keyword arguments to pass to the print function.
|
kwargs: Keyword arguments to pass to the print function.
|
||||||
"""
|
"""
|
||||||
if _LOG_LEVEL <= LogLevel.INFO:
|
if _LOG_LEVEL <= LogLevel.INFO:
|
||||||
|
if dedupe:
|
||||||
|
if msg in _EMITTED_SUCCESS:
|
||||||
|
return
|
||||||
|
else:
|
||||||
|
_EMITTED_SUCCESS.add(msg)
|
||||||
print(f"[green]Success: {msg}[/green]", **kwargs)
|
print(f"[green]Success: {msg}[/green]", **kwargs)
|
||||||
|
|
||||||
|
|
||||||
def log(msg: str, **kwargs):
|
def log(msg: str, dedupe: bool = False, **kwargs):
|
||||||
"""Takes a string and logs it to the console.
|
"""Takes a string and logs it to the console.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
msg: The message to log.
|
msg: The message to log.
|
||||||
|
dedupe: If True, suppress multiple console logs of log message.
|
||||||
kwargs: Keyword arguments to pass to the print function.
|
kwargs: Keyword arguments to pass to the print function.
|
||||||
"""
|
"""
|
||||||
if _LOG_LEVEL <= LogLevel.INFO:
|
if _LOG_LEVEL <= LogLevel.INFO:
|
||||||
|
if dedupe:
|
||||||
|
if msg in _EMITTED_LOGS:
|
||||||
|
return
|
||||||
|
else:
|
||||||
|
_EMITTED_LOGS.add(msg)
|
||||||
_console.log(msg, **kwargs)
|
_console.log(msg, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
@ -129,14 +171,20 @@ def rule(title: str, **kwargs):
|
|||||||
_console.rule(title, **kwargs)
|
_console.rule(title, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
def warn(msg: str, **kwargs):
|
def warn(msg: str, dedupe: bool = False, **kwargs):
|
||||||
"""Print a warning message.
|
"""Print a warning message.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
msg: The warning message.
|
msg: The warning message.
|
||||||
|
dedupe: If True, suppress multiple console logs of warning message.
|
||||||
kwargs: Keyword arguments to pass to the print function.
|
kwargs: Keyword arguments to pass to the print function.
|
||||||
"""
|
"""
|
||||||
if _LOG_LEVEL <= LogLevel.WARNING:
|
if _LOG_LEVEL <= LogLevel.WARNING:
|
||||||
|
if dedupe:
|
||||||
|
if msg in _EMIITED_WARNINGS:
|
||||||
|
return
|
||||||
|
else:
|
||||||
|
_EMIITED_WARNINGS.add(msg)
|
||||||
print(f"[orange1]Warning: {msg}[/orange1]", **kwargs)
|
print(f"[orange1]Warning: {msg}[/orange1]", **kwargs)
|
||||||
|
|
||||||
|
|
||||||
@ -169,14 +217,20 @@ def deprecate(
|
|||||||
_EMITTED_DEPRECATION_WARNINGS.add(feature_name)
|
_EMITTED_DEPRECATION_WARNINGS.add(feature_name)
|
||||||
|
|
||||||
|
|
||||||
def error(msg: str, **kwargs):
|
def error(msg: str, dedupe: bool = False, **kwargs):
|
||||||
"""Print an error message.
|
"""Print an error message.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
msg: The error message.
|
msg: The error message.
|
||||||
|
dedupe: If True, suppress multiple console logs of error message.
|
||||||
kwargs: Keyword arguments to pass to the print function.
|
kwargs: Keyword arguments to pass to the print function.
|
||||||
"""
|
"""
|
||||||
if _LOG_LEVEL <= LogLevel.ERROR:
|
if _LOG_LEVEL <= LogLevel.ERROR:
|
||||||
|
if dedupe:
|
||||||
|
if msg in _EMITTED_ERRORS:
|
||||||
|
return
|
||||||
|
else:
|
||||||
|
_EMITTED_ERRORS.add(msg)
|
||||||
print(f"[red]{msg}[/red]", **kwargs)
|
print(f"[red]{msg}[/red]", **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
@ -183,3 +183,7 @@ def raise_system_package_missing_error(package: str) -> NoReturn:
|
|||||||
" Please install it through your system package manager."
|
" Please install it through your system package manager."
|
||||||
+ (f" You can do so by running 'brew install {package}'." if IS_MACOS else "")
|
+ (f" You can do so by running 'brew install {package}'." if IS_MACOS else "")
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class InvalidLockWarningThresholdError(ReflexError):
|
||||||
|
"""Raised when an invalid lock warning threshold is provided."""
|
||||||
|
@ -56,6 +56,7 @@ from reflex.state import (
|
|||||||
from reflex.testing import chdir
|
from reflex.testing import chdir
|
||||||
from reflex.utils import format, prerequisites, types
|
from reflex.utils import format, prerequisites, types
|
||||||
from reflex.utils.exceptions import (
|
from reflex.utils.exceptions import (
|
||||||
|
InvalidLockWarningThresholdError,
|
||||||
ReflexRuntimeError,
|
ReflexRuntimeError,
|
||||||
SetUndefinedStateVarError,
|
SetUndefinedStateVarError,
|
||||||
StateSerializationError,
|
StateSerializationError,
|
||||||
@ -67,7 +68,9 @@ from tests.units.states.mutation import MutableSQLAModel, MutableTestState
|
|||||||
from .states import GenState
|
from .states import GenState
|
||||||
|
|
||||||
CI = bool(os.environ.get("CI", False))
|
CI = bool(os.environ.get("CI", False))
|
||||||
LOCK_EXPIRATION = 2000 if CI else 300
|
LOCK_EXPIRATION = 2500 if CI else 300
|
||||||
|
LOCK_WARNING_THRESHOLD = 1000 if CI else 100
|
||||||
|
LOCK_WARN_SLEEP = 1.5 if CI else 0.15
|
||||||
LOCK_EXPIRE_SLEEP = 2.5 if CI else 0.4
|
LOCK_EXPIRE_SLEEP = 2.5 if CI else 0.4
|
||||||
|
|
||||||
|
|
||||||
@ -1787,6 +1790,7 @@ async def test_state_manager_lock_expire(
|
|||||||
substate_token_redis: A token + substate name for looking up in state manager.
|
substate_token_redis: A token + substate name for looking up in state manager.
|
||||||
"""
|
"""
|
||||||
state_manager_redis.lock_expiration = LOCK_EXPIRATION
|
state_manager_redis.lock_expiration = LOCK_EXPIRATION
|
||||||
|
state_manager_redis.lock_warning_threshold = LOCK_WARNING_THRESHOLD
|
||||||
|
|
||||||
async with state_manager_redis.modify_state(substate_token_redis):
|
async with state_manager_redis.modify_state(substate_token_redis):
|
||||||
await asyncio.sleep(0.01)
|
await asyncio.sleep(0.01)
|
||||||
@ -1811,6 +1815,7 @@ async def test_state_manager_lock_expire_contend(
|
|||||||
unexp_num1 = 666
|
unexp_num1 = 666
|
||||||
|
|
||||||
state_manager_redis.lock_expiration = LOCK_EXPIRATION
|
state_manager_redis.lock_expiration = LOCK_EXPIRATION
|
||||||
|
state_manager_redis.lock_warning_threshold = LOCK_WARNING_THRESHOLD
|
||||||
|
|
||||||
order = []
|
order = []
|
||||||
|
|
||||||
@ -1840,6 +1845,39 @@ async def test_state_manager_lock_expire_contend(
|
|||||||
assert (await state_manager_redis.get_state(substate_token_redis)).num1 == exp_num1
|
assert (await state_manager_redis.get_state(substate_token_redis)).num1 == exp_num1
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_state_manager_lock_warning_threshold_contend(
|
||||||
|
state_manager_redis: StateManager, token: str, substate_token_redis: str, mocker
|
||||||
|
):
|
||||||
|
"""Test that the state manager triggers a warning when lock contention exceeds the warning threshold.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
state_manager_redis: A state manager instance.
|
||||||
|
token: A token.
|
||||||
|
substate_token_redis: A token + substate name for looking up in state manager.
|
||||||
|
mocker: Pytest mocker object.
|
||||||
|
"""
|
||||||
|
console_warn = mocker.patch("reflex.utils.console.warn")
|
||||||
|
|
||||||
|
state_manager_redis.lock_expiration = LOCK_EXPIRATION
|
||||||
|
state_manager_redis.lock_warning_threshold = LOCK_WARNING_THRESHOLD
|
||||||
|
|
||||||
|
order = []
|
||||||
|
|
||||||
|
async def _coro_blocker():
|
||||||
|
async with state_manager_redis.modify_state(substate_token_redis):
|
||||||
|
order.append("blocker")
|
||||||
|
await asyncio.sleep(LOCK_WARN_SLEEP)
|
||||||
|
|
||||||
|
tasks = [
|
||||||
|
asyncio.create_task(_coro_blocker()),
|
||||||
|
]
|
||||||
|
|
||||||
|
await tasks[0]
|
||||||
|
console_warn.assert_called()
|
||||||
|
assert console_warn.call_count == 7
|
||||||
|
|
||||||
|
|
||||||
class CopyingAsyncMock(AsyncMock):
|
class CopyingAsyncMock(AsyncMock):
|
||||||
"""An AsyncMock, but deepcopy the args and kwargs first."""
|
"""An AsyncMock, but deepcopy the args and kwargs first."""
|
||||||
|
|
||||||
@ -3253,12 +3291,42 @@ async def test_setvar_async_setter():
|
|||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"expiration_kwargs, expected_values",
|
"expiration_kwargs, expected_values",
|
||||||
[
|
[
|
||||||
({"redis_lock_expiration": 20000}, (20000, constants.Expiration.TOKEN)),
|
(
|
||||||
|
{"redis_lock_expiration": 20000},
|
||||||
|
(
|
||||||
|
20000,
|
||||||
|
constants.Expiration.TOKEN,
|
||||||
|
constants.Expiration.LOCK_WARNING_THRESHOLD,
|
||||||
|
),
|
||||||
|
),
|
||||||
(
|
(
|
||||||
{"redis_lock_expiration": 50000, "redis_token_expiration": 5600},
|
{"redis_lock_expiration": 50000, "redis_token_expiration": 5600},
|
||||||
(50000, 5600),
|
(50000, 5600, constants.Expiration.LOCK_WARNING_THRESHOLD),
|
||||||
|
),
|
||||||
|
(
|
||||||
|
{"redis_token_expiration": 7600},
|
||||||
|
(
|
||||||
|
constants.Expiration.LOCK,
|
||||||
|
7600,
|
||||||
|
constants.Expiration.LOCK_WARNING_THRESHOLD,
|
||||||
|
),
|
||||||
|
),
|
||||||
|
(
|
||||||
|
{"redis_lock_expiration": 50000, "redis_lock_warning_threshold": 1500},
|
||||||
|
(50000, constants.Expiration.TOKEN, 1500),
|
||||||
|
),
|
||||||
|
(
|
||||||
|
{"redis_token_expiration": 5600, "redis_lock_warning_threshold": 3000},
|
||||||
|
(constants.Expiration.LOCK, 5600, 3000),
|
||||||
|
),
|
||||||
|
(
|
||||||
|
{
|
||||||
|
"redis_lock_expiration": 50000,
|
||||||
|
"redis_token_expiration": 5600,
|
||||||
|
"redis_lock_warning_threshold": 2000,
|
||||||
|
},
|
||||||
|
(50000, 5600, 2000),
|
||||||
),
|
),
|
||||||
({"redis_token_expiration": 7600}, (constants.Expiration.LOCK, 7600)),
|
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
def test_redis_state_manager_config_knobs(tmp_path, expiration_kwargs, expected_values):
|
def test_redis_state_manager_config_knobs(tmp_path, expiration_kwargs, expected_values):
|
||||||
@ -3288,6 +3356,44 @@ config = rx.Config(
|
|||||||
state_manager = StateManager.create(state=State)
|
state_manager = StateManager.create(state=State)
|
||||||
assert state_manager.lock_expiration == expected_values[0] # type: ignore
|
assert state_manager.lock_expiration == expected_values[0] # type: ignore
|
||||||
assert state_manager.token_expiration == expected_values[1] # type: ignore
|
assert state_manager.token_expiration == expected_values[1] # type: ignore
|
||||||
|
assert state_manager.lock_warning_threshold == expected_values[2] # type: ignore
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skipif("REDIS_URL" not in os.environ, reason="Test requires redis")
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"redis_lock_expiration, redis_lock_warning_threshold",
|
||||||
|
[
|
||||||
|
(10000, 10000),
|
||||||
|
(20000, 30000),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_redis_state_manager_config_knobs_invalid_lock_warning_threshold(
|
||||||
|
tmp_path, redis_lock_expiration, redis_lock_warning_threshold
|
||||||
|
):
|
||||||
|
proj_root = tmp_path / "project1"
|
||||||
|
proj_root.mkdir()
|
||||||
|
|
||||||
|
config_string = f"""
|
||||||
|
import reflex as rx
|
||||||
|
config = rx.Config(
|
||||||
|
app_name="project1",
|
||||||
|
redis_url="redis://localhost:6379",
|
||||||
|
state_manager_mode="redis",
|
||||||
|
redis_lock_expiration = {redis_lock_expiration},
|
||||||
|
redis_lock_warning_threshold = {redis_lock_warning_threshold},
|
||||||
|
)
|
||||||
|
"""
|
||||||
|
|
||||||
|
(proj_root / "rxconfig.py").write_text(dedent(config_string))
|
||||||
|
|
||||||
|
with chdir(proj_root):
|
||||||
|
# reload config for each parameter to avoid stale values
|
||||||
|
reflex.config.get_config(reload=True)
|
||||||
|
from reflex.state import State, StateManager
|
||||||
|
|
||||||
|
with pytest.raises(InvalidLockWarningThresholdError):
|
||||||
|
StateManager.create(state=State)
|
||||||
|
del sys.modules[constants.Config.MODULE]
|
||||||
|
|
||||||
|
|
||||||
class MixinState(State, mixin=True):
|
class MixinState(State, mixin=True):
|
||||||
|
Loading…
Reference in New Issue
Block a user