improve rx.Field ObjectVar typing for sqlalchemy and dataclasses

This commit is contained in:
Benedikt Bartscher 2025-02-01 14:07:24 +01:00
parent 68547dce4c
commit 3c208da1ec
No known key found for this signature in database
2 changed files with 57 additions and 1 deletions

View File

@ -40,6 +40,7 @@ from typing import (
overload,
)
from sqlalchemy.orm import DeclarativeBase
from typing_extensions import ParamSpec, TypeGuard, deprecated, get_type_hints, override
from reflex import constants
@ -3183,6 +3184,12 @@ def dispatch(
V = TypeVar("V")
BASE_TYPE = TypeVar("BASE_TYPE", bound=Base)
SQLA_TYPE = TypeVar("SQLA_TYPE", bound=DeclarativeBase)
if TYPE_CHECKING:
from _typeshed import DataclassInstance
DATACLASS_TYPE = TypeVar("DATACLASS_TYPE", bound=DataclassInstance)
FIELD_TYPE = TypeVar("FIELD_TYPE")
MAPPING_TYPE = TypeVar("MAPPING_TYPE", bound=Mapping)
@ -3230,6 +3237,23 @@ class Field(Generic[FIELD_TYPE]):
self: Field[BASE_TYPE], instance: None, owner: Any
) -> ObjectVar[BASE_TYPE]: ...
@overload
def __get__(
self: Field[BASE_TYPE], instance: None, owner: Any
) -> ObjectVar[BASE_TYPE]: ...
@overload
def __get__(
self: Field[SQLA_TYPE], instance: None, owner: Any
) -> ObjectVar[SQLA_TYPE]: ...
if TYPE_CHECKING:
@overload
def __get__(
self: Field[DATACLASS_TYPE], instance: None, owner: Any
) -> ObjectVar[DATACLASS_TYPE]: ...
@overload
def __get__(self, instance: None, owner: Any) -> Var[FIELD_TYPE]: ...

View File

@ -1,4 +1,7 @@
import dataclasses
import pytest
from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column
from typing_extensions import assert_type
import reflex as rx
@ -32,11 +35,32 @@ class Base(rx.Base):
quantity: int = 0
class SqlaBase(DeclarativeBase):
"""Sqlalchemy declarative mapping base class."""
pass
class SqlaModel(SqlaBase):
"""A sqlalchemy model with a single attribute."""
quantity: Mapped[int] = mapped_column()
@dataclasses.dataclass
class Dataclass:
"""A dataclass with a single attribute."""
quantity: int = 0
class ObjectState(rx.State):
"""A reflex state with bare and base objects."""
"""A reflex state with bare, base and sqlalchemy base vars."""
bare: rx.Field[Bare] = rx.field(Bare())
base: rx.Field[Base] = rx.field(Base())
sqla: rx.Field[SqlaModel] = rx.field(SqlaModel())
dataclass: rx.Field[Dataclass] = rx.field(Dataclass())
@pytest.mark.parametrize("type_", [Base, Bare])
@ -100,3 +124,11 @@ def test_typing() -> None:
# Base
var = ObjectState.base
_ = assert_type(var, ObjectVar[Base])
# Sqla
var = ObjectState.sqla
_ = assert_type(var, ObjectVar[SqlaModel])
# Dataclass
var = ObjectState.dataclass
_ = assert_type(var, ObjectVar[Dataclass])