/health endpoint for K8 Liveness and Readiness probes (#3855)
* Added API Endpoint * Added API Endpoint * Added Unit Tests * Added Unit Tests * main * Apply suggestions from Code Review * Fix Ruff Formatting * Update Socket Events * Async Functions
This commit is contained in:
parent
15a9f0a104
commit
59047303c9
@ -33,7 +33,7 @@ from typing import (
|
|||||||
|
|
||||||
from fastapi import FastAPI, HTTPException, Request, UploadFile
|
from fastapi import FastAPI, HTTPException, Request, UploadFile
|
||||||
from fastapi.middleware import cors
|
from fastapi.middleware import cors
|
||||||
from fastapi.responses import StreamingResponse
|
from fastapi.responses import JSONResponse, StreamingResponse
|
||||||
from fastapi.staticfiles import StaticFiles
|
from fastapi.staticfiles import StaticFiles
|
||||||
from rich.progress import MofNCompleteColumn, Progress, TimeElapsedColumn
|
from rich.progress import MofNCompleteColumn, Progress, TimeElapsedColumn
|
||||||
from socketio import ASGIApp, AsyncNamespace, AsyncServer
|
from socketio import ASGIApp, AsyncNamespace, AsyncServer
|
||||||
@ -65,7 +65,7 @@ from reflex.components.core.upload import Upload, get_upload_dir
|
|||||||
from reflex.components.radix import themes
|
from reflex.components.radix import themes
|
||||||
from reflex.config import get_config
|
from reflex.config import get_config
|
||||||
from reflex.event import Event, EventHandler, EventSpec, window_alert
|
from reflex.event import Event, EventHandler, EventSpec, window_alert
|
||||||
from reflex.model import Model
|
from reflex.model import Model, get_db_status
|
||||||
from reflex.page import (
|
from reflex.page import (
|
||||||
DECORATED_PAGES,
|
DECORATED_PAGES,
|
||||||
)
|
)
|
||||||
@ -377,6 +377,7 @@ class App(MiddlewareMixin, LifespanMixin, Base):
|
|||||||
"""Add default api endpoints (ping)."""
|
"""Add default api endpoints (ping)."""
|
||||||
# To test the server.
|
# To test the server.
|
||||||
self.api.get(str(constants.Endpoint.PING))(ping)
|
self.api.get(str(constants.Endpoint.PING))(ping)
|
||||||
|
self.api.get(str(constants.Endpoint.HEALTH))(health)
|
||||||
|
|
||||||
def _add_optional_endpoints(self):
|
def _add_optional_endpoints(self):
|
||||||
"""Add optional api endpoints (_upload)."""
|
"""Add optional api endpoints (_upload)."""
|
||||||
@ -1319,6 +1320,38 @@ async def ping() -> str:
|
|||||||
return "pong"
|
return "pong"
|
||||||
|
|
||||||
|
|
||||||
|
async def health() -> JSONResponse:
|
||||||
|
"""Health check endpoint to assess the status of the database and Redis services.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
JSONResponse: A JSON object with the health status:
|
||||||
|
- "status" (bool): Overall health, True if all checks pass.
|
||||||
|
- "db" (bool or str): Database status - True, False, or "NA".
|
||||||
|
- "redis" (bool or str): Redis status - True, False, or "NA".
|
||||||
|
"""
|
||||||
|
health_status = {"status": True}
|
||||||
|
status_code = 200
|
||||||
|
|
||||||
|
db_status, redis_status = await asyncio.gather(
|
||||||
|
get_db_status(), prerequisites.get_redis_status()
|
||||||
|
)
|
||||||
|
|
||||||
|
health_status["db"] = db_status
|
||||||
|
|
||||||
|
if redis_status is None:
|
||||||
|
health_status["redis"] = False
|
||||||
|
else:
|
||||||
|
health_status["redis"] = redis_status
|
||||||
|
|
||||||
|
if not health_status["db"] or (
|
||||||
|
not health_status["redis"] and redis_status is not None
|
||||||
|
):
|
||||||
|
health_status["status"] = False
|
||||||
|
status_code = 503
|
||||||
|
|
||||||
|
return JSONResponse(content=health_status, status_code=status_code)
|
||||||
|
|
||||||
|
|
||||||
def upload(app: App):
|
def upload(app: App):
|
||||||
"""Upload a file.
|
"""Upload a file.
|
||||||
|
|
||||||
|
@ -11,6 +11,7 @@ class Endpoint(Enum):
|
|||||||
EVENT = "_event"
|
EVENT = "_event"
|
||||||
UPLOAD = "_upload"
|
UPLOAD = "_upload"
|
||||||
AUTH_CODESPACE = "auth-codespace"
|
AUTH_CODESPACE = "auth-codespace"
|
||||||
|
HEALTH = "_health"
|
||||||
|
|
||||||
def __str__(self) -> str:
|
def __str__(self) -> str:
|
||||||
"""Get the string representation of the endpoint.
|
"""Get the string representation of the endpoint.
|
||||||
|
@ -15,6 +15,7 @@ import alembic.runtime.environment
|
|||||||
import alembic.script
|
import alembic.script
|
||||||
import alembic.util
|
import alembic.util
|
||||||
import sqlalchemy
|
import sqlalchemy
|
||||||
|
import sqlalchemy.exc
|
||||||
import sqlalchemy.orm
|
import sqlalchemy.orm
|
||||||
|
|
||||||
from reflex import constants
|
from reflex import constants
|
||||||
@ -51,6 +52,27 @@ def get_engine(url: str | None = None) -> sqlalchemy.engine.Engine:
|
|||||||
return sqlmodel.create_engine(url, echo=echo_db_query, connect_args=connect_args)
|
return sqlmodel.create_engine(url, echo=echo_db_query, connect_args=connect_args)
|
||||||
|
|
||||||
|
|
||||||
|
async def get_db_status() -> bool:
|
||||||
|
"""Checks the status of the database connection.
|
||||||
|
|
||||||
|
Attempts to connect to the database and execute a simple query to verify connectivity.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
bool: The status of the database connection:
|
||||||
|
- True: The database is accessible.
|
||||||
|
- False: The database is not accessible.
|
||||||
|
"""
|
||||||
|
status = True
|
||||||
|
try:
|
||||||
|
engine = get_engine()
|
||||||
|
with engine.connect() as connection:
|
||||||
|
connection.execute(sqlalchemy.text("SELECT 1"))
|
||||||
|
except sqlalchemy.exc.OperationalError:
|
||||||
|
status = False
|
||||||
|
|
||||||
|
return status
|
||||||
|
|
||||||
|
|
||||||
SQLModelOrSqlAlchemy = Union[
|
SQLModelOrSqlAlchemy = Union[
|
||||||
Type[sqlmodel.SQLModel], Type[sqlalchemy.orm.DeclarativeBase]
|
Type[sqlmodel.SQLModel], Type[sqlalchemy.orm.DeclarativeBase]
|
||||||
]
|
]
|
||||||
|
@ -28,6 +28,7 @@ import typer
|
|||||||
from alembic.util.exc import CommandError
|
from alembic.util.exc import CommandError
|
||||||
from packaging import version
|
from packaging import version
|
||||||
from redis import Redis as RedisSync
|
from redis import Redis as RedisSync
|
||||||
|
from redis import exceptions
|
||||||
from redis.asyncio import Redis
|
from redis.asyncio import Redis
|
||||||
|
|
||||||
from reflex import constants, model
|
from reflex import constants, model
|
||||||
@ -344,6 +345,30 @@ def parse_redis_url() -> str | dict | None:
|
|||||||
return dict(host=redis_url, port=int(redis_port), db=0)
|
return dict(host=redis_url, port=int(redis_port), db=0)
|
||||||
|
|
||||||
|
|
||||||
|
async def get_redis_status() -> bool | None:
|
||||||
|
"""Checks the status of the Redis connection.
|
||||||
|
|
||||||
|
Attempts to connect to Redis and send a ping command to verify connectivity.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
bool or None: The status of the Redis connection:
|
||||||
|
- True: Redis is accessible and responding.
|
||||||
|
- False: Redis is not accessible due to a connection error.
|
||||||
|
- None: Redis not used i.e redis_url is not set in rxconfig.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
status = True
|
||||||
|
redis_client = get_redis_sync()
|
||||||
|
if redis_client is not None:
|
||||||
|
redis_client.ping()
|
||||||
|
else:
|
||||||
|
status = None
|
||||||
|
except exceptions.RedisError:
|
||||||
|
status = False
|
||||||
|
|
||||||
|
return status
|
||||||
|
|
||||||
|
|
||||||
def validate_app_name(app_name: str | None = None) -> str:
|
def validate_app_name(app_name: str | None = None) -> str:
|
||||||
"""Validate the app name.
|
"""Validate the app name.
|
||||||
|
|
||||||
|
106
tests/test_health_endpoint.py
Normal file
106
tests/test_health_endpoint.py
Normal file
@ -0,0 +1,106 @@
|
|||||||
|
import json
|
||||||
|
from unittest.mock import MagicMock, Mock
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
import sqlalchemy
|
||||||
|
from redis.exceptions import RedisError
|
||||||
|
|
||||||
|
from reflex.app import health
|
||||||
|
from reflex.model import get_db_status
|
||||||
|
from reflex.utils.prerequisites import get_redis_status
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"mock_redis_client, expected_status",
|
||||||
|
[
|
||||||
|
# Case 1: Redis client is available and responds to ping
|
||||||
|
(Mock(ping=lambda: None), True),
|
||||||
|
# Case 2: Redis client raises RedisError
|
||||||
|
(Mock(ping=lambda: (_ for _ in ()).throw(RedisError)), False),
|
||||||
|
# Case 3: Redis client is not used
|
||||||
|
(None, None),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
async def test_get_redis_status(mock_redis_client, expected_status, mocker):
|
||||||
|
# Mock the `get_redis_sync` function to return the mock Redis client
|
||||||
|
mock_get_redis_sync = mocker.patch(
|
||||||
|
"reflex.utils.prerequisites.get_redis_sync", return_value=mock_redis_client
|
||||||
|
)
|
||||||
|
|
||||||
|
# Call the function
|
||||||
|
status = await get_redis_status()
|
||||||
|
|
||||||
|
# Verify the result
|
||||||
|
assert status == expected_status
|
||||||
|
mock_get_redis_sync.assert_called_once()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"mock_engine, execute_side_effect, expected_status",
|
||||||
|
[
|
||||||
|
# Case 1: Database is accessible
|
||||||
|
(MagicMock(), None, True),
|
||||||
|
# Case 2: Database connection error (OperationalError)
|
||||||
|
(
|
||||||
|
MagicMock(),
|
||||||
|
sqlalchemy.exc.OperationalError("error", "error", "error"),
|
||||||
|
False,
|
||||||
|
),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
async def test_get_db_status(mock_engine, execute_side_effect, expected_status, mocker):
|
||||||
|
# Mock get_engine to return the mock_engine
|
||||||
|
mock_get_engine = mocker.patch("reflex.model.get_engine", return_value=mock_engine)
|
||||||
|
|
||||||
|
# Mock the connection and its execute method
|
||||||
|
if mock_engine:
|
||||||
|
mock_connection = mock_engine.connect.return_value.__enter__.return_value
|
||||||
|
if execute_side_effect:
|
||||||
|
# Simulate execute method raising an exception
|
||||||
|
mock_connection.execute.side_effect = execute_side_effect
|
||||||
|
else:
|
||||||
|
# Simulate successful execute call
|
||||||
|
mock_connection.execute.return_value = None
|
||||||
|
|
||||||
|
# Call the function
|
||||||
|
status = await get_db_status()
|
||||||
|
|
||||||
|
# Verify the result
|
||||||
|
assert status == expected_status
|
||||||
|
mock_get_engine.assert_called_once()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"db_status, redis_status, expected_status, expected_code",
|
||||||
|
[
|
||||||
|
# Case 1: Both services are connected
|
||||||
|
(True, True, {"status": True, "db": True, "redis": True}, 200),
|
||||||
|
# Case 2: Database not connected, Redis connected
|
||||||
|
(False, True, {"status": False, "db": False, "redis": True}, 503),
|
||||||
|
# Case 3: Database connected, Redis not connected
|
||||||
|
(True, False, {"status": False, "db": True, "redis": False}, 503),
|
||||||
|
# Case 4: Both services not connected
|
||||||
|
(False, False, {"status": False, "db": False, "redis": False}, 503),
|
||||||
|
# Case 5: Database Connected, Redis not used
|
||||||
|
(True, None, {"status": True, "db": True, "redis": False}, 200),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
async def test_health(db_status, redis_status, expected_status, expected_code, mocker):
|
||||||
|
# Mock get_db_status and get_redis_status
|
||||||
|
mocker.patch("reflex.app.get_db_status", return_value=db_status)
|
||||||
|
mocker.patch(
|
||||||
|
"reflex.utils.prerequisites.get_redis_status", return_value=redis_status
|
||||||
|
)
|
||||||
|
|
||||||
|
# Call the async health function
|
||||||
|
response = await health()
|
||||||
|
|
||||||
|
print(json.loads(response.body))
|
||||||
|
print(expected_status)
|
||||||
|
|
||||||
|
# Verify the response content and status code
|
||||||
|
assert response.status_code == expected_code
|
||||||
|
assert json.loads(response.body) == expected_status
|
Loading…
Reference in New Issue
Block a user