diff --git a/reflex/model.py b/reflex/model.py index cd1b141f7..40dbc212d 100644 --- a/reflex/model.py +++ b/reflex/model.py @@ -5,7 +5,7 @@ from __future__ import annotations import os from collections import defaultdict from pathlib import Path -from typing import Any, Optional +from typing import Any, ClassVar, Optional, Type, Union import alembic.autogenerate import alembic.command @@ -51,6 +51,88 @@ def get_engine(url: str | None = None): return sqlmodel.create_engine(url, echo=echo_db_query, connect_args=connect_args) +SQLModelOrSqlAlchemy = Union[ + Type[sqlmodel.SQLModel], Type[sqlalchemy.orm.DeclarativeBase] +] + + +class ModelRegistry: + """Registry for all models.""" + + models: ClassVar[set[SQLModelOrSqlAlchemy]] = set() + + # Cache the metadata to avoid re-creating it. + _metadata: ClassVar[sqlalchemy.MetaData | None] = None + + @classmethod + def register(cls, model: SQLModelOrSqlAlchemy): + """Register a model. Can be used directly or as a decorator. + + Args: + model: The model to register. + + Returns: + The model passed in as an argument (Allows decorator usage) + """ + cls.models.add(model) + return model + + @classmethod + def get_models(cls, include_empty: bool = False) -> set[SQLModelOrSqlAlchemy]: + """Get registered models. + + Args: + include_empty: If True, include models with empty metadata. + + Returns: + The registered models. + """ + if include_empty: + return cls.models + return { + model for model in cls.models if not cls._model_metadata_is_empty(model) + } + + @staticmethod + def _model_metadata_is_empty(model: SQLModelOrSqlAlchemy) -> bool: + """Check if the model metadata is empty. + + Args: + model: The model to check. + + Returns: + True if the model metadata is empty, False otherwise. + """ + return len(model.metadata.tables) == 0 + + @classmethod + def get_metadata(cls) -> sqlalchemy.MetaData: + """Get the database metadata. + + Returns: + The database metadata. + """ + if cls._metadata is not None: + return cls._metadata + + models = cls.get_models(include_empty=False) + + if len(models) == 1: + metadata = next(iter(models)).metadata + else: + # Merge the metadata from all the models. + # This allows mixing bare sqlalchemy models with sqlmodel models in one database. + metadata = sqlalchemy.MetaData() + for model in cls.get_models(): + for table in model.metadata.tables.values(): + table.to_metadata(metadata) + + # Cache the metadata + cls._metadata = metadata + + return metadata + + class Model(Base, sqlmodel.SQLModel): """Base class to define a table in the database.""" @@ -113,7 +195,7 @@ class Model(Base, sqlmodel.SQLModel): def create_all(): """Create all the tables.""" engine = get_engine() - sqlmodel.SQLModel.metadata.create_all(engine) + ModelRegistry.get_metadata().create_all(engine) @staticmethod def get_db_engine(): @@ -224,7 +306,7 @@ class Model(Base, sqlmodel.SQLModel): ) as env: env.configure( connection=connection, - target_metadata=sqlmodel.SQLModel.metadata, + target_metadata=ModelRegistry.get_metadata(), render_item=cls._alembic_render_item, process_revision_directives=writer, # type: ignore compare_type=False, @@ -300,7 +382,6 @@ class Model(Base, sqlmodel.SQLModel): return True @classmethod - @property def select(cls): """Select rows from the table. @@ -310,6 +391,9 @@ class Model(Base, sqlmodel.SQLModel): return sqlmodel.select(cls) +ModelRegistry.register(Model) + + def session(url: str | None = None) -> sqlmodel.Session: """Get a session to interact with the database.