sanity check that config db_url and async_db_url are the same

throw a warning if the part following the `://` differs between these two
This commit is contained in:
Masen Furer 2024-12-06 12:27:36 -08:00
parent e935115c74
commit 764d5768fa
No known key found for this signature in database
GPG Key ID: B0008AD22B3B3A95
2 changed files with 29 additions and 3 deletions

View File

@ -625,7 +625,7 @@ class Config(Base):
db_url: Optional[str] = "sqlite:///reflex.db"
# The async database url used by rx.Model.
async_db_url: Optional[str] = "sqlite+aiosqlite:///reflex.db"
async_db_url: Optional[str] = None
# The redis url
redis_url: Optional[str] = None

View File

@ -2,6 +2,7 @@
from __future__ import annotations
import re
from collections import defaultdict
from typing import Any, ClassVar, Optional, Type, Union
@ -26,6 +27,21 @@ _ENGINE: dict[str, sqlalchemy.engine.Engine] = {}
_ASYNC_ENGINE: dict[str, sqlalchemy.engine.Engine] = {}
_AsyncSessionLocal: dict[str | None, sqlalchemy.orm.sessionmaker] = {}
# Import asyncio session _after_ reflex.utils.compat
import sqlmodel.ext.asyncio.session # noqa: E402
def _safe_db_url_for_logging(url: str) -> str:
"""Remove username and password from the database URL for logging.
Args:
url: The database URL.
Returns:
The database URL with the username and password removed.
"""
return re.sub(r"://[^@]+@", "://<username>:<password>@", url)
def get_engine_args(url: str | None = None) -> dict[str, Any]:
"""Get the database engine arguments.
@ -94,8 +110,18 @@ def get_async_engine(url: str | None) -> sqlalchemy.engine.Engine:
Raises:
ValueError: If the async database url is None.
"""
conf = get_config()
url = url or conf.async_db_url
if url is None:
conf = get_config()
url = conf.async_db_url
if url is not None and conf.db_url is not None:
async_db_url_tail = url.partition("://")[2]
db_url_tail = conf.db_url.partition("://")[2]
if async_db_url_tail != db_url_tail:
console.warn(
f"async_db_url `{_safe_db_url_for_logging(url)}` "
"should reference the same database as "
f"db_url `{_safe_db_url_for_logging(conf.db_url)}`."
)
if url is None:
raise ValueError("No async database url configured")