fix tests fr, and use pydantic validators
This commit is contained in:
parent
1f2eafb4e7
commit
732fdb8366
@ -71,6 +71,11 @@ try:
|
|||||||
except ModuleNotFoundError:
|
except ModuleNotFoundError:
|
||||||
BaseModelV1 = BaseModelV2
|
BaseModelV1 = BaseModelV2
|
||||||
|
|
||||||
|
try:
|
||||||
|
from pydantic.v1 import validator
|
||||||
|
except ModuleNotFoundError:
|
||||||
|
from pydantic import validator
|
||||||
|
|
||||||
import wrapt
|
import wrapt
|
||||||
from redis.asyncio import Redis
|
from redis.asyncio import Redis
|
||||||
from redis.exceptions import ResponseError
|
from redis.exceptions import ResponseError
|
||||||
@ -2835,6 +2840,7 @@ class StateManager(Base, ABC):
|
|||||||
redis=redis,
|
redis=redis,
|
||||||
token_expiration=config.redis_token_expiration,
|
token_expiration=config.redis_token_expiration,
|
||||||
lock_expiration=config.redis_lock_expiration,
|
lock_expiration=config.redis_lock_expiration,
|
||||||
|
lock_warning_threshold=config.redis_lock_warning_threshold,
|
||||||
)
|
)
|
||||||
raise InvalidStateManagerMode(
|
raise InvalidStateManagerMode(
|
||||||
f"Expected one of: DISK, MEMORY, REDIS, got {config.state_manager_mode}"
|
f"Expected one of: DISK, MEMORY, REDIS, got {config.state_manager_mode}"
|
||||||
@ -3210,22 +3216,7 @@ def _default_lock_warning_threshold() -> int:
|
|||||||
Returns:
|
Returns:
|
||||||
The default lock warning threshold.
|
The default lock warning threshold.
|
||||||
"""
|
"""
|
||||||
lock_warning_threshold = get_config().redis_lock_warning_threshold
|
return get_config().redis_lock_warning_threshold
|
||||||
_validate_lock_warning_threshold(lock_warning_threshold, _default_lock_expiration())
|
|
||||||
return lock_warning_threshold
|
|
||||||
|
|
||||||
|
|
||||||
def _validate_lock_warning_threshold(lock_warning_threshold: int, lock_expiration: int):
|
|
||||||
"""Validate the lock warning threshold.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
lock_warning_threshold: The lock warning threshold.
|
|
||||||
lock_expiration: The lock expiration time.
|
|
||||||
"""
|
|
||||||
if lock_warning_threshold >= lock_expiration:
|
|
||||||
raise InvalidLockWarningThresholdError(
|
|
||||||
f"The lock warning threshold({lock_warning_threshold}) must be less than the lock expiration time({lock_expiration})."
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class StateManagerRedis(StateManager):
|
class StateManagerRedis(StateManager):
|
||||||
@ -3436,6 +3427,9 @@ class StateManagerRedis(StateManager):
|
|||||||
time_taken = self.lock_expiration / 1000 - (
|
time_taken = self.lock_expiration / 1000 - (
|
||||||
await self.redis.ttl(self._lock_key(token))
|
await self.redis.ttl(self._lock_key(token))
|
||||||
)
|
)
|
||||||
|
print(
|
||||||
|
f"Time taken to set state: {time_taken} - {self.lock_warning_threshold}"
|
||||||
|
)
|
||||||
if time_taken > self.lock_warning_threshold / 1000:
|
if time_taken > self.lock_warning_threshold / 1000:
|
||||||
console.warn(
|
console.warn(
|
||||||
f"Lock for token {token} was held too long {time_taken=}s, "
|
f"Lock for token {token} was held too long {time_taken=}s, "
|
||||||
@ -3492,6 +3486,21 @@ class StateManagerRedis(StateManager):
|
|||||||
yield state
|
yield state
|
||||||
await self.set_state(token, state, lock_id)
|
await self.set_state(token, state, lock_id)
|
||||||
|
|
||||||
|
@validator("lock_warning_threshold")
|
||||||
|
@classmethod
|
||||||
|
def validate_lock_warning_threshold(cls, lock_warning_threshold: int, values):
|
||||||
|
"""Validate the lock warning threshold.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
lock_warning_threshold: The lock warning threshold.
|
||||||
|
values: The validated attributes.
|
||||||
|
"""
|
||||||
|
if lock_warning_threshold >= (lock_expiration := values["lock_expiration"]):
|
||||||
|
raise InvalidLockWarningThresholdError(
|
||||||
|
f"The lock warning threshold({lock_warning_threshold}) must be less than the lock expiration time({lock_expiration})."
|
||||||
|
)
|
||||||
|
return lock_warning_threshold
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _lock_key(token: str) -> bytes:
|
def _lock_key(token: str) -> bytes:
|
||||||
"""Get the redis key for a token's lock.
|
"""Get the redis key for a token's lock.
|
||||||
|
@ -185,5 +185,5 @@ def raise_system_package_missing_error(package: str) -> NoReturn:
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class InvalidLockWarningThresholdError(ReflexError, ValueError):
|
class InvalidLockWarningThresholdError(ReflexError):
|
||||||
"""Raised when an invalid lock warning threshold is provided."""
|
"""Raised when an invalid lock warning threshold is provided."""
|
||||||
|
@ -69,8 +69,8 @@ from .states import GenState
|
|||||||
|
|
||||||
CI = bool(os.environ.get("CI", False))
|
CI = bool(os.environ.get("CI", False))
|
||||||
LOCK_EXPIRATION = 2000 if CI else 300
|
LOCK_EXPIRATION = 2000 if CI else 300
|
||||||
LOCK_WARNING_THRESHOLD = 1000 if CI else 200
|
LOCK_WARNING_THRESHOLD = 1000 if CI else 100
|
||||||
LOCK_WARN_SLEEP = 1.5 if CI else 0.25
|
LOCK_WARN_SLEEP = 1.5 if CI else 0.15
|
||||||
LOCK_EXPIRE_SLEEP = 2.5 if CI else 0.4
|
LOCK_EXPIRE_SLEEP = 2.5 if CI else 0.4
|
||||||
|
|
||||||
|
|
||||||
@ -1790,6 +1790,7 @@ async def test_state_manager_lock_expire(
|
|||||||
substate_token_redis: A token + substate name for looking up in state manager.
|
substate_token_redis: A token + substate name for looking up in state manager.
|
||||||
"""
|
"""
|
||||||
state_manager_redis.lock_expiration = LOCK_EXPIRATION
|
state_manager_redis.lock_expiration = LOCK_EXPIRATION
|
||||||
|
state_manager_redis.lock_warning_threshold = LOCK_WARNING_THRESHOLD
|
||||||
|
|
||||||
async with state_manager_redis.modify_state(substate_token_redis):
|
async with state_manager_redis.modify_state(substate_token_redis):
|
||||||
await asyncio.sleep(0.01)
|
await asyncio.sleep(0.01)
|
||||||
@ -1814,6 +1815,7 @@ async def test_state_manager_lock_expire_contend(
|
|||||||
unexp_num1 = 666
|
unexp_num1 = 666
|
||||||
|
|
||||||
state_manager_redis.lock_expiration = LOCK_EXPIRATION
|
state_manager_redis.lock_expiration = LOCK_EXPIRATION
|
||||||
|
state_manager_redis.lock_warning_threshold = LOCK_WARNING_THRESHOLD
|
||||||
|
|
||||||
order = []
|
order = []
|
||||||
|
|
||||||
@ -3391,6 +3393,7 @@ config = rx.Config(
|
|||||||
|
|
||||||
with pytest.raises(InvalidLockWarningThresholdError):
|
with pytest.raises(InvalidLockWarningThresholdError):
|
||||||
StateManager.create(state=State)
|
StateManager.create(state=State)
|
||||||
|
del sys.modules[constants.Config.MODULE]
|
||||||
|
|
||||||
|
|
||||||
class MixinState(State, mixin=True):
|
class MixinState(State, mixin=True):
|
||||||
|
Loading…
Reference in New Issue
Block a user