fix tests fr, and use pydantic validators

This commit is contained in:
Elijah 2024-12-12 10:58:33 +00:00
parent 1f2eafb4e7
commit 732fdb8366
3 changed files with 31 additions and 19 deletions

View File

@ -71,6 +71,11 @@ try:
except ModuleNotFoundError: except ModuleNotFoundError:
BaseModelV1 = BaseModelV2 BaseModelV1 = BaseModelV2
try:
from pydantic.v1 import validator
except ModuleNotFoundError:
from pydantic import validator
import wrapt import wrapt
from redis.asyncio import Redis from redis.asyncio import Redis
from redis.exceptions import ResponseError from redis.exceptions import ResponseError
@ -2835,6 +2840,7 @@ class StateManager(Base, ABC):
redis=redis, redis=redis,
token_expiration=config.redis_token_expiration, token_expiration=config.redis_token_expiration,
lock_expiration=config.redis_lock_expiration, lock_expiration=config.redis_lock_expiration,
lock_warning_threshold=config.redis_lock_warning_threshold,
) )
raise InvalidStateManagerMode( raise InvalidStateManagerMode(
f"Expected one of: DISK, MEMORY, REDIS, got {config.state_manager_mode}" f"Expected one of: DISK, MEMORY, REDIS, got {config.state_manager_mode}"
@ -3210,22 +3216,7 @@ def _default_lock_warning_threshold() -> int:
Returns: Returns:
The default lock warning threshold. The default lock warning threshold.
""" """
lock_warning_threshold = get_config().redis_lock_warning_threshold return get_config().redis_lock_warning_threshold
_validate_lock_warning_threshold(lock_warning_threshold, _default_lock_expiration())
return lock_warning_threshold
def _validate_lock_warning_threshold(lock_warning_threshold: int, lock_expiration: int):
"""Validate the lock warning threshold.
Args:
lock_warning_threshold: The lock warning threshold.
lock_expiration: The lock expiration time.
"""
if lock_warning_threshold >= lock_expiration:
raise InvalidLockWarningThresholdError(
f"The lock warning threshold({lock_warning_threshold}) must be less than the lock expiration time({lock_expiration})."
)
class StateManagerRedis(StateManager): class StateManagerRedis(StateManager):
@ -3436,6 +3427,9 @@ class StateManagerRedis(StateManager):
time_taken = self.lock_expiration / 1000 - ( time_taken = self.lock_expiration / 1000 - (
await self.redis.ttl(self._lock_key(token)) await self.redis.ttl(self._lock_key(token))
) )
print(
f"Time taken to set state: {time_taken} - {self.lock_warning_threshold}"
)
if time_taken > self.lock_warning_threshold / 1000: if time_taken > self.lock_warning_threshold / 1000:
console.warn( console.warn(
f"Lock for token {token} was held too long {time_taken=}s, " f"Lock for token {token} was held too long {time_taken=}s, "
@ -3492,6 +3486,21 @@ class StateManagerRedis(StateManager):
yield state yield state
await self.set_state(token, state, lock_id) await self.set_state(token, state, lock_id)
@validator("lock_warning_threshold")
@classmethod
def validate_lock_warning_threshold(cls, lock_warning_threshold: int, values):
"""Validate the lock warning threshold.
Args:
lock_warning_threshold: The lock warning threshold.
values: The validated attributes.
"""
if lock_warning_threshold >= (lock_expiration := values["lock_expiration"]):
raise InvalidLockWarningThresholdError(
f"The lock warning threshold({lock_warning_threshold}) must be less than the lock expiration time({lock_expiration})."
)
return lock_warning_threshold
@staticmethod @staticmethod
def _lock_key(token: str) -> bytes: def _lock_key(token: str) -> bytes:
"""Get the redis key for a token's lock. """Get the redis key for a token's lock.

View File

@ -185,5 +185,5 @@ def raise_system_package_missing_error(package: str) -> NoReturn:
) )
class InvalidLockWarningThresholdError(ReflexError, ValueError): class InvalidLockWarningThresholdError(ReflexError):
"""Raised when an invalid lock warning threshold is provided.""" """Raised when an invalid lock warning threshold is provided."""

View File

@ -69,8 +69,8 @@ from .states import GenState
CI = bool(os.environ.get("CI", False)) CI = bool(os.environ.get("CI", False))
LOCK_EXPIRATION = 2000 if CI else 300 LOCK_EXPIRATION = 2000 if CI else 300
LOCK_WARNING_THRESHOLD = 1000 if CI else 200 LOCK_WARNING_THRESHOLD = 1000 if CI else 100
LOCK_WARN_SLEEP = 1.5 if CI else 0.25 LOCK_WARN_SLEEP = 1.5 if CI else 0.15
LOCK_EXPIRE_SLEEP = 2.5 if CI else 0.4 LOCK_EXPIRE_SLEEP = 2.5 if CI else 0.4
@ -1790,6 +1790,7 @@ async def test_state_manager_lock_expire(
substate_token_redis: A token + substate name for looking up in state manager. substate_token_redis: A token + substate name for looking up in state manager.
""" """
state_manager_redis.lock_expiration = LOCK_EXPIRATION state_manager_redis.lock_expiration = LOCK_EXPIRATION
state_manager_redis.lock_warning_threshold = LOCK_WARNING_THRESHOLD
async with state_manager_redis.modify_state(substate_token_redis): async with state_manager_redis.modify_state(substate_token_redis):
await asyncio.sleep(0.01) await asyncio.sleep(0.01)
@ -1814,6 +1815,7 @@ async def test_state_manager_lock_expire_contend(
unexp_num1 = 666 unexp_num1 = 666
state_manager_redis.lock_expiration = LOCK_EXPIRATION state_manager_redis.lock_expiration = LOCK_EXPIRATION
state_manager_redis.lock_warning_threshold = LOCK_WARNING_THRESHOLD
order = [] order = []
@ -3391,6 +3393,7 @@ config = rx.Config(
with pytest.raises(InvalidLockWarningThresholdError): with pytest.raises(InvalidLockWarningThresholdError):
StateManager.create(state=State) StateManager.create(state=State)
del sys.modules[constants.Config.MODULE]
class MixinState(State, mixin=True): class MixinState(State, mixin=True):