From c387f517b6b3818679d494f182cda73ea516cc07 Mon Sep 17 00:00:00 2001 From: Elijah Ahianyo Date: Thu, 12 Dec 2024 19:36:31 +0000 Subject: [PATCH] [ENG-4100]Throw warnings when Redis lock is held for more than the allowed threshold (#4522) * Throw warnings when Redis lock is held for more than the allowed threshold * initial tests * fix tests and address comments * fix tests fr, and use pydantic validators * darglint fix * increase lock expiration in tests to 2500 * remove print statement --------- Co-authored-by: Khaleel Al-Adhami --- reflex/config.py | 3 + reflex/constants/config.py | 2 + reflex/state.py | 53 +++++++++++++++++ reflex/utils/console.py | 66 +++++++++++++++++++-- reflex/utils/exceptions.py | 4 ++ tests/units/test_state.py | 114 +++++++++++++++++++++++++++++++++++-- 6 files changed, 232 insertions(+), 10 deletions(-) diff --git a/reflex/config.py b/reflex/config.py index ae2c0ea0e..bbea6a5d0 100644 --- a/reflex/config.py +++ b/reflex/config.py @@ -684,6 +684,9 @@ class Config(Base): # Maximum expiration lock time for redis state manager redis_lock_expiration: int = constants.Expiration.LOCK + # Maximum lock time before warning for redis state manager. + redis_lock_warning_threshold: int = constants.Expiration.LOCK_WARNING_THRESHOLD + # Token expiration time for redis state manager redis_token_expiration: int = constants.Expiration.TOKEN diff --git a/reflex/constants/config.py b/reflex/constants/config.py index 970e67844..7425fd864 100644 --- a/reflex/constants/config.py +++ b/reflex/constants/config.py @@ -29,6 +29,8 @@ class Expiration(SimpleNamespace): LOCK = 10000 # The PING timeout PING = 120 + # The maximum time in milliseconds to hold a lock before throwing a warning. + LOCK_WARNING_THRESHOLD = 1000 class GitIgnore(SimpleNamespace): diff --git a/reflex/state.py b/reflex/state.py index 3e606bf57..434ee3921 100644 --- a/reflex/state.py +++ b/reflex/state.py @@ -71,6 +71,11 @@ try: except ModuleNotFoundError: BaseModelV1 = BaseModelV2 +try: + from pydantic.v1 import validator +except ModuleNotFoundError: + from pydantic import validator + import wrapt from redis.asyncio import Redis from redis.exceptions import ResponseError @@ -94,6 +99,7 @@ from reflex.utils.exceptions import ( DynamicRouteArgShadowsStateVar, EventHandlerShadowsBuiltInStateMethod, ImmutableStateError, + InvalidLockWarningThresholdError, InvalidStateManagerMode, LockExpiredError, ReflexRuntimeError, @@ -2834,6 +2840,7 @@ class StateManager(Base, ABC): redis=redis, token_expiration=config.redis_token_expiration, lock_expiration=config.redis_lock_expiration, + lock_warning_threshold=config.redis_lock_warning_threshold, ) raise InvalidStateManagerMode( f"Expected one of: DISK, MEMORY, REDIS, got {config.state_manager_mode}" @@ -3203,6 +3210,15 @@ def _default_lock_expiration() -> int: return get_config().redis_lock_expiration +def _default_lock_warning_threshold() -> int: + """Get the default lock warning threshold. + + Returns: + The default lock warning threshold. + """ + return get_config().redis_lock_warning_threshold + + class StateManagerRedis(StateManager): """A state manager that stores states in redis.""" @@ -3215,6 +3231,11 @@ class StateManagerRedis(StateManager): # The maximum time to hold a lock (ms). lock_expiration: int = pydantic.Field(default_factory=_default_lock_expiration) + # The maximum time to hold a lock (ms) before warning. + lock_warning_threshold: int = pydantic.Field( + default_factory=_default_lock_warning_threshold + ) + # The keyspace subscription string when redis is waiting for lock to be released _redis_notify_keyspace_events: str = ( "K" # Enable keyspace notifications (target a particular key) @@ -3402,6 +3423,17 @@ class StateManagerRedis(StateManager): f"`app.state_manager.lock_expiration` (currently {self.lock_expiration}) " "or use `@rx.event(background=True)` decorator for long-running tasks." ) + elif lock_id is not None: + time_taken = self.lock_expiration / 1000 - ( + await self.redis.ttl(self._lock_key(token)) + ) + if time_taken > self.lock_warning_threshold / 1000: + console.warn( + f"Lock for token {token} was held too long {time_taken=}s, " + f"use `@rx.event(background=True)` decorator for long-running tasks.", + dedupe=True, + ) + client_token, substate_name = _split_substate_key(token) # If the substate name on the token doesn't match the instance name, it cannot have a parent. if state.parent_state is not None and state.get_full_name() != substate_name: @@ -3451,6 +3483,27 @@ class StateManagerRedis(StateManager): yield state await self.set_state(token, state, lock_id) + @validator("lock_warning_threshold") + @classmethod + def validate_lock_warning_threshold(cls, lock_warning_threshold: int, values): + """Validate the lock warning threshold. + + Args: + lock_warning_threshold: The lock warning threshold. + values: The validated attributes. + + Returns: + The lock warning threshold. + + Raises: + InvalidLockWarningThresholdError: If the lock warning threshold is invalid. + """ + if lock_warning_threshold >= (lock_expiration := values["lock_expiration"]): + raise InvalidLockWarningThresholdError( + f"The lock warning threshold({lock_warning_threshold}) must be less than the lock expiration time({lock_expiration})." + ) + return lock_warning_threshold + @staticmethod def _lock_key(token: str) -> bytes: """Get the redis key for a token's lock. diff --git a/reflex/utils/console.py b/reflex/utils/console.py index b3ba7163d..be545140a 100644 --- a/reflex/utils/console.py +++ b/reflex/utils/console.py @@ -20,6 +20,24 @@ _EMITTED_DEPRECATION_WARNINGS = set() # Info messages which have been printed. _EMITTED_INFO = set() +# Warnings which have been printed. +_EMIITED_WARNINGS = set() + +# Errors which have been printed. +_EMITTED_ERRORS = set() + +# Success messages which have been printed. +_EMITTED_SUCCESS = set() + +# Debug messages which have been printed. +_EMITTED_DEBUG = set() + +# Logs which have been printed. +_EMITTED_LOGS = set() + +# Prints which have been printed. +_EMITTED_PRINTS = set() + def set_log_level(log_level: LogLevel): """Set the log level. @@ -55,25 +73,37 @@ def is_debug() -> bool: return _LOG_LEVEL <= LogLevel.DEBUG -def print(msg: str, **kwargs): +def print(msg: str, dedupe: bool = False, **kwargs): """Print a message. Args: msg: The message to print. + dedupe: If True, suppress multiple console logs of print message. kwargs: Keyword arguments to pass to the print function. """ + if dedupe: + if msg in _EMITTED_PRINTS: + return + else: + _EMITTED_PRINTS.add(msg) _console.print(msg, **kwargs) -def debug(msg: str, **kwargs): +def debug(msg: str, dedupe: bool = False, **kwargs): """Print a debug message. Args: msg: The debug message. + dedupe: If True, suppress multiple console logs of debug message. kwargs: Keyword arguments to pass to the print function. """ if is_debug(): msg_ = f"[purple]Debug: {msg}[/purple]" + if dedupe: + if msg_ in _EMITTED_DEBUG: + return + else: + _EMITTED_DEBUG.add(msg_) if progress := kwargs.pop("progress", None): progress.console.print(msg_, **kwargs) else: @@ -97,25 +127,37 @@ def info(msg: str, dedupe: bool = False, **kwargs): print(f"[cyan]Info: {msg}[/cyan]", **kwargs) -def success(msg: str, **kwargs): +def success(msg: str, dedupe: bool = False, **kwargs): """Print a success message. Args: msg: The success message. + dedupe: If True, suppress multiple console logs of success message. kwargs: Keyword arguments to pass to the print function. """ if _LOG_LEVEL <= LogLevel.INFO: + if dedupe: + if msg in _EMITTED_SUCCESS: + return + else: + _EMITTED_SUCCESS.add(msg) print(f"[green]Success: {msg}[/green]", **kwargs) -def log(msg: str, **kwargs): +def log(msg: str, dedupe: bool = False, **kwargs): """Takes a string and logs it to the console. Args: msg: The message to log. + dedupe: If True, suppress multiple console logs of log message. kwargs: Keyword arguments to pass to the print function. """ if _LOG_LEVEL <= LogLevel.INFO: + if dedupe: + if msg in _EMITTED_LOGS: + return + else: + _EMITTED_LOGS.add(msg) _console.log(msg, **kwargs) @@ -129,14 +171,20 @@ def rule(title: str, **kwargs): _console.rule(title, **kwargs) -def warn(msg: str, **kwargs): +def warn(msg: str, dedupe: bool = False, **kwargs): """Print a warning message. Args: msg: The warning message. + dedupe: If True, suppress multiple console logs of warning message. kwargs: Keyword arguments to pass to the print function. """ if _LOG_LEVEL <= LogLevel.WARNING: + if dedupe: + if msg in _EMIITED_WARNINGS: + return + else: + _EMIITED_WARNINGS.add(msg) print(f"[orange1]Warning: {msg}[/orange1]", **kwargs) @@ -169,14 +217,20 @@ def deprecate( _EMITTED_DEPRECATION_WARNINGS.add(feature_name) -def error(msg: str, **kwargs): +def error(msg: str, dedupe: bool = False, **kwargs): """Print an error message. Args: msg: The error message. + dedupe: If True, suppress multiple console logs of error message. kwargs: Keyword arguments to pass to the print function. """ if _LOG_LEVEL <= LogLevel.ERROR: + if dedupe: + if msg in _EMITTED_ERRORS: + return + else: + _EMITTED_ERRORS.add(msg) print(f"[red]{msg}[/red]", **kwargs) diff --git a/reflex/utils/exceptions.py b/reflex/utils/exceptions.py index 6c378e159..ae5ec0168 100644 --- a/reflex/utils/exceptions.py +++ b/reflex/utils/exceptions.py @@ -183,3 +183,7 @@ def raise_system_package_missing_error(package: str) -> NoReturn: " Please install it through your system package manager." + (f" You can do so by running 'brew install {package}'." if IS_MACOS else "") ) + + +class InvalidLockWarningThresholdError(ReflexError): + """Raised when an invalid lock warning threshold is provided.""" diff --git a/tests/units/test_state.py b/tests/units/test_state.py index 9e952e10f..912d72f4f 100644 --- a/tests/units/test_state.py +++ b/tests/units/test_state.py @@ -56,6 +56,7 @@ from reflex.state import ( from reflex.testing import chdir from reflex.utils import format, prerequisites, types from reflex.utils.exceptions import ( + InvalidLockWarningThresholdError, ReflexRuntimeError, SetUndefinedStateVarError, StateSerializationError, @@ -67,7 +68,9 @@ from tests.units.states.mutation import MutableSQLAModel, MutableTestState from .states import GenState CI = bool(os.environ.get("CI", False)) -LOCK_EXPIRATION = 2000 if CI else 300 +LOCK_EXPIRATION = 2500 if CI else 300 +LOCK_WARNING_THRESHOLD = 1000 if CI else 100 +LOCK_WARN_SLEEP = 1.5 if CI else 0.15 LOCK_EXPIRE_SLEEP = 2.5 if CI else 0.4 @@ -1787,6 +1790,7 @@ async def test_state_manager_lock_expire( substate_token_redis: A token + substate name for looking up in state manager. """ state_manager_redis.lock_expiration = LOCK_EXPIRATION + state_manager_redis.lock_warning_threshold = LOCK_WARNING_THRESHOLD async with state_manager_redis.modify_state(substate_token_redis): await asyncio.sleep(0.01) @@ -1811,6 +1815,7 @@ async def test_state_manager_lock_expire_contend( unexp_num1 = 666 state_manager_redis.lock_expiration = LOCK_EXPIRATION + state_manager_redis.lock_warning_threshold = LOCK_WARNING_THRESHOLD order = [] @@ -1840,6 +1845,39 @@ async def test_state_manager_lock_expire_contend( assert (await state_manager_redis.get_state(substate_token_redis)).num1 == exp_num1 +@pytest.mark.asyncio +async def test_state_manager_lock_warning_threshold_contend( + state_manager_redis: StateManager, token: str, substate_token_redis: str, mocker +): + """Test that the state manager triggers a warning when lock contention exceeds the warning threshold. + + Args: + state_manager_redis: A state manager instance. + token: A token. + substate_token_redis: A token + substate name for looking up in state manager. + mocker: Pytest mocker object. + """ + console_warn = mocker.patch("reflex.utils.console.warn") + + state_manager_redis.lock_expiration = LOCK_EXPIRATION + state_manager_redis.lock_warning_threshold = LOCK_WARNING_THRESHOLD + + order = [] + + async def _coro_blocker(): + async with state_manager_redis.modify_state(substate_token_redis): + order.append("blocker") + await asyncio.sleep(LOCK_WARN_SLEEP) + + tasks = [ + asyncio.create_task(_coro_blocker()), + ] + + await tasks[0] + console_warn.assert_called() + assert console_warn.call_count == 7 + + class CopyingAsyncMock(AsyncMock): """An AsyncMock, but deepcopy the args and kwargs first.""" @@ -3253,12 +3291,42 @@ async def test_setvar_async_setter(): @pytest.mark.parametrize( "expiration_kwargs, expected_values", [ - ({"redis_lock_expiration": 20000}, (20000, constants.Expiration.TOKEN)), + ( + {"redis_lock_expiration": 20000}, + ( + 20000, + constants.Expiration.TOKEN, + constants.Expiration.LOCK_WARNING_THRESHOLD, + ), + ), ( {"redis_lock_expiration": 50000, "redis_token_expiration": 5600}, - (50000, 5600), + (50000, 5600, constants.Expiration.LOCK_WARNING_THRESHOLD), + ), + ( + {"redis_token_expiration": 7600}, + ( + constants.Expiration.LOCK, + 7600, + constants.Expiration.LOCK_WARNING_THRESHOLD, + ), + ), + ( + {"redis_lock_expiration": 50000, "redis_lock_warning_threshold": 1500}, + (50000, constants.Expiration.TOKEN, 1500), + ), + ( + {"redis_token_expiration": 5600, "redis_lock_warning_threshold": 3000}, + (constants.Expiration.LOCK, 5600, 3000), + ), + ( + { + "redis_lock_expiration": 50000, + "redis_token_expiration": 5600, + "redis_lock_warning_threshold": 2000, + }, + (50000, 5600, 2000), ), - ({"redis_token_expiration": 7600}, (constants.Expiration.LOCK, 7600)), ], ) def test_redis_state_manager_config_knobs(tmp_path, expiration_kwargs, expected_values): @@ -3288,6 +3356,44 @@ config = rx.Config( state_manager = StateManager.create(state=State) assert state_manager.lock_expiration == expected_values[0] # type: ignore assert state_manager.token_expiration == expected_values[1] # type: ignore + assert state_manager.lock_warning_threshold == expected_values[2] # type: ignore + + +@pytest.mark.skipif("REDIS_URL" not in os.environ, reason="Test requires redis") +@pytest.mark.parametrize( + "redis_lock_expiration, redis_lock_warning_threshold", + [ + (10000, 10000), + (20000, 30000), + ], +) +def test_redis_state_manager_config_knobs_invalid_lock_warning_threshold( + tmp_path, redis_lock_expiration, redis_lock_warning_threshold +): + proj_root = tmp_path / "project1" + proj_root.mkdir() + + config_string = f""" +import reflex as rx +config = rx.Config( + app_name="project1", + redis_url="redis://localhost:6379", + state_manager_mode="redis", + redis_lock_expiration = {redis_lock_expiration}, + redis_lock_warning_threshold = {redis_lock_warning_threshold}, +) + """ + + (proj_root / "rxconfig.py").write_text(dedent(config_string)) + + with chdir(proj_root): + # reload config for each parameter to avoid stale values + reflex.config.get_config(reload=True) + from reflex.state import State, StateManager + + with pytest.raises(InvalidLockWarningThresholdError): + StateManager.create(state=State) + del sys.modules[constants.Config.MODULE] class MixinState(State, mixin=True):