improve rx.Field ObjectVar typing for sqlalchemy and dataclasses (#4728)

* improve rx.Field ObjectVar typing for sqlalchemy and dataclasses

* enable parametrized objectvar tests for sqlamodel and dataclass

* improve typing for ObjectVars in ArrayVars

* ruffing

* drop duplicate objectvar import

* remove redundant overload

* allow optional hints in rx.Field annotations to resolve to the correct var type
This commit is contained in:
benedikt-bartscher 2025-02-03 18:33:22 +01:00 committed by GitHub
parent 15da4e17bd
commit 2b7e4d6b4e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 111 additions and 15 deletions

View File

@ -40,6 +40,7 @@ from typing import (
overload, overload,
) )
from sqlalchemy.orm import DeclarativeBase
from typing_extensions import ParamSpec, TypeGuard, deprecated, get_type_hints, override from typing_extensions import ParamSpec, TypeGuard, deprecated, get_type_hints, override
from reflex import constants from reflex import constants
@ -573,7 +574,7 @@ class Var(Generic[VAR_TYPE]):
@overload @overload
@classmethod @classmethod
def create( # type: ignore[override] def create( # pyright: ignore[reportOverlappingOverload]
cls, cls,
value: bool, value: bool,
_var_data: VarData | None = None, _var_data: VarData | None = None,
@ -581,7 +582,7 @@ class Var(Generic[VAR_TYPE]):
@overload @overload
@classmethod @classmethod
def create( # type: ignore[override] def create(
cls, cls,
value: int, value: int,
_var_data: VarData | None = None, _var_data: VarData | None = None,
@ -605,7 +606,7 @@ class Var(Generic[VAR_TYPE]):
@overload @overload
@classmethod @classmethod
def create( def create( # pyright: ignore[reportOverlappingOverload]
cls, cls,
value: None, value: None,
_var_data: VarData | None = None, _var_data: VarData | None = None,
@ -3182,10 +3183,16 @@ def dispatch(
V = TypeVar("V") V = TypeVar("V")
BASE_TYPE = TypeVar("BASE_TYPE", bound=Base) BASE_TYPE = TypeVar("BASE_TYPE", bound=Base | None)
SQLA_TYPE = TypeVar("SQLA_TYPE", bound=DeclarativeBase | None)
if TYPE_CHECKING:
from _typeshed import DataclassInstance
DATACLASS_TYPE = TypeVar("DATACLASS_TYPE", bound=DataclassInstance | None)
FIELD_TYPE = TypeVar("FIELD_TYPE") FIELD_TYPE = TypeVar("FIELD_TYPE")
MAPPING_TYPE = TypeVar("MAPPING_TYPE", bound=Mapping) MAPPING_TYPE = TypeVar("MAPPING_TYPE", bound=Mapping | None)
class Field(Generic[FIELD_TYPE]): class Field(Generic[FIELD_TYPE]):
@ -3230,6 +3237,18 @@ class Field(Generic[FIELD_TYPE]):
self: Field[BASE_TYPE], instance: None, owner: Any self: Field[BASE_TYPE], instance: None, owner: Any
) -> ObjectVar[BASE_TYPE]: ... ) -> ObjectVar[BASE_TYPE]: ...
@overload
def __get__(
self: Field[SQLA_TYPE], instance: None, owner: Any
) -> ObjectVar[SQLA_TYPE]: ...
if TYPE_CHECKING:
@overload
def __get__(
self: Field[DATACLASS_TYPE], instance: None, owner: Any
) -> ObjectVar[DATACLASS_TYPE]: ...
@overload @overload
def __get__(self, instance: None, owner: Any) -> Var[FIELD_TYPE]: ... def __get__(self, instance: None, owner: Any) -> Var[FIELD_TYPE]: ...

View File

@ -53,8 +53,11 @@ from .number import (
) )
if TYPE_CHECKING: if TYPE_CHECKING:
from .base import BASE_TYPE, DATACLASS_TYPE, SQLA_TYPE
from .function import FunctionVar
from .object import ObjectVar from .object import ObjectVar
STRING_TYPE = TypeVar("STRING_TYPE", default=str) STRING_TYPE = TypeVar("STRING_TYPE", default=str)
@ -961,6 +964,24 @@ class ArrayVar(Var[ARRAY_VAR_TYPE], python_types=(list, tuple, set)):
i: int | NumberVar, i: int | NumberVar,
) -> ObjectVar[Dict[KEY_TYPE, VALUE_TYPE]]: ... ) -> ObjectVar[Dict[KEY_TYPE, VALUE_TYPE]]: ...
@overload
def __getitem__(
self: ARRAY_VAR_OF_LIST_ELEMENT[BASE_TYPE],
i: int | NumberVar,
) -> ObjectVar[BASE_TYPE]: ...
@overload
def __getitem__(
self: ARRAY_VAR_OF_LIST_ELEMENT[SQLA_TYPE],
i: int | NumberVar,
) -> ObjectVar[SQLA_TYPE]: ...
@overload
def __getitem__(
self: ARRAY_VAR_OF_LIST_ELEMENT[DATACLASS_TYPE],
i: int | NumberVar,
) -> ObjectVar[DATACLASS_TYPE]: ...
@overload @overload
def __getitem__(self, i: int | NumberVar) -> Var: ... def __getitem__(self, i: int | NumberVar) -> Var: ...
@ -1648,10 +1669,6 @@ def repeat_array_operation(
) )
if TYPE_CHECKING:
from .function import FunctionVar
@var_operation @var_operation
def map_array_operation( def map_array_operation(
array: ArrayVar[ARRAY_VAR_TYPE], array: ArrayVar[ARRAY_VAR_TYPE],

View File

@ -1,10 +1,14 @@
import dataclasses
import pytest import pytest
from sqlalchemy.orm import DeclarativeBase, Mapped, MappedAsDataclass, mapped_column
from typing_extensions import assert_type from typing_extensions import assert_type
import reflex as rx import reflex as rx
from reflex.utils.types import GenericType from reflex.utils.types import GenericType
from reflex.vars.base import Var from reflex.vars.base import Var
from reflex.vars.object import LiteralObjectVar, ObjectVar from reflex.vars.object import LiteralObjectVar, ObjectVar
from reflex.vars.sequence import ArrayVar
class Bare: class Bare:
@ -32,14 +36,44 @@ class Base(rx.Base):
quantity: int = 0 quantity: int = 0
class SqlaBase(DeclarativeBase, MappedAsDataclass):
"""Sqlalchemy declarative mapping base class."""
pass
class SqlaModel(SqlaBase):
"""A sqlalchemy model with a single attribute."""
__tablename__: str = "sqla_model"
id: Mapped[int] = mapped_column(primary_key=True, autoincrement=True, init=False)
quantity: Mapped[int] = mapped_column(default=0)
@dataclasses.dataclass
class Dataclass:
"""A dataclass with a single attribute."""
quantity: int = 0
class ObjectState(rx.State): class ObjectState(rx.State):
"""A reflex state with bare and base objects.""" """A reflex state with bare, base and sqlalchemy base vars."""
bare: rx.Field[Bare] = rx.field(Bare()) bare: rx.Field[Bare] = rx.field(Bare())
bare_optional: rx.Field[Bare | None] = rx.field(None)
base: rx.Field[Base] = rx.field(Base()) base: rx.Field[Base] = rx.field(Base())
base_optional: rx.Field[Base | None] = rx.field(None)
sqlamodel: rx.Field[SqlaModel] = rx.field(SqlaModel())
sqlamodel_optional: rx.Field[SqlaModel | None] = rx.field(None)
dataclass: rx.Field[Dataclass] = rx.field(Dataclass())
dataclass_optional: rx.Field[Dataclass | None] = rx.field(None)
base_list: rx.Field[list[Base]] = rx.field([Base()])
@pytest.mark.parametrize("type_", [Base, Bare]) @pytest.mark.parametrize("type_", [Base, Bare, SqlaModel, Dataclass])
def test_var_create(type_: GenericType) -> None: def test_var_create(type_: GenericType) -> None:
my_object = type_() my_object = type_()
var = Var.create(my_object) var = Var.create(my_object)
@ -49,7 +83,7 @@ def test_var_create(type_: GenericType) -> None:
assert quantity._var_type is int assert quantity._var_type is int
@pytest.mark.parametrize("type_", [Base, Bare]) @pytest.mark.parametrize("type_", [Base, Bare, SqlaModel, Dataclass])
def test_literal_create(type_: GenericType) -> None: def test_literal_create(type_: GenericType) -> None:
my_object = type_() my_object = type_()
var = LiteralObjectVar.create(my_object) var = LiteralObjectVar.create(my_object)
@ -59,7 +93,7 @@ def test_literal_create(type_: GenericType) -> None:
assert quantity._var_type is int assert quantity._var_type is int
@pytest.mark.parametrize("type_", [Base, Bare]) @pytest.mark.parametrize("type_", [Base, Bare, SqlaModel, Dataclass])
def test_guess(type_: GenericType) -> None: def test_guess(type_: GenericType) -> None:
my_object = type_() my_object = type_()
var = Var.create(my_object) var = Var.create(my_object)
@ -70,7 +104,7 @@ def test_guess(type_: GenericType) -> None:
assert quantity._var_type is int assert quantity._var_type is int
@pytest.mark.parametrize("type_", [Base, Bare]) @pytest.mark.parametrize("type_", [Base, Bare, SqlaModel, Dataclass])
def test_state(type_: GenericType) -> None: def test_state(type_: GenericType) -> None:
attr_name = type_.__name__.lower() attr_name = type_.__name__.lower()
var = getattr(ObjectState, attr_name) var = getattr(ObjectState, attr_name)
@ -80,7 +114,7 @@ def test_state(type_: GenericType) -> None:
assert quantity._var_type is int assert quantity._var_type is int
@pytest.mark.parametrize("type_", [Base, Bare]) @pytest.mark.parametrize("type_", [Base, Bare, SqlaModel, Dataclass])
def test_state_to_operation(type_: GenericType) -> None: def test_state_to_operation(type_: GenericType) -> None:
attr_name = type_.__name__.lower() attr_name = type_.__name__.lower()
original_var = getattr(ObjectState, attr_name) original_var = getattr(ObjectState, attr_name)
@ -100,3 +134,29 @@ def test_typing() -> None:
# Base # Base
var = ObjectState.base var = ObjectState.base
_ = assert_type(var, ObjectVar[Base]) _ = assert_type(var, ObjectVar[Base])
optional_var = ObjectState.base_optional
_ = assert_type(optional_var, ObjectVar[Base | None])
list_var = ObjectState.base_list
_ = assert_type(list_var, ArrayVar[list[Base]])
list_var_0 = list_var[0]
_ = assert_type(list_var_0, ObjectVar[Base])
# Sqla
var = ObjectState.sqlamodel
_ = assert_type(var, ObjectVar[SqlaModel])
optional_var = ObjectState.sqlamodel_optional
_ = assert_type(optional_var, ObjectVar[SqlaModel | None])
list_var = ObjectState.base_list
_ = assert_type(list_var, ArrayVar[list[Base]])
list_var_0 = list_var[0]
_ = assert_type(list_var_0, ObjectVar[Base])
# Dataclass
var = ObjectState.dataclass
_ = assert_type(var, ObjectVar[Dataclass])
optional_var = ObjectState.dataclass_optional
_ = assert_type(optional_var, ObjectVar[Dataclass | None])
list_var = ObjectState.base_list
_ = assert_type(list_var, ArrayVar[list[Base]])
list_var_0 = list_var[0]
_ = assert_type(list_var_0, ObjectVar[Base])