improve rx.Field ObjectVar typing for sqlalchemy and dataclasses
This commit is contained in:
parent
68547dce4c
commit
3c208da1ec
@ -40,6 +40,7 @@ from typing import (
|
|||||||
overload,
|
overload,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
from sqlalchemy.orm import DeclarativeBase
|
||||||
from typing_extensions import ParamSpec, TypeGuard, deprecated, get_type_hints, override
|
from typing_extensions import ParamSpec, TypeGuard, deprecated, get_type_hints, override
|
||||||
|
|
||||||
from reflex import constants
|
from reflex import constants
|
||||||
@ -3183,6 +3184,12 @@ def dispatch(
|
|||||||
V = TypeVar("V")
|
V = TypeVar("V")
|
||||||
|
|
||||||
BASE_TYPE = TypeVar("BASE_TYPE", bound=Base)
|
BASE_TYPE = TypeVar("BASE_TYPE", bound=Base)
|
||||||
|
SQLA_TYPE = TypeVar("SQLA_TYPE", bound=DeclarativeBase)
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from _typeshed import DataclassInstance
|
||||||
|
|
||||||
|
DATACLASS_TYPE = TypeVar("DATACLASS_TYPE", bound=DataclassInstance)
|
||||||
|
|
||||||
FIELD_TYPE = TypeVar("FIELD_TYPE")
|
FIELD_TYPE = TypeVar("FIELD_TYPE")
|
||||||
MAPPING_TYPE = TypeVar("MAPPING_TYPE", bound=Mapping)
|
MAPPING_TYPE = TypeVar("MAPPING_TYPE", bound=Mapping)
|
||||||
@ -3230,6 +3237,23 @@ class Field(Generic[FIELD_TYPE]):
|
|||||||
self: Field[BASE_TYPE], instance: None, owner: Any
|
self: Field[BASE_TYPE], instance: None, owner: Any
|
||||||
) -> ObjectVar[BASE_TYPE]: ...
|
) -> ObjectVar[BASE_TYPE]: ...
|
||||||
|
|
||||||
|
@overload
|
||||||
|
def __get__(
|
||||||
|
self: Field[BASE_TYPE], instance: None, owner: Any
|
||||||
|
) -> ObjectVar[BASE_TYPE]: ...
|
||||||
|
|
||||||
|
@overload
|
||||||
|
def __get__(
|
||||||
|
self: Field[SQLA_TYPE], instance: None, owner: Any
|
||||||
|
) -> ObjectVar[SQLA_TYPE]: ...
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
|
||||||
|
@overload
|
||||||
|
def __get__(
|
||||||
|
self: Field[DATACLASS_TYPE], instance: None, owner: Any
|
||||||
|
) -> ObjectVar[DATACLASS_TYPE]: ...
|
||||||
|
|
||||||
@overload
|
@overload
|
||||||
def __get__(self, instance: None, owner: Any) -> Var[FIELD_TYPE]: ...
|
def __get__(self, instance: None, owner: Any) -> Var[FIELD_TYPE]: ...
|
||||||
|
|
||||||
|
@ -1,4 +1,7 @@
|
|||||||
|
import dataclasses
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column
|
||||||
from typing_extensions import assert_type
|
from typing_extensions import assert_type
|
||||||
|
|
||||||
import reflex as rx
|
import reflex as rx
|
||||||
@ -32,11 +35,32 @@ class Base(rx.Base):
|
|||||||
quantity: int = 0
|
quantity: int = 0
|
||||||
|
|
||||||
|
|
||||||
|
class SqlaBase(DeclarativeBase):
|
||||||
|
"""Sqlalchemy declarative mapping base class."""
|
||||||
|
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class SqlaModel(SqlaBase):
|
||||||
|
"""A sqlalchemy model with a single attribute."""
|
||||||
|
|
||||||
|
quantity: Mapped[int] = mapped_column()
|
||||||
|
|
||||||
|
|
||||||
|
@dataclasses.dataclass
|
||||||
|
class Dataclass:
|
||||||
|
"""A dataclass with a single attribute."""
|
||||||
|
|
||||||
|
quantity: int = 0
|
||||||
|
|
||||||
|
|
||||||
class ObjectState(rx.State):
|
class ObjectState(rx.State):
|
||||||
"""A reflex state with bare and base objects."""
|
"""A reflex state with bare, base and sqlalchemy base vars."""
|
||||||
|
|
||||||
bare: rx.Field[Bare] = rx.field(Bare())
|
bare: rx.Field[Bare] = rx.field(Bare())
|
||||||
base: rx.Field[Base] = rx.field(Base())
|
base: rx.Field[Base] = rx.field(Base())
|
||||||
|
sqla: rx.Field[SqlaModel] = rx.field(SqlaModel())
|
||||||
|
dataclass: rx.Field[Dataclass] = rx.field(Dataclass())
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("type_", [Base, Bare])
|
@pytest.mark.parametrize("type_", [Base, Bare])
|
||||||
@ -100,3 +124,11 @@ def test_typing() -> None:
|
|||||||
# Base
|
# Base
|
||||||
var = ObjectState.base
|
var = ObjectState.base
|
||||||
_ = assert_type(var, ObjectVar[Base])
|
_ = assert_type(var, ObjectVar[Base])
|
||||||
|
|
||||||
|
# Sqla
|
||||||
|
var = ObjectState.sqla
|
||||||
|
_ = assert_type(var, ObjectVar[SqlaModel])
|
||||||
|
|
||||||
|
# Dataclass
|
||||||
|
var = ObjectState.dataclass
|
||||||
|
_ = assert_type(var, ObjectVar[Dataclass])
|
||||||
|
Loading…
Reference in New Issue
Block a user