diff --git a/reflex/state.py b/reflex/state.py index 83c933885..1eb33b122 100644 --- a/reflex/state.py +++ b/reflex/state.py @@ -71,6 +71,11 @@ try: except ModuleNotFoundError: BaseModelV1 = BaseModelV2 +try: + from pydantic.v1 import validator +except ModuleNotFoundError: + from pydantic import validator + import wrapt from redis.asyncio import Redis from redis.exceptions import ResponseError @@ -2835,6 +2840,7 @@ class StateManager(Base, ABC): redis=redis, token_expiration=config.redis_token_expiration, lock_expiration=config.redis_lock_expiration, + lock_warning_threshold=config.redis_lock_warning_threshold, ) raise InvalidStateManagerMode( f"Expected one of: DISK, MEMORY, REDIS, got {config.state_manager_mode}" @@ -3210,22 +3216,7 @@ def _default_lock_warning_threshold() -> int: Returns: The default lock warning threshold. """ - lock_warning_threshold = get_config().redis_lock_warning_threshold - _validate_lock_warning_threshold(lock_warning_threshold, _default_lock_expiration()) - return lock_warning_threshold - - -def _validate_lock_warning_threshold(lock_warning_threshold: int, lock_expiration: int): - """Validate the lock warning threshold. - - Args: - lock_warning_threshold: The lock warning threshold. - lock_expiration: The lock expiration time. - """ - if lock_warning_threshold >= lock_expiration: - raise InvalidLockWarningThresholdError( - f"The lock warning threshold({lock_warning_threshold}) must be less than the lock expiration time({lock_expiration})." - ) + return get_config().redis_lock_warning_threshold class StateManagerRedis(StateManager): @@ -3436,6 +3427,9 @@ class StateManagerRedis(StateManager): time_taken = self.lock_expiration / 1000 - ( await self.redis.ttl(self._lock_key(token)) ) + print( + f"Time taken to set state: {time_taken} - {self.lock_warning_threshold}" + ) if time_taken > self.lock_warning_threshold / 1000: console.warn( f"Lock for token {token} was held too long {time_taken=}s, " @@ -3492,6 +3486,21 @@ class StateManagerRedis(StateManager): yield state await self.set_state(token, state, lock_id) + @validator("lock_warning_threshold") + @classmethod + def validate_lock_warning_threshold(cls, lock_warning_threshold: int, values): + """Validate the lock warning threshold. + + Args: + lock_warning_threshold: The lock warning threshold. + values: The validated attributes. + """ + if lock_warning_threshold >= (lock_expiration := values["lock_expiration"]): + raise InvalidLockWarningThresholdError( + f"The lock warning threshold({lock_warning_threshold}) must be less than the lock expiration time({lock_expiration})." + ) + return lock_warning_threshold + @staticmethod def _lock_key(token: str) -> bytes: """Get the redis key for a token's lock. diff --git a/reflex/utils/exceptions.py b/reflex/utils/exceptions.py index 8ec2eeb73..ae5ec0168 100644 --- a/reflex/utils/exceptions.py +++ b/reflex/utils/exceptions.py @@ -185,5 +185,5 @@ def raise_system_package_missing_error(package: str) -> NoReturn: ) -class InvalidLockWarningThresholdError(ReflexError, ValueError): +class InvalidLockWarningThresholdError(ReflexError): """Raised when an invalid lock warning threshold is provided.""" diff --git a/tests/units/test_state.py b/tests/units/test_state.py index 68bafe4df..bfa9b80bc 100644 --- a/tests/units/test_state.py +++ b/tests/units/test_state.py @@ -69,8 +69,8 @@ from .states import GenState CI = bool(os.environ.get("CI", False)) LOCK_EXPIRATION = 2000 if CI else 300 -LOCK_WARNING_THRESHOLD = 1000 if CI else 200 -LOCK_WARN_SLEEP = 1.5 if CI else 0.25 +LOCK_WARNING_THRESHOLD = 1000 if CI else 100 +LOCK_WARN_SLEEP = 1.5 if CI else 0.15 LOCK_EXPIRE_SLEEP = 2.5 if CI else 0.4 @@ -1790,6 +1790,7 @@ async def test_state_manager_lock_expire( substate_token_redis: A token + substate name for looking up in state manager. """ state_manager_redis.lock_expiration = LOCK_EXPIRATION + state_manager_redis.lock_warning_threshold = LOCK_WARNING_THRESHOLD async with state_manager_redis.modify_state(substate_token_redis): await asyncio.sleep(0.01) @@ -1814,6 +1815,7 @@ async def test_state_manager_lock_expire_contend( unexp_num1 = 666 state_manager_redis.lock_expiration = LOCK_EXPIRATION + state_manager_redis.lock_warning_threshold = LOCK_WARNING_THRESHOLD order = [] @@ -3391,6 +3393,7 @@ config = rx.Config( with pytest.raises(InvalidLockWarningThresholdError): StateManager.create(state=State) + del sys.modules[constants.Config.MODULE] class MixinState(State, mixin=True):