diff --git a/reflex/state.py b/reflex/state.py index 042c99732..41fa3cd40 100644 --- a/reflex/state.py +++ b/reflex/state.py @@ -3142,7 +3142,7 @@ TOKEN_TYPE = TypeVar("TOKEN_TYPE", str, bytes) @functools.lru_cache -def prefix_redis_token(token: TOKEN_TYPE) -> TOKEN_TYPE: +def prefix_redis_token_str(token: str) -> str: """Prefix the token with the redis prefix. Args: @@ -3154,11 +3154,25 @@ def prefix_redis_token(token: TOKEN_TYPE) -> TOKEN_TYPE: prefix = EnvironmentVariables.REFLEX_REDIS_PREFIX.get() if not prefix: return token - if isinstance(token, bytes): - return prefix.encode() + token return f"{prefix}{token}" +@functools.lru_cache +def prefix_redis_token_bytes(token: bytes) -> bytes: + """Prefix the token with the redis prefix. + + Args: + token: The token to prefix. + + Returns: + The prefixed token. + """ + prefix = EnvironmentVariables.REFLEX_REDIS_PREFIX.get() + if not prefix: + return token + return prefix.encode() + token + + class StateManagerRedis(StateManager): """A state manager that stores states in redis.""" @@ -3296,7 +3310,7 @@ class StateManagerRedis(StateManager): state = None # Fetch the serialized substate from redis. - redis_state = await self.redis.get(prefix_redis_token(token)) + redis_state = await self.redis.get(prefix_redis_token_str(token)) if redis_state is not None: # Deserialize the substate. @@ -3351,7 +3365,7 @@ class StateManagerRedis(StateManager): # Check that we're holding the lock. if ( lock_id is not None - and await self.redis.get(prefix_redis_token(self._lock_key(token))) + and await self.redis.get(prefix_redis_token_str(self._lock_key(token))) != lock_id ): raise LockExpiredError( @@ -3383,7 +3397,7 @@ class StateManagerRedis(StateManager): pickle_state = state._serialize() if pickle_state: await self.redis.set( - prefix_redis_token(_substate_key(client_token, state)), + prefix_redis_token_str(_substate_key(client_token, state)), pickle_state, ex=self.token_expiration, ) @@ -3433,7 +3447,7 @@ class StateManagerRedis(StateManager): True if the lock was obtained. """ return await self.redis.set( - prefix_redis_token(lock_key), + prefix_redis_token_bytes(lock_key), lock_id, px=self.lock_expiration, nx=True, # only set if it doesn't exist @@ -3468,7 +3482,7 @@ class StateManagerRedis(StateManager): while not state_is_locked: # wait for the lock to be released while True: - if not await self.redis.exists(prefix_redis_token(lock_key)): + if not await self.redis.exists(prefix_redis_token_bytes(lock_key)): break # key was removed, try to get the lock again message = await pubsub.get_message( ignore_subscribe_messages=True, @@ -3509,7 +3523,7 @@ class StateManagerRedis(StateManager): finally: if state_is_locked: # only delete our lock - await self.redis.delete(prefix_redis_token(lock_key)) + await self.redis.delete(prefix_redis_token_str(lock_key)) async def close(self): """Explicitly close the redis connection and connection_pool. diff --git a/tests/units/test_state.py b/tests/units/test_state.py index 31f5493db..e1f96f4ce 100644 --- a/tests/units/test_state.py +++ b/tests/units/test_state.py @@ -40,7 +40,7 @@ from reflex.state import ( StateProxy, StateUpdate, _substate_key, - prefix_redis_token, + prefix_redis_token_str, ) from reflex.testing import chdir from reflex.utils import format, prerequisites, types @@ -1672,7 +1672,9 @@ async def test_state_manager_modify_state( """ async with state_manager.modify_state(substate_token) as state: if isinstance(state_manager, StateManagerRedis): - assert await state_manager.redis.get(prefix_redis_token(f"{token}_lock")) + assert await state_manager.redis.get( + prefix_redis_token_str(f"{token}_lock") + ) elif isinstance(state_manager, (StateManagerMemory, StateManagerDisk)): assert token in state_manager._states_locks assert state_manager._states_locks[token].locked() @@ -1683,7 +1685,7 @@ async def test_state_manager_modify_state( # lock should be dropped after exiting the context if isinstance(state_manager, StateManagerRedis): assert ( - await state_manager.redis.get(prefix_redis_token(f"{token}_lock")) + await state_manager.redis.get(prefix_redis_token_str(f"{token}_lock")) ) is None elif isinstance(state_manager, (StateManagerMemory, StateManagerDisk)): assert not state_manager._states_locks[token].locked() @@ -1727,7 +1729,7 @@ async def test_state_manager_contend( if isinstance(state_manager, StateManagerRedis): assert ( - await state_manager.redis.get(prefix_redis_token(f"{token}_lock")) + await state_manager.redis.get(prefix_redis_token_str(f"{token}_lock")) ) is None elif isinstance(state_manager, (StateManagerMemory, StateManagerDisk)): assert token in state_manager._states_locks @@ -1859,7 +1861,7 @@ async def test_state_manager_redis_prefix( async with state_manager_redis.modify_state(substate_token_redis) as state: state.num1 = 42 - prefixed_token = prefix_redis_token(substate_token_redis) + prefixed_token = prefix_redis_token_str(substate_token_redis) assert prefixed_token == f"{redis_prefix}{substate_token_redis}" assert await state_manager_redis.redis.get(prefixed_token)