improve rx.Field ObjectVar typing for sqlalchemy and dataclasses (#4728)
* improve rx.Field ObjectVar typing for sqlalchemy and dataclasses * enable parametrized objectvar tests for sqlamodel and dataclass * improve typing for ObjectVars in ArrayVars * ruffing * drop duplicate objectvar import * remove redundant overload * allow optional hints in rx.Field annotations to resolve to the correct var type
This commit is contained in:
parent
15da4e17bd
commit
2b7e4d6b4e
@ -40,6 +40,7 @@ from typing import (
|
|||||||
overload,
|
overload,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
from sqlalchemy.orm import DeclarativeBase
|
||||||
from typing_extensions import ParamSpec, TypeGuard, deprecated, get_type_hints, override
|
from typing_extensions import ParamSpec, TypeGuard, deprecated, get_type_hints, override
|
||||||
|
|
||||||
from reflex import constants
|
from reflex import constants
|
||||||
@ -573,7 +574,7 @@ class Var(Generic[VAR_TYPE]):
|
|||||||
|
|
||||||
@overload
|
@overload
|
||||||
@classmethod
|
@classmethod
|
||||||
def create( # type: ignore[override]
|
def create( # pyright: ignore[reportOverlappingOverload]
|
||||||
cls,
|
cls,
|
||||||
value: bool,
|
value: bool,
|
||||||
_var_data: VarData | None = None,
|
_var_data: VarData | None = None,
|
||||||
@ -581,7 +582,7 @@ class Var(Generic[VAR_TYPE]):
|
|||||||
|
|
||||||
@overload
|
@overload
|
||||||
@classmethod
|
@classmethod
|
||||||
def create( # type: ignore[override]
|
def create(
|
||||||
cls,
|
cls,
|
||||||
value: int,
|
value: int,
|
||||||
_var_data: VarData | None = None,
|
_var_data: VarData | None = None,
|
||||||
@ -605,7 +606,7 @@ class Var(Generic[VAR_TYPE]):
|
|||||||
|
|
||||||
@overload
|
@overload
|
||||||
@classmethod
|
@classmethod
|
||||||
def create(
|
def create( # pyright: ignore[reportOverlappingOverload]
|
||||||
cls,
|
cls,
|
||||||
value: None,
|
value: None,
|
||||||
_var_data: VarData | None = None,
|
_var_data: VarData | None = None,
|
||||||
@ -3182,10 +3183,16 @@ def dispatch(
|
|||||||
|
|
||||||
V = TypeVar("V")
|
V = TypeVar("V")
|
||||||
|
|
||||||
BASE_TYPE = TypeVar("BASE_TYPE", bound=Base)
|
BASE_TYPE = TypeVar("BASE_TYPE", bound=Base | None)
|
||||||
|
SQLA_TYPE = TypeVar("SQLA_TYPE", bound=DeclarativeBase | None)
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from _typeshed import DataclassInstance
|
||||||
|
|
||||||
|
DATACLASS_TYPE = TypeVar("DATACLASS_TYPE", bound=DataclassInstance | None)
|
||||||
|
|
||||||
FIELD_TYPE = TypeVar("FIELD_TYPE")
|
FIELD_TYPE = TypeVar("FIELD_TYPE")
|
||||||
MAPPING_TYPE = TypeVar("MAPPING_TYPE", bound=Mapping)
|
MAPPING_TYPE = TypeVar("MAPPING_TYPE", bound=Mapping | None)
|
||||||
|
|
||||||
|
|
||||||
class Field(Generic[FIELD_TYPE]):
|
class Field(Generic[FIELD_TYPE]):
|
||||||
@ -3230,6 +3237,18 @@ class Field(Generic[FIELD_TYPE]):
|
|||||||
self: Field[BASE_TYPE], instance: None, owner: Any
|
self: Field[BASE_TYPE], instance: None, owner: Any
|
||||||
) -> ObjectVar[BASE_TYPE]: ...
|
) -> ObjectVar[BASE_TYPE]: ...
|
||||||
|
|
||||||
|
@overload
|
||||||
|
def __get__(
|
||||||
|
self: Field[SQLA_TYPE], instance: None, owner: Any
|
||||||
|
) -> ObjectVar[SQLA_TYPE]: ...
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
|
||||||
|
@overload
|
||||||
|
def __get__(
|
||||||
|
self: Field[DATACLASS_TYPE], instance: None, owner: Any
|
||||||
|
) -> ObjectVar[DATACLASS_TYPE]: ...
|
||||||
|
|
||||||
@overload
|
@overload
|
||||||
def __get__(self, instance: None, owner: Any) -> Var[FIELD_TYPE]: ...
|
def __get__(self, instance: None, owner: Any) -> Var[FIELD_TYPE]: ...
|
||||||
|
|
||||||
|
@ -53,8 +53,11 @@ from .number import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
|
from .base import BASE_TYPE, DATACLASS_TYPE, SQLA_TYPE
|
||||||
|
from .function import FunctionVar
|
||||||
from .object import ObjectVar
|
from .object import ObjectVar
|
||||||
|
|
||||||
|
|
||||||
STRING_TYPE = TypeVar("STRING_TYPE", default=str)
|
STRING_TYPE = TypeVar("STRING_TYPE", default=str)
|
||||||
|
|
||||||
|
|
||||||
@ -961,6 +964,24 @@ class ArrayVar(Var[ARRAY_VAR_TYPE], python_types=(list, tuple, set)):
|
|||||||
i: int | NumberVar,
|
i: int | NumberVar,
|
||||||
) -> ObjectVar[Dict[KEY_TYPE, VALUE_TYPE]]: ...
|
) -> ObjectVar[Dict[KEY_TYPE, VALUE_TYPE]]: ...
|
||||||
|
|
||||||
|
@overload
|
||||||
|
def __getitem__(
|
||||||
|
self: ARRAY_VAR_OF_LIST_ELEMENT[BASE_TYPE],
|
||||||
|
i: int | NumberVar,
|
||||||
|
) -> ObjectVar[BASE_TYPE]: ...
|
||||||
|
|
||||||
|
@overload
|
||||||
|
def __getitem__(
|
||||||
|
self: ARRAY_VAR_OF_LIST_ELEMENT[SQLA_TYPE],
|
||||||
|
i: int | NumberVar,
|
||||||
|
) -> ObjectVar[SQLA_TYPE]: ...
|
||||||
|
|
||||||
|
@overload
|
||||||
|
def __getitem__(
|
||||||
|
self: ARRAY_VAR_OF_LIST_ELEMENT[DATACLASS_TYPE],
|
||||||
|
i: int | NumberVar,
|
||||||
|
) -> ObjectVar[DATACLASS_TYPE]: ...
|
||||||
|
|
||||||
@overload
|
@overload
|
||||||
def __getitem__(self, i: int | NumberVar) -> Var: ...
|
def __getitem__(self, i: int | NumberVar) -> Var: ...
|
||||||
|
|
||||||
@ -1648,10 +1669,6 @@ def repeat_array_operation(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
|
||||||
from .function import FunctionVar
|
|
||||||
|
|
||||||
|
|
||||||
@var_operation
|
@var_operation
|
||||||
def map_array_operation(
|
def map_array_operation(
|
||||||
array: ArrayVar[ARRAY_VAR_TYPE],
|
array: ArrayVar[ARRAY_VAR_TYPE],
|
||||||
|
@ -1,10 +1,14 @@
|
|||||||
|
import dataclasses
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
from sqlalchemy.orm import DeclarativeBase, Mapped, MappedAsDataclass, mapped_column
|
||||||
from typing_extensions import assert_type
|
from typing_extensions import assert_type
|
||||||
|
|
||||||
import reflex as rx
|
import reflex as rx
|
||||||
from reflex.utils.types import GenericType
|
from reflex.utils.types import GenericType
|
||||||
from reflex.vars.base import Var
|
from reflex.vars.base import Var
|
||||||
from reflex.vars.object import LiteralObjectVar, ObjectVar
|
from reflex.vars.object import LiteralObjectVar, ObjectVar
|
||||||
|
from reflex.vars.sequence import ArrayVar
|
||||||
|
|
||||||
|
|
||||||
class Bare:
|
class Bare:
|
||||||
@ -32,14 +36,44 @@ class Base(rx.Base):
|
|||||||
quantity: int = 0
|
quantity: int = 0
|
||||||
|
|
||||||
|
|
||||||
|
class SqlaBase(DeclarativeBase, MappedAsDataclass):
|
||||||
|
"""Sqlalchemy declarative mapping base class."""
|
||||||
|
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class SqlaModel(SqlaBase):
|
||||||
|
"""A sqlalchemy model with a single attribute."""
|
||||||
|
|
||||||
|
__tablename__: str = "sqla_model"
|
||||||
|
|
||||||
|
id: Mapped[int] = mapped_column(primary_key=True, autoincrement=True, init=False)
|
||||||
|
quantity: Mapped[int] = mapped_column(default=0)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclasses.dataclass
|
||||||
|
class Dataclass:
|
||||||
|
"""A dataclass with a single attribute."""
|
||||||
|
|
||||||
|
quantity: int = 0
|
||||||
|
|
||||||
|
|
||||||
class ObjectState(rx.State):
|
class ObjectState(rx.State):
|
||||||
"""A reflex state with bare and base objects."""
|
"""A reflex state with bare, base and sqlalchemy base vars."""
|
||||||
|
|
||||||
bare: rx.Field[Bare] = rx.field(Bare())
|
bare: rx.Field[Bare] = rx.field(Bare())
|
||||||
|
bare_optional: rx.Field[Bare | None] = rx.field(None)
|
||||||
base: rx.Field[Base] = rx.field(Base())
|
base: rx.Field[Base] = rx.field(Base())
|
||||||
|
base_optional: rx.Field[Base | None] = rx.field(None)
|
||||||
|
sqlamodel: rx.Field[SqlaModel] = rx.field(SqlaModel())
|
||||||
|
sqlamodel_optional: rx.Field[SqlaModel | None] = rx.field(None)
|
||||||
|
dataclass: rx.Field[Dataclass] = rx.field(Dataclass())
|
||||||
|
dataclass_optional: rx.Field[Dataclass | None] = rx.field(None)
|
||||||
|
|
||||||
|
base_list: rx.Field[list[Base]] = rx.field([Base()])
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("type_", [Base, Bare])
|
@pytest.mark.parametrize("type_", [Base, Bare, SqlaModel, Dataclass])
|
||||||
def test_var_create(type_: GenericType) -> None:
|
def test_var_create(type_: GenericType) -> None:
|
||||||
my_object = type_()
|
my_object = type_()
|
||||||
var = Var.create(my_object)
|
var = Var.create(my_object)
|
||||||
@ -49,7 +83,7 @@ def test_var_create(type_: GenericType) -> None:
|
|||||||
assert quantity._var_type is int
|
assert quantity._var_type is int
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("type_", [Base, Bare])
|
@pytest.mark.parametrize("type_", [Base, Bare, SqlaModel, Dataclass])
|
||||||
def test_literal_create(type_: GenericType) -> None:
|
def test_literal_create(type_: GenericType) -> None:
|
||||||
my_object = type_()
|
my_object = type_()
|
||||||
var = LiteralObjectVar.create(my_object)
|
var = LiteralObjectVar.create(my_object)
|
||||||
@ -59,7 +93,7 @@ def test_literal_create(type_: GenericType) -> None:
|
|||||||
assert quantity._var_type is int
|
assert quantity._var_type is int
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("type_", [Base, Bare])
|
@pytest.mark.parametrize("type_", [Base, Bare, SqlaModel, Dataclass])
|
||||||
def test_guess(type_: GenericType) -> None:
|
def test_guess(type_: GenericType) -> None:
|
||||||
my_object = type_()
|
my_object = type_()
|
||||||
var = Var.create(my_object)
|
var = Var.create(my_object)
|
||||||
@ -70,7 +104,7 @@ def test_guess(type_: GenericType) -> None:
|
|||||||
assert quantity._var_type is int
|
assert quantity._var_type is int
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("type_", [Base, Bare])
|
@pytest.mark.parametrize("type_", [Base, Bare, SqlaModel, Dataclass])
|
||||||
def test_state(type_: GenericType) -> None:
|
def test_state(type_: GenericType) -> None:
|
||||||
attr_name = type_.__name__.lower()
|
attr_name = type_.__name__.lower()
|
||||||
var = getattr(ObjectState, attr_name)
|
var = getattr(ObjectState, attr_name)
|
||||||
@ -80,7 +114,7 @@ def test_state(type_: GenericType) -> None:
|
|||||||
assert quantity._var_type is int
|
assert quantity._var_type is int
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("type_", [Base, Bare])
|
@pytest.mark.parametrize("type_", [Base, Bare, SqlaModel, Dataclass])
|
||||||
def test_state_to_operation(type_: GenericType) -> None:
|
def test_state_to_operation(type_: GenericType) -> None:
|
||||||
attr_name = type_.__name__.lower()
|
attr_name = type_.__name__.lower()
|
||||||
original_var = getattr(ObjectState, attr_name)
|
original_var = getattr(ObjectState, attr_name)
|
||||||
@ -100,3 +134,29 @@ def test_typing() -> None:
|
|||||||
# Base
|
# Base
|
||||||
var = ObjectState.base
|
var = ObjectState.base
|
||||||
_ = assert_type(var, ObjectVar[Base])
|
_ = assert_type(var, ObjectVar[Base])
|
||||||
|
optional_var = ObjectState.base_optional
|
||||||
|
_ = assert_type(optional_var, ObjectVar[Base | None])
|
||||||
|
list_var = ObjectState.base_list
|
||||||
|
_ = assert_type(list_var, ArrayVar[list[Base]])
|
||||||
|
list_var_0 = list_var[0]
|
||||||
|
_ = assert_type(list_var_0, ObjectVar[Base])
|
||||||
|
|
||||||
|
# Sqla
|
||||||
|
var = ObjectState.sqlamodel
|
||||||
|
_ = assert_type(var, ObjectVar[SqlaModel])
|
||||||
|
optional_var = ObjectState.sqlamodel_optional
|
||||||
|
_ = assert_type(optional_var, ObjectVar[SqlaModel | None])
|
||||||
|
list_var = ObjectState.base_list
|
||||||
|
_ = assert_type(list_var, ArrayVar[list[Base]])
|
||||||
|
list_var_0 = list_var[0]
|
||||||
|
_ = assert_type(list_var_0, ObjectVar[Base])
|
||||||
|
|
||||||
|
# Dataclass
|
||||||
|
var = ObjectState.dataclass
|
||||||
|
_ = assert_type(var, ObjectVar[Dataclass])
|
||||||
|
optional_var = ObjectState.dataclass_optional
|
||||||
|
_ = assert_type(optional_var, ObjectVar[Dataclass | None])
|
||||||
|
list_var = ObjectState.base_list
|
||||||
|
_ = assert_type(list_var, ArrayVar[list[Base]])
|
||||||
|
list_var_0 = list_var[0]
|
||||||
|
_ = assert_type(list_var_0, ObjectVar[Base])
|
||||||
|
Loading…
Reference in New Issue
Block a user