properly ignore non-lock events from redis pubsub, keep timeout

This commit is contained in:
Benedikt Bartscher 2024-12-01 17:16:04 +01:00
parent dec1ab108c
commit 368a2ce55f
No known key found for this signature in database

View File

@ -11,6 +11,7 @@ import inspect
import json import json
import pickle import pickle
import sys import sys
import time
import typing import typing
import uuid import uuid
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
@ -39,6 +40,7 @@ from typing import (
get_type_hints, get_type_hints,
) )
from redis.asyncio.client import PubSub
from sqlalchemy.orm import DeclarativeBase from sqlalchemy.orm import DeclarativeBase
from typing_extensions import Self from typing_extensions import Self
@ -3182,6 +3184,14 @@ class StateManagerRedis(StateManager):
"e" # For evicted events (i.e. maxmemory exceeded) "e" # For evicted events (i.e. maxmemory exceeded)
) )
# These events indicate that a lock is no longer held
_redis_keyspace_lock_release_events: Set[bytes] = {
b"del",
b"expire",
b"expired",
b"evicted",
}
async def _get_parent_state( async def _get_parent_state(
self, token: str, state: BaseState | None = None self, token: str, state: BaseState | None = None
) -> BaseState | None: ) -> BaseState | None:
@ -3433,6 +3443,35 @@ class StateManagerRedis(StateManager):
nx=True, # only set if it doesn't exist nx=True, # only set if it doesn't exist
) )
async def _get_pubsub_message(
self, pubsub: PubSub, timeout: float | None = None
) -> None:
"""Get lock release events from the pubsub.
Args:
pubsub: The pubsub to get a message from.
timeout: Remaining time to wait for a message.
Returns:
The message.
"""
if timeout is None:
timeout = self.lock_expiration / 1000.0
started = time.time()
message = await pubsub.get_message(
ignore_subscribe_messages=True,
timeout=timeout,
)
if (
message is None
or message["data"] not in self._redis_keyspace_lock_release_events
):
remaining = timeout - (time.time() - started)
if remaining <= 0:
return
await self._get_pubsub_message(pubsub, timeout=remaining)
async def _wait_lock(self, lock_key: bytes, lock_id: bytes) -> None: async def _wait_lock(self, lock_key: bytes, lock_id: bytes) -> None:
"""Wait for a redis lock to be released via pubsub. """Wait for a redis lock to be released via pubsub.
@ -3464,10 +3503,7 @@ class StateManagerRedis(StateManager):
if await self._try_get_lock(lock_key, lock_id): if await self._try_get_lock(lock_key, lock_id):
return return
# wait for lock events # wait for lock events
_ = await pubsub.get_message( await self._get_pubsub_message(pubsub)
ignore_subscribe_messages=True,
timeout=self.lock_expiration / 1000.0,
)
@contextlib.asynccontextmanager @contextlib.asynccontextmanager
async def _lock(self, token: str): async def _lock(self, token: str):