fix tests fr, and use pydantic validators
This commit is contained in:
parent
1f2eafb4e7
commit
732fdb8366
@ -71,6 +71,11 @@ try:
|
||||
except ModuleNotFoundError:
|
||||
BaseModelV1 = BaseModelV2
|
||||
|
||||
try:
|
||||
from pydantic.v1 import validator
|
||||
except ModuleNotFoundError:
|
||||
from pydantic import validator
|
||||
|
||||
import wrapt
|
||||
from redis.asyncio import Redis
|
||||
from redis.exceptions import ResponseError
|
||||
@ -2835,6 +2840,7 @@ class StateManager(Base, ABC):
|
||||
redis=redis,
|
||||
token_expiration=config.redis_token_expiration,
|
||||
lock_expiration=config.redis_lock_expiration,
|
||||
lock_warning_threshold=config.redis_lock_warning_threshold,
|
||||
)
|
||||
raise InvalidStateManagerMode(
|
||||
f"Expected one of: DISK, MEMORY, REDIS, got {config.state_manager_mode}"
|
||||
@ -3210,22 +3216,7 @@ def _default_lock_warning_threshold() -> int:
|
||||
Returns:
|
||||
The default lock warning threshold.
|
||||
"""
|
||||
lock_warning_threshold = get_config().redis_lock_warning_threshold
|
||||
_validate_lock_warning_threshold(lock_warning_threshold, _default_lock_expiration())
|
||||
return lock_warning_threshold
|
||||
|
||||
|
||||
def _validate_lock_warning_threshold(lock_warning_threshold: int, lock_expiration: int):
|
||||
"""Validate the lock warning threshold.
|
||||
|
||||
Args:
|
||||
lock_warning_threshold: The lock warning threshold.
|
||||
lock_expiration: The lock expiration time.
|
||||
"""
|
||||
if lock_warning_threshold >= lock_expiration:
|
||||
raise InvalidLockWarningThresholdError(
|
||||
f"The lock warning threshold({lock_warning_threshold}) must be less than the lock expiration time({lock_expiration})."
|
||||
)
|
||||
return get_config().redis_lock_warning_threshold
|
||||
|
||||
|
||||
class StateManagerRedis(StateManager):
|
||||
@ -3436,6 +3427,9 @@ class StateManagerRedis(StateManager):
|
||||
time_taken = self.lock_expiration / 1000 - (
|
||||
await self.redis.ttl(self._lock_key(token))
|
||||
)
|
||||
print(
|
||||
f"Time taken to set state: {time_taken} - {self.lock_warning_threshold}"
|
||||
)
|
||||
if time_taken > self.lock_warning_threshold / 1000:
|
||||
console.warn(
|
||||
f"Lock for token {token} was held too long {time_taken=}s, "
|
||||
@ -3492,6 +3486,21 @@ class StateManagerRedis(StateManager):
|
||||
yield state
|
||||
await self.set_state(token, state, lock_id)
|
||||
|
||||
@validator("lock_warning_threshold")
|
||||
@classmethod
|
||||
def validate_lock_warning_threshold(cls, lock_warning_threshold: int, values):
|
||||
"""Validate the lock warning threshold.
|
||||
|
||||
Args:
|
||||
lock_warning_threshold: The lock warning threshold.
|
||||
values: The validated attributes.
|
||||
"""
|
||||
if lock_warning_threshold >= (lock_expiration := values["lock_expiration"]):
|
||||
raise InvalidLockWarningThresholdError(
|
||||
f"The lock warning threshold({lock_warning_threshold}) must be less than the lock expiration time({lock_expiration})."
|
||||
)
|
||||
return lock_warning_threshold
|
||||
|
||||
@staticmethod
|
||||
def _lock_key(token: str) -> bytes:
|
||||
"""Get the redis key for a token's lock.
|
||||
|
@ -185,5 +185,5 @@ def raise_system_package_missing_error(package: str) -> NoReturn:
|
||||
)
|
||||
|
||||
|
||||
class InvalidLockWarningThresholdError(ReflexError, ValueError):
|
||||
class InvalidLockWarningThresholdError(ReflexError):
|
||||
"""Raised when an invalid lock warning threshold is provided."""
|
||||
|
@ -69,8 +69,8 @@ from .states import GenState
|
||||
|
||||
CI = bool(os.environ.get("CI", False))
|
||||
LOCK_EXPIRATION = 2000 if CI else 300
|
||||
LOCK_WARNING_THRESHOLD = 1000 if CI else 200
|
||||
LOCK_WARN_SLEEP = 1.5 if CI else 0.25
|
||||
LOCK_WARNING_THRESHOLD = 1000 if CI else 100
|
||||
LOCK_WARN_SLEEP = 1.5 if CI else 0.15
|
||||
LOCK_EXPIRE_SLEEP = 2.5 if CI else 0.4
|
||||
|
||||
|
||||
@ -1790,6 +1790,7 @@ async def test_state_manager_lock_expire(
|
||||
substate_token_redis: A token + substate name for looking up in state manager.
|
||||
"""
|
||||
state_manager_redis.lock_expiration = LOCK_EXPIRATION
|
||||
state_manager_redis.lock_warning_threshold = LOCK_WARNING_THRESHOLD
|
||||
|
||||
async with state_manager_redis.modify_state(substate_token_redis):
|
||||
await asyncio.sleep(0.01)
|
||||
@ -1814,6 +1815,7 @@ async def test_state_manager_lock_expire_contend(
|
||||
unexp_num1 = 666
|
||||
|
||||
state_manager_redis.lock_expiration = LOCK_EXPIRATION
|
||||
state_manager_redis.lock_warning_threshold = LOCK_WARNING_THRESHOLD
|
||||
|
||||
order = []
|
||||
|
||||
@ -3391,6 +3393,7 @@ config = rx.Config(
|
||||
|
||||
with pytest.raises(InvalidLockWarningThresholdError):
|
||||
StateManager.create(state=State)
|
||||
del sys.modules[constants.Config.MODULE]
|
||||
|
||||
|
||||
class MixinState(State, mixin=True):
|
||||
|
Loading…
Reference in New Issue
Block a user