diff --git a/reflex/vars/base.py b/reflex/vars/base.py index ec65c3711..8fc2530a1 100644 --- a/reflex/vars/base.py +++ b/reflex/vars/base.py @@ -40,6 +40,7 @@ from typing import ( overload, ) +from sqlalchemy.orm import DeclarativeBase from typing_extensions import ParamSpec, TypeGuard, deprecated, get_type_hints, override from reflex import constants @@ -3183,6 +3184,12 @@ def dispatch( V = TypeVar("V") BASE_TYPE = TypeVar("BASE_TYPE", bound=Base) +SQLA_TYPE = TypeVar("SQLA_TYPE", bound=DeclarativeBase) + +if TYPE_CHECKING: + from _typeshed import DataclassInstance + + DATACLASS_TYPE = TypeVar("DATACLASS_TYPE", bound=DataclassInstance) FIELD_TYPE = TypeVar("FIELD_TYPE") MAPPING_TYPE = TypeVar("MAPPING_TYPE", bound=Mapping) @@ -3230,6 +3237,23 @@ class Field(Generic[FIELD_TYPE]): self: Field[BASE_TYPE], instance: None, owner: Any ) -> ObjectVar[BASE_TYPE]: ... + @overload + def __get__( + self: Field[BASE_TYPE], instance: None, owner: Any + ) -> ObjectVar[BASE_TYPE]: ... + + @overload + def __get__( + self: Field[SQLA_TYPE], instance: None, owner: Any + ) -> ObjectVar[SQLA_TYPE]: ... + + if TYPE_CHECKING: + + @overload + def __get__( + self: Field[DATACLASS_TYPE], instance: None, owner: Any + ) -> ObjectVar[DATACLASS_TYPE]: ... + @overload def __get__(self, instance: None, owner: Any) -> Var[FIELD_TYPE]: ... diff --git a/tests/units/vars/test_object.py b/tests/units/vars/test_object.py index efcb21166..8e204c2da 100644 --- a/tests/units/vars/test_object.py +++ b/tests/units/vars/test_object.py @@ -1,4 +1,7 @@ +import dataclasses + import pytest +from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column from typing_extensions import assert_type import reflex as rx @@ -32,11 +35,32 @@ class Base(rx.Base): quantity: int = 0 +class SqlaBase(DeclarativeBase): + """Sqlalchemy declarative mapping base class.""" + + pass + + +class SqlaModel(SqlaBase): + """A sqlalchemy model with a single attribute.""" + + quantity: Mapped[int] = mapped_column() + + +@dataclasses.dataclass +class Dataclass: + """A dataclass with a single attribute.""" + + quantity: int = 0 + + class ObjectState(rx.State): - """A reflex state with bare and base objects.""" + """A reflex state with bare, base and sqlalchemy base vars.""" bare: rx.Field[Bare] = rx.field(Bare()) base: rx.Field[Base] = rx.field(Base()) + sqla: rx.Field[SqlaModel] = rx.field(SqlaModel()) + dataclass: rx.Field[Dataclass] = rx.field(Dataclass()) @pytest.mark.parametrize("type_", [Base, Bare]) @@ -100,3 +124,11 @@ def test_typing() -> None: # Base var = ObjectState.base _ = assert_type(var, ObjectVar[Base]) + + # Sqla + var = ObjectState.sqla + _ = assert_type(var, ObjectVar[SqlaModel]) + + # Dataclass + var = ObjectState.dataclass + _ = assert_type(var, ObjectVar[Dataclass])