diff --git a/reflex/app.py b/reflex/app.py index 5be0ef040..a3094885d 100644 --- a/reflex/app.py +++ b/reflex/app.py @@ -33,7 +33,7 @@ from typing import ( from fastapi import FastAPI, HTTPException, Request, UploadFile from fastapi.middleware import cors -from fastapi.responses import StreamingResponse +from fastapi.responses import JSONResponse, StreamingResponse from fastapi.staticfiles import StaticFiles from rich.progress import MofNCompleteColumn, Progress, TimeElapsedColumn from socketio import ASGIApp, AsyncNamespace, AsyncServer @@ -65,7 +65,7 @@ from reflex.components.core.upload import Upload, get_upload_dir from reflex.components.radix import themes from reflex.config import get_config from reflex.event import Event, EventHandler, EventSpec, window_alert -from reflex.model import Model +from reflex.model import Model, get_db_status from reflex.page import ( DECORATED_PAGES, ) @@ -377,6 +377,7 @@ class App(MiddlewareMixin, LifespanMixin, Base): """Add default api endpoints (ping).""" # To test the server. self.api.get(str(constants.Endpoint.PING))(ping) + self.api.get(str(constants.Endpoint.HEALTH))(health) def _add_optional_endpoints(self): """Add optional api endpoints (_upload).""" @@ -1319,6 +1320,38 @@ async def ping() -> str: return "pong" +async def health() -> JSONResponse: + """Health check endpoint to assess the status of the database and Redis services. + + Returns: + JSONResponse: A JSON object with the health status: + - "status" (bool): Overall health, True if all checks pass. + - "db" (bool or str): Database status - True, False, or "NA". + - "redis" (bool or str): Redis status - True, False, or "NA". + """ + health_status = {"status": True} + status_code = 200 + + db_status, redis_status = await asyncio.gather( + get_db_status(), prerequisites.get_redis_status() + ) + + health_status["db"] = db_status + + if redis_status is None: + health_status["redis"] = False + else: + health_status["redis"] = redis_status + + if not health_status["db"] or ( + not health_status["redis"] and redis_status is not None + ): + health_status["status"] = False + status_code = 503 + + return JSONResponse(content=health_status, status_code=status_code) + + def upload(app: App): """Upload a file. diff --git a/reflex/constants/event.py b/reflex/constants/event.py index 351a1ac52..d454e6ea8 100644 --- a/reflex/constants/event.py +++ b/reflex/constants/event.py @@ -11,6 +11,7 @@ class Endpoint(Enum): EVENT = "_event" UPLOAD = "_upload" AUTH_CODESPACE = "auth-codespace" + HEALTH = "_health" def __str__(self) -> str: """Get the string representation of the endpoint. diff --git a/reflex/model.py b/reflex/model.py index 71e26f76a..fefb1f9e9 100644 --- a/reflex/model.py +++ b/reflex/model.py @@ -15,6 +15,7 @@ import alembic.runtime.environment import alembic.script import alembic.util import sqlalchemy +import sqlalchemy.exc import sqlalchemy.orm from reflex import constants @@ -51,6 +52,27 @@ def get_engine(url: str | None = None) -> sqlalchemy.engine.Engine: return sqlmodel.create_engine(url, echo=echo_db_query, connect_args=connect_args) +async def get_db_status() -> bool: + """Checks the status of the database connection. + + Attempts to connect to the database and execute a simple query to verify connectivity. + + Returns: + bool: The status of the database connection: + - True: The database is accessible. + - False: The database is not accessible. + """ + status = True + try: + engine = get_engine() + with engine.connect() as connection: + connection.execute(sqlalchemy.text("SELECT 1")) + except sqlalchemy.exc.OperationalError: + status = False + + return status + + SQLModelOrSqlAlchemy = Union[ Type[sqlmodel.SQLModel], Type[sqlalchemy.orm.DeclarativeBase] ] diff --git a/reflex/utils/prerequisites.py b/reflex/utils/prerequisites.py index 697d51cf2..902c5111c 100644 --- a/reflex/utils/prerequisites.py +++ b/reflex/utils/prerequisites.py @@ -28,6 +28,7 @@ import typer from alembic.util.exc import CommandError from packaging import version from redis import Redis as RedisSync +from redis import exceptions from redis.asyncio import Redis from reflex import constants, model @@ -344,6 +345,30 @@ def parse_redis_url() -> str | dict | None: return dict(host=redis_url, port=int(redis_port), db=0) +async def get_redis_status() -> bool | None: + """Checks the status of the Redis connection. + + Attempts to connect to Redis and send a ping command to verify connectivity. + + Returns: + bool or None: The status of the Redis connection: + - True: Redis is accessible and responding. + - False: Redis is not accessible due to a connection error. + - None: Redis not used i.e redis_url is not set in rxconfig. + """ + try: + status = True + redis_client = get_redis_sync() + if redis_client is not None: + redis_client.ping() + else: + status = None + except exceptions.RedisError: + status = False + + return status + + def validate_app_name(app_name: str | None = None) -> str: """Validate the app name. diff --git a/tests/test_health_endpoint.py b/tests/test_health_endpoint.py new file mode 100644 index 000000000..fe350266f --- /dev/null +++ b/tests/test_health_endpoint.py @@ -0,0 +1,106 @@ +import json +from unittest.mock import MagicMock, Mock + +import pytest +import sqlalchemy +from redis.exceptions import RedisError + +from reflex.app import health +from reflex.model import get_db_status +from reflex.utils.prerequisites import get_redis_status + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + "mock_redis_client, expected_status", + [ + # Case 1: Redis client is available and responds to ping + (Mock(ping=lambda: None), True), + # Case 2: Redis client raises RedisError + (Mock(ping=lambda: (_ for _ in ()).throw(RedisError)), False), + # Case 3: Redis client is not used + (None, None), + ], +) +async def test_get_redis_status(mock_redis_client, expected_status, mocker): + # Mock the `get_redis_sync` function to return the mock Redis client + mock_get_redis_sync = mocker.patch( + "reflex.utils.prerequisites.get_redis_sync", return_value=mock_redis_client + ) + + # Call the function + status = await get_redis_status() + + # Verify the result + assert status == expected_status + mock_get_redis_sync.assert_called_once() + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + "mock_engine, execute_side_effect, expected_status", + [ + # Case 1: Database is accessible + (MagicMock(), None, True), + # Case 2: Database connection error (OperationalError) + ( + MagicMock(), + sqlalchemy.exc.OperationalError("error", "error", "error"), + False, + ), + ], +) +async def test_get_db_status(mock_engine, execute_side_effect, expected_status, mocker): + # Mock get_engine to return the mock_engine + mock_get_engine = mocker.patch("reflex.model.get_engine", return_value=mock_engine) + + # Mock the connection and its execute method + if mock_engine: + mock_connection = mock_engine.connect.return_value.__enter__.return_value + if execute_side_effect: + # Simulate execute method raising an exception + mock_connection.execute.side_effect = execute_side_effect + else: + # Simulate successful execute call + mock_connection.execute.return_value = None + + # Call the function + status = await get_db_status() + + # Verify the result + assert status == expected_status + mock_get_engine.assert_called_once() + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + "db_status, redis_status, expected_status, expected_code", + [ + # Case 1: Both services are connected + (True, True, {"status": True, "db": True, "redis": True}, 200), + # Case 2: Database not connected, Redis connected + (False, True, {"status": False, "db": False, "redis": True}, 503), + # Case 3: Database connected, Redis not connected + (True, False, {"status": False, "db": True, "redis": False}, 503), + # Case 4: Both services not connected + (False, False, {"status": False, "db": False, "redis": False}, 503), + # Case 5: Database Connected, Redis not used + (True, None, {"status": True, "db": True, "redis": False}, 200), + ], +) +async def test_health(db_status, redis_status, expected_status, expected_code, mocker): + # Mock get_db_status and get_redis_status + mocker.patch("reflex.app.get_db_status", return_value=db_status) + mocker.patch( + "reflex.utils.prerequisites.get_redis_status", return_value=redis_status + ) + + # Call the async health function + response = await health() + + print(json.loads(response.body)) + print(expected_status) + + # Verify the response content and status code + assert response.status_code == expected_code + assert json.loads(response.body) == expected_status