bare sqlalchemy session + tests (#3522)
* add bare sqlalchemy session, Closes #3512 * expose sqla_session at module level, add tests, improve typing * fix table name * add model_registry fixture, improve typing * did not meant to push this * add docstring to model_registry * do not expose sqla_session in reflex namespace
This commit is contained in:
parent
e4c17deafb
commit
9d71bcbbb5
1
.gitignore
vendored
1
.gitignore
vendored
@ -12,3 +12,4 @@ venv
|
||||
requirements.txt
|
||||
.pyi_generator_last_run
|
||||
.pyi_generator_diff
|
||||
reflex.db
|
||||
|
@ -24,7 +24,7 @@ from reflex.utils import console
|
||||
from reflex.utils.compat import sqlmodel
|
||||
|
||||
|
||||
def get_engine(url: str | None = None):
|
||||
def get_engine(url: str | None = None) -> sqlalchemy.engine.Engine:
|
||||
"""Get the database engine.
|
||||
|
||||
Args:
|
||||
@ -396,7 +396,7 @@ ModelRegistry.register(Model)
|
||||
|
||||
|
||||
def session(url: str | None = None) -> sqlmodel.Session:
|
||||
"""Get a session to interact with the database.
|
||||
"""Get a sqlmodel session to interact with the database.
|
||||
|
||||
Args:
|
||||
url: The database url.
|
||||
@ -405,3 +405,15 @@ def session(url: str | None = None) -> sqlmodel.Session:
|
||||
A database session.
|
||||
"""
|
||||
return sqlmodel.Session(get_engine(url))
|
||||
|
||||
|
||||
def sqla_session(url: str | None = None) -> sqlalchemy.orm.Session:
|
||||
"""Get a bare sqlalchemy session to interact with the database.
|
||||
|
||||
Args:
|
||||
url: The database url.
|
||||
|
||||
Returns:
|
||||
A database session.
|
||||
"""
|
||||
return sqlalchemy.orm.Session(get_engine(url))
|
||||
|
@ -5,13 +5,14 @@ import os
|
||||
import platform
|
||||
import uuid
|
||||
from pathlib import Path
|
||||
from typing import Dict, Generator
|
||||
from typing import Dict, Generator, Type
|
||||
from unittest import mock
|
||||
|
||||
import pytest
|
||||
|
||||
from reflex.app import App
|
||||
from reflex.event import EventSpec
|
||||
from reflex.model import ModelRegistry
|
||||
from reflex.utils import prerequisites
|
||||
|
||||
from .states import (
|
||||
@ -247,3 +248,14 @@ def token() -> str:
|
||||
A fresh/unique token string.
|
||||
"""
|
||||
return str(uuid.uuid4())
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def model_registry() -> Generator[Type[ModelRegistry], None, None]:
|
||||
"""Create a model registry.
|
||||
|
||||
Yields:
|
||||
A fresh model registry.
|
||||
"""
|
||||
yield ModelRegistry
|
||||
ModelRegistry._metadata = None
|
||||
|
@ -1,4 +1,5 @@
|
||||
from typing import Optional
|
||||
from pathlib import Path
|
||||
from typing import Optional, Type
|
||||
from unittest import mock
|
||||
|
||||
import pytest
|
||||
@ -7,7 +8,7 @@ import sqlmodel
|
||||
|
||||
import reflex.constants
|
||||
import reflex.model
|
||||
from reflex.model import Model
|
||||
from reflex.model import Model, ModelRegistry
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
@ -39,7 +40,7 @@ def model_custom_primary() -> Model:
|
||||
return ChildModel(name="name")
|
||||
|
||||
|
||||
def test_default_primary_key(model_default_primary):
|
||||
def test_default_primary_key(model_default_primary: Model):
|
||||
"""Test that if a primary key is not defined a default is added.
|
||||
|
||||
Args:
|
||||
@ -48,7 +49,7 @@ def test_default_primary_key(model_default_primary):
|
||||
assert "id" in model_default_primary.__class__.__fields__
|
||||
|
||||
|
||||
def test_custom_primary_key(model_custom_primary):
|
||||
def test_custom_primary_key(model_custom_primary: Model):
|
||||
"""Test that if a primary key is defined no default key is added.
|
||||
|
||||
Args:
|
||||
@ -60,12 +61,17 @@ def test_custom_primary_key(model_custom_primary):
|
||||
@pytest.mark.filterwarnings(
|
||||
"ignore:This declarative base already contains a class with the same class name",
|
||||
)
|
||||
def test_automigration(tmp_working_dir, monkeypatch):
|
||||
def test_automigration(
|
||||
tmp_working_dir: Path,
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
model_registry: Type[ModelRegistry],
|
||||
):
|
||||
"""Test alembic automigration with add and drop table and column.
|
||||
|
||||
Args:
|
||||
tmp_working_dir: directory where database and migrations are stored
|
||||
monkeypatch: pytest fixture to overwrite attributes
|
||||
model_registry: clean reflex ModelRegistry
|
||||
"""
|
||||
alembic_ini = tmp_working_dir / "alembic.ini"
|
||||
versions = tmp_working_dir / "alembic" / "versions"
|
||||
@ -84,8 +90,10 @@ def test_automigration(tmp_working_dir, monkeypatch):
|
||||
t1: str
|
||||
|
||||
with Model.get_db_engine().connect() as connection:
|
||||
Model.alembic_autogenerate(connection=connection, message="Initial Revision")
|
||||
Model.migrate()
|
||||
assert Model.alembic_autogenerate(
|
||||
connection=connection, message="Initial Revision"
|
||||
)
|
||||
assert Model.migrate()
|
||||
version_scripts = list(versions.glob("*.py"))
|
||||
assert len(version_scripts) == 1
|
||||
assert version_scripts[0].name.endswith("initial_revision.py")
|
||||
@ -94,14 +102,14 @@ def test_automigration(tmp_working_dir, monkeypatch):
|
||||
session.add(AlembicThing(id=None, t1="foo"))
|
||||
session.commit()
|
||||
|
||||
sqlmodel.SQLModel.metadata.clear()
|
||||
model_registry.get_metadata().clear()
|
||||
|
||||
# Create column t2, mark t1 as optional with default
|
||||
class AlembicThing(Model, table=True): # type: ignore
|
||||
t1: Optional[str] = "default"
|
||||
t2: str = "bar"
|
||||
|
||||
Model.migrate(autogenerate=True)
|
||||
assert Model.migrate(autogenerate=True)
|
||||
assert len(list(versions.glob("*.py"))) == 2
|
||||
|
||||
with reflex.model.session() as session:
|
||||
@ -114,13 +122,13 @@ def test_automigration(tmp_working_dir, monkeypatch):
|
||||
assert result[1].t1 == "default"
|
||||
assert result[1].t2 == "baz"
|
||||
|
||||
sqlmodel.SQLModel.metadata.clear()
|
||||
model_registry.get_metadata().clear()
|
||||
|
||||
# Drop column t1
|
||||
class AlembicThing(Model, table=True): # type: ignore
|
||||
t2: str = "bar"
|
||||
|
||||
Model.migrate(autogenerate=True)
|
||||
assert Model.migrate(autogenerate=True)
|
||||
assert len(list(versions.glob("*.py"))) == 3
|
||||
|
||||
with reflex.model.session() as session:
|
||||
@ -134,7 +142,7 @@ def test_automigration(tmp_working_dir, monkeypatch):
|
||||
a: int = 42
|
||||
b: float = 4.2
|
||||
|
||||
Model.migrate(autogenerate=True)
|
||||
assert Model.migrate(autogenerate=True)
|
||||
assert len(list(versions.glob("*.py"))) == 4
|
||||
|
||||
with reflex.model.session() as session:
|
||||
@ -146,16 +154,16 @@ def test_automigration(tmp_working_dir, monkeypatch):
|
||||
assert result[0].b == 4.2
|
||||
|
||||
# No-op
|
||||
Model.migrate(autogenerate=True)
|
||||
assert Model.migrate(autogenerate=True)
|
||||
assert len(list(versions.glob("*.py"))) == 4
|
||||
|
||||
# drop table (AlembicSecond)
|
||||
sqlmodel.SQLModel.metadata.clear()
|
||||
model_registry.get_metadata().clear()
|
||||
|
||||
class AlembicThing(Model, table=True): # type: ignore
|
||||
t2: str = "bar"
|
||||
|
||||
Model.migrate(autogenerate=True)
|
||||
assert Model.migrate(autogenerate=True)
|
||||
assert len(list(versions.glob("*.py"))) == 5
|
||||
|
||||
with reflex.model.session() as session:
|
||||
@ -168,18 +176,18 @@ def test_automigration(tmp_working_dir, monkeypatch):
|
||||
assert result[0].t2 == "bar"
|
||||
assert result[1].t2 == "baz"
|
||||
|
||||
sqlmodel.SQLModel.metadata.clear()
|
||||
model_registry.get_metadata().clear()
|
||||
|
||||
class AlembicThing(Model, table=True): # type: ignore
|
||||
# changing column type not supported by default
|
||||
t2: int = 42
|
||||
|
||||
Model.migrate(autogenerate=True)
|
||||
assert Model.migrate(autogenerate=True)
|
||||
assert len(list(versions.glob("*.py"))) == 5
|
||||
|
||||
# clear all metadata to avoid influencing subsequent tests
|
||||
sqlmodel.SQLModel.metadata.clear()
|
||||
model_registry.get_metadata().clear()
|
||||
|
||||
# drop remaining tables
|
||||
Model.migrate(autogenerate=True)
|
||||
assert Model.migrate(autogenerate=True)
|
||||
assert len(list(versions.glob("*.py"))) == 6
|
||||
|
166
tests/test_sqlalchemy.py
Normal file
166
tests/test_sqlalchemy.py
Normal file
@ -0,0 +1,166 @@
|
||||
from pathlib import Path
|
||||
from typing import Optional, Type
|
||||
from unittest import mock
|
||||
|
||||
import pytest
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.exc import OperationalError
|
||||
from sqlalchemy.orm import (
|
||||
DeclarativeBase,
|
||||
Mapped,
|
||||
MappedAsDataclass,
|
||||
declared_attr,
|
||||
mapped_column,
|
||||
)
|
||||
|
||||
import reflex.constants
|
||||
import reflex.model
|
||||
from reflex.model import Model, ModelRegistry, sqla_session
|
||||
|
||||
|
||||
@pytest.mark.filterwarnings(
|
||||
"ignore:This declarative base already contains a class with the same class name",
|
||||
)
|
||||
def test_automigration(
|
||||
tmp_working_dir: Path,
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
model_registry: Type[ModelRegistry],
|
||||
):
|
||||
"""Test alembic automigration with add and drop table and column.
|
||||
|
||||
Args:
|
||||
tmp_working_dir: directory where database and migrations are stored
|
||||
monkeypatch: pytest fixture to overwrite attributes
|
||||
model_registry: clean reflex ModelRegistry
|
||||
"""
|
||||
alembic_ini = tmp_working_dir / "alembic.ini"
|
||||
versions = tmp_working_dir / "alembic" / "versions"
|
||||
monkeypatch.setattr(reflex.constants, "ALEMBIC_CONFIG", str(alembic_ini))
|
||||
|
||||
config_mock = mock.Mock()
|
||||
config_mock.db_url = f"sqlite:///{tmp_working_dir}/reflex.db"
|
||||
monkeypatch.setattr(reflex.model, "get_config", mock.Mock(return_value=config_mock))
|
||||
|
||||
assert alembic_ini.exists() is False
|
||||
assert versions.exists() is False
|
||||
Model.alembic_init()
|
||||
assert alembic_ini.exists()
|
||||
assert versions.exists()
|
||||
|
||||
class Base(DeclarativeBase):
|
||||
@declared_attr.directive
|
||||
def __tablename__(cls) -> str:
|
||||
return cls.__name__.lower()
|
||||
|
||||
assert model_registry.register(Base)
|
||||
|
||||
class ModelBase(Base, MappedAsDataclass):
|
||||
__abstract__ = True
|
||||
id: Mapped[Optional[int]] = mapped_column(primary_key=True, default=None)
|
||||
|
||||
# initial table
|
||||
class AlembicThing(ModelBase): # pyright: ignore[reportGeneralTypeIssues]
|
||||
t1: Mapped[str] = mapped_column(default="")
|
||||
|
||||
with Model.get_db_engine().connect() as connection:
|
||||
assert Model.alembic_autogenerate(
|
||||
connection=connection, message="Initial Revision"
|
||||
)
|
||||
assert Model.migrate()
|
||||
version_scripts = list(versions.glob("*.py"))
|
||||
assert len(version_scripts) == 1
|
||||
assert version_scripts[0].name.endswith("initial_revision.py")
|
||||
|
||||
with sqla_session() as session:
|
||||
session.add(AlembicThing(t1="foo"))
|
||||
session.commit()
|
||||
|
||||
model_registry.get_metadata().clear()
|
||||
|
||||
# Create column t2, mark t1 as optional with default
|
||||
class AlembicThing(ModelBase): # pyright: ignore[reportGeneralTypeIssues]
|
||||
t1: Mapped[Optional[str]] = mapped_column(default="default")
|
||||
t2: Mapped[str] = mapped_column(default="bar")
|
||||
|
||||
assert Model.migrate(autogenerate=True)
|
||||
assert len(list(versions.glob("*.py"))) == 2
|
||||
|
||||
with sqla_session() as session:
|
||||
session.add(AlembicThing(t2="baz"))
|
||||
session.commit()
|
||||
result = session.scalars(select(AlembicThing)).all()
|
||||
assert len(result) == 2
|
||||
assert result[0].t1 == "foo"
|
||||
assert result[0].t2 == "bar"
|
||||
assert result[1].t1 == "default"
|
||||
assert result[1].t2 == "baz"
|
||||
|
||||
model_registry.get_metadata().clear()
|
||||
|
||||
# Drop column t1
|
||||
class AlembicThing(ModelBase): # pyright: ignore[reportGeneralTypeIssues]
|
||||
t2: Mapped[str] = mapped_column(default="bar")
|
||||
|
||||
assert Model.migrate(autogenerate=True)
|
||||
assert len(list(versions.glob("*.py"))) == 3
|
||||
|
||||
with sqla_session() as session:
|
||||
result = session.scalars(select(AlembicThing)).all()
|
||||
assert len(result) == 2
|
||||
assert result[0].t2 == "bar"
|
||||
assert result[1].t2 == "baz"
|
||||
|
||||
# Add table
|
||||
class AlembicSecond(ModelBase):
|
||||
a: Mapped[int] = mapped_column(default=42)
|
||||
b: Mapped[float] = mapped_column(default=4.2)
|
||||
|
||||
assert Model.migrate(autogenerate=True)
|
||||
assert len(list(versions.glob("*.py"))) == 4
|
||||
|
||||
with reflex.model.session() as session:
|
||||
session.add(AlembicSecond(id=None))
|
||||
session.commit()
|
||||
result = session.scalars(select(AlembicSecond)).all()
|
||||
assert len(result) == 1
|
||||
assert result[0].a == 42
|
||||
assert result[0].b == 4.2
|
||||
|
||||
# No-op
|
||||
# assert Model.migrate(autogenerate=True)
|
||||
# assert len(list(versions.glob("*.py"))) == 4
|
||||
|
||||
# drop table (AlembicSecond)
|
||||
model_registry.get_metadata().clear()
|
||||
|
||||
class AlembicThing(ModelBase): # pyright: ignore[reportGeneralTypeIssues]
|
||||
t2: Mapped[str] = mapped_column(default="bar")
|
||||
|
||||
assert Model.migrate(autogenerate=True)
|
||||
assert len(list(versions.glob("*.py"))) == 5
|
||||
|
||||
with reflex.model.session() as session:
|
||||
with pytest.raises(OperationalError) as errctx:
|
||||
_ = session.scalars(select(AlembicSecond)).all()
|
||||
assert errctx.match(r"no such table: alembicsecond")
|
||||
# first table should still exist
|
||||
result = session.scalars(select(AlembicThing)).all()
|
||||
assert len(result) == 2
|
||||
assert result[0].t2 == "bar"
|
||||
assert result[1].t2 == "baz"
|
||||
|
||||
model_registry.get_metadata().clear()
|
||||
|
||||
class AlembicThing(ModelBase):
|
||||
# changing column type not supported by default
|
||||
t2: Mapped[int] = mapped_column(default=42)
|
||||
|
||||
assert Model.migrate(autogenerate=True)
|
||||
assert len(list(versions.glob("*.py"))) == 5
|
||||
|
||||
# clear all metadata to avoid influencing subsequent tests
|
||||
model_registry.get_metadata().clear()
|
||||
|
||||
# drop remaining tables
|
||||
assert Model.migrate(autogenerate=True)
|
||||
assert len(list(versions.glob("*.py"))) == 6
|
Loading…
Reference in New Issue
Block a user