From 58838180517a8730e7fc803050556f4a1e57a31b Mon Sep 17 00:00:00 2001 From: Masen Furer Date: Fri, 6 Dec 2024 12:32:00 -0800 Subject: [PATCH] create_async_engine: use sqlalchemy async API update types --- reflex/model.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/reflex/model.py b/reflex/model.py index 334c15987..ea13a87ee 100644 --- a/reflex/model.py +++ b/reflex/model.py @@ -15,8 +15,8 @@ import alembic.script import alembic.util import sqlalchemy import sqlalchemy.exc +import sqlalchemy.ext.asyncio import sqlalchemy.orm -from sqlmodel.ext.asyncio.session import AsyncSession from reflex.base import Base from reflex.config import environment, get_config @@ -24,11 +24,11 @@ from reflex.utils import console from reflex.utils.compat import sqlmodel, sqlmodel_field_has_primary_key _ENGINE: dict[str, sqlalchemy.engine.Engine] = {} -_ASYNC_ENGINE: dict[str, sqlalchemy.engine.Engine] = {} +_ASYNC_ENGINE: dict[str, sqlalchemy.ext.asyncio.AsyncEngine] = {} _AsyncSessionLocal: dict[str | None, sqlalchemy.orm.sessionmaker] = {} -# Import asyncio session _after_ reflex.utils.compat -import sqlmodel.ext.asyncio.session # noqa: E402 +# Import AsyncSession _after_ reflex.utils.compat +from sqlmodel.ext.asyncio.session import AsyncSession # noqa: E402 def _safe_db_url_for_logging(url: str) -> str: @@ -98,7 +98,7 @@ def get_engine(url: str | None = None) -> sqlalchemy.engine.Engine: return _ENGINE[url] -def get_async_engine(url: str | None) -> sqlalchemy.engine.Engine: +def get_async_engine(url: str | None) -> sqlalchemy.ext.asyncio.AsyncEngine: """Get the async database engine. Args: @@ -133,7 +133,7 @@ def get_async_engine(url: str | None) -> sqlalchemy.engine.Engine: console.warn( "Database is not initialized, run [bold]reflex db init[/bold] first." ) - _ASYNC_ENGINE[url] = sqlmodel.create_engine( + _ASYNC_ENGINE[url] = sqlalchemy.ext.asyncio.create_async_engine( url, **get_engine_args(url), ) @@ -530,7 +530,7 @@ def asession(url: str | None = None) -> AsyncSession: _AsyncSessionLocal[url] = sqlalchemy.orm.sessionmaker( autocommit=False, autoflush=False, - bind=get_async_engine(url), + bind=get_async_engine(url), # type: ignore class_=AsyncSession, ) return _AsyncSessionLocal[url]()