create_async_engine: use sqlalchemy async API

update types
This commit is contained in:
Masen Furer 2024-12-06 12:32:00 -08:00
parent 764d5768fa
commit 5883818051
No known key found for this signature in database
GPG Key ID: B0008AD22B3B3A95

View File

@ -15,8 +15,8 @@ import alembic.script
import alembic.util
import sqlalchemy
import sqlalchemy.exc
import sqlalchemy.ext.asyncio
import sqlalchemy.orm
from sqlmodel.ext.asyncio.session import AsyncSession
from reflex.base import Base
from reflex.config import environment, get_config
@ -24,11 +24,11 @@ from reflex.utils import console
from reflex.utils.compat import sqlmodel, sqlmodel_field_has_primary_key
_ENGINE: dict[str, sqlalchemy.engine.Engine] = {}
_ASYNC_ENGINE: dict[str, sqlalchemy.engine.Engine] = {}
_ASYNC_ENGINE: dict[str, sqlalchemy.ext.asyncio.AsyncEngine] = {}
_AsyncSessionLocal: dict[str | None, sqlalchemy.orm.sessionmaker] = {}
# Import asyncio session _after_ reflex.utils.compat
import sqlmodel.ext.asyncio.session # noqa: E402
# Import AsyncSession _after_ reflex.utils.compat
from sqlmodel.ext.asyncio.session import AsyncSession # noqa: E402
def _safe_db_url_for_logging(url: str) -> str:
@ -98,7 +98,7 @@ def get_engine(url: str | None = None) -> sqlalchemy.engine.Engine:
return _ENGINE[url]
def get_async_engine(url: str | None) -> sqlalchemy.engine.Engine:
def get_async_engine(url: str | None) -> sqlalchemy.ext.asyncio.AsyncEngine:
"""Get the async database engine.
Args:
@ -133,7 +133,7 @@ def get_async_engine(url: str | None) -> sqlalchemy.engine.Engine:
console.warn(
"Database is not initialized, run [bold]reflex db init[/bold] first."
)
_ASYNC_ENGINE[url] = sqlmodel.create_engine(
_ASYNC_ENGINE[url] = sqlalchemy.ext.asyncio.create_async_engine(
url,
**get_engine_args(url),
)
@ -530,7 +530,7 @@ def asession(url: str | None = None) -> AsyncSession:
_AsyncSessionLocal[url] = sqlalchemy.orm.sessionmaker(
autocommit=False,
autoflush=False,
bind=get_async_engine(url),
bind=get_async_engine(url), # type: ignore
class_=AsyncSession,
)
return _AsyncSessionLocal[url]()