Reuse the sqlalchemy engine once it's created (#4493)
* Reuse the sqlalchemy engine once it's created * Implement `rx.asession` for async database support Requires setting `async_db_url` to the same database as `db_url`, except using an async driver; this will vary by database. * resolve the url first, so the key into _ENGINE is correct * Ping db connections before returning them from pool Move connect engine kwargs to a separate function * the param is `echo` * sanity check that config db_url and async_db_url are the same throw a warning if the part following the `://` differs between these two * create_async_engine: use sqlalchemy async API update types * redact ASYNC_DB_URL similarly to DB_URL when overridden in config * update rx.asession docstring * use async_sessionmaker * Redact sensitive env vars instead of hiding them
This commit is contained in:
parent
3ef7106e0e
commit
4922f7ba05
@ -331,7 +331,7 @@ _MAPPING: dict = {
|
||||
"SessionStorage",
|
||||
],
|
||||
"middleware": ["middleware", "Middleware"],
|
||||
"model": ["session", "Model"],
|
||||
"model": ["asession", "session", "Model"],
|
||||
"state": [
|
||||
"var",
|
||||
"ComponentState",
|
||||
|
@ -186,6 +186,7 @@ from .istate.wrappers import get_state as get_state
|
||||
from .middleware import Middleware as Middleware
|
||||
from .middleware import middleware as middleware
|
||||
from .model import Model as Model
|
||||
from .model import asession as asession
|
||||
from .model import session as session
|
||||
from .page import page as page
|
||||
from .state import ComponentState as ComponentState
|
||||
|
@ -512,6 +512,9 @@ class EnvironmentVariables:
|
||||
# Whether to print the SQL queries if the log level is INFO or lower.
|
||||
SQLALCHEMY_ECHO: EnvVar[bool] = env_var(False)
|
||||
|
||||
# Whether to check db connections before using them.
|
||||
SQLALCHEMY_POOL_PRE_PING: EnvVar[bool] = env_var(True)
|
||||
|
||||
# Whether to ignore the redis config error. Some redis servers only allow out-of-band configuration.
|
||||
REFLEX_IGNORE_REDIS_CONFIG_ERROR: EnvVar[bool] = env_var(False)
|
||||
|
||||
@ -568,6 +571,10 @@ class EnvironmentVariables:
|
||||
environment = EnvironmentVariables()
|
||||
|
||||
|
||||
# These vars are not logged because they may contain sensitive information.
|
||||
_sensitive_env_vars = {"DB_URL", "ASYNC_DB_URL", "REDIS_URL"}
|
||||
|
||||
|
||||
class Config(Base):
|
||||
"""The config defines runtime settings for the app.
|
||||
|
||||
@ -621,6 +628,9 @@ class Config(Base):
|
||||
# The database url used by rx.Model.
|
||||
db_url: Optional[str] = "sqlite:///reflex.db"
|
||||
|
||||
# The async database url used by rx.Model.
|
||||
async_db_url: Optional[str] = None
|
||||
|
||||
# The redis url
|
||||
redis_url: Optional[str] = None
|
||||
|
||||
@ -748,18 +758,20 @@ class Config(Base):
|
||||
|
||||
# If the env var is set, override the config value.
|
||||
if env_var is not None:
|
||||
if key.upper() != "DB_URL":
|
||||
console.info(
|
||||
f"Overriding config value {key} with env var {key.upper()}={env_var}",
|
||||
dedupe=True,
|
||||
)
|
||||
|
||||
# Interpret the value.
|
||||
value = interpret_env_var_value(env_var, field.outer_type_, field.name)
|
||||
|
||||
# Set the value.
|
||||
updated_values[key] = value
|
||||
|
||||
if key.upper() in _sensitive_env_vars:
|
||||
env_var = "***"
|
||||
|
||||
console.info(
|
||||
f"Overriding config value {key} with env var {key.upper()}={env_var}",
|
||||
dedupe=True,
|
||||
)
|
||||
|
||||
return updated_values
|
||||
|
||||
def get_event_namespace(self) -> str:
|
||||
|
126
reflex/model.py
126
reflex/model.py
@ -2,6 +2,7 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import re
|
||||
from collections import defaultdict
|
||||
from typing import Any, ClassVar, Optional, Type, Union
|
||||
|
||||
@ -14,6 +15,7 @@ import alembic.script
|
||||
import alembic.util
|
||||
import sqlalchemy
|
||||
import sqlalchemy.exc
|
||||
import sqlalchemy.ext.asyncio
|
||||
import sqlalchemy.orm
|
||||
|
||||
from reflex.base import Base
|
||||
@ -21,6 +23,48 @@ from reflex.config import environment, get_config
|
||||
from reflex.utils import console
|
||||
from reflex.utils.compat import sqlmodel, sqlmodel_field_has_primary_key
|
||||
|
||||
_ENGINE: dict[str, sqlalchemy.engine.Engine] = {}
|
||||
_ASYNC_ENGINE: dict[str, sqlalchemy.ext.asyncio.AsyncEngine] = {}
|
||||
_AsyncSessionLocal: dict[str | None, sqlalchemy.ext.asyncio.async_sessionmaker] = {}
|
||||
|
||||
# Import AsyncSession _after_ reflex.utils.compat
|
||||
from sqlmodel.ext.asyncio.session import AsyncSession # noqa: E402
|
||||
|
||||
|
||||
def _safe_db_url_for_logging(url: str) -> str:
|
||||
"""Remove username and password from the database URL for logging.
|
||||
|
||||
Args:
|
||||
url: The database URL.
|
||||
|
||||
Returns:
|
||||
The database URL with the username and password removed.
|
||||
"""
|
||||
return re.sub(r"://[^@]+@", "://<username>:<password>@", url)
|
||||
|
||||
|
||||
def get_engine_args(url: str | None = None) -> dict[str, Any]:
|
||||
"""Get the database engine arguments.
|
||||
|
||||
Args:
|
||||
url: The database url.
|
||||
|
||||
Returns:
|
||||
The database engine arguments as a dict.
|
||||
"""
|
||||
kwargs: dict[str, Any] = dict(
|
||||
# Print the SQL queries if the log level is INFO or lower.
|
||||
echo=environment.SQLALCHEMY_ECHO.get(),
|
||||
# Check connections before returning them.
|
||||
pool_pre_ping=environment.SQLALCHEMY_POOL_PRE_PING.get(),
|
||||
)
|
||||
conf = get_config()
|
||||
url = url or conf.db_url
|
||||
if url is not None and url.startswith("sqlite"):
|
||||
# Needed for the admin dash on sqlite.
|
||||
kwargs["connect_args"] = {"check_same_thread": False}
|
||||
return kwargs
|
||||
|
||||
|
||||
def get_engine(url: str | None = None) -> sqlalchemy.engine.Engine:
|
||||
"""Get the database engine.
|
||||
@ -38,15 +82,62 @@ def get_engine(url: str | None = None) -> sqlalchemy.engine.Engine:
|
||||
url = url or conf.db_url
|
||||
if url is None:
|
||||
raise ValueError("No database url configured")
|
||||
|
||||
global _ENGINE
|
||||
if url in _ENGINE:
|
||||
return _ENGINE[url]
|
||||
|
||||
if not environment.ALEMBIC_CONFIG.get().exists():
|
||||
console.warn(
|
||||
"Database is not initialized, run [bold]reflex db init[/bold] first."
|
||||
)
|
||||
# Print the SQL queries if the log level is INFO or lower.
|
||||
echo_db_query = environment.SQLALCHEMY_ECHO.get()
|
||||
# Needed for the admin dash on sqlite.
|
||||
connect_args = {"check_same_thread": False} if url.startswith("sqlite") else {}
|
||||
return sqlmodel.create_engine(url, echo=echo_db_query, connect_args=connect_args)
|
||||
_ENGINE[url] = sqlmodel.create_engine(
|
||||
url,
|
||||
**get_engine_args(url),
|
||||
)
|
||||
return _ENGINE[url]
|
||||
|
||||
|
||||
def get_async_engine(url: str | None) -> sqlalchemy.ext.asyncio.AsyncEngine:
|
||||
"""Get the async database engine.
|
||||
|
||||
Args:
|
||||
url: The database url.
|
||||
|
||||
Returns:
|
||||
The async database engine.
|
||||
|
||||
Raises:
|
||||
ValueError: If the async database url is None.
|
||||
"""
|
||||
if url is None:
|
||||
conf = get_config()
|
||||
url = conf.async_db_url
|
||||
if url is not None and conf.db_url is not None:
|
||||
async_db_url_tail = url.partition("://")[2]
|
||||
db_url_tail = conf.db_url.partition("://")[2]
|
||||
if async_db_url_tail != db_url_tail:
|
||||
console.warn(
|
||||
f"async_db_url `{_safe_db_url_for_logging(url)}` "
|
||||
"should reference the same database as "
|
||||
f"db_url `{_safe_db_url_for_logging(conf.db_url)}`."
|
||||
)
|
||||
if url is None:
|
||||
raise ValueError("No async database url configured")
|
||||
|
||||
global _ASYNC_ENGINE
|
||||
if url in _ASYNC_ENGINE:
|
||||
return _ASYNC_ENGINE[url]
|
||||
|
||||
if not environment.ALEMBIC_CONFIG.get().exists():
|
||||
console.warn(
|
||||
"Database is not initialized, run [bold]reflex db init[/bold] first."
|
||||
)
|
||||
_ASYNC_ENGINE[url] = sqlalchemy.ext.asyncio.create_async_engine(
|
||||
url,
|
||||
**get_engine_args(url),
|
||||
)
|
||||
return _ASYNC_ENGINE[url]
|
||||
|
||||
|
||||
async def get_db_status() -> bool:
|
||||
@ -425,6 +516,31 @@ def session(url: str | None = None) -> sqlmodel.Session:
|
||||
return sqlmodel.Session(get_engine(url))
|
||||
|
||||
|
||||
def asession(url: str | None = None) -> AsyncSession:
|
||||
"""Get an async sqlmodel session to interact with the database.
|
||||
|
||||
async with rx.asession() as asession:
|
||||
...
|
||||
|
||||
Most operations against the `asession` must be awaited.
|
||||
|
||||
Args:
|
||||
url: The database url.
|
||||
|
||||
Returns:
|
||||
An async database session.
|
||||
"""
|
||||
global _AsyncSessionLocal
|
||||
if url not in _AsyncSessionLocal:
|
||||
_AsyncSessionLocal[url] = sqlalchemy.ext.asyncio.async_sessionmaker(
|
||||
bind=get_async_engine(url),
|
||||
class_=AsyncSession,
|
||||
autocommit=False,
|
||||
autoflush=False,
|
||||
)
|
||||
return _AsyncSessionLocal[url]()
|
||||
|
||||
|
||||
def sqla_session(url: str | None = None) -> sqlalchemy.orm.Session:
|
||||
"""Get a bare sqlalchemy session to interact with the database.
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user