bare sqlalchemy session + tests (#3522)

* add bare sqlalchemy session, Closes #3512

* expose sqla_session at module level, add tests, improve typing

* fix table name

* add model_registry fixture, improve typing

* did not meant to push this

* add docstring to model_registry

* do not expose sqla_session in reflex namespace
This commit is contained in:
benedikt-bartscher 2024-06-25 15:29:01 +02:00 committed by GitHub
parent e4c17deafb
commit 9d71bcbbb5
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 221 additions and 22 deletions

1
.gitignore vendored
View File

@ -12,3 +12,4 @@ venv
requirements.txt requirements.txt
.pyi_generator_last_run .pyi_generator_last_run
.pyi_generator_diff .pyi_generator_diff
reflex.db

View File

@ -24,7 +24,7 @@ from reflex.utils import console
from reflex.utils.compat import sqlmodel from reflex.utils.compat import sqlmodel
def get_engine(url: str | None = None): def get_engine(url: str | None = None) -> sqlalchemy.engine.Engine:
"""Get the database engine. """Get the database engine.
Args: Args:
@ -396,7 +396,7 @@ ModelRegistry.register(Model)
def session(url: str | None = None) -> sqlmodel.Session: def session(url: str | None = None) -> sqlmodel.Session:
"""Get a session to interact with the database. """Get a sqlmodel session to interact with the database.
Args: Args:
url: The database url. url: The database url.
@ -405,3 +405,15 @@ def session(url: str | None = None) -> sqlmodel.Session:
A database session. A database session.
""" """
return sqlmodel.Session(get_engine(url)) return sqlmodel.Session(get_engine(url))
def sqla_session(url: str | None = None) -> sqlalchemy.orm.Session:
"""Get a bare sqlalchemy session to interact with the database.
Args:
url: The database url.
Returns:
A database session.
"""
return sqlalchemy.orm.Session(get_engine(url))

View File

@ -5,13 +5,14 @@ import os
import platform import platform
import uuid import uuid
from pathlib import Path from pathlib import Path
from typing import Dict, Generator from typing import Dict, Generator, Type
from unittest import mock from unittest import mock
import pytest import pytest
from reflex.app import App from reflex.app import App
from reflex.event import EventSpec from reflex.event import EventSpec
from reflex.model import ModelRegistry
from reflex.utils import prerequisites from reflex.utils import prerequisites
from .states import ( from .states import (
@ -247,3 +248,14 @@ def token() -> str:
A fresh/unique token string. A fresh/unique token string.
""" """
return str(uuid.uuid4()) return str(uuid.uuid4())
@pytest.fixture
def model_registry() -> Generator[Type[ModelRegistry], None, None]:
"""Create a model registry.
Yields:
A fresh model registry.
"""
yield ModelRegistry
ModelRegistry._metadata = None

View File

@ -1,4 +1,5 @@
from typing import Optional from pathlib import Path
from typing import Optional, Type
from unittest import mock from unittest import mock
import pytest import pytest
@ -7,7 +8,7 @@ import sqlmodel
import reflex.constants import reflex.constants
import reflex.model import reflex.model
from reflex.model import Model from reflex.model import Model, ModelRegistry
@pytest.fixture @pytest.fixture
@ -39,7 +40,7 @@ def model_custom_primary() -> Model:
return ChildModel(name="name") return ChildModel(name="name")
def test_default_primary_key(model_default_primary): def test_default_primary_key(model_default_primary: Model):
"""Test that if a primary key is not defined a default is added. """Test that if a primary key is not defined a default is added.
Args: Args:
@ -48,7 +49,7 @@ def test_default_primary_key(model_default_primary):
assert "id" in model_default_primary.__class__.__fields__ assert "id" in model_default_primary.__class__.__fields__
def test_custom_primary_key(model_custom_primary): def test_custom_primary_key(model_custom_primary: Model):
"""Test that if a primary key is defined no default key is added. """Test that if a primary key is defined no default key is added.
Args: Args:
@ -60,12 +61,17 @@ def test_custom_primary_key(model_custom_primary):
@pytest.mark.filterwarnings( @pytest.mark.filterwarnings(
"ignore:This declarative base already contains a class with the same class name", "ignore:This declarative base already contains a class with the same class name",
) )
def test_automigration(tmp_working_dir, monkeypatch): def test_automigration(
tmp_working_dir: Path,
monkeypatch: pytest.MonkeyPatch,
model_registry: Type[ModelRegistry],
):
"""Test alembic automigration with add and drop table and column. """Test alembic automigration with add and drop table and column.
Args: Args:
tmp_working_dir: directory where database and migrations are stored tmp_working_dir: directory where database and migrations are stored
monkeypatch: pytest fixture to overwrite attributes monkeypatch: pytest fixture to overwrite attributes
model_registry: clean reflex ModelRegistry
""" """
alembic_ini = tmp_working_dir / "alembic.ini" alembic_ini = tmp_working_dir / "alembic.ini"
versions = tmp_working_dir / "alembic" / "versions" versions = tmp_working_dir / "alembic" / "versions"
@ -84,8 +90,10 @@ def test_automigration(tmp_working_dir, monkeypatch):
t1: str t1: str
with Model.get_db_engine().connect() as connection: with Model.get_db_engine().connect() as connection:
Model.alembic_autogenerate(connection=connection, message="Initial Revision") assert Model.alembic_autogenerate(
Model.migrate() connection=connection, message="Initial Revision"
)
assert Model.migrate()
version_scripts = list(versions.glob("*.py")) version_scripts = list(versions.glob("*.py"))
assert len(version_scripts) == 1 assert len(version_scripts) == 1
assert version_scripts[0].name.endswith("initial_revision.py") assert version_scripts[0].name.endswith("initial_revision.py")
@ -94,14 +102,14 @@ def test_automigration(tmp_working_dir, monkeypatch):
session.add(AlembicThing(id=None, t1="foo")) session.add(AlembicThing(id=None, t1="foo"))
session.commit() session.commit()
sqlmodel.SQLModel.metadata.clear() model_registry.get_metadata().clear()
# Create column t2, mark t1 as optional with default # Create column t2, mark t1 as optional with default
class AlembicThing(Model, table=True): # type: ignore class AlembicThing(Model, table=True): # type: ignore
t1: Optional[str] = "default" t1: Optional[str] = "default"
t2: str = "bar" t2: str = "bar"
Model.migrate(autogenerate=True) assert Model.migrate(autogenerate=True)
assert len(list(versions.glob("*.py"))) == 2 assert len(list(versions.glob("*.py"))) == 2
with reflex.model.session() as session: with reflex.model.session() as session:
@ -114,13 +122,13 @@ def test_automigration(tmp_working_dir, monkeypatch):
assert result[1].t1 == "default" assert result[1].t1 == "default"
assert result[1].t2 == "baz" assert result[1].t2 == "baz"
sqlmodel.SQLModel.metadata.clear() model_registry.get_metadata().clear()
# Drop column t1 # Drop column t1
class AlembicThing(Model, table=True): # type: ignore class AlembicThing(Model, table=True): # type: ignore
t2: str = "bar" t2: str = "bar"
Model.migrate(autogenerate=True) assert Model.migrate(autogenerate=True)
assert len(list(versions.glob("*.py"))) == 3 assert len(list(versions.glob("*.py"))) == 3
with reflex.model.session() as session: with reflex.model.session() as session:
@ -134,7 +142,7 @@ def test_automigration(tmp_working_dir, monkeypatch):
a: int = 42 a: int = 42
b: float = 4.2 b: float = 4.2
Model.migrate(autogenerate=True) assert Model.migrate(autogenerate=True)
assert len(list(versions.glob("*.py"))) == 4 assert len(list(versions.glob("*.py"))) == 4
with reflex.model.session() as session: with reflex.model.session() as session:
@ -146,16 +154,16 @@ def test_automigration(tmp_working_dir, monkeypatch):
assert result[0].b == 4.2 assert result[0].b == 4.2
# No-op # No-op
Model.migrate(autogenerate=True) assert Model.migrate(autogenerate=True)
assert len(list(versions.glob("*.py"))) == 4 assert len(list(versions.glob("*.py"))) == 4
# drop table (AlembicSecond) # drop table (AlembicSecond)
sqlmodel.SQLModel.metadata.clear() model_registry.get_metadata().clear()
class AlembicThing(Model, table=True): # type: ignore class AlembicThing(Model, table=True): # type: ignore
t2: str = "bar" t2: str = "bar"
Model.migrate(autogenerate=True) assert Model.migrate(autogenerate=True)
assert len(list(versions.glob("*.py"))) == 5 assert len(list(versions.glob("*.py"))) == 5
with reflex.model.session() as session: with reflex.model.session() as session:
@ -168,18 +176,18 @@ def test_automigration(tmp_working_dir, monkeypatch):
assert result[0].t2 == "bar" assert result[0].t2 == "bar"
assert result[1].t2 == "baz" assert result[1].t2 == "baz"
sqlmodel.SQLModel.metadata.clear() model_registry.get_metadata().clear()
class AlembicThing(Model, table=True): # type: ignore class AlembicThing(Model, table=True): # type: ignore
# changing column type not supported by default # changing column type not supported by default
t2: int = 42 t2: int = 42
Model.migrate(autogenerate=True) assert Model.migrate(autogenerate=True)
assert len(list(versions.glob("*.py"))) == 5 assert len(list(versions.glob("*.py"))) == 5
# clear all metadata to avoid influencing subsequent tests # clear all metadata to avoid influencing subsequent tests
sqlmodel.SQLModel.metadata.clear() model_registry.get_metadata().clear()
# drop remaining tables # drop remaining tables
Model.migrate(autogenerate=True) assert Model.migrate(autogenerate=True)
assert len(list(versions.glob("*.py"))) == 6 assert len(list(versions.glob("*.py"))) == 6

166
tests/test_sqlalchemy.py Normal file
View File

@ -0,0 +1,166 @@
from pathlib import Path
from typing import Optional, Type
from unittest import mock
import pytest
from sqlalchemy import select
from sqlalchemy.exc import OperationalError
from sqlalchemy.orm import (
DeclarativeBase,
Mapped,
MappedAsDataclass,
declared_attr,
mapped_column,
)
import reflex.constants
import reflex.model
from reflex.model import Model, ModelRegistry, sqla_session
@pytest.mark.filterwarnings(
"ignore:This declarative base already contains a class with the same class name",
)
def test_automigration(
tmp_working_dir: Path,
monkeypatch: pytest.MonkeyPatch,
model_registry: Type[ModelRegistry],
):
"""Test alembic automigration with add and drop table and column.
Args:
tmp_working_dir: directory where database and migrations are stored
monkeypatch: pytest fixture to overwrite attributes
model_registry: clean reflex ModelRegistry
"""
alembic_ini = tmp_working_dir / "alembic.ini"
versions = tmp_working_dir / "alembic" / "versions"
monkeypatch.setattr(reflex.constants, "ALEMBIC_CONFIG", str(alembic_ini))
config_mock = mock.Mock()
config_mock.db_url = f"sqlite:///{tmp_working_dir}/reflex.db"
monkeypatch.setattr(reflex.model, "get_config", mock.Mock(return_value=config_mock))
assert alembic_ini.exists() is False
assert versions.exists() is False
Model.alembic_init()
assert alembic_ini.exists()
assert versions.exists()
class Base(DeclarativeBase):
@declared_attr.directive
def __tablename__(cls) -> str:
return cls.__name__.lower()
assert model_registry.register(Base)
class ModelBase(Base, MappedAsDataclass):
__abstract__ = True
id: Mapped[Optional[int]] = mapped_column(primary_key=True, default=None)
# initial table
class AlembicThing(ModelBase): # pyright: ignore[reportGeneralTypeIssues]
t1: Mapped[str] = mapped_column(default="")
with Model.get_db_engine().connect() as connection:
assert Model.alembic_autogenerate(
connection=connection, message="Initial Revision"
)
assert Model.migrate()
version_scripts = list(versions.glob("*.py"))
assert len(version_scripts) == 1
assert version_scripts[0].name.endswith("initial_revision.py")
with sqla_session() as session:
session.add(AlembicThing(t1="foo"))
session.commit()
model_registry.get_metadata().clear()
# Create column t2, mark t1 as optional with default
class AlembicThing(ModelBase): # pyright: ignore[reportGeneralTypeIssues]
t1: Mapped[Optional[str]] = mapped_column(default="default")
t2: Mapped[str] = mapped_column(default="bar")
assert Model.migrate(autogenerate=True)
assert len(list(versions.glob("*.py"))) == 2
with sqla_session() as session:
session.add(AlembicThing(t2="baz"))
session.commit()
result = session.scalars(select(AlembicThing)).all()
assert len(result) == 2
assert result[0].t1 == "foo"
assert result[0].t2 == "bar"
assert result[1].t1 == "default"
assert result[1].t2 == "baz"
model_registry.get_metadata().clear()
# Drop column t1
class AlembicThing(ModelBase): # pyright: ignore[reportGeneralTypeIssues]
t2: Mapped[str] = mapped_column(default="bar")
assert Model.migrate(autogenerate=True)
assert len(list(versions.glob("*.py"))) == 3
with sqla_session() as session:
result = session.scalars(select(AlembicThing)).all()
assert len(result) == 2
assert result[0].t2 == "bar"
assert result[1].t2 == "baz"
# Add table
class AlembicSecond(ModelBase):
a: Mapped[int] = mapped_column(default=42)
b: Mapped[float] = mapped_column(default=4.2)
assert Model.migrate(autogenerate=True)
assert len(list(versions.glob("*.py"))) == 4
with reflex.model.session() as session:
session.add(AlembicSecond(id=None))
session.commit()
result = session.scalars(select(AlembicSecond)).all()
assert len(result) == 1
assert result[0].a == 42
assert result[0].b == 4.2
# No-op
# assert Model.migrate(autogenerate=True)
# assert len(list(versions.glob("*.py"))) == 4
# drop table (AlembicSecond)
model_registry.get_metadata().clear()
class AlembicThing(ModelBase): # pyright: ignore[reportGeneralTypeIssues]
t2: Mapped[str] = mapped_column(default="bar")
assert Model.migrate(autogenerate=True)
assert len(list(versions.glob("*.py"))) == 5
with reflex.model.session() as session:
with pytest.raises(OperationalError) as errctx:
_ = session.scalars(select(AlembicSecond)).all()
assert errctx.match(r"no such table: alembicsecond")
# first table should still exist
result = session.scalars(select(AlembicThing)).all()
assert len(result) == 2
assert result[0].t2 == "bar"
assert result[1].t2 == "baz"
model_registry.get_metadata().clear()
class AlembicThing(ModelBase):
# changing column type not supported by default
t2: Mapped[int] = mapped_column(default=42)
assert Model.migrate(autogenerate=True)
assert len(list(versions.glob("*.py"))) == 5
# clear all metadata to avoid influencing subsequent tests
model_registry.get_metadata().clear()
# drop remaining tables
assert Model.migrate(autogenerate=True)
assert len(list(versions.glob("*.py"))) == 6