diff --git a/reflex/vars/sequence.py b/reflex/vars/sequence.py index dfd9a6af8..b92fde7b3 100644 --- a/reflex/vars/sequence.py +++ b/reflex/vars/sequence.py @@ -21,9 +21,11 @@ from typing import ( overload, ) +from sqlalchemy.orm import DeclarativeBase from typing_extensions import TypeVar from reflex import constants +from reflex.base import Base from reflex.constants.base import REFLEX_VAR_OPENING_TAG from reflex.constants.colors import Color from reflex.utils.exceptions import VarTypeError @@ -53,8 +55,11 @@ from .number import ( ) if TYPE_CHECKING: + from .base import BASE_TYPE, DATACLASS_TYPE, SQLA_TYPE, ObjectVar + from .function import FunctionVar from .object import ObjectVar + STRING_TYPE = TypeVar("STRING_TYPE", default=str) @@ -961,6 +966,24 @@ class ArrayVar(Var[ARRAY_VAR_TYPE], python_types=(list, tuple, set)): i: int | NumberVar, ) -> ObjectVar[Dict[KEY_TYPE, VALUE_TYPE]]: ... + @overload + def __getitem__( + self: ARRAY_VAR_OF_LIST_ELEMENT[BASE_TYPE], + i: int | NumberVar, + ) -> ObjectVar[BASE_TYPE]: ... + + @overload + def __getitem__( + self: ARRAY_VAR_OF_LIST_ELEMENT[SQLA_TYPE], + i: int | NumberVar, + ) -> ObjectVar[SQLA_TYPE]: ... + + @overload + def __getitem__( + self: ARRAY_VAR_OF_LIST_ELEMENT[DATACLASS_TYPE], + i: int | NumberVar, + ) -> ObjectVar[DATACLASS_TYPE]: ... + @overload def __getitem__(self, i: int | NumberVar) -> Var: ... @@ -1648,10 +1671,6 @@ def repeat_array_operation( ) -if TYPE_CHECKING: - from .function import FunctionVar - - @var_operation def map_array_operation( array: ArrayVar[ARRAY_VAR_TYPE], diff --git a/tests/units/vars/test_object.py b/tests/units/vars/test_object.py index df53fca5b..93b0288a3 100644 --- a/tests/units/vars/test_object.py +++ b/tests/units/vars/test_object.py @@ -8,6 +8,7 @@ import reflex as rx from reflex.utils.types import GenericType from reflex.vars.base import Var from reflex.vars.object import LiteralObjectVar, ObjectVar +from reflex.vars.sequence import ArrayVar class Bare: @@ -65,6 +66,8 @@ class ObjectState(rx.State): sqlamodel: rx.Field[SqlaModel] = rx.field(SqlaModel()) dataclass: rx.Field[Dataclass] = rx.field(Dataclass()) + base_list: rx.Field[list[Base]] = rx.field([Base()]) + @pytest.mark.parametrize("type_", [Base, Bare, SqlaModel, Dataclass]) def test_var_create(type_: GenericType) -> None: @@ -127,11 +130,23 @@ def test_typing() -> None: # Base var = ObjectState.base _ = assert_type(var, ObjectVar[Base]) + list_var = ObjectState.base_list + _ = assert_type(list_var, ArrayVar[list[Base]]) + list_var_0 = list_var[0] + _ = assert_type(list_var_0, ObjectVar[Base]) # Sqla var = ObjectState.sqlamodel _ = assert_type(var, ObjectVar[SqlaModel]) + list_var = ObjectState.base_list + _ = assert_type(list_var, ArrayVar[list[Base]]) + list_var_0 = list_var[0] + _ = assert_type(list_var_0, ObjectVar[Base]) # Dataclass var = ObjectState.dataclass _ = assert_type(var, ObjectVar[Dataclass]) + list_var = ObjectState.base_list + _ = assert_type(list_var, ArrayVar[list[Base]]) + list_var_0 = list_var[0] + _ = assert_type(list_var_0, ObjectVar[Base])