From 28b067c174496525f8b511276d65a538d493d19d Mon Sep 17 00:00:00 2001 From: Benedikt Bartscher Date: Sat, 1 Feb 2025 14:12:25 +0100 Subject: [PATCH] enable parametrized objectvar tests for sqlamodel and dataclass --- tests/units/vars/test_object.py | 23 +++++++++++++---------- 1 file changed, 13 insertions(+), 10 deletions(-) diff --git a/tests/units/vars/test_object.py b/tests/units/vars/test_object.py index 8e204c2da..df53fca5b 100644 --- a/tests/units/vars/test_object.py +++ b/tests/units/vars/test_object.py @@ -1,7 +1,7 @@ import dataclasses import pytest -from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column +from sqlalchemy.orm import DeclarativeBase, Mapped, MappedAsDataclass, mapped_column from typing_extensions import assert_type import reflex as rx @@ -35,7 +35,7 @@ class Base(rx.Base): quantity: int = 0 -class SqlaBase(DeclarativeBase): +class SqlaBase(DeclarativeBase, MappedAsDataclass): """Sqlalchemy declarative mapping base class.""" pass @@ -44,7 +44,10 @@ class SqlaBase(DeclarativeBase): class SqlaModel(SqlaBase): """A sqlalchemy model with a single attribute.""" - quantity: Mapped[int] = mapped_column() + __tablename__: str = "sqla_model" + + id: Mapped[int] = mapped_column(primary_key=True, autoincrement=True, init=False) + quantity: Mapped[int] = mapped_column(default=0) @dataclasses.dataclass @@ -59,11 +62,11 @@ class ObjectState(rx.State): bare: rx.Field[Bare] = rx.field(Bare()) base: rx.Field[Base] = rx.field(Base()) - sqla: rx.Field[SqlaModel] = rx.field(SqlaModel()) + sqlamodel: rx.Field[SqlaModel] = rx.field(SqlaModel()) dataclass: rx.Field[Dataclass] = rx.field(Dataclass()) -@pytest.mark.parametrize("type_", [Base, Bare]) +@pytest.mark.parametrize("type_", [Base, Bare, SqlaModel, Dataclass]) def test_var_create(type_: GenericType) -> None: my_object = type_() var = Var.create(my_object) @@ -73,7 +76,7 @@ def test_var_create(type_: GenericType) -> None: assert quantity._var_type is int -@pytest.mark.parametrize("type_", [Base, Bare]) +@pytest.mark.parametrize("type_", [Base, Bare, SqlaModel, Dataclass]) def test_literal_create(type_: GenericType) -> None: my_object = type_() var = LiteralObjectVar.create(my_object) @@ -83,7 +86,7 @@ def test_literal_create(type_: GenericType) -> None: assert quantity._var_type is int -@pytest.mark.parametrize("type_", [Base, Bare]) +@pytest.mark.parametrize("type_", [Base, Bare, SqlaModel, Dataclass]) def test_guess(type_: GenericType) -> None: my_object = type_() var = Var.create(my_object) @@ -94,7 +97,7 @@ def test_guess(type_: GenericType) -> None: assert quantity._var_type is int -@pytest.mark.parametrize("type_", [Base, Bare]) +@pytest.mark.parametrize("type_", [Base, Bare, SqlaModel, Dataclass]) def test_state(type_: GenericType) -> None: attr_name = type_.__name__.lower() var = getattr(ObjectState, attr_name) @@ -104,7 +107,7 @@ def test_state(type_: GenericType) -> None: assert quantity._var_type is int -@pytest.mark.parametrize("type_", [Base, Bare]) +@pytest.mark.parametrize("type_", [Base, Bare, SqlaModel, Dataclass]) def test_state_to_operation(type_: GenericType) -> None: attr_name = type_.__name__.lower() original_var = getattr(ObjectState, attr_name) @@ -126,7 +129,7 @@ def test_typing() -> None: _ = assert_type(var, ObjectVar[Base]) # Sqla - var = ObjectState.sqla + var = ObjectState.sqlamodel _ = assert_type(var, ObjectVar[SqlaModel]) # Dataclass