diff --git a/poetry.lock b/poetry.lock index 4a31b5010..a5e3a2a3a 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1,5 +1,26 @@ # This file is automatically @generated by Poetry 1.5.1 and should not be changed by hand. +[[package]] +name = "alembic" +version = "1.11.1" +description = "A database migration tool for SQLAlchemy." +optional = false +python-versions = ">=3.7" +files = [ + {file = "alembic-1.11.1-py3-none-any.whl", hash = "sha256:dc871798a601fab38332e38d6ddb38d5e734f60034baeb8e2db5b642fccd8ab8"}, + {file = "alembic-1.11.1.tar.gz", hash = "sha256:6a810a6b012c88b33458fceb869aef09ac75d6ace5291915ba7fae44de372c01"}, +] + +[package.dependencies] +importlib-metadata = {version = "*", markers = "python_version < \"3.9\""} +importlib-resources = {version = "*", markers = "python_version < \"3.9\""} +Mako = "*" +SQLAlchemy = ">=1.3.0" +typing-extensions = ">=4" + +[package.extras] +tz = ["python-dateutil"] + [[package]] name = "anyio" version = "3.7.1" @@ -501,6 +522,24 @@ docs = ["furo", "jaraco.packaging (>=9)", "jaraco.tidelift (>=1.4)", "rst.linker perf = ["ipython"] testing = ["flufl.flake8", "importlib-resources (>=1.3)", "packaging", "pyfakefs", "pytest (>=6)", "pytest-black (>=0.3.7)", "pytest-checkdocs (>=2.4)", "pytest-cov", "pytest-enabler (>=1.3)", "pytest-mypy (>=0.9.1)", "pytest-perf (>=0.9.2)", "pytest-ruff"] +[[package]] +name = "importlib-resources" +version = "5.12.0" +description = "Read resources from Python packages" +optional = false +python-versions = ">=3.7" +files = [ + {file = "importlib_resources-5.12.0-py3-none-any.whl", hash = "sha256:7b1deeebbf351c7578e09bf2f63fa2ce8b5ffec296e0d349139d43cca061a81a"}, + {file = "importlib_resources-5.12.0.tar.gz", hash = "sha256:4be82589bf5c1d7999aedf2a45159d10cb3ca4f19b2271f8792bc8e6da7b22f6"}, +] + +[package.dependencies] +zipp = {version = ">=3.1.0", markers = "python_version < \"3.10\""} + +[package.extras] +docs = ["furo", "jaraco.packaging (>=9)", "jaraco.tidelift (>=1.4)", "rst.linker (>=1.9)", "sphinx (>=3.5)", "sphinx-lint"] +testing = ["flake8 (<5)", "pytest (>=6)", "pytest-black (>=0.3.7)", "pytest-checkdocs (>=2.4)", "pytest-cov", "pytest-enabler (>=1.3)", "pytest-flake8", "pytest-mypy (>=0.9.1)"] + [[package]] name = "iniconfig" version = "2.0.0" @@ -529,6 +568,26 @@ MarkupSafe = ">=2.0" [package.extras] i18n = ["Babel (>=2.7)"] +[[package]] +name = "mako" +version = "1.2.4" +description = "A super-fast templating language that borrows the best ideas from the existing templating languages." +optional = false +python-versions = ">=3.7" +files = [ + {file = "Mako-1.2.4-py3-none-any.whl", hash = "sha256:c97c79c018b9165ac9922ae4f32da095ffd3c4e6872b45eded42926deea46818"}, + {file = "Mako-1.2.4.tar.gz", hash = "sha256:d60a3903dc3bb01a18ad6a89cdbe2e4eadc69c0bc8ef1e3773ba53d44c3f7a34"}, +] + +[package.dependencies] +importlib-metadata = {version = "*", markers = "python_version < \"3.8\""} +MarkupSafe = ">=0.9.2" + +[package.extras] +babel = ["Babel"] +lingua = ["lingua"] +testing = ["pytest"] + [[package]] name = "markdown-it-py" version = "2.2.0" @@ -1854,4 +1913,4 @@ testing = ["big-O", "flake8 (<5)", "jaraco.functools", "jaraco.itertools", "more [metadata] lock-version = "2.0" python-versions = "^3.7" -content-hash = "b07623f193778651dd3ece81370d2d9a2ba3b53855a57baf0147e67bf07aaed8" +content-hash = "cc659e46041316bc81ce1758334c6fa9ccf9812612ef67e170b173d6d1caa2b2" diff --git a/pyproject.toml b/pyproject.toml index 5cd919492..1ac4d8780 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -62,6 +62,7 @@ pandas = [ ] asynctest = "^0.13.0" pre-commit = {version = "^3.2.1", python = ">=3.8,<4.0"} +alembic = "^1.11.1" [tool.poetry.scripts] reflex = "reflex.reflex:main" diff --git a/reflex/app.py b/reflex/app.py index 93b5d4d80..7ffb48a97 100644 --- a/reflex/app.py +++ b/reflex/app.py @@ -444,7 +444,7 @@ class App(Base): config = get_config() # Update models during hot reload. - if config.db_url is not None: + if config.db_url is not None and not Model.automigrate(): Model.create_all() # Empty the .web pages directory diff --git a/reflex/constants.py b/reflex/constants.py index 0dc406cad..2ba4adeaa 100644 --- a/reflex/constants.py +++ b/reflex/constants.py @@ -353,3 +353,6 @@ TOGGLE_COLOR_MODE = "toggleColorMode" # Server socket configuration variables CORS_ALLOWED_ORIGINS = get_value("CORS_ALLOWED_ORIGINS", ["*"], list) POLLING_MAX_HTTP_BUFFER_SIZE = 1000 * 1000 + +# Alembic migrations +ALEMBIC_CONFIG = os.environ.get("ALEMBIC_CONFIG", "alembic.ini") diff --git a/reflex/model.py b/reflex/model.py index 51135b64f..876eee0e5 100644 --- a/reflex/model.py +++ b/reflex/model.py @@ -1,12 +1,30 @@ """Database built into Reflex.""" -from typing import Optional +from collections import defaultdict +from pathlib import Path +from typing import Any, Optional +import sqlalchemy import sqlmodel from reflex.base import Base from reflex.config import get_config +from . import constants + +try: + import alembic.autogenerate # pyright: ignore [reportMissingImports] + import alembic.command # pyright: ignore [reportMissingImports] + import alembic.operations.ops # pyright: ignore [reportMissingImports] + import alembic.runtime.environment # pyright: ignore [reportMissingImports] + import alembic.script # pyright: ignore [reportMissingImports] + import alembic.util # pyright: ignore [reportMissingImports] + from alembic.config import Config # pyright: ignore [reportMissingImports] + + has_alembic = True +except ImportError: + has_alembic = False + def get_engine(url: Optional[str] = None): """Get the database engine. @@ -75,6 +93,154 @@ class Model(Base, sqlmodel.SQLModel): """ return get_engine() + @staticmethod + def _alembic_config(): + """Get the alembic configuration and script_directory. + + Returns: + tuple of (config, script_directory) + """ + config = Config(constants.ALEMBIC_CONFIG) + return config, alembic.script.ScriptDirectory( + config.get_main_option("script_location", default="version"), + ) + + @staticmethod + def _alembic_render_item( + type_: str, + obj: Any, + autogen_context: "alembic.autogenerate.api.AutogenContext", + ): + """Alembic render_item hook call. + + This method is called to provide python code for the given obj, + but currently it is only used to add `sqlmodel` to the import list + when generating migration scripts. + + See https://alembic.sqlalchemy.org/en/latest/api/runtime.html + + Args: + type_: one of "schema", "table", "column", "index", + "unique_constraint", or "foreign_key_constraint" + obj: the object being rendered + autogen_context: shared AutogenContext passed to each render_item call + + Returns: + False - indicating that the default rendering should be used. + """ + autogen_context.imports.add("import sqlmodel") + return False + + @classmethod + def _alembic_autogenerate(cls, connection: sqlalchemy.engine.Connection) -> bool: + """Generate migration scripts for alembic-detectable changes. + + Args: + connection: sqlalchemy connection to use when detecting changes + + Returns: + True when changes have been detected. + """ + config, script_directory = cls._alembic_config() + revision_context = alembic.autogenerate.api.RevisionContext( + config=config, + script_directory=script_directory, + command_args=defaultdict( + lambda: None, + autogenerate=True, + head="head", + ), + ) + writer = alembic.autogenerate.rewriter.Rewriter() + + @writer.rewrites(alembic.operations.ops.AddColumnOp) + def render_add_column_with_server_default(context, revision, op): + # Carry the sqlmodel default as server_default so that newly added + # columns get the desired default value in existing rows + if op.column.default is not None and op.column.server_default is None: + op.column.server_default = sqlalchemy.DefaultClause( + sqlalchemy.sql.expression.literal(op.column.default.arg), + ) + return op + + def run_autogenerate(rev, context): + revision_context.run_autogenerate(rev, context) + return [] + + with alembic.runtime.environment.EnvironmentContext( + config=config, + script=script_directory, + fn=run_autogenerate, + ) as env: + env.configure( + connection=connection, + target_metadata=sqlmodel.SQLModel.metadata, + render_item=cls._alembic_render_item, + process_revision_directives=writer, # type: ignore + ) + env.run_migrations() + changes_detected = False + if revision_context.generated_revisions: + upgrade_ops = revision_context.generated_revisions[-1].upgrade_ops + if upgrade_ops is not None: + changes_detected = bool(upgrade_ops.ops) + if changes_detected: + for _script in revision_context.generate_scripts(): + pass # must iterate to actually generate the scripts + return changes_detected + + @classmethod + def _alembic_upgrade( + cls, + connection: sqlalchemy.engine.Connection, + to_rev: str = "head", + ) -> None: + """Apply alembic migrations up to the given revision. + + Args: + connection: sqlalchemy connection to use when performing upgrade + to_rev: revision to migrate towards + """ + config, script_directory = cls._alembic_config() + + def run_upgrade(rev, context): + return script_directory._upgrade_revs(to_rev, rev) + + # apply updates to database + with alembic.runtime.environment.EnvironmentContext( + config=config, + script=script_directory, + fn=run_upgrade, + ) as env: + env.configure(connection=connection) + env.run_migrations() + + @classmethod + def automigrate(cls) -> Optional[bool]: + """Generate and execute migrations for all sqlmodel Model classes. + + If alembic is not installed or has not been initialized for the project, + then no action is performed. + + If models in the app have changed in incompatible ways that alembic + cannot automatically generate revisions for, the app may not be able to + start up until migration scripts have been corrected by hand. + + Returns: + True - indicating the process was successful + None - indicating the process was skipped + """ + if not has_alembic or not Path(constants.ALEMBIC_CONFIG).exists(): + return + + with cls.get_db_engine().connect() as connection: + cls._alembic_upgrade(connection=connection) + changes_detected = cls._alembic_autogenerate(connection=connection) + if changes_detected: + cls._alembic_upgrade(connection=connection) + connection.commit() + return True + @classmethod @property def select(cls): diff --git a/tests/conftest.py b/tests/conftest.py index 6c18cbe48..6ba10f2ee 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,5 +1,8 @@ """Test fixtures.""" +import contextlib +import os import platform +from pathlib import Path from typing import Dict, Generator, List import pytest @@ -479,3 +482,48 @@ def router_data(router_data_headers) -> Dict[str, str]: "headers": router_data_headers, "ip": "127.0.0.1", } + + +# borrowed from py3.11 +class chdir(contextlib.AbstractContextManager): + """Non thread-safe context manager to change the current working directory.""" + + def __init__(self, path): + """Prepare contextmanager. + + Args: + path: the path to change to + """ + self.path = path + self._old_cwd = [] + + def __enter__(self): + """Save current directory and perform chdir.""" + self._old_cwd.append(Path(".").resolve()) + os.chdir(self.path) + + def __exit__(self, *excinfo): + """Change back to previous directory on stack. + + Args: + excinfo: sys.exc_info captured in the context block + """ + os.chdir(self._old_cwd.pop()) + + +@pytest.fixture +def tmp_working_dir(tmp_path): + """Create a temporary directory and chdir to it. + + After the test executes, chdir back to the original working directory. + + Args: + tmp_path: pytest tmp_path fixture creates per-test temp dir + + Yields: + subdirectory of tmp_path which is now the current working directory. + """ + working_dir = tmp_path / "working_dir" + working_dir.mkdir() + with chdir(working_dir): + yield working_dir diff --git a/tests/test_model.py b/tests/test_model.py index 44ee86a70..343b72b1b 100644 --- a/tests/test_model.py +++ b/tests/test_model.py @@ -1,6 +1,13 @@ +import subprocess +import sys +from unittest import mock + import pytest +import sqlalchemy import sqlmodel +import reflex.constants +import reflex.model from reflex.model import Model @@ -49,3 +56,122 @@ def test_custom_primary_key(model_custom_primary): model_custom_primary: Fixture. """ assert "id" not in model_custom_primary.__class__.__fields__ + + +@pytest.mark.filterwarnings( + "ignore:This declarative base already contains a class with the same class name", +) +def test_automigration(tmp_working_dir, monkeypatch): + """Test alembic automigration with add and drop table and column. + + Args: + tmp_working_dir: directory where database and migrations are stored + monkeypatch: pytest fixture to overwrite attributes + """ + subprocess.run( + [sys.executable, "-m", "alembic", "init", "alembic"], + cwd=tmp_working_dir, + ) + alembic_ini = tmp_working_dir / "alembic.ini" + versions = tmp_working_dir / "alembic" / "versions" + assert alembic_ini.exists() + assert versions.exists() + + config_mock = mock.Mock() + config_mock.db_url = f"sqlite:///{tmp_working_dir}/reflex.db" + monkeypatch.setattr(reflex.model, "get_config", mock.Mock(return_value=config_mock)) + monkeypatch.setattr(reflex.constants, "ALEMBIC_CONFIG", str(alembic_ini)) + + # initial table + class AlembicThing(Model, table=True): # type: ignore + t1: str + + Model.automigrate() + assert len(list(versions.glob("*.py"))) == 1 + + with reflex.model.session() as session: + session.add(AlembicThing(id=None, t1="foo")) + session.commit() + + sqlmodel.SQLModel.metadata.clear() + + # Create column t2 + class AlembicThing(Model, table=True): # type: ignore + t1: str + t2: str = "bar" + + Model.automigrate() + assert len(list(versions.glob("*.py"))) == 2 + + with reflex.model.session() as session: + result = session.exec(sqlmodel.select(AlembicThing)).all() + assert len(result) == 1 + assert result[0].t2 == "bar" + + sqlmodel.SQLModel.metadata.clear() + + # Drop column t1 + class AlembicThing(Model, table=True): # type: ignore + t2: str = "bar" + + Model.automigrate() + assert len(list(versions.glob("*.py"))) == 3 + + with reflex.model.session() as session: + result = session.exec(sqlmodel.select(AlembicThing)).all() + assert len(result) == 1 + assert result[0].t2 == "bar" + + # Add table + class AlembicSecond(Model, table=True): # type: ignore + a: int = 42 + b: float = 4.2 + + Model.automigrate() + assert len(list(versions.glob("*.py"))) == 4 + + with reflex.model.session() as session: + session.add(AlembicSecond(id=None)) + session.commit() + result = session.exec(sqlmodel.select(AlembicSecond)).all() + assert len(result) == 1 + assert result[0].a == 42 + assert result[0].b == 4.2 + + # No-op + Model.automigrate() + assert len(list(versions.glob("*.py"))) == 4 + + # drop table (AlembicSecond) + sqlmodel.SQLModel.metadata.clear() + + class AlembicThing(Model, table=True): # type: ignore + t2: str = "bar" + + Model.automigrate() + assert len(list(versions.glob("*.py"))) == 5 + + with reflex.model.session() as session: + with pytest.raises(sqlalchemy.exc.OperationalError) as errctx: # type: ignore + session.exec(sqlmodel.select(AlembicSecond)).all() + assert errctx.match(r"no such table: alembicsecond") + # first table should still exist + result = session.exec(sqlmodel.select(AlembicThing)).all() + assert len(result) == 1 + assert result[0].t2 == "bar" + + sqlmodel.SQLModel.metadata.clear() + + class AlembicThing(Model, table=True): # type: ignore + # changing column type not supported by default + t2: int = 42 + + Model.automigrate() + assert len(list(versions.glob("*.py"))) == 5 + + # clear all metadata to avoid influencing subsequent tests + sqlmodel.SQLModel.metadata.clear() + + # drop remaining tables + Model.automigrate() + assert len(list(versions.glob("*.py"))) == 6