improve typing for ObjectVars in ArrayVars
This commit is contained in:
parent
28b067c174
commit
3b19e0f356
@ -21,9 +21,11 @@ from typing import (
|
|||||||
overload,
|
overload,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
from sqlalchemy.orm import DeclarativeBase
|
||||||
from typing_extensions import TypeVar
|
from typing_extensions import TypeVar
|
||||||
|
|
||||||
from reflex import constants
|
from reflex import constants
|
||||||
|
from reflex.base import Base
|
||||||
from reflex.constants.base import REFLEX_VAR_OPENING_TAG
|
from reflex.constants.base import REFLEX_VAR_OPENING_TAG
|
||||||
from reflex.constants.colors import Color
|
from reflex.constants.colors import Color
|
||||||
from reflex.utils.exceptions import VarTypeError
|
from reflex.utils.exceptions import VarTypeError
|
||||||
@ -53,8 +55,11 @@ from .number import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
|
from .base import BASE_TYPE, DATACLASS_TYPE, SQLA_TYPE, ObjectVar
|
||||||
|
from .function import FunctionVar
|
||||||
from .object import ObjectVar
|
from .object import ObjectVar
|
||||||
|
|
||||||
|
|
||||||
STRING_TYPE = TypeVar("STRING_TYPE", default=str)
|
STRING_TYPE = TypeVar("STRING_TYPE", default=str)
|
||||||
|
|
||||||
|
|
||||||
@ -961,6 +966,24 @@ class ArrayVar(Var[ARRAY_VAR_TYPE], python_types=(list, tuple, set)):
|
|||||||
i: int | NumberVar,
|
i: int | NumberVar,
|
||||||
) -> ObjectVar[Dict[KEY_TYPE, VALUE_TYPE]]: ...
|
) -> ObjectVar[Dict[KEY_TYPE, VALUE_TYPE]]: ...
|
||||||
|
|
||||||
|
@overload
|
||||||
|
def __getitem__(
|
||||||
|
self: ARRAY_VAR_OF_LIST_ELEMENT[BASE_TYPE],
|
||||||
|
i: int | NumberVar,
|
||||||
|
) -> ObjectVar[BASE_TYPE]: ...
|
||||||
|
|
||||||
|
@overload
|
||||||
|
def __getitem__(
|
||||||
|
self: ARRAY_VAR_OF_LIST_ELEMENT[SQLA_TYPE],
|
||||||
|
i: int | NumberVar,
|
||||||
|
) -> ObjectVar[SQLA_TYPE]: ...
|
||||||
|
|
||||||
|
@overload
|
||||||
|
def __getitem__(
|
||||||
|
self: ARRAY_VAR_OF_LIST_ELEMENT[DATACLASS_TYPE],
|
||||||
|
i: int | NumberVar,
|
||||||
|
) -> ObjectVar[DATACLASS_TYPE]: ...
|
||||||
|
|
||||||
@overload
|
@overload
|
||||||
def __getitem__(self, i: int | NumberVar) -> Var: ...
|
def __getitem__(self, i: int | NumberVar) -> Var: ...
|
||||||
|
|
||||||
@ -1648,10 +1671,6 @@ def repeat_array_operation(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
|
||||||
from .function import FunctionVar
|
|
||||||
|
|
||||||
|
|
||||||
@var_operation
|
@var_operation
|
||||||
def map_array_operation(
|
def map_array_operation(
|
||||||
array: ArrayVar[ARRAY_VAR_TYPE],
|
array: ArrayVar[ARRAY_VAR_TYPE],
|
||||||
|
@ -8,6 +8,7 @@ import reflex as rx
|
|||||||
from reflex.utils.types import GenericType
|
from reflex.utils.types import GenericType
|
||||||
from reflex.vars.base import Var
|
from reflex.vars.base import Var
|
||||||
from reflex.vars.object import LiteralObjectVar, ObjectVar
|
from reflex.vars.object import LiteralObjectVar, ObjectVar
|
||||||
|
from reflex.vars.sequence import ArrayVar
|
||||||
|
|
||||||
|
|
||||||
class Bare:
|
class Bare:
|
||||||
@ -65,6 +66,8 @@ class ObjectState(rx.State):
|
|||||||
sqlamodel: rx.Field[SqlaModel] = rx.field(SqlaModel())
|
sqlamodel: rx.Field[SqlaModel] = rx.field(SqlaModel())
|
||||||
dataclass: rx.Field[Dataclass] = rx.field(Dataclass())
|
dataclass: rx.Field[Dataclass] = rx.field(Dataclass())
|
||||||
|
|
||||||
|
base_list: rx.Field[list[Base]] = rx.field([Base()])
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("type_", [Base, Bare, SqlaModel, Dataclass])
|
@pytest.mark.parametrize("type_", [Base, Bare, SqlaModel, Dataclass])
|
||||||
def test_var_create(type_: GenericType) -> None:
|
def test_var_create(type_: GenericType) -> None:
|
||||||
@ -127,11 +130,23 @@ def test_typing() -> None:
|
|||||||
# Base
|
# Base
|
||||||
var = ObjectState.base
|
var = ObjectState.base
|
||||||
_ = assert_type(var, ObjectVar[Base])
|
_ = assert_type(var, ObjectVar[Base])
|
||||||
|
list_var = ObjectState.base_list
|
||||||
|
_ = assert_type(list_var, ArrayVar[list[Base]])
|
||||||
|
list_var_0 = list_var[0]
|
||||||
|
_ = assert_type(list_var_0, ObjectVar[Base])
|
||||||
|
|
||||||
# Sqla
|
# Sqla
|
||||||
var = ObjectState.sqlamodel
|
var = ObjectState.sqlamodel
|
||||||
_ = assert_type(var, ObjectVar[SqlaModel])
|
_ = assert_type(var, ObjectVar[SqlaModel])
|
||||||
|
list_var = ObjectState.base_list
|
||||||
|
_ = assert_type(list_var, ArrayVar[list[Base]])
|
||||||
|
list_var_0 = list_var[0]
|
||||||
|
_ = assert_type(list_var_0, ObjectVar[Base])
|
||||||
|
|
||||||
# Dataclass
|
# Dataclass
|
||||||
var = ObjectState.dataclass
|
var = ObjectState.dataclass
|
||||||
_ = assert_type(var, ObjectVar[Dataclass])
|
_ = assert_type(var, ObjectVar[Dataclass])
|
||||||
|
list_var = ObjectState.base_list
|
||||||
|
_ = assert_type(list_var, ArrayVar[list[Base]])
|
||||||
|
list_var_0 = list_var[0]
|
||||||
|
_ = assert_type(list_var_0, ObjectVar[Base])
|
||||||
|
Loading…
Reference in New Issue
Block a user