diff --git a/pynecone/model.py b/pynecone/model.py index a2e64396d..367ca74c9 100644 --- a/pynecone/model.py +++ b/pynecone/model.py @@ -1,5 +1,7 @@ """Database built into Pynecone.""" +from typing import Optional + import sqlmodel from pynecone.base import Base @@ -25,7 +27,19 @@ class Model(Base, sqlmodel.SQLModel): """Base class to define a table in the database.""" # The primary key for the table. - id: int = sqlmodel.Field(primary_key=True) + id: Optional[int] = sqlmodel.Field(primary_key=True) + + def __init_subclass__(cls): + """Drop the default primary key field if any primary key field is defined.""" + non_default_primary_key_fields = [ + field_name + for field_name, field in cls.__fields__.items() + if field_name != "id" and getattr(field.field_info, "primary_key", None) + ] + if non_default_primary_key_fields: + cls.__fields__.pop("id", None) + + super().__init_subclass__() def dict(self, **kwargs): """Convert the object to a dictionary. diff --git a/tests/test_model.py b/tests/test_model.py new file mode 100644 index 000000000..eb7c44a13 --- /dev/null +++ b/tests/test_model.py @@ -0,0 +1,51 @@ +import pytest +import sqlmodel + +from pynecone.model import Model + + +@pytest.fixture +def model_default_primary() -> Model: + """Returns a model object with no defined primary key. + + Returns: + Model: Model object. + """ + + class ChildModel(Model): + name: str + + return ChildModel(name="name") # type: ignore + + +@pytest.fixture +def model_custom_primary() -> Model: + """Returns a model object with a custom primary key. + + Returns: + Model: Model object. + """ + + class ChildModel(Model): + custom_id: int = sqlmodel.Field(default=None, primary_key=True) + name: str + + return ChildModel(name="name") # type: ignore + + +def test_default_primary_key(model_default_primary): + """Test that if a primary key is not defined a default is added. + + Args: + model_default_primary: Fixture. + """ + assert "id" in model_default_primary.__class__.__fields__ + + +def test_custom_primary_key(model_custom_primary): + """Test that if a primary key is defined no default key is added. + + Args: + model_custom_primary: Fixture. + """ + assert "id" not in model_custom_primary.__class__.__fields__