/health endpoint for K8 Liveness and Readiness probes (#3855)

* Added API Endpoint

* Added API Endpoint

* Added Unit Tests

* Added Unit Tests

* main

* Apply suggestions from Code Review

* Fix Ruff Formatting

* Update Socket Events

* Async Functions
This commit is contained in:
Samarth Bhadane 2024-09-03 18:34:03 -07:00 committed by GitHub
parent 15a9f0a104
commit 59047303c9
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 189 additions and 2 deletions

View File

@ -33,7 +33,7 @@ from typing import (
from fastapi import FastAPI, HTTPException, Request, UploadFile
from fastapi.middleware import cors
from fastapi.responses import StreamingResponse
from fastapi.responses import JSONResponse, StreamingResponse
from fastapi.staticfiles import StaticFiles
from rich.progress import MofNCompleteColumn, Progress, TimeElapsedColumn
from socketio import ASGIApp, AsyncNamespace, AsyncServer
@ -65,7 +65,7 @@ from reflex.components.core.upload import Upload, get_upload_dir
from reflex.components.radix import themes
from reflex.config import get_config
from reflex.event import Event, EventHandler, EventSpec, window_alert
from reflex.model import Model
from reflex.model import Model, get_db_status
from reflex.page import (
DECORATED_PAGES,
)
@ -377,6 +377,7 @@ class App(MiddlewareMixin, LifespanMixin, Base):
"""Add default api endpoints (ping)."""
# To test the server.
self.api.get(str(constants.Endpoint.PING))(ping)
self.api.get(str(constants.Endpoint.HEALTH))(health)
def _add_optional_endpoints(self):
"""Add optional api endpoints (_upload)."""
@ -1319,6 +1320,38 @@ async def ping() -> str:
return "pong"
async def health() -> JSONResponse:
"""Health check endpoint to assess the status of the database and Redis services.
Returns:
JSONResponse: A JSON object with the health status:
- "status" (bool): Overall health, True if all checks pass.
- "db" (bool or str): Database status - True, False, or "NA".
- "redis" (bool or str): Redis status - True, False, or "NA".
"""
health_status = {"status": True}
status_code = 200
db_status, redis_status = await asyncio.gather(
get_db_status(), prerequisites.get_redis_status()
)
health_status["db"] = db_status
if redis_status is None:
health_status["redis"] = False
else:
health_status["redis"] = redis_status
if not health_status["db"] or (
not health_status["redis"] and redis_status is not None
):
health_status["status"] = False
status_code = 503
return JSONResponse(content=health_status, status_code=status_code)
def upload(app: App):
"""Upload a file.

View File

@ -11,6 +11,7 @@ class Endpoint(Enum):
EVENT = "_event"
UPLOAD = "_upload"
AUTH_CODESPACE = "auth-codespace"
HEALTH = "_health"
def __str__(self) -> str:
"""Get the string representation of the endpoint.

View File

@ -15,6 +15,7 @@ import alembic.runtime.environment
import alembic.script
import alembic.util
import sqlalchemy
import sqlalchemy.exc
import sqlalchemy.orm
from reflex import constants
@ -51,6 +52,27 @@ def get_engine(url: str | None = None) -> sqlalchemy.engine.Engine:
return sqlmodel.create_engine(url, echo=echo_db_query, connect_args=connect_args)
async def get_db_status() -> bool:
"""Checks the status of the database connection.
Attempts to connect to the database and execute a simple query to verify connectivity.
Returns:
bool: The status of the database connection:
- True: The database is accessible.
- False: The database is not accessible.
"""
status = True
try:
engine = get_engine()
with engine.connect() as connection:
connection.execute(sqlalchemy.text("SELECT 1"))
except sqlalchemy.exc.OperationalError:
status = False
return status
SQLModelOrSqlAlchemy = Union[
Type[sqlmodel.SQLModel], Type[sqlalchemy.orm.DeclarativeBase]
]

View File

@ -28,6 +28,7 @@ import typer
from alembic.util.exc import CommandError
from packaging import version
from redis import Redis as RedisSync
from redis import exceptions
from redis.asyncio import Redis
from reflex import constants, model
@ -344,6 +345,30 @@ def parse_redis_url() -> str | dict | None:
return dict(host=redis_url, port=int(redis_port), db=0)
async def get_redis_status() -> bool | None:
"""Checks the status of the Redis connection.
Attempts to connect to Redis and send a ping command to verify connectivity.
Returns:
bool or None: The status of the Redis connection:
- True: Redis is accessible and responding.
- False: Redis is not accessible due to a connection error.
- None: Redis not used i.e redis_url is not set in rxconfig.
"""
try:
status = True
redis_client = get_redis_sync()
if redis_client is not None:
redis_client.ping()
else:
status = None
except exceptions.RedisError:
status = False
return status
def validate_app_name(app_name: str | None = None) -> str:
"""Validate the app name.

View File

@ -0,0 +1,106 @@
import json
from unittest.mock import MagicMock, Mock
import pytest
import sqlalchemy
from redis.exceptions import RedisError
from reflex.app import health
from reflex.model import get_db_status
from reflex.utils.prerequisites import get_redis_status
@pytest.mark.asyncio
@pytest.mark.parametrize(
"mock_redis_client, expected_status",
[
# Case 1: Redis client is available and responds to ping
(Mock(ping=lambda: None), True),
# Case 2: Redis client raises RedisError
(Mock(ping=lambda: (_ for _ in ()).throw(RedisError)), False),
# Case 3: Redis client is not used
(None, None),
],
)
async def test_get_redis_status(mock_redis_client, expected_status, mocker):
# Mock the `get_redis_sync` function to return the mock Redis client
mock_get_redis_sync = mocker.patch(
"reflex.utils.prerequisites.get_redis_sync", return_value=mock_redis_client
)
# Call the function
status = await get_redis_status()
# Verify the result
assert status == expected_status
mock_get_redis_sync.assert_called_once()
@pytest.mark.asyncio
@pytest.mark.parametrize(
"mock_engine, execute_side_effect, expected_status",
[
# Case 1: Database is accessible
(MagicMock(), None, True),
# Case 2: Database connection error (OperationalError)
(
MagicMock(),
sqlalchemy.exc.OperationalError("error", "error", "error"),
False,
),
],
)
async def test_get_db_status(mock_engine, execute_side_effect, expected_status, mocker):
# Mock get_engine to return the mock_engine
mock_get_engine = mocker.patch("reflex.model.get_engine", return_value=mock_engine)
# Mock the connection and its execute method
if mock_engine:
mock_connection = mock_engine.connect.return_value.__enter__.return_value
if execute_side_effect:
# Simulate execute method raising an exception
mock_connection.execute.side_effect = execute_side_effect
else:
# Simulate successful execute call
mock_connection.execute.return_value = None
# Call the function
status = await get_db_status()
# Verify the result
assert status == expected_status
mock_get_engine.assert_called_once()
@pytest.mark.asyncio
@pytest.mark.parametrize(
"db_status, redis_status, expected_status, expected_code",
[
# Case 1: Both services are connected
(True, True, {"status": True, "db": True, "redis": True}, 200),
# Case 2: Database not connected, Redis connected
(False, True, {"status": False, "db": False, "redis": True}, 503),
# Case 3: Database connected, Redis not connected
(True, False, {"status": False, "db": True, "redis": False}, 503),
# Case 4: Both services not connected
(False, False, {"status": False, "db": False, "redis": False}, 503),
# Case 5: Database Connected, Redis not used
(True, None, {"status": True, "db": True, "redis": False}, 200),
],
)
async def test_health(db_status, redis_status, expected_status, expected_code, mocker):
# Mock get_db_status and get_redis_status
mocker.patch("reflex.app.get_db_status", return_value=db_status)
mocker.patch(
"reflex.utils.prerequisites.get_redis_status", return_value=redis_status
)
# Call the async health function
response = await health()
print(json.loads(response.body))
print(expected_status)
# Verify the response content and status code
assert response.status_code == expected_code
assert json.loads(response.body) == expected_status