properly ignore non-lock events from redis pubsub, keep timeout
This commit is contained in:
parent
dec1ab108c
commit
368a2ce55f
@ -11,6 +11,7 @@ import inspect
|
|||||||
import json
|
import json
|
||||||
import pickle
|
import pickle
|
||||||
import sys
|
import sys
|
||||||
|
import time
|
||||||
import typing
|
import typing
|
||||||
import uuid
|
import uuid
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
@ -39,6 +40,7 @@ from typing import (
|
|||||||
get_type_hints,
|
get_type_hints,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
from redis.asyncio.client import PubSub
|
||||||
from sqlalchemy.orm import DeclarativeBase
|
from sqlalchemy.orm import DeclarativeBase
|
||||||
from typing_extensions import Self
|
from typing_extensions import Self
|
||||||
|
|
||||||
@ -3182,6 +3184,14 @@ class StateManagerRedis(StateManager):
|
|||||||
"e" # For evicted events (i.e. maxmemory exceeded)
|
"e" # For evicted events (i.e. maxmemory exceeded)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# These events indicate that a lock is no longer held
|
||||||
|
_redis_keyspace_lock_release_events: Set[bytes] = {
|
||||||
|
b"del",
|
||||||
|
b"expire",
|
||||||
|
b"expired",
|
||||||
|
b"evicted",
|
||||||
|
}
|
||||||
|
|
||||||
async def _get_parent_state(
|
async def _get_parent_state(
|
||||||
self, token: str, state: BaseState | None = None
|
self, token: str, state: BaseState | None = None
|
||||||
) -> BaseState | None:
|
) -> BaseState | None:
|
||||||
@ -3433,6 +3443,35 @@ class StateManagerRedis(StateManager):
|
|||||||
nx=True, # only set if it doesn't exist
|
nx=True, # only set if it doesn't exist
|
||||||
)
|
)
|
||||||
|
|
||||||
|
async def _get_pubsub_message(
|
||||||
|
self, pubsub: PubSub, timeout: float | None = None
|
||||||
|
) -> None:
|
||||||
|
"""Get lock release events from the pubsub.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
pubsub: The pubsub to get a message from.
|
||||||
|
timeout: Remaining time to wait for a message.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The message.
|
||||||
|
"""
|
||||||
|
if timeout is None:
|
||||||
|
timeout = self.lock_expiration / 1000.0
|
||||||
|
|
||||||
|
started = time.time()
|
||||||
|
message = await pubsub.get_message(
|
||||||
|
ignore_subscribe_messages=True,
|
||||||
|
timeout=timeout,
|
||||||
|
)
|
||||||
|
if (
|
||||||
|
message is None
|
||||||
|
or message["data"] not in self._redis_keyspace_lock_release_events
|
||||||
|
):
|
||||||
|
remaining = timeout - (time.time() - started)
|
||||||
|
if remaining <= 0:
|
||||||
|
return
|
||||||
|
await self._get_pubsub_message(pubsub, timeout=remaining)
|
||||||
|
|
||||||
async def _wait_lock(self, lock_key: bytes, lock_id: bytes) -> None:
|
async def _wait_lock(self, lock_key: bytes, lock_id: bytes) -> None:
|
||||||
"""Wait for a redis lock to be released via pubsub.
|
"""Wait for a redis lock to be released via pubsub.
|
||||||
|
|
||||||
@ -3464,10 +3503,7 @@ class StateManagerRedis(StateManager):
|
|||||||
if await self._try_get_lock(lock_key, lock_id):
|
if await self._try_get_lock(lock_key, lock_id):
|
||||||
return
|
return
|
||||||
# wait for lock events
|
# wait for lock events
|
||||||
_ = await pubsub.get_message(
|
await self._get_pubsub_message(pubsub)
|
||||||
ignore_subscribe_messages=True,
|
|
||||||
timeout=self.lock_expiration / 1000.0,
|
|
||||||
)
|
|
||||||
|
|
||||||
@contextlib.asynccontextmanager
|
@contextlib.asynccontextmanager
|
||||||
async def _lock(self, token: str):
|
async def _lock(self, token: str):
|
||||||
|
Loading…
Reference in New Issue
Block a user