Reuse the sqlalchemy engine once it's created (#4493)

* Reuse the sqlalchemy engine once it's created

* Implement `rx.asession` for async database support

Requires setting `async_db_url` to the same database as `db_url`, except using
an async driver; this will vary by database.

* resolve the url first, so the key into _ENGINE is correct

* Ping db connections before returning them from pool

Move connect engine kwargs to a separate function

* the param is `echo`

* sanity check that config db_url and async_db_url are the same

throw a warning if the part following the `://` differs between these two

* create_async_engine: use sqlalchemy async API

update types

* redact ASYNC_DB_URL similarly to DB_URL when overridden in config

* update rx.asession docstring

* use async_sessionmaker

* Redact sensitive env vars instead of hiding them
This commit is contained in:
Masen Furer 2024-12-10 12:37:09 -08:00 committed by GitHub
parent 3ef7106e0e
commit 4922f7ba05
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 141 additions and 12 deletions

View File

@ -331,7 +331,7 @@ _MAPPING: dict = {
"SessionStorage",
],
"middleware": ["middleware", "Middleware"],
"model": ["session", "Model"],
"model": ["asession", "session", "Model"],
"state": [
"var",
"ComponentState",

View File

@ -186,6 +186,7 @@ from .istate.wrappers import get_state as get_state
from .middleware import Middleware as Middleware
from .middleware import middleware as middleware
from .model import Model as Model
from .model import asession as asession
from .model import session as session
from .page import page as page
from .state import ComponentState as ComponentState

View File

@ -512,6 +512,9 @@ class EnvironmentVariables:
# Whether to print the SQL queries if the log level is INFO or lower.
SQLALCHEMY_ECHO: EnvVar[bool] = env_var(False)
# Whether to check db connections before using them.
SQLALCHEMY_POOL_PRE_PING: EnvVar[bool] = env_var(True)
# Whether to ignore the redis config error. Some redis servers only allow out-of-band configuration.
REFLEX_IGNORE_REDIS_CONFIG_ERROR: EnvVar[bool] = env_var(False)
@ -568,6 +571,10 @@ class EnvironmentVariables:
environment = EnvironmentVariables()
# These vars are not logged because they may contain sensitive information.
_sensitive_env_vars = {"DB_URL", "ASYNC_DB_URL", "REDIS_URL"}
class Config(Base):
"""The config defines runtime settings for the app.
@ -621,6 +628,9 @@ class Config(Base):
# The database url used by rx.Model.
db_url: Optional[str] = "sqlite:///reflex.db"
# The async database url used by rx.Model.
async_db_url: Optional[str] = None
# The redis url
redis_url: Optional[str] = None
@ -748,18 +758,20 @@ class Config(Base):
# If the env var is set, override the config value.
if env_var is not None:
if key.upper() != "DB_URL":
console.info(
f"Overriding config value {key} with env var {key.upper()}={env_var}",
dedupe=True,
)
# Interpret the value.
value = interpret_env_var_value(env_var, field.outer_type_, field.name)
# Set the value.
updated_values[key] = value
if key.upper() in _sensitive_env_vars:
env_var = "***"
console.info(
f"Overriding config value {key} with env var {key.upper()}={env_var}",
dedupe=True,
)
return updated_values
def get_event_namespace(self) -> str:

View File

@ -2,6 +2,7 @@
from __future__ import annotations
import re
from collections import defaultdict
from typing import Any, ClassVar, Optional, Type, Union
@ -14,6 +15,7 @@ import alembic.script
import alembic.util
import sqlalchemy
import sqlalchemy.exc
import sqlalchemy.ext.asyncio
import sqlalchemy.orm
from reflex.base import Base
@ -21,6 +23,48 @@ from reflex.config import environment, get_config
from reflex.utils import console
from reflex.utils.compat import sqlmodel, sqlmodel_field_has_primary_key
_ENGINE: dict[str, sqlalchemy.engine.Engine] = {}
_ASYNC_ENGINE: dict[str, sqlalchemy.ext.asyncio.AsyncEngine] = {}
_AsyncSessionLocal: dict[str | None, sqlalchemy.ext.asyncio.async_sessionmaker] = {}
# Import AsyncSession _after_ reflex.utils.compat
from sqlmodel.ext.asyncio.session import AsyncSession # noqa: E402
def _safe_db_url_for_logging(url: str) -> str:
"""Remove username and password from the database URL for logging.
Args:
url: The database URL.
Returns:
The database URL with the username and password removed.
"""
return re.sub(r"://[^@]+@", "://<username>:<password>@", url)
def get_engine_args(url: str | None = None) -> dict[str, Any]:
"""Get the database engine arguments.
Args:
url: The database url.
Returns:
The database engine arguments as a dict.
"""
kwargs: dict[str, Any] = dict(
# Print the SQL queries if the log level is INFO or lower.
echo=environment.SQLALCHEMY_ECHO.get(),
# Check connections before returning them.
pool_pre_ping=environment.SQLALCHEMY_POOL_PRE_PING.get(),
)
conf = get_config()
url = url or conf.db_url
if url is not None and url.startswith("sqlite"):
# Needed for the admin dash on sqlite.
kwargs["connect_args"] = {"check_same_thread": False}
return kwargs
def get_engine(url: str | None = None) -> sqlalchemy.engine.Engine:
"""Get the database engine.
@ -38,15 +82,62 @@ def get_engine(url: str | None = None) -> sqlalchemy.engine.Engine:
url = url or conf.db_url
if url is None:
raise ValueError("No database url configured")
global _ENGINE
if url in _ENGINE:
return _ENGINE[url]
if not environment.ALEMBIC_CONFIG.get().exists():
console.warn(
"Database is not initialized, run [bold]reflex db init[/bold] first."
)
# Print the SQL queries if the log level is INFO or lower.
echo_db_query = environment.SQLALCHEMY_ECHO.get()
# Needed for the admin dash on sqlite.
connect_args = {"check_same_thread": False} if url.startswith("sqlite") else {}
return sqlmodel.create_engine(url, echo=echo_db_query, connect_args=connect_args)
_ENGINE[url] = sqlmodel.create_engine(
url,
**get_engine_args(url),
)
return _ENGINE[url]
def get_async_engine(url: str | None) -> sqlalchemy.ext.asyncio.AsyncEngine:
"""Get the async database engine.
Args:
url: The database url.
Returns:
The async database engine.
Raises:
ValueError: If the async database url is None.
"""
if url is None:
conf = get_config()
url = conf.async_db_url
if url is not None and conf.db_url is not None:
async_db_url_tail = url.partition("://")[2]
db_url_tail = conf.db_url.partition("://")[2]
if async_db_url_tail != db_url_tail:
console.warn(
f"async_db_url `{_safe_db_url_for_logging(url)}` "
"should reference the same database as "
f"db_url `{_safe_db_url_for_logging(conf.db_url)}`."
)
if url is None:
raise ValueError("No async database url configured")
global _ASYNC_ENGINE
if url in _ASYNC_ENGINE:
return _ASYNC_ENGINE[url]
if not environment.ALEMBIC_CONFIG.get().exists():
console.warn(
"Database is not initialized, run [bold]reflex db init[/bold] first."
)
_ASYNC_ENGINE[url] = sqlalchemy.ext.asyncio.create_async_engine(
url,
**get_engine_args(url),
)
return _ASYNC_ENGINE[url]
async def get_db_status() -> bool:
@ -425,6 +516,31 @@ def session(url: str | None = None) -> sqlmodel.Session:
return sqlmodel.Session(get_engine(url))
def asession(url: str | None = None) -> AsyncSession:
"""Get an async sqlmodel session to interact with the database.
async with rx.asession() as asession:
...
Most operations against the `asession` must be awaited.
Args:
url: The database url.
Returns:
An async database session.
"""
global _AsyncSessionLocal
if url not in _AsyncSessionLocal:
_AsyncSessionLocal[url] = sqlalchemy.ext.asyncio.async_sessionmaker(
bind=get_async_engine(url),
class_=AsyncSession,
autocommit=False,
autoflush=False,
)
return _AsyncSessionLocal[url]()
def sqla_session(url: str | None = None) -> sqlalchemy.orm.Session:
"""Get a bare sqlalchemy session to interact with the database.