Simplify redis prefix for lock key

This commit is contained in:
Masen Furer 2024-11-21 17:14:09 -08:00
parent 181175fc7c
commit 2eedf15322
No known key found for this signature in database
GPG Key ID: 2AE2BD5531FF94F4

View File

@ -3157,22 +3157,6 @@ def prefix_redis_token_str(token: str) -> str:
return f"{prefix}{token}" return f"{prefix}{token}"
@functools.lru_cache
def prefix_redis_token_bytes(token: bytes) -> bytes:
"""Prefix the token with the redis prefix.
Args:
token: The token to prefix.
Returns:
The prefixed token.
"""
prefix = EnvironmentVariables.REFLEX_REDIS_PREFIX.get()
if not prefix:
return token
return prefix.encode() + token
class StateManagerRedis(StateManager): class StateManagerRedis(StateManager):
"""A state manager that stores states in redis.""" """A state manager that stores states in redis."""
@ -3365,8 +3349,7 @@ class StateManagerRedis(StateManager):
# Check that we're holding the lock. # Check that we're holding the lock.
if ( if (
lock_id is not None lock_id is not None
and await self.redis.get(prefix_redis_token_bytes(self._lock_key(token))) and await self.redis.get(self._lock_key(token)) != lock_id
!= lock_id
): ):
raise LockExpiredError( raise LockExpiredError(
f"Lock expired for token {token} while processing. Consider increasing " f"Lock expired for token {token} while processing. Consider increasing "
@ -3434,7 +3417,7 @@ class StateManagerRedis(StateManager):
""" """
# All substates share the same lock domain, so ignore any substate path suffix. # All substates share the same lock domain, so ignore any substate path suffix.
client_token = _split_substate_key(token)[0] client_token = _split_substate_key(token)[0]
return f"{client_token}_lock".encode() return prefix_redis_token_str(f"{client_token}_lock").encode()
async def _try_get_lock(self, lock_key: bytes, lock_id: bytes) -> bool | None: async def _try_get_lock(self, lock_key: bytes, lock_id: bytes) -> bool | None:
"""Try to get a redis lock for a token. """Try to get a redis lock for a token.
@ -3447,7 +3430,7 @@ class StateManagerRedis(StateManager):
True if the lock was obtained. True if the lock was obtained.
""" """
return await self.redis.set( return await self.redis.set(
prefix_redis_token_bytes(lock_key), lock_key,
lock_id, lock_id,
px=self.lock_expiration, px=self.lock_expiration,
nx=True, # only set if it doesn't exist nx=True, # only set if it doesn't exist
@ -3482,7 +3465,7 @@ class StateManagerRedis(StateManager):
while not state_is_locked: while not state_is_locked:
# wait for the lock to be released # wait for the lock to be released
while True: while True:
if not await self.redis.exists(prefix_redis_token_bytes(lock_key)): if not await self.redis.exists(lock_key):
break # key was removed, try to get the lock again break # key was removed, try to get the lock again
message = await pubsub.get_message( message = await pubsub.get_message(
ignore_subscribe_messages=True, ignore_subscribe_messages=True,
@ -3523,7 +3506,7 @@ class StateManagerRedis(StateManager):
finally: finally:
if state_is_locked: if state_is_locked:
# only delete our lock # only delete our lock
await self.redis.delete(prefix_redis_token_str(lock_key)) await self.redis.delete(lock_key)
async def close(self): async def close(self):
"""Explicitly close the redis connection and connection_pool. """Explicitly close the redis connection and connection_pool.