From 8c04e2ae374b5dbf7718a5031cad768b9aa935fb Mon Sep 17 00:00:00 2001 From: Benedikt Bartscher Date: Sat, 1 Feb 2025 21:17:12 +0100 Subject: [PATCH] allow optional hints in rx.Field annotations to resolve to the correct var type --- reflex/vars/base.py | 14 +++++++------- tests/units/vars/test_object.py | 10 ++++++++++ 2 files changed, 17 insertions(+), 7 deletions(-) diff --git a/reflex/vars/base.py b/reflex/vars/base.py index be66234cf..8609d46cc 100644 --- a/reflex/vars/base.py +++ b/reflex/vars/base.py @@ -574,7 +574,7 @@ class Var(Generic[VAR_TYPE]): @overload @classmethod - def create( # type: ignore[override] + def create( # pyright: ignore[reportOverlappingOverload] cls, value: bool, _var_data: VarData | None = None, @@ -582,7 +582,7 @@ class Var(Generic[VAR_TYPE]): @overload @classmethod - def create( # type: ignore[override] + def create( cls, value: int, _var_data: VarData | None = None, @@ -606,7 +606,7 @@ class Var(Generic[VAR_TYPE]): @overload @classmethod - def create( + def create( # pyright: ignore[reportOverlappingOverload] cls, value: None, _var_data: VarData | None = None, @@ -3183,16 +3183,16 @@ def dispatch( V = TypeVar("V") -BASE_TYPE = TypeVar("BASE_TYPE", bound=Base) -SQLA_TYPE = TypeVar("SQLA_TYPE", bound=DeclarativeBase) +BASE_TYPE = TypeVar("BASE_TYPE", bound=Base | None) +SQLA_TYPE = TypeVar("SQLA_TYPE", bound=DeclarativeBase | None) if TYPE_CHECKING: from _typeshed import DataclassInstance - DATACLASS_TYPE = TypeVar("DATACLASS_TYPE", bound=DataclassInstance) + DATACLASS_TYPE = TypeVar("DATACLASS_TYPE", bound=DataclassInstance | None) FIELD_TYPE = TypeVar("FIELD_TYPE") -MAPPING_TYPE = TypeVar("MAPPING_TYPE", bound=Mapping) +MAPPING_TYPE = TypeVar("MAPPING_TYPE", bound=Mapping | None) class Field(Generic[FIELD_TYPE]): diff --git a/tests/units/vars/test_object.py b/tests/units/vars/test_object.py index 93b0288a3..90e34be96 100644 --- a/tests/units/vars/test_object.py +++ b/tests/units/vars/test_object.py @@ -62,9 +62,13 @@ class ObjectState(rx.State): """A reflex state with bare, base and sqlalchemy base vars.""" bare: rx.Field[Bare] = rx.field(Bare()) + bare_optional: rx.Field[Bare | None] = rx.field(None) base: rx.Field[Base] = rx.field(Base()) + base_optional: rx.Field[Base | None] = rx.field(None) sqlamodel: rx.Field[SqlaModel] = rx.field(SqlaModel()) + sqlamodel_optional: rx.Field[SqlaModel | None] = rx.field(None) dataclass: rx.Field[Dataclass] = rx.field(Dataclass()) + dataclass_optional: rx.Field[Dataclass | None] = rx.field(None) base_list: rx.Field[list[Base]] = rx.field([Base()]) @@ -130,6 +134,8 @@ def test_typing() -> None: # Base var = ObjectState.base _ = assert_type(var, ObjectVar[Base]) + optional_var = ObjectState.base_optional + _ = assert_type(optional_var, ObjectVar[Base | None]) list_var = ObjectState.base_list _ = assert_type(list_var, ArrayVar[list[Base]]) list_var_0 = list_var[0] @@ -138,6 +144,8 @@ def test_typing() -> None: # Sqla var = ObjectState.sqlamodel _ = assert_type(var, ObjectVar[SqlaModel]) + optional_var = ObjectState.sqlamodel_optional + _ = assert_type(optional_var, ObjectVar[SqlaModel | None]) list_var = ObjectState.base_list _ = assert_type(list_var, ArrayVar[list[Base]]) list_var_0 = list_var[0] @@ -146,6 +154,8 @@ def test_typing() -> None: # Dataclass var = ObjectState.dataclass _ = assert_type(var, ObjectVar[Dataclass]) + optional_var = ObjectState.dataclass_optional + _ = assert_type(optional_var, ObjectVar[Dataclass | None]) list_var = ObjectState.base_list _ = assert_type(list_var, ArrayVar[list[Base]]) list_var_0 = list_var[0]