reflex db migrate CLI and associated config (#1336)
This commit is contained in:
parent
391135e235
commit
4a661a5395
2
poetry.lock
generated
2
poetry.lock
generated
@ -2128,4 +2128,4 @@ testing = ["big-O", "flake8 (<5)", "jaraco.functools", "jaraco.itertools", "more
|
|||||||
[metadata]
|
[metadata]
|
||||||
lock-version = "2.0"
|
lock-version = "2.0"
|
||||||
python-versions = "^3.7"
|
python-versions = "^3.7"
|
||||||
content-hash = "f4586d218b5320f0b595db276e8426db6bfffb406e8108b8a1bd9e785b6407c4"
|
content-hash = "ac27016107e8a033aa39d9a712d3ef685132e22ede599a26214b17da6ff35829"
|
||||||
|
@ -45,6 +45,7 @@ websockets = "^10.4"
|
|||||||
starlette-admin = "^0.9.0"
|
starlette-admin = "^0.9.0"
|
||||||
python-dotenv = "^0.13.0"
|
python-dotenv = "^0.13.0"
|
||||||
importlib-metadata = {version = "^6.7.0", python = ">=3.7, <3.8"}
|
importlib-metadata = {version = "^6.7.0", python = ">=3.7, <3.8"}
|
||||||
|
alembic = "^1.11.1"
|
||||||
|
|
||||||
[tool.poetry.group.dev.dependencies]
|
[tool.poetry.group.dev.dependencies]
|
||||||
pytest = "^7.1.2"
|
pytest = "^7.1.2"
|
||||||
@ -62,7 +63,6 @@ pandas = [
|
|||||||
]
|
]
|
||||||
asynctest = "^0.13.0"
|
asynctest = "^0.13.0"
|
||||||
pre-commit = {version = "^3.2.1", python = ">=3.8,<4.0"}
|
pre-commit = {version = "^3.2.1", python = ">=3.8,<4.0"}
|
||||||
alembic = "^1.11.1"
|
|
||||||
selenium = "^4.10.0"
|
selenium = "^4.10.0"
|
||||||
|
|
||||||
[tool.poetry.scripts]
|
[tool.poetry.scripts]
|
||||||
|
@ -452,10 +452,6 @@ class App(Base):
|
|||||||
# Get the env mode.
|
# Get the env mode.
|
||||||
config = get_config()
|
config = get_config()
|
||||||
|
|
||||||
# Update models during hot reload.
|
|
||||||
if config.db_url is not None and not Model.automigrate():
|
|
||||||
Model.create_all()
|
|
||||||
|
|
||||||
# Empty the .web pages directory
|
# Empty the .web pages directory
|
||||||
compiler.purge_web_pages_dir()
|
compiler.purge_web_pages_dir()
|
||||||
|
|
||||||
|
@ -123,7 +123,10 @@ def compile_state(state: Type[State]) -> Dict:
|
|||||||
Returns:
|
Returns:
|
||||||
A dictionary of the compiled state.
|
A dictionary of the compiled state.
|
||||||
"""
|
"""
|
||||||
initial_state = state().dict()
|
try:
|
||||||
|
initial_state = state().dict()
|
||||||
|
except Exception:
|
||||||
|
initial_state = state().dict(include_computed=False)
|
||||||
initial_state.update(
|
initial_state.update(
|
||||||
{
|
{
|
||||||
"events": [{"name": get_hydrate_event(state)}],
|
"events": [{"name": get_hydrate_event(state)}],
|
||||||
|
100
reflex/model.py
100
reflex/model.py
@ -4,26 +4,20 @@ from collections import defaultdict
|
|||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any, Optional
|
from typing import Any, Optional
|
||||||
|
|
||||||
|
import alembic.autogenerate
|
||||||
|
import alembic.command
|
||||||
|
import alembic.config
|
||||||
|
import alembic.operations.ops
|
||||||
|
import alembic.runtime.environment
|
||||||
|
import alembic.script
|
||||||
|
import alembic.util
|
||||||
import sqlalchemy
|
import sqlalchemy
|
||||||
import sqlmodel
|
import sqlmodel
|
||||||
|
|
||||||
|
from reflex import constants
|
||||||
from reflex.base import Base
|
from reflex.base import Base
|
||||||
from reflex.config import get_config
|
from reflex.config import get_config
|
||||||
|
from reflex.utils import console
|
||||||
from . import constants
|
|
||||||
|
|
||||||
try:
|
|
||||||
import alembic.autogenerate # pyright: ignore [reportMissingImports]
|
|
||||||
import alembic.command # pyright: ignore [reportMissingImports]
|
|
||||||
import alembic.operations.ops # pyright: ignore [reportMissingImports]
|
|
||||||
import alembic.runtime.environment # pyright: ignore [reportMissingImports]
|
|
||||||
import alembic.script # pyright: ignore [reportMissingImports]
|
|
||||||
import alembic.util # pyright: ignore [reportMissingImports]
|
|
||||||
from alembic.config import Config # pyright: ignore [reportMissingImports]
|
|
||||||
|
|
||||||
has_alembic = True
|
|
||||||
except ImportError:
|
|
||||||
has_alembic = False
|
|
||||||
|
|
||||||
|
|
||||||
def get_engine(url: Optional[str] = None):
|
def get_engine(url: Optional[str] = None):
|
||||||
@ -42,6 +36,10 @@ def get_engine(url: Optional[str] = None):
|
|||||||
url = url or conf.db_url
|
url = url or conf.db_url
|
||||||
if url is None:
|
if url is None:
|
||||||
raise ValueError("No database url configured")
|
raise ValueError("No database url configured")
|
||||||
|
if not Path(constants.ALEMBIC_CONFIG).exists():
|
||||||
|
console.print(
|
||||||
|
"[red]Database is not initialized, run [bold]reflex db init[/bold] first."
|
||||||
|
)
|
||||||
return sqlmodel.create_engine(
|
return sqlmodel.create_engine(
|
||||||
url,
|
url,
|
||||||
echo=False,
|
echo=False,
|
||||||
@ -100,7 +98,7 @@ class Model(Base, sqlmodel.SQLModel):
|
|||||||
Returns:
|
Returns:
|
||||||
tuple of (config, script_directory)
|
tuple of (config, script_directory)
|
||||||
"""
|
"""
|
||||||
config = Config(constants.ALEMBIC_CONFIG)
|
config = alembic.config.Config(constants.ALEMBIC_CONFIG)
|
||||||
return config, alembic.script.ScriptDirectory(
|
return config, alembic.script.ScriptDirectory(
|
||||||
config.get_main_option("script_location", default="version"),
|
config.get_main_option("script_location", default="version"),
|
||||||
)
|
)
|
||||||
@ -120,27 +118,45 @@ class Model(Base, sqlmodel.SQLModel):
|
|||||||
See https://alembic.sqlalchemy.org/en/latest/api/runtime.html
|
See https://alembic.sqlalchemy.org/en/latest/api/runtime.html
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
type_: one of "schema", "table", "column", "index",
|
type_: One of "schema", "table", "column", "index",
|
||||||
"unique_constraint", or "foreign_key_constraint"
|
"unique_constraint", or "foreign_key_constraint".
|
||||||
obj: the object being rendered
|
obj: The object being rendered.
|
||||||
autogen_context: shared AutogenContext passed to each render_item call
|
autogen_context: Shared AutogenContext passed to each render_item call.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
False - indicating that the default rendering should be used.
|
False - Indicating that the default rendering should be used.
|
||||||
"""
|
"""
|
||||||
autogen_context.imports.add("import sqlmodel")
|
autogen_context.imports.add("import sqlmodel")
|
||||||
return False
|
return False
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def _alembic_autogenerate(cls, connection: sqlalchemy.engine.Connection) -> bool:
|
def alembic_init(cls):
|
||||||
|
"""Initialize alembic for the project."""
|
||||||
|
alembic.command.init(
|
||||||
|
config=alembic.config.Config(constants.ALEMBIC_CONFIG),
|
||||||
|
directory=str(Path(constants.ALEMBIC_CONFIG).parent / "alembic"),
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def alembic_autogenerate(
|
||||||
|
cls,
|
||||||
|
connection: sqlalchemy.engine.Connection,
|
||||||
|
message: Optional[str] = None,
|
||||||
|
write_migration_scripts: bool = True,
|
||||||
|
) -> bool:
|
||||||
"""Generate migration scripts for alembic-detectable changes.
|
"""Generate migration scripts for alembic-detectable changes.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
connection: sqlalchemy connection to use when detecting changes
|
connection: SQLAlchemy connection to use when detecting changes.
|
||||||
|
message: Human readable identifier describing the generated revision.
|
||||||
|
write_migration_scripts: If True, write autogenerated revisions to script directory.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
True when changes have been detected.
|
True when changes have been detected.
|
||||||
"""
|
"""
|
||||||
|
if not Path(constants.ALEMBIC_CONFIG).exists():
|
||||||
|
return False
|
||||||
|
|
||||||
config, script_directory = cls._alembic_config()
|
config, script_directory = cls._alembic_config()
|
||||||
revision_context = alembic.autogenerate.api.RevisionContext(
|
revision_context = alembic.autogenerate.api.RevisionContext(
|
||||||
config=config,
|
config=config,
|
||||||
@ -149,6 +165,7 @@ class Model(Base, sqlmodel.SQLModel):
|
|||||||
lambda: None,
|
lambda: None,
|
||||||
autogenerate=True,
|
autogenerate=True,
|
||||||
head="head",
|
head="head",
|
||||||
|
message=message,
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
writer = alembic.autogenerate.rewriter.Rewriter()
|
writer = alembic.autogenerate.rewriter.Rewriter()
|
||||||
@ -156,7 +173,7 @@ class Model(Base, sqlmodel.SQLModel):
|
|||||||
@writer.rewrites(alembic.operations.ops.AddColumnOp)
|
@writer.rewrites(alembic.operations.ops.AddColumnOp)
|
||||||
def render_add_column_with_server_default(context, revision, op):
|
def render_add_column_with_server_default(context, revision, op):
|
||||||
# Carry the sqlmodel default as server_default so that newly added
|
# Carry the sqlmodel default as server_default so that newly added
|
||||||
# columns get the desired default value in existing rows
|
# columns get the desired default value in existing rows.
|
||||||
if op.column.default is not None and op.column.server_default is None:
|
if op.column.default is not None and op.column.server_default is None:
|
||||||
op.column.server_default = sqlalchemy.DefaultClause(
|
op.column.server_default = sqlalchemy.DefaultClause(
|
||||||
sqlalchemy.sql.expression.literal(op.column.default.arg),
|
sqlalchemy.sql.expression.literal(op.column.default.arg),
|
||||||
@ -184,9 +201,9 @@ class Model(Base, sqlmodel.SQLModel):
|
|||||||
upgrade_ops = revision_context.generated_revisions[-1].upgrade_ops
|
upgrade_ops = revision_context.generated_revisions[-1].upgrade_ops
|
||||||
if upgrade_ops is not None:
|
if upgrade_ops is not None:
|
||||||
changes_detected = bool(upgrade_ops.ops)
|
changes_detected = bool(upgrade_ops.ops)
|
||||||
if changes_detected:
|
if changes_detected and write_migration_scripts:
|
||||||
for _script in revision_context.generate_scripts():
|
# Must iterate the generator to actually write the scripts.
|
||||||
pass # must iterate to actually generate the scripts
|
_ = tuple(revision_context.generate_scripts())
|
||||||
return changes_detected
|
return changes_detected
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@ -198,15 +215,14 @@ class Model(Base, sqlmodel.SQLModel):
|
|||||||
"""Apply alembic migrations up to the given revision.
|
"""Apply alembic migrations up to the given revision.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
connection: sqlalchemy connection to use when performing upgrade
|
connection: SQLAlchemy connection to use when performing upgrade.
|
||||||
to_rev: revision to migrate towards
|
to_rev: Revision to migrate towards.
|
||||||
"""
|
"""
|
||||||
config, script_directory = cls._alembic_config()
|
config, script_directory = cls._alembic_config()
|
||||||
|
|
||||||
def run_upgrade(rev, context):
|
def run_upgrade(rev, context):
|
||||||
return script_directory._upgrade_revs(to_rev, rev)
|
return script_directory._upgrade_revs(to_rev, rev)
|
||||||
|
|
||||||
# apply updates to database
|
|
||||||
with alembic.runtime.environment.EnvironmentContext(
|
with alembic.runtime.environment.EnvironmentContext(
|
||||||
config=config,
|
config=config,
|
||||||
script=script_directory,
|
script=script_directory,
|
||||||
@ -216,28 +232,36 @@ class Model(Base, sqlmodel.SQLModel):
|
|||||||
env.run_migrations()
|
env.run_migrations()
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def automigrate(cls) -> Optional[bool]:
|
def migrate(cls, autogenerate: bool = False) -> Optional[bool]:
|
||||||
"""Generate and execute migrations for all sqlmodel Model classes.
|
"""Execute alembic migrations for all sqlmodel Model classes.
|
||||||
|
|
||||||
If alembic is not installed or has not been initialized for the project,
|
If alembic is not installed or has not been initialized for the project,
|
||||||
then no action is performed.
|
then no action is performed.
|
||||||
|
|
||||||
|
If there are no revisions currently tracked by alembic, then
|
||||||
|
an initial revision will be created based on sqlmodel metadata.
|
||||||
|
|
||||||
If models in the app have changed in incompatible ways that alembic
|
If models in the app have changed in incompatible ways that alembic
|
||||||
cannot automatically generate revisions for, the app may not be able to
|
cannot automatically generate revisions for, the app may not be able to
|
||||||
start up until migration scripts have been corrected by hand.
|
start up until migration scripts have been corrected by hand.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
autogenerate: If True, generate migration script and use it to upgrade schema
|
||||||
|
(otherwise, just bring the schema to current "head" revision).
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
True - indicating the process was successful
|
True - indicating the process was successful.
|
||||||
None - indicating the process was skipped
|
None - indicating the process was skipped.
|
||||||
"""
|
"""
|
||||||
if not has_alembic or not Path(constants.ALEMBIC_CONFIG).exists():
|
if not Path(constants.ALEMBIC_CONFIG).exists():
|
||||||
return
|
return
|
||||||
|
|
||||||
with cls.get_db_engine().connect() as connection:
|
with cls.get_db_engine().connect() as connection:
|
||||||
cls._alembic_upgrade(connection=connection)
|
cls._alembic_upgrade(connection=connection)
|
||||||
changes_detected = cls._alembic_autogenerate(connection=connection)
|
if autogenerate:
|
||||||
if changes_detected:
|
changes_detected = cls.alembic_autogenerate(connection=connection)
|
||||||
cls._alembic_upgrade(connection=connection)
|
if changes_detected:
|
||||||
|
cls._alembic_upgrade(connection=connection)
|
||||||
connection.commit()
|
connection.commit()
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
@ -9,7 +9,7 @@ from pathlib import Path
|
|||||||
import httpx
|
import httpx
|
||||||
import typer
|
import typer
|
||||||
|
|
||||||
from reflex import constants
|
from reflex import constants, model
|
||||||
from reflex.config import get_config
|
from reflex.config import get_config
|
||||||
from reflex.utils import build, console, exec, prerequisites, processes, telemetry
|
from reflex.utils import build, console, exec, prerequisites, processes, telemetry
|
||||||
|
|
||||||
@ -132,6 +132,9 @@ def run(
|
|||||||
# Check the admin dashboard settings.
|
# Check the admin dashboard settings.
|
||||||
prerequisites.check_admin_settings()
|
prerequisites.check_admin_settings()
|
||||||
|
|
||||||
|
# Warn if schema is not up to date.
|
||||||
|
prerequisites.check_schema_up_to_date()
|
||||||
|
|
||||||
# Get the frontend and backend commands, based on the environment.
|
# Get the frontend and backend commands, based on the environment.
|
||||||
setup_frontend = frontend_cmd = backend_cmd = None
|
setup_frontend = frontend_cmd = backend_cmd = None
|
||||||
if env == constants.Env.DEV:
|
if env == constants.Env.DEV:
|
||||||
@ -158,7 +161,6 @@ def run(
|
|||||||
target=frontend_cmd, args=(Path.cwd(), frontend_port, loglevel)
|
target=frontend_cmd, args=(Path.cwd(), frontend_port, loglevel)
|
||||||
).start()
|
).start()
|
||||||
if backend:
|
if backend:
|
||||||
build.setup_backend()
|
|
||||||
threading.Thread(
|
threading.Thread(
|
||||||
target=backend_cmd,
|
target=backend_cmd,
|
||||||
args=(app.__name__, backend_host, backend_port, loglevel),
|
args=(app.__name__, backend_host, backend_port, loglevel),
|
||||||
@ -258,6 +260,51 @@ def export(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
db_cli = typer.Typer()
|
||||||
|
|
||||||
|
|
||||||
|
@db_cli.command(name="init")
|
||||||
|
def db_init():
|
||||||
|
"""Create database schema and migration configuration."""
|
||||||
|
if get_config().db_url is None:
|
||||||
|
console.print("[red]db_url is not configured, cannot initialize.")
|
||||||
|
if Path(constants.ALEMBIC_CONFIG).exists():
|
||||||
|
console.print(
|
||||||
|
"[red]Database is already initialized. Use "
|
||||||
|
"[bold]reflex db makemigrations[/bold] to create schema change "
|
||||||
|
"scripts and [bold]reflex db migrate[/bold] to apply migrations "
|
||||||
|
"to a new or existing database.",
|
||||||
|
)
|
||||||
|
prerequisites.get_app()
|
||||||
|
model.Model.alembic_init()
|
||||||
|
model.Model.migrate(autogenerate=True)
|
||||||
|
|
||||||
|
|
||||||
|
@db_cli.command()
|
||||||
|
def migrate():
|
||||||
|
"""Create or update database schema based on app models or existing migration scripts."""
|
||||||
|
prerequisites.get_app()
|
||||||
|
if not prerequisites.check_db_initialized():
|
||||||
|
return
|
||||||
|
model.Model.migrate()
|
||||||
|
prerequisites.check_schema_up_to_date()
|
||||||
|
|
||||||
|
|
||||||
|
@db_cli.command()
|
||||||
|
def makemigrations(
|
||||||
|
message: str = typer.Option(
|
||||||
|
None, help="Human readable identifier for the generated revision."
|
||||||
|
),
|
||||||
|
):
|
||||||
|
"""Create autogenerated alembic migration scripts."""
|
||||||
|
prerequisites.get_app()
|
||||||
|
if not prerequisites.check_db_initialized():
|
||||||
|
return
|
||||||
|
with model.Model.get_db_engine().connect() as connection:
|
||||||
|
model.Model.alembic_autogenerate(connection=connection, message=message)
|
||||||
|
|
||||||
|
|
||||||
|
cli.add_typer(db_cli, name="db", help="Subcommands for managing the database schema")
|
||||||
main = cli
|
main = cli
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
@ -12,7 +12,6 @@ from typing import Optional, Union
|
|||||||
from rich.progress import Progress
|
from rich.progress import Progress
|
||||||
|
|
||||||
from reflex import constants
|
from reflex import constants
|
||||||
from reflex.config import get_config
|
|
||||||
from reflex.utils import path_ops, prerequisites
|
from reflex.utils import path_ops, prerequisites
|
||||||
from reflex.utils.processes import new_process
|
from reflex.utils.processes import new_process
|
||||||
|
|
||||||
@ -240,16 +239,3 @@ def setup_frontend_prod(
|
|||||||
"""
|
"""
|
||||||
setup_frontend(root, loglevel, disable_telemetry)
|
setup_frontend(root, loglevel, disable_telemetry)
|
||||||
export_app(loglevel=loglevel)
|
export_app(loglevel=loglevel)
|
||||||
|
|
||||||
|
|
||||||
def setup_backend():
|
|
||||||
"""Set up backend.
|
|
||||||
|
|
||||||
Specifically ensures backend database is updated when running --no-frontend.
|
|
||||||
"""
|
|
||||||
# Import here to avoid circular imports.
|
|
||||||
from reflex.model import Model
|
|
||||||
|
|
||||||
config = get_config()
|
|
||||||
if config.db_url is not None:
|
|
||||||
Model.create_all()
|
|
||||||
|
@ -16,10 +16,11 @@ from types import ModuleType
|
|||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
import typer
|
import typer
|
||||||
|
from alembic.util.exc import CommandError
|
||||||
from packaging import version
|
from packaging import version
|
||||||
from redis import Redis
|
from redis import Redis
|
||||||
|
|
||||||
from reflex import constants
|
from reflex import constants, model
|
||||||
from reflex.config import get_config
|
from reflex.config import get_config
|
||||||
from reflex.utils import console, path_ops
|
from reflex.utils import console, path_ops
|
||||||
|
|
||||||
@ -370,6 +371,41 @@ def check_admin_settings():
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def check_db_initialized() -> bool:
|
||||||
|
"""Check if the database migrations are initialized.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if alembic is initialized (or if database is not used).
|
||||||
|
"""
|
||||||
|
if get_config().db_url is not None and not Path(constants.ALEMBIC_CONFIG).exists():
|
||||||
|
console.print(
|
||||||
|
"[red]Database is not initialized. Run [bold]reflex db init[/bold] first."
|
||||||
|
)
|
||||||
|
return False
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
def check_schema_up_to_date():
|
||||||
|
"""Check if the sqlmodel metadata matches the current database schema."""
|
||||||
|
if get_config().db_url is None or not Path(constants.ALEMBIC_CONFIG).exists():
|
||||||
|
return
|
||||||
|
with model.Model.get_db_engine().connect() as connection:
|
||||||
|
try:
|
||||||
|
if model.Model.alembic_autogenerate(
|
||||||
|
connection=connection,
|
||||||
|
write_migration_scripts=False,
|
||||||
|
):
|
||||||
|
console.print(
|
||||||
|
"[red]Detected database schema changes. Run [bold]reflex db makemigrations[/bold] "
|
||||||
|
"to generate migration scripts.",
|
||||||
|
)
|
||||||
|
except CommandError as command_error:
|
||||||
|
if "Target database is not up to date." in str(command_error):
|
||||||
|
console.print(
|
||||||
|
f"[red]{command_error} Run [bold]reflex db migrate[/bold] to update database."
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def migrate_to_reflex():
|
def migrate_to_reflex():
|
||||||
"""Migration from Pynecone to Reflex."""
|
"""Migration from Pynecone to Reflex."""
|
||||||
# Check if the old config file exists.
|
# Check if the old config file exists.
|
||||||
|
@ -1,5 +1,3 @@
|
|||||||
import subprocess
|
|
||||||
import sys
|
|
||||||
from unittest import mock
|
from unittest import mock
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
@ -68,26 +66,28 @@ def test_automigration(tmp_working_dir, monkeypatch):
|
|||||||
tmp_working_dir: directory where database and migrations are stored
|
tmp_working_dir: directory where database and migrations are stored
|
||||||
monkeypatch: pytest fixture to overwrite attributes
|
monkeypatch: pytest fixture to overwrite attributes
|
||||||
"""
|
"""
|
||||||
subprocess.run(
|
|
||||||
[sys.executable, "-m", "alembic", "init", "alembic"],
|
|
||||||
cwd=tmp_working_dir,
|
|
||||||
)
|
|
||||||
alembic_ini = tmp_working_dir / "alembic.ini"
|
alembic_ini = tmp_working_dir / "alembic.ini"
|
||||||
versions = tmp_working_dir / "alembic" / "versions"
|
versions = tmp_working_dir / "alembic" / "versions"
|
||||||
assert alembic_ini.exists()
|
monkeypatch.setattr(reflex.constants, "ALEMBIC_CONFIG", str(alembic_ini))
|
||||||
assert versions.exists()
|
|
||||||
|
|
||||||
config_mock = mock.Mock()
|
config_mock = mock.Mock()
|
||||||
config_mock.db_url = f"sqlite:///{tmp_working_dir}/reflex.db"
|
config_mock.db_url = f"sqlite:///{tmp_working_dir}/reflex.db"
|
||||||
monkeypatch.setattr(reflex.model, "get_config", mock.Mock(return_value=config_mock))
|
monkeypatch.setattr(reflex.model, "get_config", mock.Mock(return_value=config_mock))
|
||||||
monkeypatch.setattr(reflex.constants, "ALEMBIC_CONFIG", str(alembic_ini))
|
|
||||||
|
Model.alembic_init()
|
||||||
|
assert alembic_ini.exists()
|
||||||
|
assert versions.exists()
|
||||||
|
|
||||||
# initial table
|
# initial table
|
||||||
class AlembicThing(Model, table=True): # type: ignore
|
class AlembicThing(Model, table=True): # type: ignore
|
||||||
t1: str
|
t1: str
|
||||||
|
|
||||||
Model.automigrate()
|
with Model.get_db_engine().connect() as connection:
|
||||||
assert len(list(versions.glob("*.py"))) == 1
|
Model.alembic_autogenerate(connection=connection, message="Initial Revision")
|
||||||
|
Model.migrate()
|
||||||
|
version_scripts = list(versions.glob("*.py"))
|
||||||
|
assert len(version_scripts) == 1
|
||||||
|
assert version_scripts[0].name.endswith("initial_revision.py")
|
||||||
|
|
||||||
with reflex.model.session() as session:
|
with reflex.model.session() as session:
|
||||||
session.add(AlembicThing(id=None, t1="foo"))
|
session.add(AlembicThing(id=None, t1="foo"))
|
||||||
@ -100,7 +100,7 @@ def test_automigration(tmp_working_dir, monkeypatch):
|
|||||||
t1: str
|
t1: str
|
||||||
t2: str = "bar"
|
t2: str = "bar"
|
||||||
|
|
||||||
Model.automigrate()
|
Model.migrate(autogenerate=True)
|
||||||
assert len(list(versions.glob("*.py"))) == 2
|
assert len(list(versions.glob("*.py"))) == 2
|
||||||
|
|
||||||
with reflex.model.session() as session:
|
with reflex.model.session() as session:
|
||||||
@ -114,7 +114,7 @@ def test_automigration(tmp_working_dir, monkeypatch):
|
|||||||
class AlembicThing(Model, table=True): # type: ignore
|
class AlembicThing(Model, table=True): # type: ignore
|
||||||
t2: str = "bar"
|
t2: str = "bar"
|
||||||
|
|
||||||
Model.automigrate()
|
Model.migrate(autogenerate=True)
|
||||||
assert len(list(versions.glob("*.py"))) == 3
|
assert len(list(versions.glob("*.py"))) == 3
|
||||||
|
|
||||||
with reflex.model.session() as session:
|
with reflex.model.session() as session:
|
||||||
@ -127,7 +127,7 @@ def test_automigration(tmp_working_dir, monkeypatch):
|
|||||||
a: int = 42
|
a: int = 42
|
||||||
b: float = 4.2
|
b: float = 4.2
|
||||||
|
|
||||||
Model.automigrate()
|
Model.migrate(autogenerate=True)
|
||||||
assert len(list(versions.glob("*.py"))) == 4
|
assert len(list(versions.glob("*.py"))) == 4
|
||||||
|
|
||||||
with reflex.model.session() as session:
|
with reflex.model.session() as session:
|
||||||
@ -139,7 +139,7 @@ def test_automigration(tmp_working_dir, monkeypatch):
|
|||||||
assert result[0].b == 4.2
|
assert result[0].b == 4.2
|
||||||
|
|
||||||
# No-op
|
# No-op
|
||||||
Model.automigrate()
|
Model.migrate(autogenerate=True)
|
||||||
assert len(list(versions.glob("*.py"))) == 4
|
assert len(list(versions.glob("*.py"))) == 4
|
||||||
|
|
||||||
# drop table (AlembicSecond)
|
# drop table (AlembicSecond)
|
||||||
@ -148,7 +148,7 @@ def test_automigration(tmp_working_dir, monkeypatch):
|
|||||||
class AlembicThing(Model, table=True): # type: ignore
|
class AlembicThing(Model, table=True): # type: ignore
|
||||||
t2: str = "bar"
|
t2: str = "bar"
|
||||||
|
|
||||||
Model.automigrate()
|
Model.migrate(autogenerate=True)
|
||||||
assert len(list(versions.glob("*.py"))) == 5
|
assert len(list(versions.glob("*.py"))) == 5
|
||||||
|
|
||||||
with reflex.model.session() as session:
|
with reflex.model.session() as session:
|
||||||
@ -166,12 +166,12 @@ def test_automigration(tmp_working_dir, monkeypatch):
|
|||||||
# changing column type not supported by default
|
# changing column type not supported by default
|
||||||
t2: int = 42
|
t2: int = 42
|
||||||
|
|
||||||
Model.automigrate()
|
Model.migrate(autogenerate=True)
|
||||||
assert len(list(versions.glob("*.py"))) == 5
|
assert len(list(versions.glob("*.py"))) == 5
|
||||||
|
|
||||||
# clear all metadata to avoid influencing subsequent tests
|
# clear all metadata to avoid influencing subsequent tests
|
||||||
sqlmodel.SQLModel.metadata.clear()
|
sqlmodel.SQLModel.metadata.clear()
|
||||||
|
|
||||||
# drop remaining tables
|
# drop remaining tables
|
||||||
Model.automigrate()
|
Model.migrate(autogenerate=True)
|
||||||
assert len(list(versions.glob("*.py"))) == 6
|
assert len(list(versions.glob("*.py"))) == 6
|
||||||
|
Loading…
Reference in New Issue
Block a user