allow optional hints in rx.Field annotations to resolve to the correct var type

This commit is contained in:
Benedikt Bartscher 2025-02-01 21:17:12 +01:00
parent 7e6bc323c9
commit 8c04e2ae37
No known key found for this signature in database
2 changed files with 17 additions and 7 deletions

View File

@ -574,7 +574,7 @@ class Var(Generic[VAR_TYPE]):
@overload
@classmethod
def create( # type: ignore[override]
def create( # pyright: ignore[reportOverlappingOverload]
cls,
value: bool,
_var_data: VarData | None = None,
@ -582,7 +582,7 @@ class Var(Generic[VAR_TYPE]):
@overload
@classmethod
def create( # type: ignore[override]
def create(
cls,
value: int,
_var_data: VarData | None = None,
@ -606,7 +606,7 @@ class Var(Generic[VAR_TYPE]):
@overload
@classmethod
def create(
def create( # pyright: ignore[reportOverlappingOverload]
cls,
value: None,
_var_data: VarData | None = None,
@ -3183,16 +3183,16 @@ def dispatch(
V = TypeVar("V")
BASE_TYPE = TypeVar("BASE_TYPE", bound=Base)
SQLA_TYPE = TypeVar("SQLA_TYPE", bound=DeclarativeBase)
BASE_TYPE = TypeVar("BASE_TYPE", bound=Base | None)
SQLA_TYPE = TypeVar("SQLA_TYPE", bound=DeclarativeBase | None)
if TYPE_CHECKING:
from _typeshed import DataclassInstance
DATACLASS_TYPE = TypeVar("DATACLASS_TYPE", bound=DataclassInstance)
DATACLASS_TYPE = TypeVar("DATACLASS_TYPE", bound=DataclassInstance | None)
FIELD_TYPE = TypeVar("FIELD_TYPE")
MAPPING_TYPE = TypeVar("MAPPING_TYPE", bound=Mapping)
MAPPING_TYPE = TypeVar("MAPPING_TYPE", bound=Mapping | None)
class Field(Generic[FIELD_TYPE]):

View File

@ -62,9 +62,13 @@ class ObjectState(rx.State):
"""A reflex state with bare, base and sqlalchemy base vars."""
bare: rx.Field[Bare] = rx.field(Bare())
bare_optional: rx.Field[Bare | None] = rx.field(None)
base: rx.Field[Base] = rx.field(Base())
base_optional: rx.Field[Base | None] = rx.field(None)
sqlamodel: rx.Field[SqlaModel] = rx.field(SqlaModel())
sqlamodel_optional: rx.Field[SqlaModel | None] = rx.field(None)
dataclass: rx.Field[Dataclass] = rx.field(Dataclass())
dataclass_optional: rx.Field[Dataclass | None] = rx.field(None)
base_list: rx.Field[list[Base]] = rx.field([Base()])
@ -130,6 +134,8 @@ def test_typing() -> None:
# Base
var = ObjectState.base
_ = assert_type(var, ObjectVar[Base])
optional_var = ObjectState.base_optional
_ = assert_type(optional_var, ObjectVar[Base | None])
list_var = ObjectState.base_list
_ = assert_type(list_var, ArrayVar[list[Base]])
list_var_0 = list_var[0]
@ -138,6 +144,8 @@ def test_typing() -> None:
# Sqla
var = ObjectState.sqlamodel
_ = assert_type(var, ObjectVar[SqlaModel])
optional_var = ObjectState.sqlamodel_optional
_ = assert_type(optional_var, ObjectVar[SqlaModel | None])
list_var = ObjectState.base_list
_ = assert_type(list_var, ArrayVar[list[Base]])
list_var_0 = list_var[0]
@ -146,6 +154,8 @@ def test_typing() -> None:
# Dataclass
var = ObjectState.dataclass
_ = assert_type(var, ObjectVar[Dataclass])
optional_var = ObjectState.dataclass_optional
_ = assert_type(optional_var, ObjectVar[Dataclass | None])
list_var = ObjectState.base_list
_ = assert_type(list_var, ArrayVar[list[Base]])
list_var_0 = list_var[0]