Bare sqlalchemy metadata (#2355)
This commit is contained in:
parent
036afa951a
commit
5701a72c8f
@ -5,7 +5,7 @@ from __future__ import annotations
|
||||
import os
|
||||
from collections import defaultdict
|
||||
from pathlib import Path
|
||||
from typing import Any, Optional
|
||||
from typing import Any, ClassVar, Optional, Type, Union
|
||||
|
||||
import alembic.autogenerate
|
||||
import alembic.command
|
||||
@ -51,6 +51,88 @@ def get_engine(url: str | None = None):
|
||||
return sqlmodel.create_engine(url, echo=echo_db_query, connect_args=connect_args)
|
||||
|
||||
|
||||
SQLModelOrSqlAlchemy = Union[
|
||||
Type[sqlmodel.SQLModel], Type[sqlalchemy.orm.DeclarativeBase]
|
||||
]
|
||||
|
||||
|
||||
class ModelRegistry:
|
||||
"""Registry for all models."""
|
||||
|
||||
models: ClassVar[set[SQLModelOrSqlAlchemy]] = set()
|
||||
|
||||
# Cache the metadata to avoid re-creating it.
|
||||
_metadata: ClassVar[sqlalchemy.MetaData | None] = None
|
||||
|
||||
@classmethod
|
||||
def register(cls, model: SQLModelOrSqlAlchemy):
|
||||
"""Register a model. Can be used directly or as a decorator.
|
||||
|
||||
Args:
|
||||
model: The model to register.
|
||||
|
||||
Returns:
|
||||
The model passed in as an argument (Allows decorator usage)
|
||||
"""
|
||||
cls.models.add(model)
|
||||
return model
|
||||
|
||||
@classmethod
|
||||
def get_models(cls, include_empty: bool = False) -> set[SQLModelOrSqlAlchemy]:
|
||||
"""Get registered models.
|
||||
|
||||
Args:
|
||||
include_empty: If True, include models with empty metadata.
|
||||
|
||||
Returns:
|
||||
The registered models.
|
||||
"""
|
||||
if include_empty:
|
||||
return cls.models
|
||||
return {
|
||||
model for model in cls.models if not cls._model_metadata_is_empty(model)
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def _model_metadata_is_empty(model: SQLModelOrSqlAlchemy) -> bool:
|
||||
"""Check if the model metadata is empty.
|
||||
|
||||
Args:
|
||||
model: The model to check.
|
||||
|
||||
Returns:
|
||||
True if the model metadata is empty, False otherwise.
|
||||
"""
|
||||
return len(model.metadata.tables) == 0
|
||||
|
||||
@classmethod
|
||||
def get_metadata(cls) -> sqlalchemy.MetaData:
|
||||
"""Get the database metadata.
|
||||
|
||||
Returns:
|
||||
The database metadata.
|
||||
"""
|
||||
if cls._metadata is not None:
|
||||
return cls._metadata
|
||||
|
||||
models = cls.get_models(include_empty=False)
|
||||
|
||||
if len(models) == 1:
|
||||
metadata = next(iter(models)).metadata
|
||||
else:
|
||||
# Merge the metadata from all the models.
|
||||
# This allows mixing bare sqlalchemy models with sqlmodel models in one database.
|
||||
metadata = sqlalchemy.MetaData()
|
||||
for model in cls.get_models():
|
||||
for table in model.metadata.tables.values():
|
||||
table.to_metadata(metadata)
|
||||
|
||||
# Cache the metadata
|
||||
cls._metadata = metadata
|
||||
|
||||
return metadata
|
||||
|
||||
|
||||
class Model(Base, sqlmodel.SQLModel):
|
||||
"""Base class to define a table in the database."""
|
||||
|
||||
@ -113,7 +195,7 @@ class Model(Base, sqlmodel.SQLModel):
|
||||
def create_all():
|
||||
"""Create all the tables."""
|
||||
engine = get_engine()
|
||||
sqlmodel.SQLModel.metadata.create_all(engine)
|
||||
ModelRegistry.get_metadata().create_all(engine)
|
||||
|
||||
@staticmethod
|
||||
def get_db_engine():
|
||||
@ -224,7 +306,7 @@ class Model(Base, sqlmodel.SQLModel):
|
||||
) as env:
|
||||
env.configure(
|
||||
connection=connection,
|
||||
target_metadata=sqlmodel.SQLModel.metadata,
|
||||
target_metadata=ModelRegistry.get_metadata(),
|
||||
render_item=cls._alembic_render_item,
|
||||
process_revision_directives=writer, # type: ignore
|
||||
compare_type=False,
|
||||
@ -300,7 +382,6 @@ class Model(Base, sqlmodel.SQLModel):
|
||||
return True
|
||||
|
||||
@classmethod
|
||||
@property
|
||||
def select(cls):
|
||||
"""Select rows from the table.
|
||||
|
||||
@ -310,6 +391,9 @@ class Model(Base, sqlmodel.SQLModel):
|
||||
return sqlmodel.select(cls)
|
||||
|
||||
|
||||
ModelRegistry.register(Model)
|
||||
|
||||
|
||||
def session(url: str | None = None) -> sqlmodel.Session:
|
||||
"""Get a session to interact with the database.
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user