diff --git a/reflex/model.py b/reflex/model.py index c40fdf3f5..3daa49ded 100644 --- a/reflex/model.py +++ b/reflex/model.py @@ -15,6 +15,7 @@ import alembic.util import sqlalchemy import sqlalchemy.exc import sqlalchemy.orm +from alembic.runtime.migration import MigrationContext from reflex.base import Base from reflex.config import environment, get_config @@ -304,7 +305,11 @@ class Model(Base, sqlmodel.SQLModel): # pyright: ignore [reportGeneralTypeIssue writer = alembic.autogenerate.rewriter.Rewriter() @writer.rewrites(alembic.operations.ops.AddColumnOp) - def render_add_column_with_server_default(context, revision, op): + def render_add_column_with_server_default( + context: MigrationContext, + revision: str | None, + op: Any, + ): # Carry the sqlmodel default as server_default so that newly added # columns get the desired default value in existing rows. if op.column.default is not None and op.column.server_default is None: @@ -313,7 +318,7 @@ class Model(Base, sqlmodel.SQLModel): # pyright: ignore [reportGeneralTypeIssue ) return op - def run_autogenerate(rev, context): + def run_autogenerate(rev: str, context: MigrationContext): revision_context.run_autogenerate(rev, context) return [] @@ -355,7 +360,7 @@ class Model(Base, sqlmodel.SQLModel): # pyright: ignore [reportGeneralTypeIssue """ config, script_directory = cls._alembic_config() - def run_upgrade(rev, context): + def run_upgrade(rev: str, context: MigrationContext): return script_directory._upgrade_revs(to_rev, rev) with alembic.runtime.environment.EnvironmentContext(