diff --git a/tests/units/test_state.py b/tests/units/test_state.py index 9e952e10f..d79a070db 100644 --- a/tests/units/test_state.py +++ b/tests/units/test_state.py @@ -56,6 +56,7 @@ from reflex.state import ( from reflex.testing import chdir from reflex.utils import format, prerequisites, types from reflex.utils.exceptions import ( + InvalidLockWarningThresholdError, ReflexRuntimeError, SetUndefinedStateVarError, StateSerializationError, @@ -68,6 +69,8 @@ from .states import GenState CI = bool(os.environ.get("CI", False)) LOCK_EXPIRATION = 2000 if CI else 300 +LOCK_WARNING_THRESHOLD = 1000 if CI else 200 +LOCK_WARN_SLEEP = 1.5 if CI else 0.25 LOCK_EXPIRE_SLEEP = 2.5 if CI else 0.4 @@ -1840,6 +1843,39 @@ async def test_state_manager_lock_expire_contend( assert (await state_manager_redis.get_state(substate_token_redis)).num1 == exp_num1 +@pytest.mark.asyncio +async def test_state_manager_lock_warning_threshold_contend( + state_manager_redis: StateManager, token: str, substate_token_redis: str, mocker +): + """Test that the state manager triggers a warning when lock contention exceeds the warning threshold. + + Args: + state_manager_redis: A state manager instance. + token: A token. + substate_token_redis: A token + substate name for looking up in state manager. + mocker: Pytest mocker object. + """ + console_warn = mocker.patch("reflex.utils.console.warn") + + state_manager_redis.lock_expiration = LOCK_EXPIRATION + state_manager_redis.lock_warning_threshold = LOCK_WARNING_THRESHOLD + + order = [] + + async def _coro_blocker(): + async with state_manager_redis.modify_state(substate_token_redis): + order.append("blocker") + await asyncio.sleep(LOCK_WARN_SLEEP) + + tasks = [ + asyncio.create_task(_coro_blocker()), + ] + + await tasks[0] + console_warn.assert_called() + assert console_warn.call_count == 7 + + class CopyingAsyncMock(AsyncMock): """An AsyncMock, but deepcopy the args and kwargs first.""" @@ -3259,6 +3295,10 @@ async def test_setvar_async_setter(): (50000, 5600), ), ({"redis_token_expiration": 7600}, (constants.Expiration.LOCK, 7600)), + ( + {"redis_lock_expiration": 50000, "redis_lock_warning_threshold": 2000}, + (50000, 2000), + ), ], ) def test_redis_state_manager_config_knobs(tmp_path, expiration_kwargs, expected_values): @@ -3290,6 +3330,42 @@ config = rx.Config( assert state_manager.token_expiration == expected_values[1] # type: ignore +@pytest.mark.skipif("REDIS_URL" not in os.environ, reason="Test requires redis") +@pytest.mark.parametrize( + "redis_lock_expiration, redis_lock_warning_threshold", + [ + (10000, 10000), + (20000, 30000), + ], +) +def test_redis_state_manager_config_knobs_invalid_lock_warning_threshold( + tmp_path, redis_lock_expiration, redis_lock_warning_threshold +): + proj_root = tmp_path / "project1" + proj_root.mkdir() + + config_string = f""" +import reflex as rx +config = rx.Config( + app_name="project1", + redis_url="redis://localhost:6379", + state_manager_mode="redis", + redis_lock_expiration = {redis_lock_expiration}, + redis_lock_warning_threshold = {redis_lock_warning_threshold}, +) + """ + + (proj_root / "rxconfig.py").write_text(dedent(config_string)) + + with chdir(proj_root): + # reload config for each parameter to avoid stale values + reflex.config.get_config(reload=True) + from reflex.state import State, StateManager + + with pytest.raises(InvalidLockWarningThresholdError): + StateManager.create(state=State) + + class MixinState(State, mixin=True): """A mixin state for testing."""