enable parametrized objectvar tests for sqlamodel and dataclass

This commit is contained in:
Benedikt Bartscher 2025-02-01 14:12:25 +01:00
parent 3c208da1ec
commit 28b067c174
No known key found for this signature in database

View File

@ -1,7 +1,7 @@
import dataclasses import dataclasses
import pytest import pytest
from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column from sqlalchemy.orm import DeclarativeBase, Mapped, MappedAsDataclass, mapped_column
from typing_extensions import assert_type from typing_extensions import assert_type
import reflex as rx import reflex as rx
@ -35,7 +35,7 @@ class Base(rx.Base):
quantity: int = 0 quantity: int = 0
class SqlaBase(DeclarativeBase): class SqlaBase(DeclarativeBase, MappedAsDataclass):
"""Sqlalchemy declarative mapping base class.""" """Sqlalchemy declarative mapping base class."""
pass pass
@ -44,7 +44,10 @@ class SqlaBase(DeclarativeBase):
class SqlaModel(SqlaBase): class SqlaModel(SqlaBase):
"""A sqlalchemy model with a single attribute.""" """A sqlalchemy model with a single attribute."""
quantity: Mapped[int] = mapped_column() __tablename__: str = "sqla_model"
id: Mapped[int] = mapped_column(primary_key=True, autoincrement=True, init=False)
quantity: Mapped[int] = mapped_column(default=0)
@dataclasses.dataclass @dataclasses.dataclass
@ -59,11 +62,11 @@ class ObjectState(rx.State):
bare: rx.Field[Bare] = rx.field(Bare()) bare: rx.Field[Bare] = rx.field(Bare())
base: rx.Field[Base] = rx.field(Base()) base: rx.Field[Base] = rx.field(Base())
sqla: rx.Field[SqlaModel] = rx.field(SqlaModel()) sqlamodel: rx.Field[SqlaModel] = rx.field(SqlaModel())
dataclass: rx.Field[Dataclass] = rx.field(Dataclass()) dataclass: rx.Field[Dataclass] = rx.field(Dataclass())
@pytest.mark.parametrize("type_", [Base, Bare]) @pytest.mark.parametrize("type_", [Base, Bare, SqlaModel, Dataclass])
def test_var_create(type_: GenericType) -> None: def test_var_create(type_: GenericType) -> None:
my_object = type_() my_object = type_()
var = Var.create(my_object) var = Var.create(my_object)
@ -73,7 +76,7 @@ def test_var_create(type_: GenericType) -> None:
assert quantity._var_type is int assert quantity._var_type is int
@pytest.mark.parametrize("type_", [Base, Bare]) @pytest.mark.parametrize("type_", [Base, Bare, SqlaModel, Dataclass])
def test_literal_create(type_: GenericType) -> None: def test_literal_create(type_: GenericType) -> None:
my_object = type_() my_object = type_()
var = LiteralObjectVar.create(my_object) var = LiteralObjectVar.create(my_object)
@ -83,7 +86,7 @@ def test_literal_create(type_: GenericType) -> None:
assert quantity._var_type is int assert quantity._var_type is int
@pytest.mark.parametrize("type_", [Base, Bare]) @pytest.mark.parametrize("type_", [Base, Bare, SqlaModel, Dataclass])
def test_guess(type_: GenericType) -> None: def test_guess(type_: GenericType) -> None:
my_object = type_() my_object = type_()
var = Var.create(my_object) var = Var.create(my_object)
@ -94,7 +97,7 @@ def test_guess(type_: GenericType) -> None:
assert quantity._var_type is int assert quantity._var_type is int
@pytest.mark.parametrize("type_", [Base, Bare]) @pytest.mark.parametrize("type_", [Base, Bare, SqlaModel, Dataclass])
def test_state(type_: GenericType) -> None: def test_state(type_: GenericType) -> None:
attr_name = type_.__name__.lower() attr_name = type_.__name__.lower()
var = getattr(ObjectState, attr_name) var = getattr(ObjectState, attr_name)
@ -104,7 +107,7 @@ def test_state(type_: GenericType) -> None:
assert quantity._var_type is int assert quantity._var_type is int
@pytest.mark.parametrize("type_", [Base, Bare]) @pytest.mark.parametrize("type_", [Base, Bare, SqlaModel, Dataclass])
def test_state_to_operation(type_: GenericType) -> None: def test_state_to_operation(type_: GenericType) -> None:
attr_name = type_.__name__.lower() attr_name = type_.__name__.lower()
original_var = getattr(ObjectState, attr_name) original_var = getattr(ObjectState, attr_name)
@ -126,7 +129,7 @@ def test_typing() -> None:
_ = assert_type(var, ObjectVar[Base]) _ = assert_type(var, ObjectVar[Base])
# Sqla # Sqla
var = ObjectState.sqla var = ObjectState.sqlamodel
_ = assert_type(var, ObjectVar[SqlaModel]) _ = assert_type(var, ObjectVar[SqlaModel])
# Dataclass # Dataclass