diff --git a/reflex/__init__.py b/reflex/__init__.py index 562524416..0be568b1a 100644 --- a/reflex/__init__.py +++ b/reflex/__init__.py @@ -331,7 +331,7 @@ _MAPPING: dict = { "SessionStorage", ], "middleware": ["middleware", "Middleware"], - "model": ["session", "Model"], + "model": ["asession", "session", "Model"], "state": [ "var", "ComponentState", diff --git a/reflex/__init__.pyi b/reflex/__init__.pyi index 6f61435e6..6bdfa11a0 100644 --- a/reflex/__init__.pyi +++ b/reflex/__init__.pyi @@ -186,6 +186,7 @@ from .istate.wrappers import get_state as get_state from .middleware import Middleware as Middleware from .middleware import middleware as middleware from .model import Model as Model +from .model import asession as asession from .model import session as session from .page import page as page from .state import ComponentState as ComponentState diff --git a/reflex/config.py b/reflex/config.py index 1d952f4f0..a039d4c2f 100644 --- a/reflex/config.py +++ b/reflex/config.py @@ -621,6 +621,9 @@ class Config(Base): # The database url used by rx.Model. db_url: Optional[str] = "sqlite:///reflex.db" + # The async database url used by rx.Model. + async_db_url: Optional[str] = "sqlite+aiosqlite:///reflex.db" + # The redis url redis_url: Optional[str] = None diff --git a/reflex/model.py b/reflex/model.py index 19b1e71b7..53588f5c3 100644 --- a/reflex/model.py +++ b/reflex/model.py @@ -15,13 +15,16 @@ import alembic.util import sqlalchemy import sqlalchemy.exc import sqlalchemy.orm +from sqlmodel.ext.asyncio.session import AsyncSession from reflex.base import Base from reflex.config import environment, get_config from reflex.utils import console from reflex.utils.compat import sqlmodel, sqlmodel_field_has_primary_key -_ENGINE: sqlalchemy.engine.Engine | None = None +_ENGINE: dict[str, sqlalchemy.engine.Engine] = {} +_ASYNC_ENGINE: dict[str, sqlalchemy.engine.Engine] = {} +_AsyncSessionLocal: dict[str | None, sqlalchemy.orm.sessionmaker] = {} def get_engine(url: str | None = None) -> sqlalchemy.engine.Engine: @@ -37,8 +40,8 @@ def get_engine(url: str | None = None) -> sqlalchemy.engine.Engine: ValueError: If the database url is None. """ global _ENGINE - if _ENGINE is not None: - return _ENGINE + if url in _ENGINE: + return _ENGINE[url] conf = get_config() url = url or conf.db_url @@ -53,8 +56,44 @@ def get_engine(url: str | None = None) -> sqlalchemy.engine.Engine: # Needed for the admin dash on sqlite. connect_args = {"check_same_thread": False} if url.startswith("sqlite") else {} - _ENGINE = sqlmodel.create_engine(url, echo=echo_db_query, connect_args=connect_args) - return _ENGINE + _ENGINE[url] = sqlmodel.create_engine( + url, echo=echo_db_query, connect_args=connect_args + ) + return _ENGINE[url] + + +def get_async_engine(url: str | None) -> sqlalchemy.engine.Engine: + """Get the async database engine. + + Args: + url: The database url. + + Returns: + The async database engine. + + Raises: + ValueError: If the async database url is None. + """ + global _ASYNC_ENGINE + if url in _ASYNC_ENGINE: + return _ASYNC_ENGINE[url] + + conf = get_config() + url = url or conf.async_db_url + if url is None: + raise ValueError("No async database url configured") + if not environment.ALEMBIC_CONFIG.get().exists(): + console.warn( + "Database is not initialized, run [bold]reflex db init[/bold] first." + ) + # Print the SQL queries if the log level is INFO or lower. + echo_db_query = environment.SQLALCHEMY_ECHO.get() + # Needed for the admin dash on sqlite. + connect_args = {"check_same_thread": False} if url.startswith("sqlite") else {} + _ASYNC_ENGINE[url] = sqlmodel.create_engine( + url, echo=echo_db_query, connect_args=connect_args + ) + return _ASYNC_ENGINE[url] async def get_db_status() -> bool: @@ -433,6 +472,26 @@ def session(url: str | None = None) -> sqlmodel.Session: return sqlmodel.Session(get_engine(url)) +def asession(url: str | None = None) -> AsyncSession: + """Get an async sqlmodel session to interact with the database. + + Args: + url: The database url. + + Returns: + An async database session. + """ + global _AsyncSessionLocal + if url not in _AsyncSessionLocal: + _AsyncSessionLocal[url] = sqlalchemy.orm.sessionmaker( + autocommit=False, + autoflush=False, + bind=get_async_engine(url), + class_=AsyncSession, + ) + return _AsyncSessionLocal[url]() + + def sqla_session(url: str | None = None) -> sqlalchemy.orm.Session: """Get a bare sqlalchemy session to interact with the database.