improve rx.Field ObjectVar typing for sqlalchemy and dataclasses
This commit is contained in:
parent
68547dce4c
commit
3c208da1ec
@ -40,6 +40,7 @@ from typing import (
|
||||
overload,
|
||||
)
|
||||
|
||||
from sqlalchemy.orm import DeclarativeBase
|
||||
from typing_extensions import ParamSpec, TypeGuard, deprecated, get_type_hints, override
|
||||
|
||||
from reflex import constants
|
||||
@ -3183,6 +3184,12 @@ def dispatch(
|
||||
V = TypeVar("V")
|
||||
|
||||
BASE_TYPE = TypeVar("BASE_TYPE", bound=Base)
|
||||
SQLA_TYPE = TypeVar("SQLA_TYPE", bound=DeclarativeBase)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from _typeshed import DataclassInstance
|
||||
|
||||
DATACLASS_TYPE = TypeVar("DATACLASS_TYPE", bound=DataclassInstance)
|
||||
|
||||
FIELD_TYPE = TypeVar("FIELD_TYPE")
|
||||
MAPPING_TYPE = TypeVar("MAPPING_TYPE", bound=Mapping)
|
||||
@ -3230,6 +3237,23 @@ class Field(Generic[FIELD_TYPE]):
|
||||
self: Field[BASE_TYPE], instance: None, owner: Any
|
||||
) -> ObjectVar[BASE_TYPE]: ...
|
||||
|
||||
@overload
|
||||
def __get__(
|
||||
self: Field[BASE_TYPE], instance: None, owner: Any
|
||||
) -> ObjectVar[BASE_TYPE]: ...
|
||||
|
||||
@overload
|
||||
def __get__(
|
||||
self: Field[SQLA_TYPE], instance: None, owner: Any
|
||||
) -> ObjectVar[SQLA_TYPE]: ...
|
||||
|
||||
if TYPE_CHECKING:
|
||||
|
||||
@overload
|
||||
def __get__(
|
||||
self: Field[DATACLASS_TYPE], instance: None, owner: Any
|
||||
) -> ObjectVar[DATACLASS_TYPE]: ...
|
||||
|
||||
@overload
|
||||
def __get__(self, instance: None, owner: Any) -> Var[FIELD_TYPE]: ...
|
||||
|
||||
|
@ -1,4 +1,7 @@
|
||||
import dataclasses
|
||||
|
||||
import pytest
|
||||
from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column
|
||||
from typing_extensions import assert_type
|
||||
|
||||
import reflex as rx
|
||||
@ -32,11 +35,32 @@ class Base(rx.Base):
|
||||
quantity: int = 0
|
||||
|
||||
|
||||
class SqlaBase(DeclarativeBase):
|
||||
"""Sqlalchemy declarative mapping base class."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class SqlaModel(SqlaBase):
|
||||
"""A sqlalchemy model with a single attribute."""
|
||||
|
||||
quantity: Mapped[int] = mapped_column()
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class Dataclass:
|
||||
"""A dataclass with a single attribute."""
|
||||
|
||||
quantity: int = 0
|
||||
|
||||
|
||||
class ObjectState(rx.State):
|
||||
"""A reflex state with bare and base objects."""
|
||||
"""A reflex state with bare, base and sqlalchemy base vars."""
|
||||
|
||||
bare: rx.Field[Bare] = rx.field(Bare())
|
||||
base: rx.Field[Base] = rx.field(Base())
|
||||
sqla: rx.Field[SqlaModel] = rx.field(SqlaModel())
|
||||
dataclass: rx.Field[Dataclass] = rx.field(Dataclass())
|
||||
|
||||
|
||||
@pytest.mark.parametrize("type_", [Base, Bare])
|
||||
@ -100,3 +124,11 @@ def test_typing() -> None:
|
||||
# Base
|
||||
var = ObjectState.base
|
||||
_ = assert_type(var, ObjectVar[Base])
|
||||
|
||||
# Sqla
|
||||
var = ObjectState.sqla
|
||||
_ = assert_type(var, ObjectVar[SqlaModel])
|
||||
|
||||
# Dataclass
|
||||
var = ObjectState.dataclass
|
||||
_ = assert_type(var, ObjectVar[Dataclass])
|
||||
|
Loading…
Reference in New Issue
Block a user