reflex db migrate CLI and associated config (#1336)
This commit is contained in:
parent
391135e235
commit
4a661a5395
2
poetry.lock
generated
2
poetry.lock
generated
@ -2128,4 +2128,4 @@ testing = ["big-O", "flake8 (<5)", "jaraco.functools", "jaraco.itertools", "more
|
||||
[metadata]
|
||||
lock-version = "2.0"
|
||||
python-versions = "^3.7"
|
||||
content-hash = "f4586d218b5320f0b595db276e8426db6bfffb406e8108b8a1bd9e785b6407c4"
|
||||
content-hash = "ac27016107e8a033aa39d9a712d3ef685132e22ede599a26214b17da6ff35829"
|
||||
|
@ -45,6 +45,7 @@ websockets = "^10.4"
|
||||
starlette-admin = "^0.9.0"
|
||||
python-dotenv = "^0.13.0"
|
||||
importlib-metadata = {version = "^6.7.0", python = ">=3.7, <3.8"}
|
||||
alembic = "^1.11.1"
|
||||
|
||||
[tool.poetry.group.dev.dependencies]
|
||||
pytest = "^7.1.2"
|
||||
@ -62,7 +63,6 @@ pandas = [
|
||||
]
|
||||
asynctest = "^0.13.0"
|
||||
pre-commit = {version = "^3.2.1", python = ">=3.8,<4.0"}
|
||||
alembic = "^1.11.1"
|
||||
selenium = "^4.10.0"
|
||||
|
||||
[tool.poetry.scripts]
|
||||
|
@ -452,10 +452,6 @@ class App(Base):
|
||||
# Get the env mode.
|
||||
config = get_config()
|
||||
|
||||
# Update models during hot reload.
|
||||
if config.db_url is not None and not Model.automigrate():
|
||||
Model.create_all()
|
||||
|
||||
# Empty the .web pages directory
|
||||
compiler.purge_web_pages_dir()
|
||||
|
||||
|
@ -123,7 +123,10 @@ def compile_state(state: Type[State]) -> Dict:
|
||||
Returns:
|
||||
A dictionary of the compiled state.
|
||||
"""
|
||||
initial_state = state().dict()
|
||||
try:
|
||||
initial_state = state().dict()
|
||||
except Exception:
|
||||
initial_state = state().dict(include_computed=False)
|
||||
initial_state.update(
|
||||
{
|
||||
"events": [{"name": get_hydrate_event(state)}],
|
||||
|
100
reflex/model.py
100
reflex/model.py
@ -4,26 +4,20 @@ from collections import defaultdict
|
||||
from pathlib import Path
|
||||
from typing import Any, Optional
|
||||
|
||||
import alembic.autogenerate
|
||||
import alembic.command
|
||||
import alembic.config
|
||||
import alembic.operations.ops
|
||||
import alembic.runtime.environment
|
||||
import alembic.script
|
||||
import alembic.util
|
||||
import sqlalchemy
|
||||
import sqlmodel
|
||||
|
||||
from reflex import constants
|
||||
from reflex.base import Base
|
||||
from reflex.config import get_config
|
||||
|
||||
from . import constants
|
||||
|
||||
try:
|
||||
import alembic.autogenerate # pyright: ignore [reportMissingImports]
|
||||
import alembic.command # pyright: ignore [reportMissingImports]
|
||||
import alembic.operations.ops # pyright: ignore [reportMissingImports]
|
||||
import alembic.runtime.environment # pyright: ignore [reportMissingImports]
|
||||
import alembic.script # pyright: ignore [reportMissingImports]
|
||||
import alembic.util # pyright: ignore [reportMissingImports]
|
||||
from alembic.config import Config # pyright: ignore [reportMissingImports]
|
||||
|
||||
has_alembic = True
|
||||
except ImportError:
|
||||
has_alembic = False
|
||||
from reflex.utils import console
|
||||
|
||||
|
||||
def get_engine(url: Optional[str] = None):
|
||||
@ -42,6 +36,10 @@ def get_engine(url: Optional[str] = None):
|
||||
url = url or conf.db_url
|
||||
if url is None:
|
||||
raise ValueError("No database url configured")
|
||||
if not Path(constants.ALEMBIC_CONFIG).exists():
|
||||
console.print(
|
||||
"[red]Database is not initialized, run [bold]reflex db init[/bold] first."
|
||||
)
|
||||
return sqlmodel.create_engine(
|
||||
url,
|
||||
echo=False,
|
||||
@ -100,7 +98,7 @@ class Model(Base, sqlmodel.SQLModel):
|
||||
Returns:
|
||||
tuple of (config, script_directory)
|
||||
"""
|
||||
config = Config(constants.ALEMBIC_CONFIG)
|
||||
config = alembic.config.Config(constants.ALEMBIC_CONFIG)
|
||||
return config, alembic.script.ScriptDirectory(
|
||||
config.get_main_option("script_location", default="version"),
|
||||
)
|
||||
@ -120,27 +118,45 @@ class Model(Base, sqlmodel.SQLModel):
|
||||
See https://alembic.sqlalchemy.org/en/latest/api/runtime.html
|
||||
|
||||
Args:
|
||||
type_: one of "schema", "table", "column", "index",
|
||||
"unique_constraint", or "foreign_key_constraint"
|
||||
obj: the object being rendered
|
||||
autogen_context: shared AutogenContext passed to each render_item call
|
||||
type_: One of "schema", "table", "column", "index",
|
||||
"unique_constraint", or "foreign_key_constraint".
|
||||
obj: The object being rendered.
|
||||
autogen_context: Shared AutogenContext passed to each render_item call.
|
||||
|
||||
Returns:
|
||||
False - indicating that the default rendering should be used.
|
||||
False - Indicating that the default rendering should be used.
|
||||
"""
|
||||
autogen_context.imports.add("import sqlmodel")
|
||||
return False
|
||||
|
||||
@classmethod
|
||||
def _alembic_autogenerate(cls, connection: sqlalchemy.engine.Connection) -> bool:
|
||||
def alembic_init(cls):
|
||||
"""Initialize alembic for the project."""
|
||||
alembic.command.init(
|
||||
config=alembic.config.Config(constants.ALEMBIC_CONFIG),
|
||||
directory=str(Path(constants.ALEMBIC_CONFIG).parent / "alembic"),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def alembic_autogenerate(
|
||||
cls,
|
||||
connection: sqlalchemy.engine.Connection,
|
||||
message: Optional[str] = None,
|
||||
write_migration_scripts: bool = True,
|
||||
) -> bool:
|
||||
"""Generate migration scripts for alembic-detectable changes.
|
||||
|
||||
Args:
|
||||
connection: sqlalchemy connection to use when detecting changes
|
||||
connection: SQLAlchemy connection to use when detecting changes.
|
||||
message: Human readable identifier describing the generated revision.
|
||||
write_migration_scripts: If True, write autogenerated revisions to script directory.
|
||||
|
||||
Returns:
|
||||
True when changes have been detected.
|
||||
"""
|
||||
if not Path(constants.ALEMBIC_CONFIG).exists():
|
||||
return False
|
||||
|
||||
config, script_directory = cls._alembic_config()
|
||||
revision_context = alembic.autogenerate.api.RevisionContext(
|
||||
config=config,
|
||||
@ -149,6 +165,7 @@ class Model(Base, sqlmodel.SQLModel):
|
||||
lambda: None,
|
||||
autogenerate=True,
|
||||
head="head",
|
||||
message=message,
|
||||
),
|
||||
)
|
||||
writer = alembic.autogenerate.rewriter.Rewriter()
|
||||
@ -156,7 +173,7 @@ class Model(Base, sqlmodel.SQLModel):
|
||||
@writer.rewrites(alembic.operations.ops.AddColumnOp)
|
||||
def render_add_column_with_server_default(context, revision, op):
|
||||
# Carry the sqlmodel default as server_default so that newly added
|
||||
# columns get the desired default value in existing rows
|
||||
# columns get the desired default value in existing rows.
|
||||
if op.column.default is not None and op.column.server_default is None:
|
||||
op.column.server_default = sqlalchemy.DefaultClause(
|
||||
sqlalchemy.sql.expression.literal(op.column.default.arg),
|
||||
@ -184,9 +201,9 @@ class Model(Base, sqlmodel.SQLModel):
|
||||
upgrade_ops = revision_context.generated_revisions[-1].upgrade_ops
|
||||
if upgrade_ops is not None:
|
||||
changes_detected = bool(upgrade_ops.ops)
|
||||
if changes_detected:
|
||||
for _script in revision_context.generate_scripts():
|
||||
pass # must iterate to actually generate the scripts
|
||||
if changes_detected and write_migration_scripts:
|
||||
# Must iterate the generator to actually write the scripts.
|
||||
_ = tuple(revision_context.generate_scripts())
|
||||
return changes_detected
|
||||
|
||||
@classmethod
|
||||
@ -198,15 +215,14 @@ class Model(Base, sqlmodel.SQLModel):
|
||||
"""Apply alembic migrations up to the given revision.
|
||||
|
||||
Args:
|
||||
connection: sqlalchemy connection to use when performing upgrade
|
||||
to_rev: revision to migrate towards
|
||||
connection: SQLAlchemy connection to use when performing upgrade.
|
||||
to_rev: Revision to migrate towards.
|
||||
"""
|
||||
config, script_directory = cls._alembic_config()
|
||||
|
||||
def run_upgrade(rev, context):
|
||||
return script_directory._upgrade_revs(to_rev, rev)
|
||||
|
||||
# apply updates to database
|
||||
with alembic.runtime.environment.EnvironmentContext(
|
||||
config=config,
|
||||
script=script_directory,
|
||||
@ -216,28 +232,36 @@ class Model(Base, sqlmodel.SQLModel):
|
||||
env.run_migrations()
|
||||
|
||||
@classmethod
|
||||
def automigrate(cls) -> Optional[bool]:
|
||||
"""Generate and execute migrations for all sqlmodel Model classes.
|
||||
def migrate(cls, autogenerate: bool = False) -> Optional[bool]:
|
||||
"""Execute alembic migrations for all sqlmodel Model classes.
|
||||
|
||||
If alembic is not installed or has not been initialized for the project,
|
||||
then no action is performed.
|
||||
|
||||
If there are no revisions currently tracked by alembic, then
|
||||
an initial revision will be created based on sqlmodel metadata.
|
||||
|
||||
If models in the app have changed in incompatible ways that alembic
|
||||
cannot automatically generate revisions for, the app may not be able to
|
||||
start up until migration scripts have been corrected by hand.
|
||||
|
||||
Args:
|
||||
autogenerate: If True, generate migration script and use it to upgrade schema
|
||||
(otherwise, just bring the schema to current "head" revision).
|
||||
|
||||
Returns:
|
||||
True - indicating the process was successful
|
||||
None - indicating the process was skipped
|
||||
True - indicating the process was successful.
|
||||
None - indicating the process was skipped.
|
||||
"""
|
||||
if not has_alembic or not Path(constants.ALEMBIC_CONFIG).exists():
|
||||
if not Path(constants.ALEMBIC_CONFIG).exists():
|
||||
return
|
||||
|
||||
with cls.get_db_engine().connect() as connection:
|
||||
cls._alembic_upgrade(connection=connection)
|
||||
changes_detected = cls._alembic_autogenerate(connection=connection)
|
||||
if changes_detected:
|
||||
cls._alembic_upgrade(connection=connection)
|
||||
if autogenerate:
|
||||
changes_detected = cls.alembic_autogenerate(connection=connection)
|
||||
if changes_detected:
|
||||
cls._alembic_upgrade(connection=connection)
|
||||
connection.commit()
|
||||
return True
|
||||
|
||||
|
@ -9,7 +9,7 @@ from pathlib import Path
|
||||
import httpx
|
||||
import typer
|
||||
|
||||
from reflex import constants
|
||||
from reflex import constants, model
|
||||
from reflex.config import get_config
|
||||
from reflex.utils import build, console, exec, prerequisites, processes, telemetry
|
||||
|
||||
@ -132,6 +132,9 @@ def run(
|
||||
# Check the admin dashboard settings.
|
||||
prerequisites.check_admin_settings()
|
||||
|
||||
# Warn if schema is not up to date.
|
||||
prerequisites.check_schema_up_to_date()
|
||||
|
||||
# Get the frontend and backend commands, based on the environment.
|
||||
setup_frontend = frontend_cmd = backend_cmd = None
|
||||
if env == constants.Env.DEV:
|
||||
@ -158,7 +161,6 @@ def run(
|
||||
target=frontend_cmd, args=(Path.cwd(), frontend_port, loglevel)
|
||||
).start()
|
||||
if backend:
|
||||
build.setup_backend()
|
||||
threading.Thread(
|
||||
target=backend_cmd,
|
||||
args=(app.__name__, backend_host, backend_port, loglevel),
|
||||
@ -258,6 +260,51 @@ def export(
|
||||
)
|
||||
|
||||
|
||||
db_cli = typer.Typer()
|
||||
|
||||
|
||||
@db_cli.command(name="init")
|
||||
def db_init():
|
||||
"""Create database schema and migration configuration."""
|
||||
if get_config().db_url is None:
|
||||
console.print("[red]db_url is not configured, cannot initialize.")
|
||||
if Path(constants.ALEMBIC_CONFIG).exists():
|
||||
console.print(
|
||||
"[red]Database is already initialized. Use "
|
||||
"[bold]reflex db makemigrations[/bold] to create schema change "
|
||||
"scripts and [bold]reflex db migrate[/bold] to apply migrations "
|
||||
"to a new or existing database.",
|
||||
)
|
||||
prerequisites.get_app()
|
||||
model.Model.alembic_init()
|
||||
model.Model.migrate(autogenerate=True)
|
||||
|
||||
|
||||
@db_cli.command()
|
||||
def migrate():
|
||||
"""Create or update database schema based on app models or existing migration scripts."""
|
||||
prerequisites.get_app()
|
||||
if not prerequisites.check_db_initialized():
|
||||
return
|
||||
model.Model.migrate()
|
||||
prerequisites.check_schema_up_to_date()
|
||||
|
||||
|
||||
@db_cli.command()
|
||||
def makemigrations(
|
||||
message: str = typer.Option(
|
||||
None, help="Human readable identifier for the generated revision."
|
||||
),
|
||||
):
|
||||
"""Create autogenerated alembic migration scripts."""
|
||||
prerequisites.get_app()
|
||||
if not prerequisites.check_db_initialized():
|
||||
return
|
||||
with model.Model.get_db_engine().connect() as connection:
|
||||
model.Model.alembic_autogenerate(connection=connection, message=message)
|
||||
|
||||
|
||||
cli.add_typer(db_cli, name="db", help="Subcommands for managing the database schema")
|
||||
main = cli
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
@ -12,7 +12,6 @@ from typing import Optional, Union
|
||||
from rich.progress import Progress
|
||||
|
||||
from reflex import constants
|
||||
from reflex.config import get_config
|
||||
from reflex.utils import path_ops, prerequisites
|
||||
from reflex.utils.processes import new_process
|
||||
|
||||
@ -240,16 +239,3 @@ def setup_frontend_prod(
|
||||
"""
|
||||
setup_frontend(root, loglevel, disable_telemetry)
|
||||
export_app(loglevel=loglevel)
|
||||
|
||||
|
||||
def setup_backend():
|
||||
"""Set up backend.
|
||||
|
||||
Specifically ensures backend database is updated when running --no-frontend.
|
||||
"""
|
||||
# Import here to avoid circular imports.
|
||||
from reflex.model import Model
|
||||
|
||||
config = get_config()
|
||||
if config.db_url is not None:
|
||||
Model.create_all()
|
||||
|
@ -16,10 +16,11 @@ from types import ModuleType
|
||||
from typing import Optional
|
||||
|
||||
import typer
|
||||
from alembic.util.exc import CommandError
|
||||
from packaging import version
|
||||
from redis import Redis
|
||||
|
||||
from reflex import constants
|
||||
from reflex import constants, model
|
||||
from reflex.config import get_config
|
||||
from reflex.utils import console, path_ops
|
||||
|
||||
@ -370,6 +371,41 @@ def check_admin_settings():
|
||||
)
|
||||
|
||||
|
||||
def check_db_initialized() -> bool:
|
||||
"""Check if the database migrations are initialized.
|
||||
|
||||
Returns:
|
||||
True if alembic is initialized (or if database is not used).
|
||||
"""
|
||||
if get_config().db_url is not None and not Path(constants.ALEMBIC_CONFIG).exists():
|
||||
console.print(
|
||||
"[red]Database is not initialized. Run [bold]reflex db init[/bold] first."
|
||||
)
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
def check_schema_up_to_date():
|
||||
"""Check if the sqlmodel metadata matches the current database schema."""
|
||||
if get_config().db_url is None or not Path(constants.ALEMBIC_CONFIG).exists():
|
||||
return
|
||||
with model.Model.get_db_engine().connect() as connection:
|
||||
try:
|
||||
if model.Model.alembic_autogenerate(
|
||||
connection=connection,
|
||||
write_migration_scripts=False,
|
||||
):
|
||||
console.print(
|
||||
"[red]Detected database schema changes. Run [bold]reflex db makemigrations[/bold] "
|
||||
"to generate migration scripts.",
|
||||
)
|
||||
except CommandError as command_error:
|
||||
if "Target database is not up to date." in str(command_error):
|
||||
console.print(
|
||||
f"[red]{command_error} Run [bold]reflex db migrate[/bold] to update database."
|
||||
)
|
||||
|
||||
|
||||
def migrate_to_reflex():
|
||||
"""Migration from Pynecone to Reflex."""
|
||||
# Check if the old config file exists.
|
||||
|
@ -1,5 +1,3 @@
|
||||
import subprocess
|
||||
import sys
|
||||
from unittest import mock
|
||||
|
||||
import pytest
|
||||
@ -68,26 +66,28 @@ def test_automigration(tmp_working_dir, monkeypatch):
|
||||
tmp_working_dir: directory where database and migrations are stored
|
||||
monkeypatch: pytest fixture to overwrite attributes
|
||||
"""
|
||||
subprocess.run(
|
||||
[sys.executable, "-m", "alembic", "init", "alembic"],
|
||||
cwd=tmp_working_dir,
|
||||
)
|
||||
alembic_ini = tmp_working_dir / "alembic.ini"
|
||||
versions = tmp_working_dir / "alembic" / "versions"
|
||||
assert alembic_ini.exists()
|
||||
assert versions.exists()
|
||||
monkeypatch.setattr(reflex.constants, "ALEMBIC_CONFIG", str(alembic_ini))
|
||||
|
||||
config_mock = mock.Mock()
|
||||
config_mock.db_url = f"sqlite:///{tmp_working_dir}/reflex.db"
|
||||
monkeypatch.setattr(reflex.model, "get_config", mock.Mock(return_value=config_mock))
|
||||
monkeypatch.setattr(reflex.constants, "ALEMBIC_CONFIG", str(alembic_ini))
|
||||
|
||||
Model.alembic_init()
|
||||
assert alembic_ini.exists()
|
||||
assert versions.exists()
|
||||
|
||||
# initial table
|
||||
class AlembicThing(Model, table=True): # type: ignore
|
||||
t1: str
|
||||
|
||||
Model.automigrate()
|
||||
assert len(list(versions.glob("*.py"))) == 1
|
||||
with Model.get_db_engine().connect() as connection:
|
||||
Model.alembic_autogenerate(connection=connection, message="Initial Revision")
|
||||
Model.migrate()
|
||||
version_scripts = list(versions.glob("*.py"))
|
||||
assert len(version_scripts) == 1
|
||||
assert version_scripts[0].name.endswith("initial_revision.py")
|
||||
|
||||
with reflex.model.session() as session:
|
||||
session.add(AlembicThing(id=None, t1="foo"))
|
||||
@ -100,7 +100,7 @@ def test_automigration(tmp_working_dir, monkeypatch):
|
||||
t1: str
|
||||
t2: str = "bar"
|
||||
|
||||
Model.automigrate()
|
||||
Model.migrate(autogenerate=True)
|
||||
assert len(list(versions.glob("*.py"))) == 2
|
||||
|
||||
with reflex.model.session() as session:
|
||||
@ -114,7 +114,7 @@ def test_automigration(tmp_working_dir, monkeypatch):
|
||||
class AlembicThing(Model, table=True): # type: ignore
|
||||
t2: str = "bar"
|
||||
|
||||
Model.automigrate()
|
||||
Model.migrate(autogenerate=True)
|
||||
assert len(list(versions.glob("*.py"))) == 3
|
||||
|
||||
with reflex.model.session() as session:
|
||||
@ -127,7 +127,7 @@ def test_automigration(tmp_working_dir, monkeypatch):
|
||||
a: int = 42
|
||||
b: float = 4.2
|
||||
|
||||
Model.automigrate()
|
||||
Model.migrate(autogenerate=True)
|
||||
assert len(list(versions.glob("*.py"))) == 4
|
||||
|
||||
with reflex.model.session() as session:
|
||||
@ -139,7 +139,7 @@ def test_automigration(tmp_working_dir, monkeypatch):
|
||||
assert result[0].b == 4.2
|
||||
|
||||
# No-op
|
||||
Model.automigrate()
|
||||
Model.migrate(autogenerate=True)
|
||||
assert len(list(versions.glob("*.py"))) == 4
|
||||
|
||||
# drop table (AlembicSecond)
|
||||
@ -148,7 +148,7 @@ def test_automigration(tmp_working_dir, monkeypatch):
|
||||
class AlembicThing(Model, table=True): # type: ignore
|
||||
t2: str = "bar"
|
||||
|
||||
Model.automigrate()
|
||||
Model.migrate(autogenerate=True)
|
||||
assert len(list(versions.glob("*.py"))) == 5
|
||||
|
||||
with reflex.model.session() as session:
|
||||
@ -166,12 +166,12 @@ def test_automigration(tmp_working_dir, monkeypatch):
|
||||
# changing column type not supported by default
|
||||
t2: int = 42
|
||||
|
||||
Model.automigrate()
|
||||
Model.migrate(autogenerate=True)
|
||||
assert len(list(versions.glob("*.py"))) == 5
|
||||
|
||||
# clear all metadata to avoid influencing subsequent tests
|
||||
sqlmodel.SQLModel.metadata.clear()
|
||||
|
||||
# drop remaining tables
|
||||
Model.automigrate()
|
||||
Model.migrate(autogenerate=True)
|
||||
assert len(list(versions.glob("*.py"))) == 6
|
||||
|
Loading…
Reference in New Issue
Block a user