From 2b7e4d6b4e0ae41856080c2ff4db0c2f67f52476 Mon Sep 17 00:00:00 2001 From: benedikt-bartscher <31854409+benedikt-bartscher@users.noreply.github.com> Date: Mon, 3 Feb 2025 18:33:22 +0100 Subject: [PATCH] improve rx.Field ObjectVar typing for sqlalchemy and dataclasses (#4728) * improve rx.Field ObjectVar typing for sqlalchemy and dataclasses * enable parametrized objectvar tests for sqlamodel and dataclass * improve typing for ObjectVars in ArrayVars * ruffing * drop duplicate objectvar import * remove redundant overload * allow optional hints in rx.Field annotations to resolve to the correct var type --- reflex/vars/base.py | 29 ++++++++++--- reflex/vars/sequence.py | 25 ++++++++++-- tests/units/vars/test_object.py | 72 ++++++++++++++++++++++++++++++--- 3 files changed, 111 insertions(+), 15 deletions(-) diff --git a/reflex/vars/base.py b/reflex/vars/base.py index ec65c3711..8609d46cc 100644 --- a/reflex/vars/base.py +++ b/reflex/vars/base.py @@ -40,6 +40,7 @@ from typing import ( overload, ) +from sqlalchemy.orm import DeclarativeBase from typing_extensions import ParamSpec, TypeGuard, deprecated, get_type_hints, override from reflex import constants @@ -573,7 +574,7 @@ class Var(Generic[VAR_TYPE]): @overload @classmethod - def create( # type: ignore[override] + def create( # pyright: ignore[reportOverlappingOverload] cls, value: bool, _var_data: VarData | None = None, @@ -581,7 +582,7 @@ class Var(Generic[VAR_TYPE]): @overload @classmethod - def create( # type: ignore[override] + def create( cls, value: int, _var_data: VarData | None = None, @@ -605,7 +606,7 @@ class Var(Generic[VAR_TYPE]): @overload @classmethod - def create( + def create( # pyright: ignore[reportOverlappingOverload] cls, value: None, _var_data: VarData | None = None, @@ -3182,10 +3183,16 @@ def dispatch( V = TypeVar("V") -BASE_TYPE = TypeVar("BASE_TYPE", bound=Base) +BASE_TYPE = TypeVar("BASE_TYPE", bound=Base | None) +SQLA_TYPE = TypeVar("SQLA_TYPE", bound=DeclarativeBase | None) + +if TYPE_CHECKING: + from _typeshed import DataclassInstance + + DATACLASS_TYPE = TypeVar("DATACLASS_TYPE", bound=DataclassInstance | None) FIELD_TYPE = TypeVar("FIELD_TYPE") -MAPPING_TYPE = TypeVar("MAPPING_TYPE", bound=Mapping) +MAPPING_TYPE = TypeVar("MAPPING_TYPE", bound=Mapping | None) class Field(Generic[FIELD_TYPE]): @@ -3230,6 +3237,18 @@ class Field(Generic[FIELD_TYPE]): self: Field[BASE_TYPE], instance: None, owner: Any ) -> ObjectVar[BASE_TYPE]: ... + @overload + def __get__( + self: Field[SQLA_TYPE], instance: None, owner: Any + ) -> ObjectVar[SQLA_TYPE]: ... + + if TYPE_CHECKING: + + @overload + def __get__( + self: Field[DATACLASS_TYPE], instance: None, owner: Any + ) -> ObjectVar[DATACLASS_TYPE]: ... + @overload def __get__(self, instance: None, owner: Any) -> Var[FIELD_TYPE]: ... diff --git a/reflex/vars/sequence.py b/reflex/vars/sequence.py index dfd9a6af8..fb797b4ec 100644 --- a/reflex/vars/sequence.py +++ b/reflex/vars/sequence.py @@ -53,8 +53,11 @@ from .number import ( ) if TYPE_CHECKING: + from .base import BASE_TYPE, DATACLASS_TYPE, SQLA_TYPE + from .function import FunctionVar from .object import ObjectVar + STRING_TYPE = TypeVar("STRING_TYPE", default=str) @@ -961,6 +964,24 @@ class ArrayVar(Var[ARRAY_VAR_TYPE], python_types=(list, tuple, set)): i: int | NumberVar, ) -> ObjectVar[Dict[KEY_TYPE, VALUE_TYPE]]: ... + @overload + def __getitem__( + self: ARRAY_VAR_OF_LIST_ELEMENT[BASE_TYPE], + i: int | NumberVar, + ) -> ObjectVar[BASE_TYPE]: ... + + @overload + def __getitem__( + self: ARRAY_VAR_OF_LIST_ELEMENT[SQLA_TYPE], + i: int | NumberVar, + ) -> ObjectVar[SQLA_TYPE]: ... + + @overload + def __getitem__( + self: ARRAY_VAR_OF_LIST_ELEMENT[DATACLASS_TYPE], + i: int | NumberVar, + ) -> ObjectVar[DATACLASS_TYPE]: ... + @overload def __getitem__(self, i: int | NumberVar) -> Var: ... @@ -1648,10 +1669,6 @@ def repeat_array_operation( ) -if TYPE_CHECKING: - from .function import FunctionVar - - @var_operation def map_array_operation( array: ArrayVar[ARRAY_VAR_TYPE], diff --git a/tests/units/vars/test_object.py b/tests/units/vars/test_object.py index efcb21166..90e34be96 100644 --- a/tests/units/vars/test_object.py +++ b/tests/units/vars/test_object.py @@ -1,10 +1,14 @@ +import dataclasses + import pytest +from sqlalchemy.orm import DeclarativeBase, Mapped, MappedAsDataclass, mapped_column from typing_extensions import assert_type import reflex as rx from reflex.utils.types import GenericType from reflex.vars.base import Var from reflex.vars.object import LiteralObjectVar, ObjectVar +from reflex.vars.sequence import ArrayVar class Bare: @@ -32,14 +36,44 @@ class Base(rx.Base): quantity: int = 0 +class SqlaBase(DeclarativeBase, MappedAsDataclass): + """Sqlalchemy declarative mapping base class.""" + + pass + + +class SqlaModel(SqlaBase): + """A sqlalchemy model with a single attribute.""" + + __tablename__: str = "sqla_model" + + id: Mapped[int] = mapped_column(primary_key=True, autoincrement=True, init=False) + quantity: Mapped[int] = mapped_column(default=0) + + +@dataclasses.dataclass +class Dataclass: + """A dataclass with a single attribute.""" + + quantity: int = 0 + + class ObjectState(rx.State): - """A reflex state with bare and base objects.""" + """A reflex state with bare, base and sqlalchemy base vars.""" bare: rx.Field[Bare] = rx.field(Bare()) + bare_optional: rx.Field[Bare | None] = rx.field(None) base: rx.Field[Base] = rx.field(Base()) + base_optional: rx.Field[Base | None] = rx.field(None) + sqlamodel: rx.Field[SqlaModel] = rx.field(SqlaModel()) + sqlamodel_optional: rx.Field[SqlaModel | None] = rx.field(None) + dataclass: rx.Field[Dataclass] = rx.field(Dataclass()) + dataclass_optional: rx.Field[Dataclass | None] = rx.field(None) + + base_list: rx.Field[list[Base]] = rx.field([Base()]) -@pytest.mark.parametrize("type_", [Base, Bare]) +@pytest.mark.parametrize("type_", [Base, Bare, SqlaModel, Dataclass]) def test_var_create(type_: GenericType) -> None: my_object = type_() var = Var.create(my_object) @@ -49,7 +83,7 @@ def test_var_create(type_: GenericType) -> None: assert quantity._var_type is int -@pytest.mark.parametrize("type_", [Base, Bare]) +@pytest.mark.parametrize("type_", [Base, Bare, SqlaModel, Dataclass]) def test_literal_create(type_: GenericType) -> None: my_object = type_() var = LiteralObjectVar.create(my_object) @@ -59,7 +93,7 @@ def test_literal_create(type_: GenericType) -> None: assert quantity._var_type is int -@pytest.mark.parametrize("type_", [Base, Bare]) +@pytest.mark.parametrize("type_", [Base, Bare, SqlaModel, Dataclass]) def test_guess(type_: GenericType) -> None: my_object = type_() var = Var.create(my_object) @@ -70,7 +104,7 @@ def test_guess(type_: GenericType) -> None: assert quantity._var_type is int -@pytest.mark.parametrize("type_", [Base, Bare]) +@pytest.mark.parametrize("type_", [Base, Bare, SqlaModel, Dataclass]) def test_state(type_: GenericType) -> None: attr_name = type_.__name__.lower() var = getattr(ObjectState, attr_name) @@ -80,7 +114,7 @@ def test_state(type_: GenericType) -> None: assert quantity._var_type is int -@pytest.mark.parametrize("type_", [Base, Bare]) +@pytest.mark.parametrize("type_", [Base, Bare, SqlaModel, Dataclass]) def test_state_to_operation(type_: GenericType) -> None: attr_name = type_.__name__.lower() original_var = getattr(ObjectState, attr_name) @@ -100,3 +134,29 @@ def test_typing() -> None: # Base var = ObjectState.base _ = assert_type(var, ObjectVar[Base]) + optional_var = ObjectState.base_optional + _ = assert_type(optional_var, ObjectVar[Base | None]) + list_var = ObjectState.base_list + _ = assert_type(list_var, ArrayVar[list[Base]]) + list_var_0 = list_var[0] + _ = assert_type(list_var_0, ObjectVar[Base]) + + # Sqla + var = ObjectState.sqlamodel + _ = assert_type(var, ObjectVar[SqlaModel]) + optional_var = ObjectState.sqlamodel_optional + _ = assert_type(optional_var, ObjectVar[SqlaModel | None]) + list_var = ObjectState.base_list + _ = assert_type(list_var, ArrayVar[list[Base]]) + list_var_0 = list_var[0] + _ = assert_type(list_var_0, ObjectVar[Base]) + + # Dataclass + var = ObjectState.dataclass + _ = assert_type(var, ObjectVar[Dataclass]) + optional_var = ObjectState.dataclass_optional + _ = assert_type(optional_var, ObjectVar[Dataclass | None]) + list_var = ObjectState.base_list + _ = assert_type(list_var, ArrayVar[list[Base]]) + list_var_0 = list_var[0] + _ = assert_type(list_var_0, ObjectVar[Base])