diff --git a/poetry.lock b/poetry.lock index 1f91206a4..817e2f189 100644 --- a/poetry.lock +++ b/poetry.lock @@ -2128,4 +2128,4 @@ testing = ["big-O", "flake8 (<5)", "jaraco.functools", "jaraco.itertools", "more [metadata] lock-version = "2.0" python-versions = "^3.7" -content-hash = "f4586d218b5320f0b595db276e8426db6bfffb406e8108b8a1bd9e785b6407c4" +content-hash = "ac27016107e8a033aa39d9a712d3ef685132e22ede599a26214b17da6ff35829" diff --git a/pyproject.toml b/pyproject.toml index 6a19bccde..f26a7348f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -45,6 +45,7 @@ websockets = "^10.4" starlette-admin = "^0.9.0" python-dotenv = "^0.13.0" importlib-metadata = {version = "^6.7.0", python = ">=3.7, <3.8"} +alembic = "^1.11.1" [tool.poetry.group.dev.dependencies] pytest = "^7.1.2" @@ -62,7 +63,6 @@ pandas = [ ] asynctest = "^0.13.0" pre-commit = {version = "^3.2.1", python = ">=3.8,<4.0"} -alembic = "^1.11.1" selenium = "^4.10.0" [tool.poetry.scripts] diff --git a/reflex/app.py b/reflex/app.py index b980d9c7b..490f868d8 100644 --- a/reflex/app.py +++ b/reflex/app.py @@ -452,10 +452,6 @@ class App(Base): # Get the env mode. config = get_config() - # Update models during hot reload. - if config.db_url is not None and not Model.automigrate(): - Model.create_all() - # Empty the .web pages directory compiler.purge_web_pages_dir() diff --git a/reflex/compiler/utils.py b/reflex/compiler/utils.py index c195f053d..63c75fb97 100644 --- a/reflex/compiler/utils.py +++ b/reflex/compiler/utils.py @@ -123,7 +123,10 @@ def compile_state(state: Type[State]) -> Dict: Returns: A dictionary of the compiled state. """ - initial_state = state().dict() + try: + initial_state = state().dict() + except Exception: + initial_state = state().dict(include_computed=False) initial_state.update( { "events": [{"name": get_hydrate_event(state)}], diff --git a/reflex/model.py b/reflex/model.py index 876eee0e5..2575be3c6 100644 --- a/reflex/model.py +++ b/reflex/model.py @@ -4,26 +4,20 @@ from collections import defaultdict from pathlib import Path from typing import Any, Optional +import alembic.autogenerate +import alembic.command +import alembic.config +import alembic.operations.ops +import alembic.runtime.environment +import alembic.script +import alembic.util import sqlalchemy import sqlmodel +from reflex import constants from reflex.base import Base from reflex.config import get_config - -from . import constants - -try: - import alembic.autogenerate # pyright: ignore [reportMissingImports] - import alembic.command # pyright: ignore [reportMissingImports] - import alembic.operations.ops # pyright: ignore [reportMissingImports] - import alembic.runtime.environment # pyright: ignore [reportMissingImports] - import alembic.script # pyright: ignore [reportMissingImports] - import alembic.util # pyright: ignore [reportMissingImports] - from alembic.config import Config # pyright: ignore [reportMissingImports] - - has_alembic = True -except ImportError: - has_alembic = False +from reflex.utils import console def get_engine(url: Optional[str] = None): @@ -42,6 +36,10 @@ def get_engine(url: Optional[str] = None): url = url or conf.db_url if url is None: raise ValueError("No database url configured") + if not Path(constants.ALEMBIC_CONFIG).exists(): + console.print( + "[red]Database is not initialized, run [bold]reflex db init[/bold] first." + ) return sqlmodel.create_engine( url, echo=False, @@ -100,7 +98,7 @@ class Model(Base, sqlmodel.SQLModel): Returns: tuple of (config, script_directory) """ - config = Config(constants.ALEMBIC_CONFIG) + config = alembic.config.Config(constants.ALEMBIC_CONFIG) return config, alembic.script.ScriptDirectory( config.get_main_option("script_location", default="version"), ) @@ -120,27 +118,45 @@ class Model(Base, sqlmodel.SQLModel): See https://alembic.sqlalchemy.org/en/latest/api/runtime.html Args: - type_: one of "schema", "table", "column", "index", - "unique_constraint", or "foreign_key_constraint" - obj: the object being rendered - autogen_context: shared AutogenContext passed to each render_item call + type_: One of "schema", "table", "column", "index", + "unique_constraint", or "foreign_key_constraint". + obj: The object being rendered. + autogen_context: Shared AutogenContext passed to each render_item call. Returns: - False - indicating that the default rendering should be used. + False - Indicating that the default rendering should be used. """ autogen_context.imports.add("import sqlmodel") return False @classmethod - def _alembic_autogenerate(cls, connection: sqlalchemy.engine.Connection) -> bool: + def alembic_init(cls): + """Initialize alembic for the project.""" + alembic.command.init( + config=alembic.config.Config(constants.ALEMBIC_CONFIG), + directory=str(Path(constants.ALEMBIC_CONFIG).parent / "alembic"), + ) + + @classmethod + def alembic_autogenerate( + cls, + connection: sqlalchemy.engine.Connection, + message: Optional[str] = None, + write_migration_scripts: bool = True, + ) -> bool: """Generate migration scripts for alembic-detectable changes. Args: - connection: sqlalchemy connection to use when detecting changes + connection: SQLAlchemy connection to use when detecting changes. + message: Human readable identifier describing the generated revision. + write_migration_scripts: If True, write autogenerated revisions to script directory. Returns: True when changes have been detected. """ + if not Path(constants.ALEMBIC_CONFIG).exists(): + return False + config, script_directory = cls._alembic_config() revision_context = alembic.autogenerate.api.RevisionContext( config=config, @@ -149,6 +165,7 @@ class Model(Base, sqlmodel.SQLModel): lambda: None, autogenerate=True, head="head", + message=message, ), ) writer = alembic.autogenerate.rewriter.Rewriter() @@ -156,7 +173,7 @@ class Model(Base, sqlmodel.SQLModel): @writer.rewrites(alembic.operations.ops.AddColumnOp) def render_add_column_with_server_default(context, revision, op): # Carry the sqlmodel default as server_default so that newly added - # columns get the desired default value in existing rows + # columns get the desired default value in existing rows. if op.column.default is not None and op.column.server_default is None: op.column.server_default = sqlalchemy.DefaultClause( sqlalchemy.sql.expression.literal(op.column.default.arg), @@ -184,9 +201,9 @@ class Model(Base, sqlmodel.SQLModel): upgrade_ops = revision_context.generated_revisions[-1].upgrade_ops if upgrade_ops is not None: changes_detected = bool(upgrade_ops.ops) - if changes_detected: - for _script in revision_context.generate_scripts(): - pass # must iterate to actually generate the scripts + if changes_detected and write_migration_scripts: + # Must iterate the generator to actually write the scripts. + _ = tuple(revision_context.generate_scripts()) return changes_detected @classmethod @@ -198,15 +215,14 @@ class Model(Base, sqlmodel.SQLModel): """Apply alembic migrations up to the given revision. Args: - connection: sqlalchemy connection to use when performing upgrade - to_rev: revision to migrate towards + connection: SQLAlchemy connection to use when performing upgrade. + to_rev: Revision to migrate towards. """ config, script_directory = cls._alembic_config() def run_upgrade(rev, context): return script_directory._upgrade_revs(to_rev, rev) - # apply updates to database with alembic.runtime.environment.EnvironmentContext( config=config, script=script_directory, @@ -216,28 +232,36 @@ class Model(Base, sqlmodel.SQLModel): env.run_migrations() @classmethod - def automigrate(cls) -> Optional[bool]: - """Generate and execute migrations for all sqlmodel Model classes. + def migrate(cls, autogenerate: bool = False) -> Optional[bool]: + """Execute alembic migrations for all sqlmodel Model classes. If alembic is not installed or has not been initialized for the project, then no action is performed. + If there are no revisions currently tracked by alembic, then + an initial revision will be created based on sqlmodel metadata. + If models in the app have changed in incompatible ways that alembic cannot automatically generate revisions for, the app may not be able to start up until migration scripts have been corrected by hand. + Args: + autogenerate: If True, generate migration script and use it to upgrade schema + (otherwise, just bring the schema to current "head" revision). + Returns: - True - indicating the process was successful - None - indicating the process was skipped + True - indicating the process was successful. + None - indicating the process was skipped. """ - if not has_alembic or not Path(constants.ALEMBIC_CONFIG).exists(): + if not Path(constants.ALEMBIC_CONFIG).exists(): return with cls.get_db_engine().connect() as connection: cls._alembic_upgrade(connection=connection) - changes_detected = cls._alembic_autogenerate(connection=connection) - if changes_detected: - cls._alembic_upgrade(connection=connection) + if autogenerate: + changes_detected = cls.alembic_autogenerate(connection=connection) + if changes_detected: + cls._alembic_upgrade(connection=connection) connection.commit() return True diff --git a/reflex/reflex.py b/reflex/reflex.py index c81fa12ab..3a9f90bf8 100644 --- a/reflex/reflex.py +++ b/reflex/reflex.py @@ -9,7 +9,7 @@ from pathlib import Path import httpx import typer -from reflex import constants +from reflex import constants, model from reflex.config import get_config from reflex.utils import build, console, exec, prerequisites, processes, telemetry @@ -132,6 +132,9 @@ def run( # Check the admin dashboard settings. prerequisites.check_admin_settings() + # Warn if schema is not up to date. + prerequisites.check_schema_up_to_date() + # Get the frontend and backend commands, based on the environment. setup_frontend = frontend_cmd = backend_cmd = None if env == constants.Env.DEV: @@ -158,7 +161,6 @@ def run( target=frontend_cmd, args=(Path.cwd(), frontend_port, loglevel) ).start() if backend: - build.setup_backend() threading.Thread( target=backend_cmd, args=(app.__name__, backend_host, backend_port, loglevel), @@ -258,6 +260,51 @@ def export( ) +db_cli = typer.Typer() + + +@db_cli.command(name="init") +def db_init(): + """Create database schema and migration configuration.""" + if get_config().db_url is None: + console.print("[red]db_url is not configured, cannot initialize.") + if Path(constants.ALEMBIC_CONFIG).exists(): + console.print( + "[red]Database is already initialized. Use " + "[bold]reflex db makemigrations[/bold] to create schema change " + "scripts and [bold]reflex db migrate[/bold] to apply migrations " + "to a new or existing database.", + ) + prerequisites.get_app() + model.Model.alembic_init() + model.Model.migrate(autogenerate=True) + + +@db_cli.command() +def migrate(): + """Create or update database schema based on app models or existing migration scripts.""" + prerequisites.get_app() + if not prerequisites.check_db_initialized(): + return + model.Model.migrate() + prerequisites.check_schema_up_to_date() + + +@db_cli.command() +def makemigrations( + message: str = typer.Option( + None, help="Human readable identifier for the generated revision." + ), +): + """Create autogenerated alembic migration scripts.""" + prerequisites.get_app() + if not prerequisites.check_db_initialized(): + return + with model.Model.get_db_engine().connect() as connection: + model.Model.alembic_autogenerate(connection=connection, message=message) + + +cli.add_typer(db_cli, name="db", help="Subcommands for managing the database schema") main = cli if __name__ == "__main__": diff --git a/reflex/utils/build.py b/reflex/utils/build.py index 35947f450..a54ea63f7 100644 --- a/reflex/utils/build.py +++ b/reflex/utils/build.py @@ -12,7 +12,6 @@ from typing import Optional, Union from rich.progress import Progress from reflex import constants -from reflex.config import get_config from reflex.utils import path_ops, prerequisites from reflex.utils.processes import new_process @@ -240,16 +239,3 @@ def setup_frontend_prod( """ setup_frontend(root, loglevel, disable_telemetry) export_app(loglevel=loglevel) - - -def setup_backend(): - """Set up backend. - - Specifically ensures backend database is updated when running --no-frontend. - """ - # Import here to avoid circular imports. - from reflex.model import Model - - config = get_config() - if config.db_url is not None: - Model.create_all() diff --git a/reflex/utils/prerequisites.py b/reflex/utils/prerequisites.py index 7d1919517..0e4516101 100644 --- a/reflex/utils/prerequisites.py +++ b/reflex/utils/prerequisites.py @@ -16,10 +16,11 @@ from types import ModuleType from typing import Optional import typer +from alembic.util.exc import CommandError from packaging import version from redis import Redis -from reflex import constants +from reflex import constants, model from reflex.config import get_config from reflex.utils import console, path_ops @@ -370,6 +371,41 @@ def check_admin_settings(): ) +def check_db_initialized() -> bool: + """Check if the database migrations are initialized. + + Returns: + True if alembic is initialized (or if database is not used). + """ + if get_config().db_url is not None and not Path(constants.ALEMBIC_CONFIG).exists(): + console.print( + "[red]Database is not initialized. Run [bold]reflex db init[/bold] first." + ) + return False + return True + + +def check_schema_up_to_date(): + """Check if the sqlmodel metadata matches the current database schema.""" + if get_config().db_url is None or not Path(constants.ALEMBIC_CONFIG).exists(): + return + with model.Model.get_db_engine().connect() as connection: + try: + if model.Model.alembic_autogenerate( + connection=connection, + write_migration_scripts=False, + ): + console.print( + "[red]Detected database schema changes. Run [bold]reflex db makemigrations[/bold] " + "to generate migration scripts.", + ) + except CommandError as command_error: + if "Target database is not up to date." in str(command_error): + console.print( + f"[red]{command_error} Run [bold]reflex db migrate[/bold] to update database." + ) + + def migrate_to_reflex(): """Migration from Pynecone to Reflex.""" # Check if the old config file exists. diff --git a/tests/test_model.py b/tests/test_model.py index 343b72b1b..0359d1917 100644 --- a/tests/test_model.py +++ b/tests/test_model.py @@ -1,5 +1,3 @@ -import subprocess -import sys from unittest import mock import pytest @@ -68,26 +66,28 @@ def test_automigration(tmp_working_dir, monkeypatch): tmp_working_dir: directory where database and migrations are stored monkeypatch: pytest fixture to overwrite attributes """ - subprocess.run( - [sys.executable, "-m", "alembic", "init", "alembic"], - cwd=tmp_working_dir, - ) alembic_ini = tmp_working_dir / "alembic.ini" versions = tmp_working_dir / "alembic" / "versions" - assert alembic_ini.exists() - assert versions.exists() + monkeypatch.setattr(reflex.constants, "ALEMBIC_CONFIG", str(alembic_ini)) config_mock = mock.Mock() config_mock.db_url = f"sqlite:///{tmp_working_dir}/reflex.db" monkeypatch.setattr(reflex.model, "get_config", mock.Mock(return_value=config_mock)) - monkeypatch.setattr(reflex.constants, "ALEMBIC_CONFIG", str(alembic_ini)) + + Model.alembic_init() + assert alembic_ini.exists() + assert versions.exists() # initial table class AlembicThing(Model, table=True): # type: ignore t1: str - Model.automigrate() - assert len(list(versions.glob("*.py"))) == 1 + with Model.get_db_engine().connect() as connection: + Model.alembic_autogenerate(connection=connection, message="Initial Revision") + Model.migrate() + version_scripts = list(versions.glob("*.py")) + assert len(version_scripts) == 1 + assert version_scripts[0].name.endswith("initial_revision.py") with reflex.model.session() as session: session.add(AlembicThing(id=None, t1="foo")) @@ -100,7 +100,7 @@ def test_automigration(tmp_working_dir, monkeypatch): t1: str t2: str = "bar" - Model.automigrate() + Model.migrate(autogenerate=True) assert len(list(versions.glob("*.py"))) == 2 with reflex.model.session() as session: @@ -114,7 +114,7 @@ def test_automigration(tmp_working_dir, monkeypatch): class AlembicThing(Model, table=True): # type: ignore t2: str = "bar" - Model.automigrate() + Model.migrate(autogenerate=True) assert len(list(versions.glob("*.py"))) == 3 with reflex.model.session() as session: @@ -127,7 +127,7 @@ def test_automigration(tmp_working_dir, monkeypatch): a: int = 42 b: float = 4.2 - Model.automigrate() + Model.migrate(autogenerate=True) assert len(list(versions.glob("*.py"))) == 4 with reflex.model.session() as session: @@ -139,7 +139,7 @@ def test_automigration(tmp_working_dir, monkeypatch): assert result[0].b == 4.2 # No-op - Model.automigrate() + Model.migrate(autogenerate=True) assert len(list(versions.glob("*.py"))) == 4 # drop table (AlembicSecond) @@ -148,7 +148,7 @@ def test_automigration(tmp_working_dir, monkeypatch): class AlembicThing(Model, table=True): # type: ignore t2: str = "bar" - Model.automigrate() + Model.migrate(autogenerate=True) assert len(list(versions.glob("*.py"))) == 5 with reflex.model.session() as session: @@ -166,12 +166,12 @@ def test_automigration(tmp_working_dir, monkeypatch): # changing column type not supported by default t2: int = 42 - Model.automigrate() + Model.migrate(autogenerate=True) assert len(list(versions.glob("*.py"))) == 5 # clear all metadata to avoid influencing subsequent tests sqlmodel.SQLModel.metadata.clear() # drop remaining tables - Model.automigrate() + Model.migrate(autogenerate=True) assert len(list(versions.glob("*.py"))) == 6