Bare sqlalchemy metadata (#2355)

This commit is contained in:
benedikt-bartscher 2024-03-13 23:32:35 +01:00 committed by GitHub
parent 036afa951a
commit 5701a72c8f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -5,7 +5,7 @@ from __future__ import annotations
import os
from collections import defaultdict
from pathlib import Path
from typing import Any, Optional
from typing import Any, ClassVar, Optional, Type, Union
import alembic.autogenerate
import alembic.command
@ -51,6 +51,88 @@ def get_engine(url: str | None = None):
return sqlmodel.create_engine(url, echo=echo_db_query, connect_args=connect_args)
SQLModelOrSqlAlchemy = Union[
Type[sqlmodel.SQLModel], Type[sqlalchemy.orm.DeclarativeBase]
]
class ModelRegistry:
"""Registry for all models."""
models: ClassVar[set[SQLModelOrSqlAlchemy]] = set()
# Cache the metadata to avoid re-creating it.
_metadata: ClassVar[sqlalchemy.MetaData | None] = None
@classmethod
def register(cls, model: SQLModelOrSqlAlchemy):
"""Register a model. Can be used directly or as a decorator.
Args:
model: The model to register.
Returns:
The model passed in as an argument (Allows decorator usage)
"""
cls.models.add(model)
return model
@classmethod
def get_models(cls, include_empty: bool = False) -> set[SQLModelOrSqlAlchemy]:
"""Get registered models.
Args:
include_empty: If True, include models with empty metadata.
Returns:
The registered models.
"""
if include_empty:
return cls.models
return {
model for model in cls.models if not cls._model_metadata_is_empty(model)
}
@staticmethod
def _model_metadata_is_empty(model: SQLModelOrSqlAlchemy) -> bool:
"""Check if the model metadata is empty.
Args:
model: The model to check.
Returns:
True if the model metadata is empty, False otherwise.
"""
return len(model.metadata.tables) == 0
@classmethod
def get_metadata(cls) -> sqlalchemy.MetaData:
"""Get the database metadata.
Returns:
The database metadata.
"""
if cls._metadata is not None:
return cls._metadata
models = cls.get_models(include_empty=False)
if len(models) == 1:
metadata = next(iter(models)).metadata
else:
# Merge the metadata from all the models.
# This allows mixing bare sqlalchemy models with sqlmodel models in one database.
metadata = sqlalchemy.MetaData()
for model in cls.get_models():
for table in model.metadata.tables.values():
table.to_metadata(metadata)
# Cache the metadata
cls._metadata = metadata
return metadata
class Model(Base, sqlmodel.SQLModel):
"""Base class to define a table in the database."""
@ -113,7 +195,7 @@ class Model(Base, sqlmodel.SQLModel):
def create_all():
"""Create all the tables."""
engine = get_engine()
sqlmodel.SQLModel.metadata.create_all(engine)
ModelRegistry.get_metadata().create_all(engine)
@staticmethod
def get_db_engine():
@ -224,7 +306,7 @@ class Model(Base, sqlmodel.SQLModel):
) as env:
env.configure(
connection=connection,
target_metadata=sqlmodel.SQLModel.metadata,
target_metadata=ModelRegistry.get_metadata(),
render_item=cls._alembic_render_item,
process_revision_directives=writer, # type: ignore
compare_type=False,
@ -300,7 +382,6 @@ class Model(Base, sqlmodel.SQLModel):
return True
@classmethod
@property
def select(cls):
"""Select rows from the table.
@ -310,6 +391,9 @@ class Model(Base, sqlmodel.SQLModel):
return sqlmodel.select(cls)
ModelRegistry.register(Model)
def session(url: str | None = None) -> sqlmodel.Session:
"""Get a session to interact with the database.