[ENG-4100]Throw warnings when Redis lock is held for more than the allowed threshold (#4522)

* Throw warnings when Redis lock is held for more than the allowed threshold

* initial tests

* fix tests and address comments

* fix tests fr, and use pydantic validators

* darglint fix

* increase lock expiration in tests to 2500

* remove print statement

---------

Co-authored-by: Khaleel Al-Adhami <khaleel.aladhami@gmail.com>
This commit is contained in:
Elijah Ahianyo 2024-12-12 19:36:31 +00:00 committed by GitHub
parent 2d9849e00a
commit c387f517b6
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 232 additions and 10 deletions

View File

@ -684,6 +684,9 @@ class Config(Base):
# Maximum expiration lock time for redis state manager
redis_lock_expiration: int = constants.Expiration.LOCK
# Maximum lock time before warning for redis state manager.
redis_lock_warning_threshold: int = constants.Expiration.LOCK_WARNING_THRESHOLD
# Token expiration time for redis state manager
redis_token_expiration: int = constants.Expiration.TOKEN

View File

@ -29,6 +29,8 @@ class Expiration(SimpleNamespace):
LOCK = 10000
# The PING timeout
PING = 120
# The maximum time in milliseconds to hold a lock before throwing a warning.
LOCK_WARNING_THRESHOLD = 1000
class GitIgnore(SimpleNamespace):

View File

@ -71,6 +71,11 @@ try:
except ModuleNotFoundError:
BaseModelV1 = BaseModelV2
try:
from pydantic.v1 import validator
except ModuleNotFoundError:
from pydantic import validator
import wrapt
from redis.asyncio import Redis
from redis.exceptions import ResponseError
@ -94,6 +99,7 @@ from reflex.utils.exceptions import (
DynamicRouteArgShadowsStateVar,
EventHandlerShadowsBuiltInStateMethod,
ImmutableStateError,
InvalidLockWarningThresholdError,
InvalidStateManagerMode,
LockExpiredError,
ReflexRuntimeError,
@ -2834,6 +2840,7 @@ class StateManager(Base, ABC):
redis=redis,
token_expiration=config.redis_token_expiration,
lock_expiration=config.redis_lock_expiration,
lock_warning_threshold=config.redis_lock_warning_threshold,
)
raise InvalidStateManagerMode(
f"Expected one of: DISK, MEMORY, REDIS, got {config.state_manager_mode}"
@ -3203,6 +3210,15 @@ def _default_lock_expiration() -> int:
return get_config().redis_lock_expiration
def _default_lock_warning_threshold() -> int:
"""Get the default lock warning threshold.
Returns:
The default lock warning threshold.
"""
return get_config().redis_lock_warning_threshold
class StateManagerRedis(StateManager):
"""A state manager that stores states in redis."""
@ -3215,6 +3231,11 @@ class StateManagerRedis(StateManager):
# The maximum time to hold a lock (ms).
lock_expiration: int = pydantic.Field(default_factory=_default_lock_expiration)
# The maximum time to hold a lock (ms) before warning.
lock_warning_threshold: int = pydantic.Field(
default_factory=_default_lock_warning_threshold
)
# The keyspace subscription string when redis is waiting for lock to be released
_redis_notify_keyspace_events: str = (
"K" # Enable keyspace notifications (target a particular key)
@ -3402,6 +3423,17 @@ class StateManagerRedis(StateManager):
f"`app.state_manager.lock_expiration` (currently {self.lock_expiration}) "
"or use `@rx.event(background=True)` decorator for long-running tasks."
)
elif lock_id is not None:
time_taken = self.lock_expiration / 1000 - (
await self.redis.ttl(self._lock_key(token))
)
if time_taken > self.lock_warning_threshold / 1000:
console.warn(
f"Lock for token {token} was held too long {time_taken=}s, "
f"use `@rx.event(background=True)` decorator for long-running tasks.",
dedupe=True,
)
client_token, substate_name = _split_substate_key(token)
# If the substate name on the token doesn't match the instance name, it cannot have a parent.
if state.parent_state is not None and state.get_full_name() != substate_name:
@ -3451,6 +3483,27 @@ class StateManagerRedis(StateManager):
yield state
await self.set_state(token, state, lock_id)
@validator("lock_warning_threshold")
@classmethod
def validate_lock_warning_threshold(cls, lock_warning_threshold: int, values):
"""Validate the lock warning threshold.
Args:
lock_warning_threshold: The lock warning threshold.
values: The validated attributes.
Returns:
The lock warning threshold.
Raises:
InvalidLockWarningThresholdError: If the lock warning threshold is invalid.
"""
if lock_warning_threshold >= (lock_expiration := values["lock_expiration"]):
raise InvalidLockWarningThresholdError(
f"The lock warning threshold({lock_warning_threshold}) must be less than the lock expiration time({lock_expiration})."
)
return lock_warning_threshold
@staticmethod
def _lock_key(token: str) -> bytes:
"""Get the redis key for a token's lock.

View File

@ -20,6 +20,24 @@ _EMITTED_DEPRECATION_WARNINGS = set()
# Info messages which have been printed.
_EMITTED_INFO = set()
# Warnings which have been printed.
_EMIITED_WARNINGS = set()
# Errors which have been printed.
_EMITTED_ERRORS = set()
# Success messages which have been printed.
_EMITTED_SUCCESS = set()
# Debug messages which have been printed.
_EMITTED_DEBUG = set()
# Logs which have been printed.
_EMITTED_LOGS = set()
# Prints which have been printed.
_EMITTED_PRINTS = set()
def set_log_level(log_level: LogLevel):
"""Set the log level.
@ -55,25 +73,37 @@ def is_debug() -> bool:
return _LOG_LEVEL <= LogLevel.DEBUG
def print(msg: str, **kwargs):
def print(msg: str, dedupe: bool = False, **kwargs):
"""Print a message.
Args:
msg: The message to print.
dedupe: If True, suppress multiple console logs of print message.
kwargs: Keyword arguments to pass to the print function.
"""
if dedupe:
if msg in _EMITTED_PRINTS:
return
else:
_EMITTED_PRINTS.add(msg)
_console.print(msg, **kwargs)
def debug(msg: str, **kwargs):
def debug(msg: str, dedupe: bool = False, **kwargs):
"""Print a debug message.
Args:
msg: The debug message.
dedupe: If True, suppress multiple console logs of debug message.
kwargs: Keyword arguments to pass to the print function.
"""
if is_debug():
msg_ = f"[purple]Debug: {msg}[/purple]"
if dedupe:
if msg_ in _EMITTED_DEBUG:
return
else:
_EMITTED_DEBUG.add(msg_)
if progress := kwargs.pop("progress", None):
progress.console.print(msg_, **kwargs)
else:
@ -97,25 +127,37 @@ def info(msg: str, dedupe: bool = False, **kwargs):
print(f"[cyan]Info: {msg}[/cyan]", **kwargs)
def success(msg: str, **kwargs):
def success(msg: str, dedupe: bool = False, **kwargs):
"""Print a success message.
Args:
msg: The success message.
dedupe: If True, suppress multiple console logs of success message.
kwargs: Keyword arguments to pass to the print function.
"""
if _LOG_LEVEL <= LogLevel.INFO:
if dedupe:
if msg in _EMITTED_SUCCESS:
return
else:
_EMITTED_SUCCESS.add(msg)
print(f"[green]Success: {msg}[/green]", **kwargs)
def log(msg: str, **kwargs):
def log(msg: str, dedupe: bool = False, **kwargs):
"""Takes a string and logs it to the console.
Args:
msg: The message to log.
dedupe: If True, suppress multiple console logs of log message.
kwargs: Keyword arguments to pass to the print function.
"""
if _LOG_LEVEL <= LogLevel.INFO:
if dedupe:
if msg in _EMITTED_LOGS:
return
else:
_EMITTED_LOGS.add(msg)
_console.log(msg, **kwargs)
@ -129,14 +171,20 @@ def rule(title: str, **kwargs):
_console.rule(title, **kwargs)
def warn(msg: str, **kwargs):
def warn(msg: str, dedupe: bool = False, **kwargs):
"""Print a warning message.
Args:
msg: The warning message.
dedupe: If True, suppress multiple console logs of warning message.
kwargs: Keyword arguments to pass to the print function.
"""
if _LOG_LEVEL <= LogLevel.WARNING:
if dedupe:
if msg in _EMIITED_WARNINGS:
return
else:
_EMIITED_WARNINGS.add(msg)
print(f"[orange1]Warning: {msg}[/orange1]", **kwargs)
@ -169,14 +217,20 @@ def deprecate(
_EMITTED_DEPRECATION_WARNINGS.add(feature_name)
def error(msg: str, **kwargs):
def error(msg: str, dedupe: bool = False, **kwargs):
"""Print an error message.
Args:
msg: The error message.
dedupe: If True, suppress multiple console logs of error message.
kwargs: Keyword arguments to pass to the print function.
"""
if _LOG_LEVEL <= LogLevel.ERROR:
if dedupe:
if msg in _EMITTED_ERRORS:
return
else:
_EMITTED_ERRORS.add(msg)
print(f"[red]{msg}[/red]", **kwargs)

View File

@ -183,3 +183,7 @@ def raise_system_package_missing_error(package: str) -> NoReturn:
" Please install it through your system package manager."
+ (f" You can do so by running 'brew install {package}'." if IS_MACOS else "")
)
class InvalidLockWarningThresholdError(ReflexError):
"""Raised when an invalid lock warning threshold is provided."""

View File

@ -56,6 +56,7 @@ from reflex.state import (
from reflex.testing import chdir
from reflex.utils import format, prerequisites, types
from reflex.utils.exceptions import (
InvalidLockWarningThresholdError,
ReflexRuntimeError,
SetUndefinedStateVarError,
StateSerializationError,
@ -67,7 +68,9 @@ from tests.units.states.mutation import MutableSQLAModel, MutableTestState
from .states import GenState
CI = bool(os.environ.get("CI", False))
LOCK_EXPIRATION = 2000 if CI else 300
LOCK_EXPIRATION = 2500 if CI else 300
LOCK_WARNING_THRESHOLD = 1000 if CI else 100
LOCK_WARN_SLEEP = 1.5 if CI else 0.15
LOCK_EXPIRE_SLEEP = 2.5 if CI else 0.4
@ -1787,6 +1790,7 @@ async def test_state_manager_lock_expire(
substate_token_redis: A token + substate name for looking up in state manager.
"""
state_manager_redis.lock_expiration = LOCK_EXPIRATION
state_manager_redis.lock_warning_threshold = LOCK_WARNING_THRESHOLD
async with state_manager_redis.modify_state(substate_token_redis):
await asyncio.sleep(0.01)
@ -1811,6 +1815,7 @@ async def test_state_manager_lock_expire_contend(
unexp_num1 = 666
state_manager_redis.lock_expiration = LOCK_EXPIRATION
state_manager_redis.lock_warning_threshold = LOCK_WARNING_THRESHOLD
order = []
@ -1840,6 +1845,39 @@ async def test_state_manager_lock_expire_contend(
assert (await state_manager_redis.get_state(substate_token_redis)).num1 == exp_num1
@pytest.mark.asyncio
async def test_state_manager_lock_warning_threshold_contend(
state_manager_redis: StateManager, token: str, substate_token_redis: str, mocker
):
"""Test that the state manager triggers a warning when lock contention exceeds the warning threshold.
Args:
state_manager_redis: A state manager instance.
token: A token.
substate_token_redis: A token + substate name for looking up in state manager.
mocker: Pytest mocker object.
"""
console_warn = mocker.patch("reflex.utils.console.warn")
state_manager_redis.lock_expiration = LOCK_EXPIRATION
state_manager_redis.lock_warning_threshold = LOCK_WARNING_THRESHOLD
order = []
async def _coro_blocker():
async with state_manager_redis.modify_state(substate_token_redis):
order.append("blocker")
await asyncio.sleep(LOCK_WARN_SLEEP)
tasks = [
asyncio.create_task(_coro_blocker()),
]
await tasks[0]
console_warn.assert_called()
assert console_warn.call_count == 7
class CopyingAsyncMock(AsyncMock):
"""An AsyncMock, but deepcopy the args and kwargs first."""
@ -3253,12 +3291,42 @@ async def test_setvar_async_setter():
@pytest.mark.parametrize(
"expiration_kwargs, expected_values",
[
({"redis_lock_expiration": 20000}, (20000, constants.Expiration.TOKEN)),
(
{"redis_lock_expiration": 20000},
(
20000,
constants.Expiration.TOKEN,
constants.Expiration.LOCK_WARNING_THRESHOLD,
),
),
(
{"redis_lock_expiration": 50000, "redis_token_expiration": 5600},
(50000, 5600),
(50000, 5600, constants.Expiration.LOCK_WARNING_THRESHOLD),
),
(
{"redis_token_expiration": 7600},
(
constants.Expiration.LOCK,
7600,
constants.Expiration.LOCK_WARNING_THRESHOLD,
),
),
(
{"redis_lock_expiration": 50000, "redis_lock_warning_threshold": 1500},
(50000, constants.Expiration.TOKEN, 1500),
),
(
{"redis_token_expiration": 5600, "redis_lock_warning_threshold": 3000},
(constants.Expiration.LOCK, 5600, 3000),
),
(
{
"redis_lock_expiration": 50000,
"redis_token_expiration": 5600,
"redis_lock_warning_threshold": 2000,
},
(50000, 5600, 2000),
),
({"redis_token_expiration": 7600}, (constants.Expiration.LOCK, 7600)),
],
)
def test_redis_state_manager_config_knobs(tmp_path, expiration_kwargs, expected_values):
@ -3288,6 +3356,44 @@ config = rx.Config(
state_manager = StateManager.create(state=State)
assert state_manager.lock_expiration == expected_values[0] # type: ignore
assert state_manager.token_expiration == expected_values[1] # type: ignore
assert state_manager.lock_warning_threshold == expected_values[2] # type: ignore
@pytest.mark.skipif("REDIS_URL" not in os.environ, reason="Test requires redis")
@pytest.mark.parametrize(
"redis_lock_expiration, redis_lock_warning_threshold",
[
(10000, 10000),
(20000, 30000),
],
)
def test_redis_state_manager_config_knobs_invalid_lock_warning_threshold(
tmp_path, redis_lock_expiration, redis_lock_warning_threshold
):
proj_root = tmp_path / "project1"
proj_root.mkdir()
config_string = f"""
import reflex as rx
config = rx.Config(
app_name="project1",
redis_url="redis://localhost:6379",
state_manager_mode="redis",
redis_lock_expiration = {redis_lock_expiration},
redis_lock_warning_threshold = {redis_lock_warning_threshold},
)
"""
(proj_root / "rxconfig.py").write_text(dedent(config_string))
with chdir(proj_root):
# reload config for each parameter to avoid stale values
reflex.config.get_config(reload=True)
from reflex.state import State, StateManager
with pytest.raises(InvalidLockWarningThresholdError):
StateManager.create(state=State)
del sys.modules[constants.Config.MODULE]
class MixinState(State, mixin=True):