fix tests fr, and use pydantic validators

This commit is contained in:
Elijah 2024-12-12 10:58:33 +00:00
parent 1f2eafb4e7
commit 732fdb8366
3 changed files with 31 additions and 19 deletions

View File

@ -71,6 +71,11 @@ try:
except ModuleNotFoundError:
BaseModelV1 = BaseModelV2
try:
from pydantic.v1 import validator
except ModuleNotFoundError:
from pydantic import validator
import wrapt
from redis.asyncio import Redis
from redis.exceptions import ResponseError
@ -2835,6 +2840,7 @@ class StateManager(Base, ABC):
redis=redis,
token_expiration=config.redis_token_expiration,
lock_expiration=config.redis_lock_expiration,
lock_warning_threshold=config.redis_lock_warning_threshold,
)
raise InvalidStateManagerMode(
f"Expected one of: DISK, MEMORY, REDIS, got {config.state_manager_mode}"
@ -3210,22 +3216,7 @@ def _default_lock_warning_threshold() -> int:
Returns:
The default lock warning threshold.
"""
lock_warning_threshold = get_config().redis_lock_warning_threshold
_validate_lock_warning_threshold(lock_warning_threshold, _default_lock_expiration())
return lock_warning_threshold
def _validate_lock_warning_threshold(lock_warning_threshold: int, lock_expiration: int):
"""Validate the lock warning threshold.
Args:
lock_warning_threshold: The lock warning threshold.
lock_expiration: The lock expiration time.
"""
if lock_warning_threshold >= lock_expiration:
raise InvalidLockWarningThresholdError(
f"The lock warning threshold({lock_warning_threshold}) must be less than the lock expiration time({lock_expiration})."
)
return get_config().redis_lock_warning_threshold
class StateManagerRedis(StateManager):
@ -3436,6 +3427,9 @@ class StateManagerRedis(StateManager):
time_taken = self.lock_expiration / 1000 - (
await self.redis.ttl(self._lock_key(token))
)
print(
f"Time taken to set state: {time_taken} - {self.lock_warning_threshold}"
)
if time_taken > self.lock_warning_threshold / 1000:
console.warn(
f"Lock for token {token} was held too long {time_taken=}s, "
@ -3492,6 +3486,21 @@ class StateManagerRedis(StateManager):
yield state
await self.set_state(token, state, lock_id)
@validator("lock_warning_threshold")
@classmethod
def validate_lock_warning_threshold(cls, lock_warning_threshold: int, values):
"""Validate the lock warning threshold.
Args:
lock_warning_threshold: The lock warning threshold.
values: The validated attributes.
"""
if lock_warning_threshold >= (lock_expiration := values["lock_expiration"]):
raise InvalidLockWarningThresholdError(
f"The lock warning threshold({lock_warning_threshold}) must be less than the lock expiration time({lock_expiration})."
)
return lock_warning_threshold
@staticmethod
def _lock_key(token: str) -> bytes:
"""Get the redis key for a token's lock.

View File

@ -185,5 +185,5 @@ def raise_system_package_missing_error(package: str) -> NoReturn:
)
class InvalidLockWarningThresholdError(ReflexError, ValueError):
class InvalidLockWarningThresholdError(ReflexError):
"""Raised when an invalid lock warning threshold is provided."""

View File

@ -69,8 +69,8 @@ from .states import GenState
CI = bool(os.environ.get("CI", False))
LOCK_EXPIRATION = 2000 if CI else 300
LOCK_WARNING_THRESHOLD = 1000 if CI else 200
LOCK_WARN_SLEEP = 1.5 if CI else 0.25
LOCK_WARNING_THRESHOLD = 1000 if CI else 100
LOCK_WARN_SLEEP = 1.5 if CI else 0.15
LOCK_EXPIRE_SLEEP = 2.5 if CI else 0.4
@ -1790,6 +1790,7 @@ async def test_state_manager_lock_expire(
substate_token_redis: A token + substate name for looking up in state manager.
"""
state_manager_redis.lock_expiration = LOCK_EXPIRATION
state_manager_redis.lock_warning_threshold = LOCK_WARNING_THRESHOLD
async with state_manager_redis.modify_state(substate_token_redis):
await asyncio.sleep(0.01)
@ -1814,6 +1815,7 @@ async def test_state_manager_lock_expire_contend(
unexp_num1 = 666
state_manager_redis.lock_expiration = LOCK_EXPIRATION
state_manager_redis.lock_warning_threshold = LOCK_WARNING_THRESHOLD
order = []
@ -3391,6 +3393,7 @@ config = rx.Config(
with pytest.raises(InvalidLockWarningThresholdError):
StateManager.create(state=State)
del sys.modules[constants.Config.MODULE]
class MixinState(State, mixin=True):