improve performance

This commit is contained in:
Benedikt Bartscher 2024-11-21 02:51:10 +01:00
parent 4b2cec3784
commit b28f4ac1c6
No known key found for this signature in database
2 changed files with 30 additions and 14 deletions

View File

@ -3142,7 +3142,7 @@ TOKEN_TYPE = TypeVar("TOKEN_TYPE", str, bytes)
@functools.lru_cache @functools.lru_cache
def prefix_redis_token(token: TOKEN_TYPE) -> TOKEN_TYPE: def prefix_redis_token_str(token: str) -> str:
"""Prefix the token with the redis prefix. """Prefix the token with the redis prefix.
Args: Args:
@ -3154,11 +3154,25 @@ def prefix_redis_token(token: TOKEN_TYPE) -> TOKEN_TYPE:
prefix = EnvironmentVariables.REFLEX_REDIS_PREFIX.get() prefix = EnvironmentVariables.REFLEX_REDIS_PREFIX.get()
if not prefix: if not prefix:
return token return token
if isinstance(token, bytes):
return prefix.encode() + token
return f"{prefix}{token}" return f"{prefix}{token}"
@functools.lru_cache
def prefix_redis_token_bytes(token: bytes) -> bytes:
"""Prefix the token with the redis prefix.
Args:
token: The token to prefix.
Returns:
The prefixed token.
"""
prefix = EnvironmentVariables.REFLEX_REDIS_PREFIX.get()
if not prefix:
return token
return prefix.encode() + token
class StateManagerRedis(StateManager): class StateManagerRedis(StateManager):
"""A state manager that stores states in redis.""" """A state manager that stores states in redis."""
@ -3296,7 +3310,7 @@ class StateManagerRedis(StateManager):
state = None state = None
# Fetch the serialized substate from redis. # Fetch the serialized substate from redis.
redis_state = await self.redis.get(prefix_redis_token(token)) redis_state = await self.redis.get(prefix_redis_token_str(token))
if redis_state is not None: if redis_state is not None:
# Deserialize the substate. # Deserialize the substate.
@ -3351,7 +3365,7 @@ class StateManagerRedis(StateManager):
# Check that we're holding the lock. # Check that we're holding the lock.
if ( if (
lock_id is not None lock_id is not None
and await self.redis.get(prefix_redis_token(self._lock_key(token))) and await self.redis.get(prefix_redis_token_str(self._lock_key(token)))
!= lock_id != lock_id
): ):
raise LockExpiredError( raise LockExpiredError(
@ -3383,7 +3397,7 @@ class StateManagerRedis(StateManager):
pickle_state = state._serialize() pickle_state = state._serialize()
if pickle_state: if pickle_state:
await self.redis.set( await self.redis.set(
prefix_redis_token(_substate_key(client_token, state)), prefix_redis_token_str(_substate_key(client_token, state)),
pickle_state, pickle_state,
ex=self.token_expiration, ex=self.token_expiration,
) )
@ -3433,7 +3447,7 @@ class StateManagerRedis(StateManager):
True if the lock was obtained. True if the lock was obtained.
""" """
return await self.redis.set( return await self.redis.set(
prefix_redis_token(lock_key), prefix_redis_token_bytes(lock_key),
lock_id, lock_id,
px=self.lock_expiration, px=self.lock_expiration,
nx=True, # only set if it doesn't exist nx=True, # only set if it doesn't exist
@ -3468,7 +3482,7 @@ class StateManagerRedis(StateManager):
while not state_is_locked: while not state_is_locked:
# wait for the lock to be released # wait for the lock to be released
while True: while True:
if not await self.redis.exists(prefix_redis_token(lock_key)): if not await self.redis.exists(prefix_redis_token_bytes(lock_key)):
break # key was removed, try to get the lock again break # key was removed, try to get the lock again
message = await pubsub.get_message( message = await pubsub.get_message(
ignore_subscribe_messages=True, ignore_subscribe_messages=True,
@ -3509,7 +3523,7 @@ class StateManagerRedis(StateManager):
finally: finally:
if state_is_locked: if state_is_locked:
# only delete our lock # only delete our lock
await self.redis.delete(prefix_redis_token(lock_key)) await self.redis.delete(prefix_redis_token_str(lock_key))
async def close(self): async def close(self):
"""Explicitly close the redis connection and connection_pool. """Explicitly close the redis connection and connection_pool.

View File

@ -40,7 +40,7 @@ from reflex.state import (
StateProxy, StateProxy,
StateUpdate, StateUpdate,
_substate_key, _substate_key,
prefix_redis_token, prefix_redis_token_str,
) )
from reflex.testing import chdir from reflex.testing import chdir
from reflex.utils import format, prerequisites, types from reflex.utils import format, prerequisites, types
@ -1672,7 +1672,9 @@ async def test_state_manager_modify_state(
""" """
async with state_manager.modify_state(substate_token) as state: async with state_manager.modify_state(substate_token) as state:
if isinstance(state_manager, StateManagerRedis): if isinstance(state_manager, StateManagerRedis):
assert await state_manager.redis.get(prefix_redis_token(f"{token}_lock")) assert await state_manager.redis.get(
prefix_redis_token_str(f"{token}_lock")
)
elif isinstance(state_manager, (StateManagerMemory, StateManagerDisk)): elif isinstance(state_manager, (StateManagerMemory, StateManagerDisk)):
assert token in state_manager._states_locks assert token in state_manager._states_locks
assert state_manager._states_locks[token].locked() assert state_manager._states_locks[token].locked()
@ -1683,7 +1685,7 @@ async def test_state_manager_modify_state(
# lock should be dropped after exiting the context # lock should be dropped after exiting the context
if isinstance(state_manager, StateManagerRedis): if isinstance(state_manager, StateManagerRedis):
assert ( assert (
await state_manager.redis.get(prefix_redis_token(f"{token}_lock")) await state_manager.redis.get(prefix_redis_token_str(f"{token}_lock"))
) is None ) is None
elif isinstance(state_manager, (StateManagerMemory, StateManagerDisk)): elif isinstance(state_manager, (StateManagerMemory, StateManagerDisk)):
assert not state_manager._states_locks[token].locked() assert not state_manager._states_locks[token].locked()
@ -1727,7 +1729,7 @@ async def test_state_manager_contend(
if isinstance(state_manager, StateManagerRedis): if isinstance(state_manager, StateManagerRedis):
assert ( assert (
await state_manager.redis.get(prefix_redis_token(f"{token}_lock")) await state_manager.redis.get(prefix_redis_token_str(f"{token}_lock"))
) is None ) is None
elif isinstance(state_manager, (StateManagerMemory, StateManagerDisk)): elif isinstance(state_manager, (StateManagerMemory, StateManagerDisk)):
assert token in state_manager._states_locks assert token in state_manager._states_locks
@ -1859,7 +1861,7 @@ async def test_state_manager_redis_prefix(
async with state_manager_redis.modify_state(substate_token_redis) as state: async with state_manager_redis.modify_state(substate_token_redis) as state:
state.num1 = 42 state.num1 = 42
prefixed_token = prefix_redis_token(substate_token_redis) prefixed_token = prefix_redis_token_str(substate_token_redis)
assert prefixed_token == f"{redis_prefix}{substate_token_redis}" assert prefixed_token == f"{redis_prefix}{substate_token_redis}"
assert await state_manager_redis.redis.get(prefixed_token) assert await state_manager_redis.redis.get(prefixed_token)