rx.Model: automigrate using alembic (#1321)
This commit is contained in:
parent
fc87589434
commit
5505d10989
61
poetry.lock
generated
61
poetry.lock
generated
@ -1,5 +1,26 @@
|
||||
# This file is automatically @generated by Poetry 1.5.1 and should not be changed by hand.
|
||||
|
||||
[[package]]
|
||||
name = "alembic"
|
||||
version = "1.11.1"
|
||||
description = "A database migration tool for SQLAlchemy."
|
||||
optional = false
|
||||
python-versions = ">=3.7"
|
||||
files = [
|
||||
{file = "alembic-1.11.1-py3-none-any.whl", hash = "sha256:dc871798a601fab38332e38d6ddb38d5e734f60034baeb8e2db5b642fccd8ab8"},
|
||||
{file = "alembic-1.11.1.tar.gz", hash = "sha256:6a810a6b012c88b33458fceb869aef09ac75d6ace5291915ba7fae44de372c01"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
importlib-metadata = {version = "*", markers = "python_version < \"3.9\""}
|
||||
importlib-resources = {version = "*", markers = "python_version < \"3.9\""}
|
||||
Mako = "*"
|
||||
SQLAlchemy = ">=1.3.0"
|
||||
typing-extensions = ">=4"
|
||||
|
||||
[package.extras]
|
||||
tz = ["python-dateutil"]
|
||||
|
||||
[[package]]
|
||||
name = "anyio"
|
||||
version = "3.7.1"
|
||||
@ -501,6 +522,24 @@ docs = ["furo", "jaraco.packaging (>=9)", "jaraco.tidelift (>=1.4)", "rst.linker
|
||||
perf = ["ipython"]
|
||||
testing = ["flufl.flake8", "importlib-resources (>=1.3)", "packaging", "pyfakefs", "pytest (>=6)", "pytest-black (>=0.3.7)", "pytest-checkdocs (>=2.4)", "pytest-cov", "pytest-enabler (>=1.3)", "pytest-mypy (>=0.9.1)", "pytest-perf (>=0.9.2)", "pytest-ruff"]
|
||||
|
||||
[[package]]
|
||||
name = "importlib-resources"
|
||||
version = "5.12.0"
|
||||
description = "Read resources from Python packages"
|
||||
optional = false
|
||||
python-versions = ">=3.7"
|
||||
files = [
|
||||
{file = "importlib_resources-5.12.0-py3-none-any.whl", hash = "sha256:7b1deeebbf351c7578e09bf2f63fa2ce8b5ffec296e0d349139d43cca061a81a"},
|
||||
{file = "importlib_resources-5.12.0.tar.gz", hash = "sha256:4be82589bf5c1d7999aedf2a45159d10cb3ca4f19b2271f8792bc8e6da7b22f6"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
zipp = {version = ">=3.1.0", markers = "python_version < \"3.10\""}
|
||||
|
||||
[package.extras]
|
||||
docs = ["furo", "jaraco.packaging (>=9)", "jaraco.tidelift (>=1.4)", "rst.linker (>=1.9)", "sphinx (>=3.5)", "sphinx-lint"]
|
||||
testing = ["flake8 (<5)", "pytest (>=6)", "pytest-black (>=0.3.7)", "pytest-checkdocs (>=2.4)", "pytest-cov", "pytest-enabler (>=1.3)", "pytest-flake8", "pytest-mypy (>=0.9.1)"]
|
||||
|
||||
[[package]]
|
||||
name = "iniconfig"
|
||||
version = "2.0.0"
|
||||
@ -529,6 +568,26 @@ MarkupSafe = ">=2.0"
|
||||
[package.extras]
|
||||
i18n = ["Babel (>=2.7)"]
|
||||
|
||||
[[package]]
|
||||
name = "mako"
|
||||
version = "1.2.4"
|
||||
description = "A super-fast templating language that borrows the best ideas from the existing templating languages."
|
||||
optional = false
|
||||
python-versions = ">=3.7"
|
||||
files = [
|
||||
{file = "Mako-1.2.4-py3-none-any.whl", hash = "sha256:c97c79c018b9165ac9922ae4f32da095ffd3c4e6872b45eded42926deea46818"},
|
||||
{file = "Mako-1.2.4.tar.gz", hash = "sha256:d60a3903dc3bb01a18ad6a89cdbe2e4eadc69c0bc8ef1e3773ba53d44c3f7a34"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
importlib-metadata = {version = "*", markers = "python_version < \"3.8\""}
|
||||
MarkupSafe = ">=0.9.2"
|
||||
|
||||
[package.extras]
|
||||
babel = ["Babel"]
|
||||
lingua = ["lingua"]
|
||||
testing = ["pytest"]
|
||||
|
||||
[[package]]
|
||||
name = "markdown-it-py"
|
||||
version = "2.2.0"
|
||||
@ -1854,4 +1913,4 @@ testing = ["big-O", "flake8 (<5)", "jaraco.functools", "jaraco.itertools", "more
|
||||
[metadata]
|
||||
lock-version = "2.0"
|
||||
python-versions = "^3.7"
|
||||
content-hash = "b07623f193778651dd3ece81370d2d9a2ba3b53855a57baf0147e67bf07aaed8"
|
||||
content-hash = "cc659e46041316bc81ce1758334c6fa9ccf9812612ef67e170b173d6d1caa2b2"
|
||||
|
@ -62,6 +62,7 @@ pandas = [
|
||||
]
|
||||
asynctest = "^0.13.0"
|
||||
pre-commit = {version = "^3.2.1", python = ">=3.8,<4.0"}
|
||||
alembic = "^1.11.1"
|
||||
|
||||
[tool.poetry.scripts]
|
||||
reflex = "reflex.reflex:main"
|
||||
|
@ -444,7 +444,7 @@ class App(Base):
|
||||
config = get_config()
|
||||
|
||||
# Update models during hot reload.
|
||||
if config.db_url is not None:
|
||||
if config.db_url is not None and not Model.automigrate():
|
||||
Model.create_all()
|
||||
|
||||
# Empty the .web pages directory
|
||||
|
@ -353,3 +353,6 @@ TOGGLE_COLOR_MODE = "toggleColorMode"
|
||||
# Server socket configuration variables
|
||||
CORS_ALLOWED_ORIGINS = get_value("CORS_ALLOWED_ORIGINS", ["*"], list)
|
||||
POLLING_MAX_HTTP_BUFFER_SIZE = 1000 * 1000
|
||||
|
||||
# Alembic migrations
|
||||
ALEMBIC_CONFIG = os.environ.get("ALEMBIC_CONFIG", "alembic.ini")
|
||||
|
168
reflex/model.py
168
reflex/model.py
@ -1,12 +1,30 @@
|
||||
"""Database built into Reflex."""
|
||||
|
||||
from typing import Optional
|
||||
from collections import defaultdict
|
||||
from pathlib import Path
|
||||
from typing import Any, Optional
|
||||
|
||||
import sqlalchemy
|
||||
import sqlmodel
|
||||
|
||||
from reflex.base import Base
|
||||
from reflex.config import get_config
|
||||
|
||||
from . import constants
|
||||
|
||||
try:
|
||||
import alembic.autogenerate # pyright: ignore [reportMissingImports]
|
||||
import alembic.command # pyright: ignore [reportMissingImports]
|
||||
import alembic.operations.ops # pyright: ignore [reportMissingImports]
|
||||
import alembic.runtime.environment # pyright: ignore [reportMissingImports]
|
||||
import alembic.script # pyright: ignore [reportMissingImports]
|
||||
import alembic.util # pyright: ignore [reportMissingImports]
|
||||
from alembic.config import Config # pyright: ignore [reportMissingImports]
|
||||
|
||||
has_alembic = True
|
||||
except ImportError:
|
||||
has_alembic = False
|
||||
|
||||
|
||||
def get_engine(url: Optional[str] = None):
|
||||
"""Get the database engine.
|
||||
@ -75,6 +93,154 @@ class Model(Base, sqlmodel.SQLModel):
|
||||
"""
|
||||
return get_engine()
|
||||
|
||||
@staticmethod
|
||||
def _alembic_config():
|
||||
"""Get the alembic configuration and script_directory.
|
||||
|
||||
Returns:
|
||||
tuple of (config, script_directory)
|
||||
"""
|
||||
config = Config(constants.ALEMBIC_CONFIG)
|
||||
return config, alembic.script.ScriptDirectory(
|
||||
config.get_main_option("script_location", default="version"),
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _alembic_render_item(
|
||||
type_: str,
|
||||
obj: Any,
|
||||
autogen_context: "alembic.autogenerate.api.AutogenContext",
|
||||
):
|
||||
"""Alembic render_item hook call.
|
||||
|
||||
This method is called to provide python code for the given obj,
|
||||
but currently it is only used to add `sqlmodel` to the import list
|
||||
when generating migration scripts.
|
||||
|
||||
See https://alembic.sqlalchemy.org/en/latest/api/runtime.html
|
||||
|
||||
Args:
|
||||
type_: one of "schema", "table", "column", "index",
|
||||
"unique_constraint", or "foreign_key_constraint"
|
||||
obj: the object being rendered
|
||||
autogen_context: shared AutogenContext passed to each render_item call
|
||||
|
||||
Returns:
|
||||
False - indicating that the default rendering should be used.
|
||||
"""
|
||||
autogen_context.imports.add("import sqlmodel")
|
||||
return False
|
||||
|
||||
@classmethod
|
||||
def _alembic_autogenerate(cls, connection: sqlalchemy.engine.Connection) -> bool:
|
||||
"""Generate migration scripts for alembic-detectable changes.
|
||||
|
||||
Args:
|
||||
connection: sqlalchemy connection to use when detecting changes
|
||||
|
||||
Returns:
|
||||
True when changes have been detected.
|
||||
"""
|
||||
config, script_directory = cls._alembic_config()
|
||||
revision_context = alembic.autogenerate.api.RevisionContext(
|
||||
config=config,
|
||||
script_directory=script_directory,
|
||||
command_args=defaultdict(
|
||||
lambda: None,
|
||||
autogenerate=True,
|
||||
head="head",
|
||||
),
|
||||
)
|
||||
writer = alembic.autogenerate.rewriter.Rewriter()
|
||||
|
||||
@writer.rewrites(alembic.operations.ops.AddColumnOp)
|
||||
def render_add_column_with_server_default(context, revision, op):
|
||||
# Carry the sqlmodel default as server_default so that newly added
|
||||
# columns get the desired default value in existing rows
|
||||
if op.column.default is not None and op.column.server_default is None:
|
||||
op.column.server_default = sqlalchemy.DefaultClause(
|
||||
sqlalchemy.sql.expression.literal(op.column.default.arg),
|
||||
)
|
||||
return op
|
||||
|
||||
def run_autogenerate(rev, context):
|
||||
revision_context.run_autogenerate(rev, context)
|
||||
return []
|
||||
|
||||
with alembic.runtime.environment.EnvironmentContext(
|
||||
config=config,
|
||||
script=script_directory,
|
||||
fn=run_autogenerate,
|
||||
) as env:
|
||||
env.configure(
|
||||
connection=connection,
|
||||
target_metadata=sqlmodel.SQLModel.metadata,
|
||||
render_item=cls._alembic_render_item,
|
||||
process_revision_directives=writer, # type: ignore
|
||||
)
|
||||
env.run_migrations()
|
||||
changes_detected = False
|
||||
if revision_context.generated_revisions:
|
||||
upgrade_ops = revision_context.generated_revisions[-1].upgrade_ops
|
||||
if upgrade_ops is not None:
|
||||
changes_detected = bool(upgrade_ops.ops)
|
||||
if changes_detected:
|
||||
for _script in revision_context.generate_scripts():
|
||||
pass # must iterate to actually generate the scripts
|
||||
return changes_detected
|
||||
|
||||
@classmethod
|
||||
def _alembic_upgrade(
|
||||
cls,
|
||||
connection: sqlalchemy.engine.Connection,
|
||||
to_rev: str = "head",
|
||||
) -> None:
|
||||
"""Apply alembic migrations up to the given revision.
|
||||
|
||||
Args:
|
||||
connection: sqlalchemy connection to use when performing upgrade
|
||||
to_rev: revision to migrate towards
|
||||
"""
|
||||
config, script_directory = cls._alembic_config()
|
||||
|
||||
def run_upgrade(rev, context):
|
||||
return script_directory._upgrade_revs(to_rev, rev)
|
||||
|
||||
# apply updates to database
|
||||
with alembic.runtime.environment.EnvironmentContext(
|
||||
config=config,
|
||||
script=script_directory,
|
||||
fn=run_upgrade,
|
||||
) as env:
|
||||
env.configure(connection=connection)
|
||||
env.run_migrations()
|
||||
|
||||
@classmethod
|
||||
def automigrate(cls) -> Optional[bool]:
|
||||
"""Generate and execute migrations for all sqlmodel Model classes.
|
||||
|
||||
If alembic is not installed or has not been initialized for the project,
|
||||
then no action is performed.
|
||||
|
||||
If models in the app have changed in incompatible ways that alembic
|
||||
cannot automatically generate revisions for, the app may not be able to
|
||||
start up until migration scripts have been corrected by hand.
|
||||
|
||||
Returns:
|
||||
True - indicating the process was successful
|
||||
None - indicating the process was skipped
|
||||
"""
|
||||
if not has_alembic or not Path(constants.ALEMBIC_CONFIG).exists():
|
||||
return
|
||||
|
||||
with cls.get_db_engine().connect() as connection:
|
||||
cls._alembic_upgrade(connection=connection)
|
||||
changes_detected = cls._alembic_autogenerate(connection=connection)
|
||||
if changes_detected:
|
||||
cls._alembic_upgrade(connection=connection)
|
||||
connection.commit()
|
||||
return True
|
||||
|
||||
@classmethod
|
||||
@property
|
||||
def select(cls):
|
||||
|
@ -1,5 +1,8 @@
|
||||
"""Test fixtures."""
|
||||
import contextlib
|
||||
import os
|
||||
import platform
|
||||
from pathlib import Path
|
||||
from typing import Dict, Generator, List
|
||||
|
||||
import pytest
|
||||
@ -479,3 +482,48 @@ def router_data(router_data_headers) -> Dict[str, str]:
|
||||
"headers": router_data_headers,
|
||||
"ip": "127.0.0.1",
|
||||
}
|
||||
|
||||
|
||||
# borrowed from py3.11
|
||||
class chdir(contextlib.AbstractContextManager):
|
||||
"""Non thread-safe context manager to change the current working directory."""
|
||||
|
||||
def __init__(self, path):
|
||||
"""Prepare contextmanager.
|
||||
|
||||
Args:
|
||||
path: the path to change to
|
||||
"""
|
||||
self.path = path
|
||||
self._old_cwd = []
|
||||
|
||||
def __enter__(self):
|
||||
"""Save current directory and perform chdir."""
|
||||
self._old_cwd.append(Path(".").resolve())
|
||||
os.chdir(self.path)
|
||||
|
||||
def __exit__(self, *excinfo):
|
||||
"""Change back to previous directory on stack.
|
||||
|
||||
Args:
|
||||
excinfo: sys.exc_info captured in the context block
|
||||
"""
|
||||
os.chdir(self._old_cwd.pop())
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def tmp_working_dir(tmp_path):
|
||||
"""Create a temporary directory and chdir to it.
|
||||
|
||||
After the test executes, chdir back to the original working directory.
|
||||
|
||||
Args:
|
||||
tmp_path: pytest tmp_path fixture creates per-test temp dir
|
||||
|
||||
Yields:
|
||||
subdirectory of tmp_path which is now the current working directory.
|
||||
"""
|
||||
working_dir = tmp_path / "working_dir"
|
||||
working_dir.mkdir()
|
||||
with chdir(working_dir):
|
||||
yield working_dir
|
||||
|
@ -1,6 +1,13 @@
|
||||
import subprocess
|
||||
import sys
|
||||
from unittest import mock
|
||||
|
||||
import pytest
|
||||
import sqlalchemy
|
||||
import sqlmodel
|
||||
|
||||
import reflex.constants
|
||||
import reflex.model
|
||||
from reflex.model import Model
|
||||
|
||||
|
||||
@ -49,3 +56,122 @@ def test_custom_primary_key(model_custom_primary):
|
||||
model_custom_primary: Fixture.
|
||||
"""
|
||||
assert "id" not in model_custom_primary.__class__.__fields__
|
||||
|
||||
|
||||
@pytest.mark.filterwarnings(
|
||||
"ignore:This declarative base already contains a class with the same class name",
|
||||
)
|
||||
def test_automigration(tmp_working_dir, monkeypatch):
|
||||
"""Test alembic automigration with add and drop table and column.
|
||||
|
||||
Args:
|
||||
tmp_working_dir: directory where database and migrations are stored
|
||||
monkeypatch: pytest fixture to overwrite attributes
|
||||
"""
|
||||
subprocess.run(
|
||||
[sys.executable, "-m", "alembic", "init", "alembic"],
|
||||
cwd=tmp_working_dir,
|
||||
)
|
||||
alembic_ini = tmp_working_dir / "alembic.ini"
|
||||
versions = tmp_working_dir / "alembic" / "versions"
|
||||
assert alembic_ini.exists()
|
||||
assert versions.exists()
|
||||
|
||||
config_mock = mock.Mock()
|
||||
config_mock.db_url = f"sqlite:///{tmp_working_dir}/reflex.db"
|
||||
monkeypatch.setattr(reflex.model, "get_config", mock.Mock(return_value=config_mock))
|
||||
monkeypatch.setattr(reflex.constants, "ALEMBIC_CONFIG", str(alembic_ini))
|
||||
|
||||
# initial table
|
||||
class AlembicThing(Model, table=True): # type: ignore
|
||||
t1: str
|
||||
|
||||
Model.automigrate()
|
||||
assert len(list(versions.glob("*.py"))) == 1
|
||||
|
||||
with reflex.model.session() as session:
|
||||
session.add(AlembicThing(id=None, t1="foo"))
|
||||
session.commit()
|
||||
|
||||
sqlmodel.SQLModel.metadata.clear()
|
||||
|
||||
# Create column t2
|
||||
class AlembicThing(Model, table=True): # type: ignore
|
||||
t1: str
|
||||
t2: str = "bar"
|
||||
|
||||
Model.automigrate()
|
||||
assert len(list(versions.glob("*.py"))) == 2
|
||||
|
||||
with reflex.model.session() as session:
|
||||
result = session.exec(sqlmodel.select(AlembicThing)).all()
|
||||
assert len(result) == 1
|
||||
assert result[0].t2 == "bar"
|
||||
|
||||
sqlmodel.SQLModel.metadata.clear()
|
||||
|
||||
# Drop column t1
|
||||
class AlembicThing(Model, table=True): # type: ignore
|
||||
t2: str = "bar"
|
||||
|
||||
Model.automigrate()
|
||||
assert len(list(versions.glob("*.py"))) == 3
|
||||
|
||||
with reflex.model.session() as session:
|
||||
result = session.exec(sqlmodel.select(AlembicThing)).all()
|
||||
assert len(result) == 1
|
||||
assert result[0].t2 == "bar"
|
||||
|
||||
# Add table
|
||||
class AlembicSecond(Model, table=True): # type: ignore
|
||||
a: int = 42
|
||||
b: float = 4.2
|
||||
|
||||
Model.automigrate()
|
||||
assert len(list(versions.glob("*.py"))) == 4
|
||||
|
||||
with reflex.model.session() as session:
|
||||
session.add(AlembicSecond(id=None))
|
||||
session.commit()
|
||||
result = session.exec(sqlmodel.select(AlembicSecond)).all()
|
||||
assert len(result) == 1
|
||||
assert result[0].a == 42
|
||||
assert result[0].b == 4.2
|
||||
|
||||
# No-op
|
||||
Model.automigrate()
|
||||
assert len(list(versions.glob("*.py"))) == 4
|
||||
|
||||
# drop table (AlembicSecond)
|
||||
sqlmodel.SQLModel.metadata.clear()
|
||||
|
||||
class AlembicThing(Model, table=True): # type: ignore
|
||||
t2: str = "bar"
|
||||
|
||||
Model.automigrate()
|
||||
assert len(list(versions.glob("*.py"))) == 5
|
||||
|
||||
with reflex.model.session() as session:
|
||||
with pytest.raises(sqlalchemy.exc.OperationalError) as errctx: # type: ignore
|
||||
session.exec(sqlmodel.select(AlembicSecond)).all()
|
||||
assert errctx.match(r"no such table: alembicsecond")
|
||||
# first table should still exist
|
||||
result = session.exec(sqlmodel.select(AlembicThing)).all()
|
||||
assert len(result) == 1
|
||||
assert result[0].t2 == "bar"
|
||||
|
||||
sqlmodel.SQLModel.metadata.clear()
|
||||
|
||||
class AlembicThing(Model, table=True): # type: ignore
|
||||
# changing column type not supported by default
|
||||
t2: int = 42
|
||||
|
||||
Model.automigrate()
|
||||
assert len(list(versions.glob("*.py"))) == 5
|
||||
|
||||
# clear all metadata to avoid influencing subsequent tests
|
||||
sqlmodel.SQLModel.metadata.clear()
|
||||
|
||||
# drop remaining tables
|
||||
Model.automigrate()
|
||||
assert len(list(versions.glob("*.py"))) == 6
|
||||
|
Loading…
Reference in New Issue
Block a user