Bare sqlalchemy metadata (#2355)
This commit is contained in:
parent
036afa951a
commit
5701a72c8f
@ -5,7 +5,7 @@ from __future__ import annotations
|
|||||||
import os
|
import os
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any, Optional
|
from typing import Any, ClassVar, Optional, Type, Union
|
||||||
|
|
||||||
import alembic.autogenerate
|
import alembic.autogenerate
|
||||||
import alembic.command
|
import alembic.command
|
||||||
@ -51,6 +51,88 @@ def get_engine(url: str | None = None):
|
|||||||
return sqlmodel.create_engine(url, echo=echo_db_query, connect_args=connect_args)
|
return sqlmodel.create_engine(url, echo=echo_db_query, connect_args=connect_args)
|
||||||
|
|
||||||
|
|
||||||
|
SQLModelOrSqlAlchemy = Union[
|
||||||
|
Type[sqlmodel.SQLModel], Type[sqlalchemy.orm.DeclarativeBase]
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
class ModelRegistry:
|
||||||
|
"""Registry for all models."""
|
||||||
|
|
||||||
|
models: ClassVar[set[SQLModelOrSqlAlchemy]] = set()
|
||||||
|
|
||||||
|
# Cache the metadata to avoid re-creating it.
|
||||||
|
_metadata: ClassVar[sqlalchemy.MetaData | None] = None
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def register(cls, model: SQLModelOrSqlAlchemy):
|
||||||
|
"""Register a model. Can be used directly or as a decorator.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model: The model to register.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The model passed in as an argument (Allows decorator usage)
|
||||||
|
"""
|
||||||
|
cls.models.add(model)
|
||||||
|
return model
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_models(cls, include_empty: bool = False) -> set[SQLModelOrSqlAlchemy]:
|
||||||
|
"""Get registered models.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
include_empty: If True, include models with empty metadata.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The registered models.
|
||||||
|
"""
|
||||||
|
if include_empty:
|
||||||
|
return cls.models
|
||||||
|
return {
|
||||||
|
model for model in cls.models if not cls._model_metadata_is_empty(model)
|
||||||
|
}
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _model_metadata_is_empty(model: SQLModelOrSqlAlchemy) -> bool:
|
||||||
|
"""Check if the model metadata is empty.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model: The model to check.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if the model metadata is empty, False otherwise.
|
||||||
|
"""
|
||||||
|
return len(model.metadata.tables) == 0
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_metadata(cls) -> sqlalchemy.MetaData:
|
||||||
|
"""Get the database metadata.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The database metadata.
|
||||||
|
"""
|
||||||
|
if cls._metadata is not None:
|
||||||
|
return cls._metadata
|
||||||
|
|
||||||
|
models = cls.get_models(include_empty=False)
|
||||||
|
|
||||||
|
if len(models) == 1:
|
||||||
|
metadata = next(iter(models)).metadata
|
||||||
|
else:
|
||||||
|
# Merge the metadata from all the models.
|
||||||
|
# This allows mixing bare sqlalchemy models with sqlmodel models in one database.
|
||||||
|
metadata = sqlalchemy.MetaData()
|
||||||
|
for model in cls.get_models():
|
||||||
|
for table in model.metadata.tables.values():
|
||||||
|
table.to_metadata(metadata)
|
||||||
|
|
||||||
|
# Cache the metadata
|
||||||
|
cls._metadata = metadata
|
||||||
|
|
||||||
|
return metadata
|
||||||
|
|
||||||
|
|
||||||
class Model(Base, sqlmodel.SQLModel):
|
class Model(Base, sqlmodel.SQLModel):
|
||||||
"""Base class to define a table in the database."""
|
"""Base class to define a table in the database."""
|
||||||
|
|
||||||
@ -113,7 +195,7 @@ class Model(Base, sqlmodel.SQLModel):
|
|||||||
def create_all():
|
def create_all():
|
||||||
"""Create all the tables."""
|
"""Create all the tables."""
|
||||||
engine = get_engine()
|
engine = get_engine()
|
||||||
sqlmodel.SQLModel.metadata.create_all(engine)
|
ModelRegistry.get_metadata().create_all(engine)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_db_engine():
|
def get_db_engine():
|
||||||
@ -224,7 +306,7 @@ class Model(Base, sqlmodel.SQLModel):
|
|||||||
) as env:
|
) as env:
|
||||||
env.configure(
|
env.configure(
|
||||||
connection=connection,
|
connection=connection,
|
||||||
target_metadata=sqlmodel.SQLModel.metadata,
|
target_metadata=ModelRegistry.get_metadata(),
|
||||||
render_item=cls._alembic_render_item,
|
render_item=cls._alembic_render_item,
|
||||||
process_revision_directives=writer, # type: ignore
|
process_revision_directives=writer, # type: ignore
|
||||||
compare_type=False,
|
compare_type=False,
|
||||||
@ -300,7 +382,6 @@ class Model(Base, sqlmodel.SQLModel):
|
|||||||
return True
|
return True
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@property
|
|
||||||
def select(cls):
|
def select(cls):
|
||||||
"""Select rows from the table.
|
"""Select rows from the table.
|
||||||
|
|
||||||
@ -310,6 +391,9 @@ class Model(Base, sqlmodel.SQLModel):
|
|||||||
return sqlmodel.select(cls)
|
return sqlmodel.select(cls)
|
||||||
|
|
||||||
|
|
||||||
|
ModelRegistry.register(Model)
|
||||||
|
|
||||||
|
|
||||||
def session(url: str | None = None) -> sqlmodel.Session:
|
def session(url: str | None = None) -> sqlmodel.Session:
|
||||||
"""Get a session to interact with the database.
|
"""Get a session to interact with the database.
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user