[ENG-4100]Throw warnings when Redis lock is held for more than the allowed threshold (#4522)
* Throw warnings when Redis lock is held for more than the allowed threshold * initial tests * fix tests and address comments * fix tests fr, and use pydantic validators * darglint fix * increase lock expiration in tests to 2500 * remove print statement --------- Co-authored-by: Khaleel Al-Adhami <khaleel.aladhami@gmail.com>
This commit is contained in:
parent
2d9849e00a
commit
c387f517b6
@ -684,6 +684,9 @@ class Config(Base):
|
||||
# Maximum expiration lock time for redis state manager
|
||||
redis_lock_expiration: int = constants.Expiration.LOCK
|
||||
|
||||
# Maximum lock time before warning for redis state manager.
|
||||
redis_lock_warning_threshold: int = constants.Expiration.LOCK_WARNING_THRESHOLD
|
||||
|
||||
# Token expiration time for redis state manager
|
||||
redis_token_expiration: int = constants.Expiration.TOKEN
|
||||
|
||||
|
@ -29,6 +29,8 @@ class Expiration(SimpleNamespace):
|
||||
LOCK = 10000
|
||||
# The PING timeout
|
||||
PING = 120
|
||||
# The maximum time in milliseconds to hold a lock before throwing a warning.
|
||||
LOCK_WARNING_THRESHOLD = 1000
|
||||
|
||||
|
||||
class GitIgnore(SimpleNamespace):
|
||||
|
@ -71,6 +71,11 @@ try:
|
||||
except ModuleNotFoundError:
|
||||
BaseModelV1 = BaseModelV2
|
||||
|
||||
try:
|
||||
from pydantic.v1 import validator
|
||||
except ModuleNotFoundError:
|
||||
from pydantic import validator
|
||||
|
||||
import wrapt
|
||||
from redis.asyncio import Redis
|
||||
from redis.exceptions import ResponseError
|
||||
@ -94,6 +99,7 @@ from reflex.utils.exceptions import (
|
||||
DynamicRouteArgShadowsStateVar,
|
||||
EventHandlerShadowsBuiltInStateMethod,
|
||||
ImmutableStateError,
|
||||
InvalidLockWarningThresholdError,
|
||||
InvalidStateManagerMode,
|
||||
LockExpiredError,
|
||||
ReflexRuntimeError,
|
||||
@ -2834,6 +2840,7 @@ class StateManager(Base, ABC):
|
||||
redis=redis,
|
||||
token_expiration=config.redis_token_expiration,
|
||||
lock_expiration=config.redis_lock_expiration,
|
||||
lock_warning_threshold=config.redis_lock_warning_threshold,
|
||||
)
|
||||
raise InvalidStateManagerMode(
|
||||
f"Expected one of: DISK, MEMORY, REDIS, got {config.state_manager_mode}"
|
||||
@ -3203,6 +3210,15 @@ def _default_lock_expiration() -> int:
|
||||
return get_config().redis_lock_expiration
|
||||
|
||||
|
||||
def _default_lock_warning_threshold() -> int:
|
||||
"""Get the default lock warning threshold.
|
||||
|
||||
Returns:
|
||||
The default lock warning threshold.
|
||||
"""
|
||||
return get_config().redis_lock_warning_threshold
|
||||
|
||||
|
||||
class StateManagerRedis(StateManager):
|
||||
"""A state manager that stores states in redis."""
|
||||
|
||||
@ -3215,6 +3231,11 @@ class StateManagerRedis(StateManager):
|
||||
# The maximum time to hold a lock (ms).
|
||||
lock_expiration: int = pydantic.Field(default_factory=_default_lock_expiration)
|
||||
|
||||
# The maximum time to hold a lock (ms) before warning.
|
||||
lock_warning_threshold: int = pydantic.Field(
|
||||
default_factory=_default_lock_warning_threshold
|
||||
)
|
||||
|
||||
# The keyspace subscription string when redis is waiting for lock to be released
|
||||
_redis_notify_keyspace_events: str = (
|
||||
"K" # Enable keyspace notifications (target a particular key)
|
||||
@ -3402,6 +3423,17 @@ class StateManagerRedis(StateManager):
|
||||
f"`app.state_manager.lock_expiration` (currently {self.lock_expiration}) "
|
||||
"or use `@rx.event(background=True)` decorator for long-running tasks."
|
||||
)
|
||||
elif lock_id is not None:
|
||||
time_taken = self.lock_expiration / 1000 - (
|
||||
await self.redis.ttl(self._lock_key(token))
|
||||
)
|
||||
if time_taken > self.lock_warning_threshold / 1000:
|
||||
console.warn(
|
||||
f"Lock for token {token} was held too long {time_taken=}s, "
|
||||
f"use `@rx.event(background=True)` decorator for long-running tasks.",
|
||||
dedupe=True,
|
||||
)
|
||||
|
||||
client_token, substate_name = _split_substate_key(token)
|
||||
# If the substate name on the token doesn't match the instance name, it cannot have a parent.
|
||||
if state.parent_state is not None and state.get_full_name() != substate_name:
|
||||
@ -3451,6 +3483,27 @@ class StateManagerRedis(StateManager):
|
||||
yield state
|
||||
await self.set_state(token, state, lock_id)
|
||||
|
||||
@validator("lock_warning_threshold")
|
||||
@classmethod
|
||||
def validate_lock_warning_threshold(cls, lock_warning_threshold: int, values):
|
||||
"""Validate the lock warning threshold.
|
||||
|
||||
Args:
|
||||
lock_warning_threshold: The lock warning threshold.
|
||||
values: The validated attributes.
|
||||
|
||||
Returns:
|
||||
The lock warning threshold.
|
||||
|
||||
Raises:
|
||||
InvalidLockWarningThresholdError: If the lock warning threshold is invalid.
|
||||
"""
|
||||
if lock_warning_threshold >= (lock_expiration := values["lock_expiration"]):
|
||||
raise InvalidLockWarningThresholdError(
|
||||
f"The lock warning threshold({lock_warning_threshold}) must be less than the lock expiration time({lock_expiration})."
|
||||
)
|
||||
return lock_warning_threshold
|
||||
|
||||
@staticmethod
|
||||
def _lock_key(token: str) -> bytes:
|
||||
"""Get the redis key for a token's lock.
|
||||
|
@ -20,6 +20,24 @@ _EMITTED_DEPRECATION_WARNINGS = set()
|
||||
# Info messages which have been printed.
|
||||
_EMITTED_INFO = set()
|
||||
|
||||
# Warnings which have been printed.
|
||||
_EMIITED_WARNINGS = set()
|
||||
|
||||
# Errors which have been printed.
|
||||
_EMITTED_ERRORS = set()
|
||||
|
||||
# Success messages which have been printed.
|
||||
_EMITTED_SUCCESS = set()
|
||||
|
||||
# Debug messages which have been printed.
|
||||
_EMITTED_DEBUG = set()
|
||||
|
||||
# Logs which have been printed.
|
||||
_EMITTED_LOGS = set()
|
||||
|
||||
# Prints which have been printed.
|
||||
_EMITTED_PRINTS = set()
|
||||
|
||||
|
||||
def set_log_level(log_level: LogLevel):
|
||||
"""Set the log level.
|
||||
@ -55,25 +73,37 @@ def is_debug() -> bool:
|
||||
return _LOG_LEVEL <= LogLevel.DEBUG
|
||||
|
||||
|
||||
def print(msg: str, **kwargs):
|
||||
def print(msg: str, dedupe: bool = False, **kwargs):
|
||||
"""Print a message.
|
||||
|
||||
Args:
|
||||
msg: The message to print.
|
||||
dedupe: If True, suppress multiple console logs of print message.
|
||||
kwargs: Keyword arguments to pass to the print function.
|
||||
"""
|
||||
if dedupe:
|
||||
if msg in _EMITTED_PRINTS:
|
||||
return
|
||||
else:
|
||||
_EMITTED_PRINTS.add(msg)
|
||||
_console.print(msg, **kwargs)
|
||||
|
||||
|
||||
def debug(msg: str, **kwargs):
|
||||
def debug(msg: str, dedupe: bool = False, **kwargs):
|
||||
"""Print a debug message.
|
||||
|
||||
Args:
|
||||
msg: The debug message.
|
||||
dedupe: If True, suppress multiple console logs of debug message.
|
||||
kwargs: Keyword arguments to pass to the print function.
|
||||
"""
|
||||
if is_debug():
|
||||
msg_ = f"[purple]Debug: {msg}[/purple]"
|
||||
if dedupe:
|
||||
if msg_ in _EMITTED_DEBUG:
|
||||
return
|
||||
else:
|
||||
_EMITTED_DEBUG.add(msg_)
|
||||
if progress := kwargs.pop("progress", None):
|
||||
progress.console.print(msg_, **kwargs)
|
||||
else:
|
||||
@ -97,25 +127,37 @@ def info(msg: str, dedupe: bool = False, **kwargs):
|
||||
print(f"[cyan]Info: {msg}[/cyan]", **kwargs)
|
||||
|
||||
|
||||
def success(msg: str, **kwargs):
|
||||
def success(msg: str, dedupe: bool = False, **kwargs):
|
||||
"""Print a success message.
|
||||
|
||||
Args:
|
||||
msg: The success message.
|
||||
dedupe: If True, suppress multiple console logs of success message.
|
||||
kwargs: Keyword arguments to pass to the print function.
|
||||
"""
|
||||
if _LOG_LEVEL <= LogLevel.INFO:
|
||||
if dedupe:
|
||||
if msg in _EMITTED_SUCCESS:
|
||||
return
|
||||
else:
|
||||
_EMITTED_SUCCESS.add(msg)
|
||||
print(f"[green]Success: {msg}[/green]", **kwargs)
|
||||
|
||||
|
||||
def log(msg: str, **kwargs):
|
||||
def log(msg: str, dedupe: bool = False, **kwargs):
|
||||
"""Takes a string and logs it to the console.
|
||||
|
||||
Args:
|
||||
msg: The message to log.
|
||||
dedupe: If True, suppress multiple console logs of log message.
|
||||
kwargs: Keyword arguments to pass to the print function.
|
||||
"""
|
||||
if _LOG_LEVEL <= LogLevel.INFO:
|
||||
if dedupe:
|
||||
if msg in _EMITTED_LOGS:
|
||||
return
|
||||
else:
|
||||
_EMITTED_LOGS.add(msg)
|
||||
_console.log(msg, **kwargs)
|
||||
|
||||
|
||||
@ -129,14 +171,20 @@ def rule(title: str, **kwargs):
|
||||
_console.rule(title, **kwargs)
|
||||
|
||||
|
||||
def warn(msg: str, **kwargs):
|
||||
def warn(msg: str, dedupe: bool = False, **kwargs):
|
||||
"""Print a warning message.
|
||||
|
||||
Args:
|
||||
msg: The warning message.
|
||||
dedupe: If True, suppress multiple console logs of warning message.
|
||||
kwargs: Keyword arguments to pass to the print function.
|
||||
"""
|
||||
if _LOG_LEVEL <= LogLevel.WARNING:
|
||||
if dedupe:
|
||||
if msg in _EMIITED_WARNINGS:
|
||||
return
|
||||
else:
|
||||
_EMIITED_WARNINGS.add(msg)
|
||||
print(f"[orange1]Warning: {msg}[/orange1]", **kwargs)
|
||||
|
||||
|
||||
@ -169,14 +217,20 @@ def deprecate(
|
||||
_EMITTED_DEPRECATION_WARNINGS.add(feature_name)
|
||||
|
||||
|
||||
def error(msg: str, **kwargs):
|
||||
def error(msg: str, dedupe: bool = False, **kwargs):
|
||||
"""Print an error message.
|
||||
|
||||
Args:
|
||||
msg: The error message.
|
||||
dedupe: If True, suppress multiple console logs of error message.
|
||||
kwargs: Keyword arguments to pass to the print function.
|
||||
"""
|
||||
if _LOG_LEVEL <= LogLevel.ERROR:
|
||||
if dedupe:
|
||||
if msg in _EMITTED_ERRORS:
|
||||
return
|
||||
else:
|
||||
_EMITTED_ERRORS.add(msg)
|
||||
print(f"[red]{msg}[/red]", **kwargs)
|
||||
|
||||
|
||||
|
@ -183,3 +183,7 @@ def raise_system_package_missing_error(package: str) -> NoReturn:
|
||||
" Please install it through your system package manager."
|
||||
+ (f" You can do so by running 'brew install {package}'." if IS_MACOS else "")
|
||||
)
|
||||
|
||||
|
||||
class InvalidLockWarningThresholdError(ReflexError):
|
||||
"""Raised when an invalid lock warning threshold is provided."""
|
||||
|
@ -56,6 +56,7 @@ from reflex.state import (
|
||||
from reflex.testing import chdir
|
||||
from reflex.utils import format, prerequisites, types
|
||||
from reflex.utils.exceptions import (
|
||||
InvalidLockWarningThresholdError,
|
||||
ReflexRuntimeError,
|
||||
SetUndefinedStateVarError,
|
||||
StateSerializationError,
|
||||
@ -67,7 +68,9 @@ from tests.units.states.mutation import MutableSQLAModel, MutableTestState
|
||||
from .states import GenState
|
||||
|
||||
CI = bool(os.environ.get("CI", False))
|
||||
LOCK_EXPIRATION = 2000 if CI else 300
|
||||
LOCK_EXPIRATION = 2500 if CI else 300
|
||||
LOCK_WARNING_THRESHOLD = 1000 if CI else 100
|
||||
LOCK_WARN_SLEEP = 1.5 if CI else 0.15
|
||||
LOCK_EXPIRE_SLEEP = 2.5 if CI else 0.4
|
||||
|
||||
|
||||
@ -1787,6 +1790,7 @@ async def test_state_manager_lock_expire(
|
||||
substate_token_redis: A token + substate name for looking up in state manager.
|
||||
"""
|
||||
state_manager_redis.lock_expiration = LOCK_EXPIRATION
|
||||
state_manager_redis.lock_warning_threshold = LOCK_WARNING_THRESHOLD
|
||||
|
||||
async with state_manager_redis.modify_state(substate_token_redis):
|
||||
await asyncio.sleep(0.01)
|
||||
@ -1811,6 +1815,7 @@ async def test_state_manager_lock_expire_contend(
|
||||
unexp_num1 = 666
|
||||
|
||||
state_manager_redis.lock_expiration = LOCK_EXPIRATION
|
||||
state_manager_redis.lock_warning_threshold = LOCK_WARNING_THRESHOLD
|
||||
|
||||
order = []
|
||||
|
||||
@ -1840,6 +1845,39 @@ async def test_state_manager_lock_expire_contend(
|
||||
assert (await state_manager_redis.get_state(substate_token_redis)).num1 == exp_num1
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_state_manager_lock_warning_threshold_contend(
|
||||
state_manager_redis: StateManager, token: str, substate_token_redis: str, mocker
|
||||
):
|
||||
"""Test that the state manager triggers a warning when lock contention exceeds the warning threshold.
|
||||
|
||||
Args:
|
||||
state_manager_redis: A state manager instance.
|
||||
token: A token.
|
||||
substate_token_redis: A token + substate name for looking up in state manager.
|
||||
mocker: Pytest mocker object.
|
||||
"""
|
||||
console_warn = mocker.patch("reflex.utils.console.warn")
|
||||
|
||||
state_manager_redis.lock_expiration = LOCK_EXPIRATION
|
||||
state_manager_redis.lock_warning_threshold = LOCK_WARNING_THRESHOLD
|
||||
|
||||
order = []
|
||||
|
||||
async def _coro_blocker():
|
||||
async with state_manager_redis.modify_state(substate_token_redis):
|
||||
order.append("blocker")
|
||||
await asyncio.sleep(LOCK_WARN_SLEEP)
|
||||
|
||||
tasks = [
|
||||
asyncio.create_task(_coro_blocker()),
|
||||
]
|
||||
|
||||
await tasks[0]
|
||||
console_warn.assert_called()
|
||||
assert console_warn.call_count == 7
|
||||
|
||||
|
||||
class CopyingAsyncMock(AsyncMock):
|
||||
"""An AsyncMock, but deepcopy the args and kwargs first."""
|
||||
|
||||
@ -3253,12 +3291,42 @@ async def test_setvar_async_setter():
|
||||
@pytest.mark.parametrize(
|
||||
"expiration_kwargs, expected_values",
|
||||
[
|
||||
({"redis_lock_expiration": 20000}, (20000, constants.Expiration.TOKEN)),
|
||||
(
|
||||
{"redis_lock_expiration": 20000},
|
||||
(
|
||||
20000,
|
||||
constants.Expiration.TOKEN,
|
||||
constants.Expiration.LOCK_WARNING_THRESHOLD,
|
||||
),
|
||||
),
|
||||
(
|
||||
{"redis_lock_expiration": 50000, "redis_token_expiration": 5600},
|
||||
(50000, 5600),
|
||||
(50000, 5600, constants.Expiration.LOCK_WARNING_THRESHOLD),
|
||||
),
|
||||
(
|
||||
{"redis_token_expiration": 7600},
|
||||
(
|
||||
constants.Expiration.LOCK,
|
||||
7600,
|
||||
constants.Expiration.LOCK_WARNING_THRESHOLD,
|
||||
),
|
||||
),
|
||||
(
|
||||
{"redis_lock_expiration": 50000, "redis_lock_warning_threshold": 1500},
|
||||
(50000, constants.Expiration.TOKEN, 1500),
|
||||
),
|
||||
(
|
||||
{"redis_token_expiration": 5600, "redis_lock_warning_threshold": 3000},
|
||||
(constants.Expiration.LOCK, 5600, 3000),
|
||||
),
|
||||
(
|
||||
{
|
||||
"redis_lock_expiration": 50000,
|
||||
"redis_token_expiration": 5600,
|
||||
"redis_lock_warning_threshold": 2000,
|
||||
},
|
||||
(50000, 5600, 2000),
|
||||
),
|
||||
({"redis_token_expiration": 7600}, (constants.Expiration.LOCK, 7600)),
|
||||
],
|
||||
)
|
||||
def test_redis_state_manager_config_knobs(tmp_path, expiration_kwargs, expected_values):
|
||||
@ -3288,6 +3356,44 @@ config = rx.Config(
|
||||
state_manager = StateManager.create(state=State)
|
||||
assert state_manager.lock_expiration == expected_values[0] # type: ignore
|
||||
assert state_manager.token_expiration == expected_values[1] # type: ignore
|
||||
assert state_manager.lock_warning_threshold == expected_values[2] # type: ignore
|
||||
|
||||
|
||||
@pytest.mark.skipif("REDIS_URL" not in os.environ, reason="Test requires redis")
|
||||
@pytest.mark.parametrize(
|
||||
"redis_lock_expiration, redis_lock_warning_threshold",
|
||||
[
|
||||
(10000, 10000),
|
||||
(20000, 30000),
|
||||
],
|
||||
)
|
||||
def test_redis_state_manager_config_knobs_invalid_lock_warning_threshold(
|
||||
tmp_path, redis_lock_expiration, redis_lock_warning_threshold
|
||||
):
|
||||
proj_root = tmp_path / "project1"
|
||||
proj_root.mkdir()
|
||||
|
||||
config_string = f"""
|
||||
import reflex as rx
|
||||
config = rx.Config(
|
||||
app_name="project1",
|
||||
redis_url="redis://localhost:6379",
|
||||
state_manager_mode="redis",
|
||||
redis_lock_expiration = {redis_lock_expiration},
|
||||
redis_lock_warning_threshold = {redis_lock_warning_threshold},
|
||||
)
|
||||
"""
|
||||
|
||||
(proj_root / "rxconfig.py").write_text(dedent(config_string))
|
||||
|
||||
with chdir(proj_root):
|
||||
# reload config for each parameter to avoid stale values
|
||||
reflex.config.get_config(reload=True)
|
||||
from reflex.state import State, StateManager
|
||||
|
||||
with pytest.raises(InvalidLockWarningThresholdError):
|
||||
StateManager.create(state=State)
|
||||
del sys.modules[constants.Config.MODULE]
|
||||
|
||||
|
||||
class MixinState(State, mixin=True):
|
||||
|
Loading…
Reference in New Issue
Block a user