Bare sqlalchemy metadata (#2355)

This commit is contained in:
benedikt-bartscher 2024-03-13 23:32:35 +01:00 committed by GitHub
parent 036afa951a
commit 5701a72c8f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -5,7 +5,7 @@ from __future__ import annotations
import os import os
from collections import defaultdict from collections import defaultdict
from pathlib import Path from pathlib import Path
from typing import Any, Optional from typing import Any, ClassVar, Optional, Type, Union
import alembic.autogenerate import alembic.autogenerate
import alembic.command import alembic.command
@ -51,6 +51,88 @@ def get_engine(url: str | None = None):
return sqlmodel.create_engine(url, echo=echo_db_query, connect_args=connect_args) return sqlmodel.create_engine(url, echo=echo_db_query, connect_args=connect_args)
SQLModelOrSqlAlchemy = Union[
Type[sqlmodel.SQLModel], Type[sqlalchemy.orm.DeclarativeBase]
]
class ModelRegistry:
"""Registry for all models."""
models: ClassVar[set[SQLModelOrSqlAlchemy]] = set()
# Cache the metadata to avoid re-creating it.
_metadata: ClassVar[sqlalchemy.MetaData | None] = None
@classmethod
def register(cls, model: SQLModelOrSqlAlchemy):
"""Register a model. Can be used directly or as a decorator.
Args:
model: The model to register.
Returns:
The model passed in as an argument (Allows decorator usage)
"""
cls.models.add(model)
return model
@classmethod
def get_models(cls, include_empty: bool = False) -> set[SQLModelOrSqlAlchemy]:
"""Get registered models.
Args:
include_empty: If True, include models with empty metadata.
Returns:
The registered models.
"""
if include_empty:
return cls.models
return {
model for model in cls.models if not cls._model_metadata_is_empty(model)
}
@staticmethod
def _model_metadata_is_empty(model: SQLModelOrSqlAlchemy) -> bool:
"""Check if the model metadata is empty.
Args:
model: The model to check.
Returns:
True if the model metadata is empty, False otherwise.
"""
return len(model.metadata.tables) == 0
@classmethod
def get_metadata(cls) -> sqlalchemy.MetaData:
"""Get the database metadata.
Returns:
The database metadata.
"""
if cls._metadata is not None:
return cls._metadata
models = cls.get_models(include_empty=False)
if len(models) == 1:
metadata = next(iter(models)).metadata
else:
# Merge the metadata from all the models.
# This allows mixing bare sqlalchemy models with sqlmodel models in one database.
metadata = sqlalchemy.MetaData()
for model in cls.get_models():
for table in model.metadata.tables.values():
table.to_metadata(metadata)
# Cache the metadata
cls._metadata = metadata
return metadata
class Model(Base, sqlmodel.SQLModel): class Model(Base, sqlmodel.SQLModel):
"""Base class to define a table in the database.""" """Base class to define a table in the database."""
@ -113,7 +195,7 @@ class Model(Base, sqlmodel.SQLModel):
def create_all(): def create_all():
"""Create all the tables.""" """Create all the tables."""
engine = get_engine() engine = get_engine()
sqlmodel.SQLModel.metadata.create_all(engine) ModelRegistry.get_metadata().create_all(engine)
@staticmethod @staticmethod
def get_db_engine(): def get_db_engine():
@ -224,7 +306,7 @@ class Model(Base, sqlmodel.SQLModel):
) as env: ) as env:
env.configure( env.configure(
connection=connection, connection=connection,
target_metadata=sqlmodel.SQLModel.metadata, target_metadata=ModelRegistry.get_metadata(),
render_item=cls._alembic_render_item, render_item=cls._alembic_render_item,
process_revision_directives=writer, # type: ignore process_revision_directives=writer, # type: ignore
compare_type=False, compare_type=False,
@ -300,7 +382,6 @@ class Model(Base, sqlmodel.SQLModel):
return True return True
@classmethod @classmethod
@property
def select(cls): def select(cls):
"""Select rows from the table. """Select rows from the table.
@ -310,6 +391,9 @@ class Model(Base, sqlmodel.SQLModel):
return sqlmodel.select(cls) return sqlmodel.select(cls)
ModelRegistry.register(Model)
def session(url: str | None = None) -> sqlmodel.Session: def session(url: str | None = None) -> sqlmodel.Session:
"""Get a session to interact with the database. """Get a session to interact with the database.