improve typing for ObjectVars in ArrayVars

This commit is contained in:
Benedikt Bartscher 2025-02-01 14:54:39 +01:00
parent 28b067c174
commit 3b19e0f356
No known key found for this signature in database
2 changed files with 38 additions and 4 deletions

View File

@ -21,9 +21,11 @@ from typing import (
overload, overload,
) )
from sqlalchemy.orm import DeclarativeBase
from typing_extensions import TypeVar from typing_extensions import TypeVar
from reflex import constants from reflex import constants
from reflex.base import Base
from reflex.constants.base import REFLEX_VAR_OPENING_TAG from reflex.constants.base import REFLEX_VAR_OPENING_TAG
from reflex.constants.colors import Color from reflex.constants.colors import Color
from reflex.utils.exceptions import VarTypeError from reflex.utils.exceptions import VarTypeError
@ -53,8 +55,11 @@ from .number import (
) )
if TYPE_CHECKING: if TYPE_CHECKING:
from .base import BASE_TYPE, DATACLASS_TYPE, SQLA_TYPE, ObjectVar
from .function import FunctionVar
from .object import ObjectVar from .object import ObjectVar
STRING_TYPE = TypeVar("STRING_TYPE", default=str) STRING_TYPE = TypeVar("STRING_TYPE", default=str)
@ -961,6 +966,24 @@ class ArrayVar(Var[ARRAY_VAR_TYPE], python_types=(list, tuple, set)):
i: int | NumberVar, i: int | NumberVar,
) -> ObjectVar[Dict[KEY_TYPE, VALUE_TYPE]]: ... ) -> ObjectVar[Dict[KEY_TYPE, VALUE_TYPE]]: ...
@overload
def __getitem__(
self: ARRAY_VAR_OF_LIST_ELEMENT[BASE_TYPE],
i: int | NumberVar,
) -> ObjectVar[BASE_TYPE]: ...
@overload
def __getitem__(
self: ARRAY_VAR_OF_LIST_ELEMENT[SQLA_TYPE],
i: int | NumberVar,
) -> ObjectVar[SQLA_TYPE]: ...
@overload
def __getitem__(
self: ARRAY_VAR_OF_LIST_ELEMENT[DATACLASS_TYPE],
i: int | NumberVar,
) -> ObjectVar[DATACLASS_TYPE]: ...
@overload @overload
def __getitem__(self, i: int | NumberVar) -> Var: ... def __getitem__(self, i: int | NumberVar) -> Var: ...
@ -1648,10 +1671,6 @@ def repeat_array_operation(
) )
if TYPE_CHECKING:
from .function import FunctionVar
@var_operation @var_operation
def map_array_operation( def map_array_operation(
array: ArrayVar[ARRAY_VAR_TYPE], array: ArrayVar[ARRAY_VAR_TYPE],

View File

@ -8,6 +8,7 @@ import reflex as rx
from reflex.utils.types import GenericType from reflex.utils.types import GenericType
from reflex.vars.base import Var from reflex.vars.base import Var
from reflex.vars.object import LiteralObjectVar, ObjectVar from reflex.vars.object import LiteralObjectVar, ObjectVar
from reflex.vars.sequence import ArrayVar
class Bare: class Bare:
@ -65,6 +66,8 @@ class ObjectState(rx.State):
sqlamodel: rx.Field[SqlaModel] = rx.field(SqlaModel()) sqlamodel: rx.Field[SqlaModel] = rx.field(SqlaModel())
dataclass: rx.Field[Dataclass] = rx.field(Dataclass()) dataclass: rx.Field[Dataclass] = rx.field(Dataclass())
base_list: rx.Field[list[Base]] = rx.field([Base()])
@pytest.mark.parametrize("type_", [Base, Bare, SqlaModel, Dataclass]) @pytest.mark.parametrize("type_", [Base, Bare, SqlaModel, Dataclass])
def test_var_create(type_: GenericType) -> None: def test_var_create(type_: GenericType) -> None:
@ -127,11 +130,23 @@ def test_typing() -> None:
# Base # Base
var = ObjectState.base var = ObjectState.base
_ = assert_type(var, ObjectVar[Base]) _ = assert_type(var, ObjectVar[Base])
list_var = ObjectState.base_list
_ = assert_type(list_var, ArrayVar[list[Base]])
list_var_0 = list_var[0]
_ = assert_type(list_var_0, ObjectVar[Base])
# Sqla # Sqla
var = ObjectState.sqlamodel var = ObjectState.sqlamodel
_ = assert_type(var, ObjectVar[SqlaModel]) _ = assert_type(var, ObjectVar[SqlaModel])
list_var = ObjectState.base_list
_ = assert_type(list_var, ArrayVar[list[Base]])
list_var_0 = list_var[0]
_ = assert_type(list_var_0, ObjectVar[Base])
# Dataclass # Dataclass
var = ObjectState.dataclass var = ObjectState.dataclass
_ = assert_type(var, ObjectVar[Dataclass]) _ = assert_type(var, ObjectVar[Dataclass])
list_var = ObjectState.base_list
_ = assert_type(list_var, ArrayVar[list[Base]])
list_var_0 = list_var[0]
_ = assert_type(list_var_0, ObjectVar[Base])