From a3be76fb75a71b9979d497b89f3e9e1bef02d30a Mon Sep 17 00:00:00 2001 From: Martin Xu <15661672+martinxu9@users.noreply.github.com> Date: Wed, 21 Feb 2024 07:01:44 -0800 Subject: [PATCH] use sync redis client to sanity check (#2679) --- reflex/state.py | 1 + reflex/utils/prerequisites.py | 35 +++++++++++++++++++++++++++++++---- reflex/utils/processes.py | 13 ++++++++++++- 3 files changed, 44 insertions(+), 5 deletions(-) diff --git a/reflex/state.py b/reflex/state.py index 50468c0b4..7b377b4d0 100644 --- a/reflex/state.py +++ b/reflex/state.py @@ -1,4 +1,5 @@ """Define the reflex state specification.""" + from __future__ import annotations import asyncio diff --git a/reflex/utils/prerequisites.py b/reflex/utils/prerequisites.py index 8fe5493ac..5e12a60d0 100644 --- a/reflex/utils/prerequisites.py +++ b/reflex/utils/prerequisites.py @@ -24,6 +24,7 @@ import pkg_resources import typer from alembic.util.exc import CommandError from packaging import version +from redis import Redis as RedisSync from redis.asyncio import Redis import reflex @@ -189,16 +190,42 @@ def get_compiled_app(reload: bool = False) -> ModuleType: def get_redis() -> Redis | None: - """Get the redis client. + """Get the asynchronous redis client. Returns: - The redis client. + The asynchronous redis client. + """ + if isinstance((redis_url_or_options := parse_redis_url()), str): + return Redis.from_url(redis_url_or_options) + elif isinstance(redis_url_or_options, dict): + return Redis(**redis_url_or_options) + return None + + +def get_redis_sync() -> RedisSync | None: + """Get the synchronous redis client. + + Returns: + The synchronous redis client. + """ + if isinstance((redis_url_or_options := parse_redis_url()), str): + return RedisSync.from_url(redis_url_or_options) + elif isinstance(redis_url_or_options, dict): + return RedisSync(**redis_url_or_options) + return None + + +def parse_redis_url() -> str | dict | None: + """Parse the REDIS_URL in config if applicable. + + Returns: + If redis-py syntax, return the URL as it is. Otherwise, return the host/port/db as a dict. """ config = get_config() if not config.redis_url: return None if config.redis_url.startswith(("redis://", "rediss://", "unix://")): - return Redis.from_url(config.redis_url) + return config.redis_url console.deprecate( feature_name="host[:port] style redis urls", reason="redis-py url syntax is now being used", @@ -209,7 +236,7 @@ def get_redis() -> Redis | None: if not has_port: redis_port = 6379 console.info(f"Using redis at {config.redis_url}") - return Redis(host=redis_url, port=int(redis_port), db=0) + return dict(host=redis_url, port=int(redis_port), db=0) def get_production_backend_url() -> str: diff --git a/reflex/utils/processes.py b/reflex/utils/processes.py index 19a2d6804..ce165d70c 100644 --- a/reflex/utils/processes.py +++ b/reflex/utils/processes.py @@ -12,6 +12,7 @@ from typing import Callable, Generator, List, Optional, Tuple, Union import psutil import typer +from redis.exceptions import RedisError from reflex.utils import console, path_ops, prerequisites @@ -28,10 +29,20 @@ def kill(pid): def get_num_workers() -> int: """Get the number of backend worker processes. + Raises: + Exit: If unable to connect to Redis. + Returns: The number of backend worker processes. """ - return 1 if prerequisites.get_redis() is None else (os.cpu_count() or 1) * 2 + 1 + if (redis_client := prerequisites.get_redis_sync()) is None: + return 1 + try: + redis_client.ping() + except RedisError as re: + console.error(f"Unable to connect to Redis: {re}") + raise typer.Exit(1) from re + return (os.cpu_count() or 1) * 2 + 1 def get_process_on_port(port) -> Optional[psutil.Process]: