rx.Model: automigrate using alembic (#1321)

This commit is contained in:
Masen Furer 2023-07-12 15:47:19 -07:00 committed by GitHub
parent fc87589434
commit 5505d10989
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 406 additions and 3 deletions

61
poetry.lock generated
View File

@ -1,5 +1,26 @@
# This file is automatically @generated by Poetry 1.5.1 and should not be changed by hand.
[[package]]
name = "alembic"
version = "1.11.1"
description = "A database migration tool for SQLAlchemy."
optional = false
python-versions = ">=3.7"
files = [
{file = "alembic-1.11.1-py3-none-any.whl", hash = "sha256:dc871798a601fab38332e38d6ddb38d5e734f60034baeb8e2db5b642fccd8ab8"},
{file = "alembic-1.11.1.tar.gz", hash = "sha256:6a810a6b012c88b33458fceb869aef09ac75d6ace5291915ba7fae44de372c01"},
]
[package.dependencies]
importlib-metadata = {version = "*", markers = "python_version < \"3.9\""}
importlib-resources = {version = "*", markers = "python_version < \"3.9\""}
Mako = "*"
SQLAlchemy = ">=1.3.0"
typing-extensions = ">=4"
[package.extras]
tz = ["python-dateutil"]
[[package]]
name = "anyio"
version = "3.7.1"
@ -501,6 +522,24 @@ docs = ["furo", "jaraco.packaging (>=9)", "jaraco.tidelift (>=1.4)", "rst.linker
perf = ["ipython"]
testing = ["flufl.flake8", "importlib-resources (>=1.3)", "packaging", "pyfakefs", "pytest (>=6)", "pytest-black (>=0.3.7)", "pytest-checkdocs (>=2.4)", "pytest-cov", "pytest-enabler (>=1.3)", "pytest-mypy (>=0.9.1)", "pytest-perf (>=0.9.2)", "pytest-ruff"]
[[package]]
name = "importlib-resources"
version = "5.12.0"
description = "Read resources from Python packages"
optional = false
python-versions = ">=3.7"
files = [
{file = "importlib_resources-5.12.0-py3-none-any.whl", hash = "sha256:7b1deeebbf351c7578e09bf2f63fa2ce8b5ffec296e0d349139d43cca061a81a"},
{file = "importlib_resources-5.12.0.tar.gz", hash = "sha256:4be82589bf5c1d7999aedf2a45159d10cb3ca4f19b2271f8792bc8e6da7b22f6"},
]
[package.dependencies]
zipp = {version = ">=3.1.0", markers = "python_version < \"3.10\""}
[package.extras]
docs = ["furo", "jaraco.packaging (>=9)", "jaraco.tidelift (>=1.4)", "rst.linker (>=1.9)", "sphinx (>=3.5)", "sphinx-lint"]
testing = ["flake8 (<5)", "pytest (>=6)", "pytest-black (>=0.3.7)", "pytest-checkdocs (>=2.4)", "pytest-cov", "pytest-enabler (>=1.3)", "pytest-flake8", "pytest-mypy (>=0.9.1)"]
[[package]]
name = "iniconfig"
version = "2.0.0"
@ -529,6 +568,26 @@ MarkupSafe = ">=2.0"
[package.extras]
i18n = ["Babel (>=2.7)"]
[[package]]
name = "mako"
version = "1.2.4"
description = "A super-fast templating language that borrows the best ideas from the existing templating languages."
optional = false
python-versions = ">=3.7"
files = [
{file = "Mako-1.2.4-py3-none-any.whl", hash = "sha256:c97c79c018b9165ac9922ae4f32da095ffd3c4e6872b45eded42926deea46818"},
{file = "Mako-1.2.4.tar.gz", hash = "sha256:d60a3903dc3bb01a18ad6a89cdbe2e4eadc69c0bc8ef1e3773ba53d44c3f7a34"},
]
[package.dependencies]
importlib-metadata = {version = "*", markers = "python_version < \"3.8\""}
MarkupSafe = ">=0.9.2"
[package.extras]
babel = ["Babel"]
lingua = ["lingua"]
testing = ["pytest"]
[[package]]
name = "markdown-it-py"
version = "2.2.0"
@ -1854,4 +1913,4 @@ testing = ["big-O", "flake8 (<5)", "jaraco.functools", "jaraco.itertools", "more
[metadata]
lock-version = "2.0"
python-versions = "^3.7"
content-hash = "b07623f193778651dd3ece81370d2d9a2ba3b53855a57baf0147e67bf07aaed8"
content-hash = "cc659e46041316bc81ce1758334c6fa9ccf9812612ef67e170b173d6d1caa2b2"

View File

@ -62,6 +62,7 @@ pandas = [
]
asynctest = "^0.13.0"
pre-commit = {version = "^3.2.1", python = ">=3.8,<4.0"}
alembic = "^1.11.1"
[tool.poetry.scripts]
reflex = "reflex.reflex:main"

View File

@ -444,7 +444,7 @@ class App(Base):
config = get_config()
# Update models during hot reload.
if config.db_url is not None:
if config.db_url is not None and not Model.automigrate():
Model.create_all()
# Empty the .web pages directory

View File

@ -353,3 +353,6 @@ TOGGLE_COLOR_MODE = "toggleColorMode"
# Server socket configuration variables
CORS_ALLOWED_ORIGINS = get_value("CORS_ALLOWED_ORIGINS", ["*"], list)
POLLING_MAX_HTTP_BUFFER_SIZE = 1000 * 1000
# Alembic migrations
ALEMBIC_CONFIG = os.environ.get("ALEMBIC_CONFIG", "alembic.ini")

View File

@ -1,12 +1,30 @@
"""Database built into Reflex."""
from typing import Optional
from collections import defaultdict
from pathlib import Path
from typing import Any, Optional
import sqlalchemy
import sqlmodel
from reflex.base import Base
from reflex.config import get_config
from . import constants
try:
import alembic.autogenerate # pyright: ignore [reportMissingImports]
import alembic.command # pyright: ignore [reportMissingImports]
import alembic.operations.ops # pyright: ignore [reportMissingImports]
import alembic.runtime.environment # pyright: ignore [reportMissingImports]
import alembic.script # pyright: ignore [reportMissingImports]
import alembic.util # pyright: ignore [reportMissingImports]
from alembic.config import Config # pyright: ignore [reportMissingImports]
has_alembic = True
except ImportError:
has_alembic = False
def get_engine(url: Optional[str] = None):
"""Get the database engine.
@ -75,6 +93,154 @@ class Model(Base, sqlmodel.SQLModel):
"""
return get_engine()
@staticmethod
def _alembic_config():
"""Get the alembic configuration and script_directory.
Returns:
tuple of (config, script_directory)
"""
config = Config(constants.ALEMBIC_CONFIG)
return config, alembic.script.ScriptDirectory(
config.get_main_option("script_location", default="version"),
)
@staticmethod
def _alembic_render_item(
type_: str,
obj: Any,
autogen_context: "alembic.autogenerate.api.AutogenContext",
):
"""Alembic render_item hook call.
This method is called to provide python code for the given obj,
but currently it is only used to add `sqlmodel` to the import list
when generating migration scripts.
See https://alembic.sqlalchemy.org/en/latest/api/runtime.html
Args:
type_: one of "schema", "table", "column", "index",
"unique_constraint", or "foreign_key_constraint"
obj: the object being rendered
autogen_context: shared AutogenContext passed to each render_item call
Returns:
False - indicating that the default rendering should be used.
"""
autogen_context.imports.add("import sqlmodel")
return False
@classmethod
def _alembic_autogenerate(cls, connection: sqlalchemy.engine.Connection) -> bool:
"""Generate migration scripts for alembic-detectable changes.
Args:
connection: sqlalchemy connection to use when detecting changes
Returns:
True when changes have been detected.
"""
config, script_directory = cls._alembic_config()
revision_context = alembic.autogenerate.api.RevisionContext(
config=config,
script_directory=script_directory,
command_args=defaultdict(
lambda: None,
autogenerate=True,
head="head",
),
)
writer = alembic.autogenerate.rewriter.Rewriter()
@writer.rewrites(alembic.operations.ops.AddColumnOp)
def render_add_column_with_server_default(context, revision, op):
# Carry the sqlmodel default as server_default so that newly added
# columns get the desired default value in existing rows
if op.column.default is not None and op.column.server_default is None:
op.column.server_default = sqlalchemy.DefaultClause(
sqlalchemy.sql.expression.literal(op.column.default.arg),
)
return op
def run_autogenerate(rev, context):
revision_context.run_autogenerate(rev, context)
return []
with alembic.runtime.environment.EnvironmentContext(
config=config,
script=script_directory,
fn=run_autogenerate,
) as env:
env.configure(
connection=connection,
target_metadata=sqlmodel.SQLModel.metadata,
render_item=cls._alembic_render_item,
process_revision_directives=writer, # type: ignore
)
env.run_migrations()
changes_detected = False
if revision_context.generated_revisions:
upgrade_ops = revision_context.generated_revisions[-1].upgrade_ops
if upgrade_ops is not None:
changes_detected = bool(upgrade_ops.ops)
if changes_detected:
for _script in revision_context.generate_scripts():
pass # must iterate to actually generate the scripts
return changes_detected
@classmethod
def _alembic_upgrade(
cls,
connection: sqlalchemy.engine.Connection,
to_rev: str = "head",
) -> None:
"""Apply alembic migrations up to the given revision.
Args:
connection: sqlalchemy connection to use when performing upgrade
to_rev: revision to migrate towards
"""
config, script_directory = cls._alembic_config()
def run_upgrade(rev, context):
return script_directory._upgrade_revs(to_rev, rev)
# apply updates to database
with alembic.runtime.environment.EnvironmentContext(
config=config,
script=script_directory,
fn=run_upgrade,
) as env:
env.configure(connection=connection)
env.run_migrations()
@classmethod
def automigrate(cls) -> Optional[bool]:
"""Generate and execute migrations for all sqlmodel Model classes.
If alembic is not installed or has not been initialized for the project,
then no action is performed.
If models in the app have changed in incompatible ways that alembic
cannot automatically generate revisions for, the app may not be able to
start up until migration scripts have been corrected by hand.
Returns:
True - indicating the process was successful
None - indicating the process was skipped
"""
if not has_alembic or not Path(constants.ALEMBIC_CONFIG).exists():
return
with cls.get_db_engine().connect() as connection:
cls._alembic_upgrade(connection=connection)
changes_detected = cls._alembic_autogenerate(connection=connection)
if changes_detected:
cls._alembic_upgrade(connection=connection)
connection.commit()
return True
@classmethod
@property
def select(cls):

View File

@ -1,5 +1,8 @@
"""Test fixtures."""
import contextlib
import os
import platform
from pathlib import Path
from typing import Dict, Generator, List
import pytest
@ -479,3 +482,48 @@ def router_data(router_data_headers) -> Dict[str, str]:
"headers": router_data_headers,
"ip": "127.0.0.1",
}
# borrowed from py3.11
class chdir(contextlib.AbstractContextManager):
"""Non thread-safe context manager to change the current working directory."""
def __init__(self, path):
"""Prepare contextmanager.
Args:
path: the path to change to
"""
self.path = path
self._old_cwd = []
def __enter__(self):
"""Save current directory and perform chdir."""
self._old_cwd.append(Path(".").resolve())
os.chdir(self.path)
def __exit__(self, *excinfo):
"""Change back to previous directory on stack.
Args:
excinfo: sys.exc_info captured in the context block
"""
os.chdir(self._old_cwd.pop())
@pytest.fixture
def tmp_working_dir(tmp_path):
"""Create a temporary directory and chdir to it.
After the test executes, chdir back to the original working directory.
Args:
tmp_path: pytest tmp_path fixture creates per-test temp dir
Yields:
subdirectory of tmp_path which is now the current working directory.
"""
working_dir = tmp_path / "working_dir"
working_dir.mkdir()
with chdir(working_dir):
yield working_dir

View File

@ -1,6 +1,13 @@
import subprocess
import sys
from unittest import mock
import pytest
import sqlalchemy
import sqlmodel
import reflex.constants
import reflex.model
from reflex.model import Model
@ -49,3 +56,122 @@ def test_custom_primary_key(model_custom_primary):
model_custom_primary: Fixture.
"""
assert "id" not in model_custom_primary.__class__.__fields__
@pytest.mark.filterwarnings(
"ignore:This declarative base already contains a class with the same class name",
)
def test_automigration(tmp_working_dir, monkeypatch):
"""Test alembic automigration with add and drop table and column.
Args:
tmp_working_dir: directory where database and migrations are stored
monkeypatch: pytest fixture to overwrite attributes
"""
subprocess.run(
[sys.executable, "-m", "alembic", "init", "alembic"],
cwd=tmp_working_dir,
)
alembic_ini = tmp_working_dir / "alembic.ini"
versions = tmp_working_dir / "alembic" / "versions"
assert alembic_ini.exists()
assert versions.exists()
config_mock = mock.Mock()
config_mock.db_url = f"sqlite:///{tmp_working_dir}/reflex.db"
monkeypatch.setattr(reflex.model, "get_config", mock.Mock(return_value=config_mock))
monkeypatch.setattr(reflex.constants, "ALEMBIC_CONFIG", str(alembic_ini))
# initial table
class AlembicThing(Model, table=True): # type: ignore
t1: str
Model.automigrate()
assert len(list(versions.glob("*.py"))) == 1
with reflex.model.session() as session:
session.add(AlembicThing(id=None, t1="foo"))
session.commit()
sqlmodel.SQLModel.metadata.clear()
# Create column t2
class AlembicThing(Model, table=True): # type: ignore
t1: str
t2: str = "bar"
Model.automigrate()
assert len(list(versions.glob("*.py"))) == 2
with reflex.model.session() as session:
result = session.exec(sqlmodel.select(AlembicThing)).all()
assert len(result) == 1
assert result[0].t2 == "bar"
sqlmodel.SQLModel.metadata.clear()
# Drop column t1
class AlembicThing(Model, table=True): # type: ignore
t2: str = "bar"
Model.automigrate()
assert len(list(versions.glob("*.py"))) == 3
with reflex.model.session() as session:
result = session.exec(sqlmodel.select(AlembicThing)).all()
assert len(result) == 1
assert result[0].t2 == "bar"
# Add table
class AlembicSecond(Model, table=True): # type: ignore
a: int = 42
b: float = 4.2
Model.automigrate()
assert len(list(versions.glob("*.py"))) == 4
with reflex.model.session() as session:
session.add(AlembicSecond(id=None))
session.commit()
result = session.exec(sqlmodel.select(AlembicSecond)).all()
assert len(result) == 1
assert result[0].a == 42
assert result[0].b == 4.2
# No-op
Model.automigrate()
assert len(list(versions.glob("*.py"))) == 4
# drop table (AlembicSecond)
sqlmodel.SQLModel.metadata.clear()
class AlembicThing(Model, table=True): # type: ignore
t2: str = "bar"
Model.automigrate()
assert len(list(versions.glob("*.py"))) == 5
with reflex.model.session() as session:
with pytest.raises(sqlalchemy.exc.OperationalError) as errctx: # type: ignore
session.exec(sqlmodel.select(AlembicSecond)).all()
assert errctx.match(r"no such table: alembicsecond")
# first table should still exist
result = session.exec(sqlmodel.select(AlembicThing)).all()
assert len(result) == 1
assert result[0].t2 == "bar"
sqlmodel.SQLModel.metadata.clear()
class AlembicThing(Model, table=True): # type: ignore
# changing column type not supported by default
t2: int = 42
Model.automigrate()
assert len(list(versions.glob("*.py"))) == 5
# clear all metadata to avoid influencing subsequent tests
sqlmodel.SQLModel.metadata.clear()
# drop remaining tables
Model.automigrate()
assert len(list(versions.glob("*.py"))) == 6