create_async_engine: use sqlalchemy async API
update types
This commit is contained in:
parent
764d5768fa
commit
5883818051
@ -15,8 +15,8 @@ import alembic.script
|
|||||||
import alembic.util
|
import alembic.util
|
||||||
import sqlalchemy
|
import sqlalchemy
|
||||||
import sqlalchemy.exc
|
import sqlalchemy.exc
|
||||||
|
import sqlalchemy.ext.asyncio
|
||||||
import sqlalchemy.orm
|
import sqlalchemy.orm
|
||||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
|
||||||
|
|
||||||
from reflex.base import Base
|
from reflex.base import Base
|
||||||
from reflex.config import environment, get_config
|
from reflex.config import environment, get_config
|
||||||
@ -24,11 +24,11 @@ from reflex.utils import console
|
|||||||
from reflex.utils.compat import sqlmodel, sqlmodel_field_has_primary_key
|
from reflex.utils.compat import sqlmodel, sqlmodel_field_has_primary_key
|
||||||
|
|
||||||
_ENGINE: dict[str, sqlalchemy.engine.Engine] = {}
|
_ENGINE: dict[str, sqlalchemy.engine.Engine] = {}
|
||||||
_ASYNC_ENGINE: dict[str, sqlalchemy.engine.Engine] = {}
|
_ASYNC_ENGINE: dict[str, sqlalchemy.ext.asyncio.AsyncEngine] = {}
|
||||||
_AsyncSessionLocal: dict[str | None, sqlalchemy.orm.sessionmaker] = {}
|
_AsyncSessionLocal: dict[str | None, sqlalchemy.orm.sessionmaker] = {}
|
||||||
|
|
||||||
# Import asyncio session _after_ reflex.utils.compat
|
# Import AsyncSession _after_ reflex.utils.compat
|
||||||
import sqlmodel.ext.asyncio.session # noqa: E402
|
from sqlmodel.ext.asyncio.session import AsyncSession # noqa: E402
|
||||||
|
|
||||||
|
|
||||||
def _safe_db_url_for_logging(url: str) -> str:
|
def _safe_db_url_for_logging(url: str) -> str:
|
||||||
@ -98,7 +98,7 @@ def get_engine(url: str | None = None) -> sqlalchemy.engine.Engine:
|
|||||||
return _ENGINE[url]
|
return _ENGINE[url]
|
||||||
|
|
||||||
|
|
||||||
def get_async_engine(url: str | None) -> sqlalchemy.engine.Engine:
|
def get_async_engine(url: str | None) -> sqlalchemy.ext.asyncio.AsyncEngine:
|
||||||
"""Get the async database engine.
|
"""Get the async database engine.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@ -133,7 +133,7 @@ def get_async_engine(url: str | None) -> sqlalchemy.engine.Engine:
|
|||||||
console.warn(
|
console.warn(
|
||||||
"Database is not initialized, run [bold]reflex db init[/bold] first."
|
"Database is not initialized, run [bold]reflex db init[/bold] first."
|
||||||
)
|
)
|
||||||
_ASYNC_ENGINE[url] = sqlmodel.create_engine(
|
_ASYNC_ENGINE[url] = sqlalchemy.ext.asyncio.create_async_engine(
|
||||||
url,
|
url,
|
||||||
**get_engine_args(url),
|
**get_engine_args(url),
|
||||||
)
|
)
|
||||||
@ -530,7 +530,7 @@ def asession(url: str | None = None) -> AsyncSession:
|
|||||||
_AsyncSessionLocal[url] = sqlalchemy.orm.sessionmaker(
|
_AsyncSessionLocal[url] = sqlalchemy.orm.sessionmaker(
|
||||||
autocommit=False,
|
autocommit=False,
|
||||||
autoflush=False,
|
autoflush=False,
|
||||||
bind=get_async_engine(url),
|
bind=get_async_engine(url), # type: ignore
|
||||||
class_=AsyncSession,
|
class_=AsyncSession,
|
||||||
)
|
)
|
||||||
return _AsyncSessionLocal[url]()
|
return _AsyncSessionLocal[url]()
|
||||||
|
Loading…
Reference in New Issue
Block a user