Implement rx.asession for async database support

Requires setting `async_db_url` to the same database as `db_url`, except using
an async driver; this will vary by database.
This commit is contained in:
Masen Furer 2024-12-05 22:07:56 -08:00
parent 264aade3bc
commit f3abe475e9
No known key found for this signature in database
GPG Key ID: B0008AD22B3B3A95
4 changed files with 69 additions and 6 deletions

View File

@ -331,7 +331,7 @@ _MAPPING: dict = {
"SessionStorage",
],
"middleware": ["middleware", "Middleware"],
"model": ["session", "Model"],
"model": ["asession", "session", "Model"],
"state": [
"var",
"ComponentState",

View File

@ -186,6 +186,7 @@ from .istate.wrappers import get_state as get_state
from .middleware import Middleware as Middleware
from .middleware import middleware as middleware
from .model import Model as Model
from .model import asession as asession
from .model import session as session
from .page import page as page
from .state import ComponentState as ComponentState

View File

@ -621,6 +621,9 @@ class Config(Base):
# The database url used by rx.Model.
db_url: Optional[str] = "sqlite:///reflex.db"
# The async database url used by rx.Model.
async_db_url: Optional[str] = "sqlite+aiosqlite:///reflex.db"
# The redis url
redis_url: Optional[str] = None

View File

@ -15,13 +15,16 @@ import alembic.util
import sqlalchemy
import sqlalchemy.exc
import sqlalchemy.orm
from sqlmodel.ext.asyncio.session import AsyncSession
from reflex.base import Base
from reflex.config import environment, get_config
from reflex.utils import console
from reflex.utils.compat import sqlmodel, sqlmodel_field_has_primary_key
_ENGINE: sqlalchemy.engine.Engine | None = None
_ENGINE: dict[str, sqlalchemy.engine.Engine] = {}
_ASYNC_ENGINE: dict[str, sqlalchemy.engine.Engine] = {}
_AsyncSessionLocal: dict[str | None, sqlalchemy.orm.sessionmaker] = {}
def get_engine(url: str | None = None) -> sqlalchemy.engine.Engine:
@ -37,8 +40,8 @@ def get_engine(url: str | None = None) -> sqlalchemy.engine.Engine:
ValueError: If the database url is None.
"""
global _ENGINE
if _ENGINE is not None:
return _ENGINE
if url in _ENGINE:
return _ENGINE[url]
conf = get_config()
url = url or conf.db_url
@ -53,8 +56,44 @@ def get_engine(url: str | None = None) -> sqlalchemy.engine.Engine:
# Needed for the admin dash on sqlite.
connect_args = {"check_same_thread": False} if url.startswith("sqlite") else {}
_ENGINE = sqlmodel.create_engine(url, echo=echo_db_query, connect_args=connect_args)
return _ENGINE
_ENGINE[url] = sqlmodel.create_engine(
url, echo=echo_db_query, connect_args=connect_args
)
return _ENGINE[url]
def get_async_engine(url: str | None) -> sqlalchemy.engine.Engine:
"""Get the async database engine.
Args:
url: The database url.
Returns:
The async database engine.
Raises:
ValueError: If the async database url is None.
"""
global _ASYNC_ENGINE
if url in _ASYNC_ENGINE:
return _ASYNC_ENGINE[url]
conf = get_config()
url = url or conf.async_db_url
if url is None:
raise ValueError("No async database url configured")
if not environment.ALEMBIC_CONFIG.get().exists():
console.warn(
"Database is not initialized, run [bold]reflex db init[/bold] first."
)
# Print the SQL queries if the log level is INFO or lower.
echo_db_query = environment.SQLALCHEMY_ECHO.get()
# Needed for the admin dash on sqlite.
connect_args = {"check_same_thread": False} if url.startswith("sqlite") else {}
_ASYNC_ENGINE[url] = sqlmodel.create_engine(
url, echo=echo_db_query, connect_args=connect_args
)
return _ASYNC_ENGINE[url]
async def get_db_status() -> bool:
@ -433,6 +472,26 @@ def session(url: str | None = None) -> sqlmodel.Session:
return sqlmodel.Session(get_engine(url))
def asession(url: str | None = None) -> AsyncSession:
"""Get an async sqlmodel session to interact with the database.
Args:
url: The database url.
Returns:
An async database session.
"""
global _AsyncSessionLocal
if url not in _AsyncSessionLocal:
_AsyncSessionLocal[url] = sqlalchemy.orm.sessionmaker(
autocommit=False,
autoflush=False,
bind=get_async_engine(url),
class_=AsyncSession,
)
return _AsyncSessionLocal[url]()
def sqla_session(url: str | None = None) -> sqlalchemy.orm.Session:
"""Get a bare sqlalchemy session to interact with the database.