create_async_engine: use sqlalchemy async API

update types
This commit is contained in:
Masen Furer 2024-12-06 12:32:00 -08:00
parent 764d5768fa
commit 5883818051
No known key found for this signature in database
GPG Key ID: B0008AD22B3B3A95

View File

@ -15,8 +15,8 @@ import alembic.script
import alembic.util import alembic.util
import sqlalchemy import sqlalchemy
import sqlalchemy.exc import sqlalchemy.exc
import sqlalchemy.ext.asyncio
import sqlalchemy.orm import sqlalchemy.orm
from sqlmodel.ext.asyncio.session import AsyncSession
from reflex.base import Base from reflex.base import Base
from reflex.config import environment, get_config from reflex.config import environment, get_config
@ -24,11 +24,11 @@ from reflex.utils import console
from reflex.utils.compat import sqlmodel, sqlmodel_field_has_primary_key from reflex.utils.compat import sqlmodel, sqlmodel_field_has_primary_key
_ENGINE: dict[str, sqlalchemy.engine.Engine] = {} _ENGINE: dict[str, sqlalchemy.engine.Engine] = {}
_ASYNC_ENGINE: dict[str, sqlalchemy.engine.Engine] = {} _ASYNC_ENGINE: dict[str, sqlalchemy.ext.asyncio.AsyncEngine] = {}
_AsyncSessionLocal: dict[str | None, sqlalchemy.orm.sessionmaker] = {} _AsyncSessionLocal: dict[str | None, sqlalchemy.orm.sessionmaker] = {}
# Import asyncio session _after_ reflex.utils.compat # Import AsyncSession _after_ reflex.utils.compat
import sqlmodel.ext.asyncio.session # noqa: E402 from sqlmodel.ext.asyncio.session import AsyncSession # noqa: E402
def _safe_db_url_for_logging(url: str) -> str: def _safe_db_url_for_logging(url: str) -> str:
@ -98,7 +98,7 @@ def get_engine(url: str | None = None) -> sqlalchemy.engine.Engine:
return _ENGINE[url] return _ENGINE[url]
def get_async_engine(url: str | None) -> sqlalchemy.engine.Engine: def get_async_engine(url: str | None) -> sqlalchemy.ext.asyncio.AsyncEngine:
"""Get the async database engine. """Get the async database engine.
Args: Args:
@ -133,7 +133,7 @@ def get_async_engine(url: str | None) -> sqlalchemy.engine.Engine:
console.warn( console.warn(
"Database is not initialized, run [bold]reflex db init[/bold] first." "Database is not initialized, run [bold]reflex db init[/bold] first."
) )
_ASYNC_ENGINE[url] = sqlmodel.create_engine( _ASYNC_ENGINE[url] = sqlalchemy.ext.asyncio.create_async_engine(
url, url,
**get_engine_args(url), **get_engine_args(url),
) )
@ -530,7 +530,7 @@ def asession(url: str | None = None) -> AsyncSession:
_AsyncSessionLocal[url] = sqlalchemy.orm.sessionmaker( _AsyncSessionLocal[url] = sqlalchemy.orm.sessionmaker(
autocommit=False, autocommit=False,
autoflush=False, autoflush=False,
bind=get_async_engine(url), bind=get_async_engine(url), # type: ignore
class_=AsyncSession, class_=AsyncSession,
) )
return _AsyncSessionLocal[url]() return _AsyncSessionLocal[url]()