add type hinting to existing types (#3729)
* add type hinting to existing types * dang it darglint * i cannot
This commit is contained in:
parent
129adc941a
commit
ad14f38329
@ -10,9 +10,15 @@ from typing import (
|
||||
TYPE_CHECKING,
|
||||
Any,
|
||||
Callable,
|
||||
Dict,
|
||||
Generic,
|
||||
List,
|
||||
Optional,
|
||||
Set,
|
||||
Tuple,
|
||||
Type,
|
||||
TypeVar,
|
||||
Union,
|
||||
overload,
|
||||
)
|
||||
|
||||
@ -42,13 +48,15 @@ if TYPE_CHECKING:
|
||||
from .object import ObjectVar, ToObjectOperation
|
||||
from .sequence import ArrayVar, StringVar, ToArrayOperation, ToStringOperation
|
||||
|
||||
VAR_TYPE = TypeVar("VAR_TYPE")
|
||||
|
||||
|
||||
@dataclasses.dataclass(
|
||||
eq=False,
|
||||
frozen=True,
|
||||
**{"slots": True} if sys.version_info >= (3, 10) else {},
|
||||
)
|
||||
class ImmutableVar(Var):
|
||||
class ImmutableVar(Var, Generic[VAR_TYPE]):
|
||||
"""Base class for immutable vars."""
|
||||
|
||||
# The name of the var.
|
||||
@ -405,6 +413,8 @@ class ImmutableVar(Var):
|
||||
return self.to(ArrayVar, var_type)
|
||||
if issubclass(fixed_type, str):
|
||||
return self.to(StringVar)
|
||||
if issubclass(fixed_type, Base):
|
||||
return self.to(ObjectVar, var_type)
|
||||
return self
|
||||
|
||||
|
||||
@ -531,3 +541,43 @@ def var_operation(*, output: Type[T]) -> Callable[[Callable[P, str]], Callable[P
|
||||
return wrapper
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
def unionize(*args: Type) -> Type:
|
||||
"""Unionize the types.
|
||||
|
||||
Args:
|
||||
args: The types to unionize.
|
||||
|
||||
Returns:
|
||||
The unionized types.
|
||||
"""
|
||||
if not args:
|
||||
return Any
|
||||
first, *rest = args
|
||||
if not rest:
|
||||
return first
|
||||
return Union[first, unionize(*rest)]
|
||||
|
||||
|
||||
def figure_out_type(value: Any) -> Type:
|
||||
"""Figure out the type of the value.
|
||||
|
||||
Args:
|
||||
value: The value to figure out the type of.
|
||||
|
||||
Returns:
|
||||
The type of the value.
|
||||
"""
|
||||
if isinstance(value, list):
|
||||
return List[unionize(*(figure_out_type(v) for v in value))]
|
||||
if isinstance(value, set):
|
||||
return Set[unionize(*(figure_out_type(v) for v in value))]
|
||||
if isinstance(value, tuple):
|
||||
return Tuple[unionize(*(figure_out_type(v) for v in value)), ...]
|
||||
if isinstance(value, dict):
|
||||
return Dict[
|
||||
unionize(*(figure_out_type(k) for k in value)),
|
||||
unionize(*(figure_out_type(v) for v in value.values())),
|
||||
]
|
||||
return type(value)
|
||||
|
@ -11,7 +11,7 @@ from reflex.experimental.vars.base import ImmutableVar, LiteralVar
|
||||
from reflex.vars import ImmutableVarData, Var, VarData
|
||||
|
||||
|
||||
class FunctionVar(ImmutableVar):
|
||||
class FunctionVar(ImmutableVar[Callable]):
|
||||
"""Base class for immutable function vars."""
|
||||
|
||||
def __call__(self, *args: Var | Any) -> ArgsFunctionOperation:
|
||||
|
@ -15,7 +15,7 @@ from reflex.experimental.vars.base import (
|
||||
from reflex.vars import ImmutableVarData, Var, VarData
|
||||
|
||||
|
||||
class NumberVar(ImmutableVar):
|
||||
class NumberVar(ImmutableVar[Union[int, float]]):
|
||||
"""Base class for immutable number vars."""
|
||||
|
||||
def __add__(self, other: number_types | boolean_types) -> NumberAddOperation:
|
||||
@ -693,7 +693,7 @@ class NumberTruncOperation(UnaryNumberOperation):
|
||||
return f"Math.trunc({str(value)})"
|
||||
|
||||
|
||||
class BooleanVar(ImmutableVar):
|
||||
class BooleanVar(ImmutableVar[bool]):
|
||||
"""Base class for immutable boolean vars."""
|
||||
|
||||
def __and__(self, other: bool) -> BooleanAndOperation:
|
||||
|
@ -6,23 +6,69 @@ import dataclasses
|
||||
import sys
|
||||
import typing
|
||||
from functools import cached_property
|
||||
from typing import Any, Dict, List, Tuple, Type, Union
|
||||
from inspect import isclass
|
||||
from typing import (
|
||||
Any,
|
||||
Dict,
|
||||
List,
|
||||
NoReturn,
|
||||
Tuple,
|
||||
Type,
|
||||
TypeVar,
|
||||
Union,
|
||||
get_args,
|
||||
overload,
|
||||
)
|
||||
|
||||
from reflex.experimental.vars.base import ImmutableVar, LiteralVar
|
||||
from reflex.experimental.vars.sequence import ArrayVar, unionize
|
||||
from typing_extensions import get_origin
|
||||
|
||||
from reflex.experimental.vars.base import (
|
||||
ImmutableVar,
|
||||
LiteralVar,
|
||||
figure_out_type,
|
||||
)
|
||||
from reflex.experimental.vars.number import NumberVar
|
||||
from reflex.experimental.vars.sequence import ArrayVar, StringVar
|
||||
from reflex.utils.exceptions import VarAttributeError
|
||||
from reflex.utils.types import GenericType, get_attribute_access_type
|
||||
from reflex.vars import ImmutableVarData, Var, VarData
|
||||
|
||||
OBJECT_TYPE = TypeVar("OBJECT_TYPE")
|
||||
|
||||
class ObjectVar(ImmutableVar):
|
||||
KEY_TYPE = TypeVar("KEY_TYPE")
|
||||
VALUE_TYPE = TypeVar("VALUE_TYPE")
|
||||
|
||||
ARRAY_INNER_TYPE = TypeVar("ARRAY_INNER_TYPE")
|
||||
|
||||
OTHER_KEY_TYPE = TypeVar("OTHER_KEY_TYPE")
|
||||
|
||||
|
||||
class ObjectVar(ImmutableVar[OBJECT_TYPE]):
|
||||
"""Base class for immutable object vars."""
|
||||
|
||||
@overload
|
||||
def _key_type(self: ObjectVar[Dict[KEY_TYPE, VALUE_TYPE]]) -> KEY_TYPE: ...
|
||||
|
||||
@overload
|
||||
def _key_type(self) -> Type: ...
|
||||
|
||||
def _key_type(self) -> Type:
|
||||
"""Get the type of the keys of the object.
|
||||
|
||||
Returns:
|
||||
The type of the keys of the object.
|
||||
"""
|
||||
return ImmutableVar
|
||||
fixed_type = (
|
||||
self._var_type if isclass(self._var_type) else get_origin(self._var_type)
|
||||
)
|
||||
args = get_args(self._var_type) if issubclass(fixed_type, dict) else ()
|
||||
return args[0] if args else Any
|
||||
|
||||
@overload
|
||||
def _value_type(self: ObjectVar[Dict[KEY_TYPE, VALUE_TYPE]]) -> VALUE_TYPE: ...
|
||||
|
||||
@overload
|
||||
def _value_type(self) -> Type: ...
|
||||
|
||||
def _value_type(self) -> Type:
|
||||
"""Get the type of the values of the object.
|
||||
@ -30,9 +76,21 @@ class ObjectVar(ImmutableVar):
|
||||
Returns:
|
||||
The type of the values of the object.
|
||||
"""
|
||||
return ImmutableVar
|
||||
fixed_type = (
|
||||
self._var_type if isclass(self._var_type) else get_origin(self._var_type)
|
||||
)
|
||||
args = get_args(self._var_type) if issubclass(fixed_type, dict) else ()
|
||||
return args[1] if args else Any
|
||||
|
||||
def keys(self) -> ObjectKeysOperation:
|
||||
@overload
|
||||
def keys(
|
||||
self: ObjectVar[Dict[KEY_TYPE, VALUE_TYPE]],
|
||||
) -> ArrayVar[List[KEY_TYPE]]: ...
|
||||
|
||||
@overload
|
||||
def keys(self) -> ArrayVar: ...
|
||||
|
||||
def keys(self) -> ArrayVar:
|
||||
"""Get the keys of the object.
|
||||
|
||||
Returns:
|
||||
@ -40,7 +98,15 @@ class ObjectVar(ImmutableVar):
|
||||
"""
|
||||
return ObjectKeysOperation(self)
|
||||
|
||||
def values(self) -> ObjectValuesOperation:
|
||||
@overload
|
||||
def values(
|
||||
self: ObjectVar[Dict[KEY_TYPE, VALUE_TYPE]],
|
||||
) -> ArrayVar[List[VALUE_TYPE]]: ...
|
||||
|
||||
@overload
|
||||
def values(self) -> ArrayVar: ...
|
||||
|
||||
def values(self) -> ArrayVar:
|
||||
"""Get the values of the object.
|
||||
|
||||
Returns:
|
||||
@ -48,7 +114,15 @@ class ObjectVar(ImmutableVar):
|
||||
"""
|
||||
return ObjectValuesOperation(self)
|
||||
|
||||
def entries(self) -> ObjectEntriesOperation:
|
||||
@overload
|
||||
def entries(
|
||||
self: ObjectVar[Dict[KEY_TYPE, VALUE_TYPE]],
|
||||
) -> ArrayVar[List[Tuple[KEY_TYPE, VALUE_TYPE]]]: ...
|
||||
|
||||
@overload
|
||||
def entries(self) -> ArrayVar: ...
|
||||
|
||||
def entries(self) -> ArrayVar:
|
||||
"""Get the entries of the object.
|
||||
|
||||
Returns:
|
||||
@ -67,6 +141,53 @@ class ObjectVar(ImmutableVar):
|
||||
"""
|
||||
return ObjectMergeOperation(self, other)
|
||||
|
||||
# NoReturn is used here to catch when key value is Any
|
||||
@overload
|
||||
def __getitem__(
|
||||
self: ObjectVar[Dict[KEY_TYPE, NoReturn]],
|
||||
key: Var | Any,
|
||||
) -> ImmutableVar: ...
|
||||
|
||||
@overload
|
||||
def __getitem__(
|
||||
self: (
|
||||
ObjectVar[Dict[KEY_TYPE, int]]
|
||||
| ObjectVar[Dict[KEY_TYPE, float]]
|
||||
| ObjectVar[Dict[KEY_TYPE, int | float]]
|
||||
),
|
||||
key: Var | Any,
|
||||
) -> NumberVar: ...
|
||||
|
||||
@overload
|
||||
def __getitem__(
|
||||
self: ObjectVar[Dict[KEY_TYPE, str]],
|
||||
key: Var | Any,
|
||||
) -> StringVar: ...
|
||||
|
||||
@overload
|
||||
def __getitem__(
|
||||
self: ObjectVar[Dict[KEY_TYPE, list[ARRAY_INNER_TYPE]]],
|
||||
key: Var | Any,
|
||||
) -> ArrayVar[list[ARRAY_INNER_TYPE]]: ...
|
||||
|
||||
@overload
|
||||
def __getitem__(
|
||||
self: ObjectVar[Dict[KEY_TYPE, set[ARRAY_INNER_TYPE]]],
|
||||
key: Var | Any,
|
||||
) -> ArrayVar[set[ARRAY_INNER_TYPE]]: ...
|
||||
|
||||
@overload
|
||||
def __getitem__(
|
||||
self: ObjectVar[Dict[KEY_TYPE, tuple[ARRAY_INNER_TYPE, ...]]],
|
||||
key: Var | Any,
|
||||
) -> ArrayVar[tuple[ARRAY_INNER_TYPE, ...]]: ...
|
||||
|
||||
@overload
|
||||
def __getitem__(
|
||||
self: ObjectVar[Dict[KEY_TYPE, dict[OTHER_KEY_TYPE, VALUE_TYPE]]],
|
||||
key: Var | Any,
|
||||
) -> ObjectVar[dict[OTHER_KEY_TYPE, VALUE_TYPE]]: ...
|
||||
|
||||
def __getitem__(self, key: Var | Any) -> ImmutableVar:
|
||||
"""Get an item from the object.
|
||||
|
||||
@ -78,16 +199,78 @@ class ObjectVar(ImmutableVar):
|
||||
"""
|
||||
return ObjectItemOperation(self, key).guess_type()
|
||||
|
||||
def __getattr__(self, name) -> ObjectItemOperation:
|
||||
# NoReturn is used here to catch when key value is Any
|
||||
@overload
|
||||
def __getattr__(
|
||||
self: ObjectVar[Dict[KEY_TYPE, NoReturn]],
|
||||
name: str,
|
||||
) -> ImmutableVar: ...
|
||||
|
||||
@overload
|
||||
def __getattr__(
|
||||
self: (
|
||||
ObjectVar[Dict[KEY_TYPE, int]]
|
||||
| ObjectVar[Dict[KEY_TYPE, float]]
|
||||
| ObjectVar[Dict[KEY_TYPE, int | float]]
|
||||
),
|
||||
name: str,
|
||||
) -> NumberVar: ...
|
||||
|
||||
@overload
|
||||
def __getattr__(
|
||||
self: ObjectVar[Dict[KEY_TYPE, str]],
|
||||
name: str,
|
||||
) -> StringVar: ...
|
||||
|
||||
@overload
|
||||
def __getattr__(
|
||||
self: ObjectVar[Dict[KEY_TYPE, list[ARRAY_INNER_TYPE]]],
|
||||
name: str,
|
||||
) -> ArrayVar[list[ARRAY_INNER_TYPE]]: ...
|
||||
|
||||
@overload
|
||||
def __getattr__(
|
||||
self: ObjectVar[Dict[KEY_TYPE, set[ARRAY_INNER_TYPE]]],
|
||||
name: str,
|
||||
) -> ArrayVar[set[ARRAY_INNER_TYPE]]: ...
|
||||
|
||||
@overload
|
||||
def __getattr__(
|
||||
self: ObjectVar[Dict[KEY_TYPE, tuple[ARRAY_INNER_TYPE, ...]]],
|
||||
name: str,
|
||||
) -> ArrayVar[tuple[ARRAY_INNER_TYPE, ...]]: ...
|
||||
|
||||
@overload
|
||||
def __getattr__(
|
||||
self: ObjectVar[Dict[KEY_TYPE, dict[OTHER_KEY_TYPE, VALUE_TYPE]]],
|
||||
name: str,
|
||||
) -> ObjectVar[dict[OTHER_KEY_TYPE, VALUE_TYPE]]: ...
|
||||
|
||||
def __getattr__(self, name) -> ImmutableVar:
|
||||
"""Get an attribute of the var.
|
||||
|
||||
Args:
|
||||
name: The name of the attribute.
|
||||
|
||||
Raises:
|
||||
VarAttributeError: The State var has no such attribute or may have been annotated wrongly.
|
||||
|
||||
Returns:
|
||||
The attribute of the var.
|
||||
"""
|
||||
return ObjectItemOperation(self, name)
|
||||
fixed_type = (
|
||||
self._var_type if isclass(self._var_type) else get_origin(self._var_type)
|
||||
)
|
||||
if not issubclass(fixed_type, dict):
|
||||
attribute_type = get_attribute_access_type(self._var_type, name)
|
||||
if attribute_type is None:
|
||||
raise VarAttributeError(
|
||||
f"The State var `{self._var_name}` has no attribute '{name}' or may have been annotated "
|
||||
f"wrongly."
|
||||
)
|
||||
return ObjectItemOperation(self, name, attribute_type).guess_type()
|
||||
else:
|
||||
return ObjectItemOperation(self, name).guess_type()
|
||||
|
||||
|
||||
@dataclasses.dataclass(
|
||||
@ -95,7 +278,7 @@ class ObjectVar(ImmutableVar):
|
||||
frozen=True,
|
||||
**{"slots": True} if sys.version_info >= (3, 10) else {},
|
||||
)
|
||||
class LiteralObjectVar(LiteralVar, ObjectVar):
|
||||
class LiteralObjectVar(LiteralVar, ObjectVar[OBJECT_TYPE]):
|
||||
"""Base class for immutable literal object vars."""
|
||||
|
||||
_var_value: Dict[Union[Var, Any], Union[Var, Any]] = dataclasses.field(
|
||||
@ -103,9 +286,9 @@ class LiteralObjectVar(LiteralVar, ObjectVar):
|
||||
)
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
_var_value: dict[Var | Any, Var | Any],
|
||||
_var_type: Type | None = None,
|
||||
self: LiteralObjectVar[OBJECT_TYPE],
|
||||
_var_value: OBJECT_TYPE,
|
||||
_var_type: Type[OBJECT_TYPE] | None = None,
|
||||
_var_data: VarData | None = None,
|
||||
):
|
||||
"""Initialize the object var.
|
||||
@ -117,14 +300,7 @@ class LiteralObjectVar(LiteralVar, ObjectVar):
|
||||
"""
|
||||
super(LiteralObjectVar, self).__init__(
|
||||
_var_name="",
|
||||
_var_type=(
|
||||
Dict[
|
||||
unionize(*map(type, _var_value.keys())),
|
||||
unionize(*map(type, _var_value.values())),
|
||||
]
|
||||
if _var_type is None
|
||||
else _var_type
|
||||
),
|
||||
_var_type=(figure_out_type(_var_value) if _var_type is None else _var_type),
|
||||
_var_data=ImmutableVarData.merge(_var_data),
|
||||
)
|
||||
object.__setattr__(
|
||||
@ -489,6 +665,7 @@ class ObjectItemOperation(ImmutableVar):
|
||||
self,
|
||||
value: ObjectVar,
|
||||
key: Var | Any,
|
||||
_var_type: GenericType | None = None,
|
||||
_var_data: VarData | None = None,
|
||||
):
|
||||
"""Initialize the object item operation.
|
||||
@ -500,7 +677,7 @@ class ObjectItemOperation(ImmutableVar):
|
||||
"""
|
||||
super(ObjectItemOperation, self).__init__(
|
||||
_var_name="",
|
||||
_var_type=value._value_type(),
|
||||
_var_type=value._value_type() if _var_type is None else _var_type,
|
||||
_var_data=ImmutableVarData.merge(_var_data),
|
||||
)
|
||||
object.__setattr__(self, "value", value)
|
||||
|
@ -10,7 +10,18 @@ import re
|
||||
import sys
|
||||
import typing
|
||||
from functools import cached_property
|
||||
from typing import Any, List, Set, Tuple, Type, Union, overload
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Any,
|
||||
Dict,
|
||||
List,
|
||||
Literal,
|
||||
Set,
|
||||
Tuple,
|
||||
TypeVar,
|
||||
Union,
|
||||
overload,
|
||||
)
|
||||
|
||||
from typing_extensions import get_origin
|
||||
|
||||
@ -19,6 +30,8 @@ from reflex.constants.base import REFLEX_VAR_OPENING_TAG
|
||||
from reflex.experimental.vars.base import (
|
||||
ImmutableVar,
|
||||
LiteralVar,
|
||||
figure_out_type,
|
||||
unionize,
|
||||
)
|
||||
from reflex.experimental.vars.number import (
|
||||
BooleanVar,
|
||||
@ -29,8 +42,11 @@ from reflex.experimental.vars.number import (
|
||||
from reflex.utils.types import GenericType
|
||||
from reflex.vars import ImmutableVarData, Var, VarData, _global_vars
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .object import ObjectVar
|
||||
|
||||
class StringVar(ImmutableVar):
|
||||
|
||||
class StringVar(ImmutableVar[str]):
|
||||
"""Base class for immutable string vars."""
|
||||
|
||||
def __add__(self, other: StringVar | str) -> ConcatVarOperation:
|
||||
@ -699,7 +715,17 @@ class ConcatVarOperation(StringVar):
|
||||
pass
|
||||
|
||||
|
||||
class ArrayVar(ImmutableVar):
|
||||
ARRAY_VAR_TYPE = TypeVar("ARRAY_VAR_TYPE", bound=Union[List, Tuple, Set])
|
||||
|
||||
OTHER_TUPLE = TypeVar("OTHER_TUPLE")
|
||||
|
||||
INNER_ARRAY_VAR = TypeVar("INNER_ARRAY_VAR")
|
||||
|
||||
KEY_TYPE = TypeVar("KEY_TYPE")
|
||||
VALUE_TYPE = TypeVar("VALUE_TYPE")
|
||||
|
||||
|
||||
class ArrayVar(ImmutableVar[ARRAY_VAR_TYPE]):
|
||||
"""Base class for immutable array vars."""
|
||||
|
||||
from reflex.experimental.vars.sequence import StringVar
|
||||
@ -717,7 +743,7 @@ class ArrayVar(ImmutableVar):
|
||||
|
||||
return ArrayJoinOperation(self, sep)
|
||||
|
||||
def reverse(self) -> ArrayReverseOperation:
|
||||
def reverse(self) -> ArrayVar[ARRAY_VAR_TYPE]:
|
||||
"""Reverse the array.
|
||||
|
||||
Returns:
|
||||
@ -726,14 +752,98 @@ class ArrayVar(ImmutableVar):
|
||||
return ArrayReverseOperation(self)
|
||||
|
||||
@overload
|
||||
def __getitem__(self, i: slice) -> ArraySliceOperation: ...
|
||||
def __getitem__(self, i: slice) -> ArrayVar[ARRAY_VAR_TYPE]: ...
|
||||
|
||||
@overload
|
||||
def __getitem__(
|
||||
self: (
|
||||
ArrayVar[Tuple[int, OTHER_TUPLE]]
|
||||
| ArrayVar[Tuple[float, OTHER_TUPLE]]
|
||||
| ArrayVar[Tuple[int | float, OTHER_TUPLE]]
|
||||
),
|
||||
i: Literal[0, -2],
|
||||
) -> NumberVar: ...
|
||||
|
||||
@overload
|
||||
def __getitem__(
|
||||
self: (
|
||||
ArrayVar[Tuple[OTHER_TUPLE, int]]
|
||||
| ArrayVar[Tuple[OTHER_TUPLE, float]]
|
||||
| ArrayVar[Tuple[OTHER_TUPLE, int | float]]
|
||||
),
|
||||
i: Literal[1, -1],
|
||||
) -> NumberVar: ...
|
||||
|
||||
@overload
|
||||
def __getitem__(
|
||||
self: ArrayVar[Tuple[str, OTHER_TUPLE]], i: Literal[0, -2]
|
||||
) -> StringVar: ...
|
||||
|
||||
@overload
|
||||
def __getitem__(
|
||||
self: ArrayVar[Tuple[OTHER_TUPLE, str]], i: Literal[1, -1]
|
||||
) -> StringVar: ...
|
||||
|
||||
@overload
|
||||
def __getitem__(
|
||||
self: ArrayVar[Tuple[bool, OTHER_TUPLE]], i: Literal[0, -2]
|
||||
) -> BooleanVar: ...
|
||||
|
||||
@overload
|
||||
def __getitem__(
|
||||
self: ArrayVar[Tuple[OTHER_TUPLE, bool]], i: Literal[1, -1]
|
||||
) -> BooleanVar: ...
|
||||
|
||||
@overload
|
||||
def __getitem__(
|
||||
self: (
|
||||
ARRAY_VAR_OF_LIST_ELEMENT[int]
|
||||
| ARRAY_VAR_OF_LIST_ELEMENT[float]
|
||||
| ARRAY_VAR_OF_LIST_ELEMENT[int | float]
|
||||
),
|
||||
i: int | NumberVar,
|
||||
) -> NumberVar: ...
|
||||
|
||||
@overload
|
||||
def __getitem__(
|
||||
self: ARRAY_VAR_OF_LIST_ELEMENT[str], i: int | NumberVar
|
||||
) -> StringVar: ...
|
||||
|
||||
@overload
|
||||
def __getitem__(
|
||||
self: ARRAY_VAR_OF_LIST_ELEMENT[bool], i: int | NumberVar
|
||||
) -> BooleanVar: ...
|
||||
|
||||
@overload
|
||||
def __getitem__(
|
||||
self: ARRAY_VAR_OF_LIST_ELEMENT[List[INNER_ARRAY_VAR]],
|
||||
i: int | NumberVar,
|
||||
) -> ArrayVar[List[INNER_ARRAY_VAR]]: ...
|
||||
|
||||
@overload
|
||||
def __getitem__(
|
||||
self: ARRAY_VAR_OF_LIST_ELEMENT[Set[INNER_ARRAY_VAR]],
|
||||
i: int | NumberVar,
|
||||
) -> ArrayVar[Set[INNER_ARRAY_VAR]]: ...
|
||||
|
||||
@overload
|
||||
def __getitem__(
|
||||
self: ARRAY_VAR_OF_LIST_ELEMENT[Tuple[INNER_ARRAY_VAR, ...]],
|
||||
i: int | NumberVar,
|
||||
) -> ArrayVar[Tuple[INNER_ARRAY_VAR, ...]]: ...
|
||||
|
||||
@overload
|
||||
def __getitem__(
|
||||
self: ARRAY_VAR_OF_LIST_ELEMENT[Dict[KEY_TYPE, VALUE_TYPE]],
|
||||
i: int | NumberVar,
|
||||
) -> ObjectVar[Dict[KEY_TYPE, VALUE_TYPE]]: ...
|
||||
|
||||
@overload
|
||||
def __getitem__(self, i: int | NumberVar) -> ImmutableVar: ...
|
||||
|
||||
def __getitem__(
|
||||
self, i: slice | int | NumberVar
|
||||
) -> ArraySliceOperation | ImmutableVar:
|
||||
) -> ArrayVar[ARRAY_VAR_TYPE] | ImmutableVar:
|
||||
"""Get a slice of the array.
|
||||
|
||||
Args:
|
||||
@ -756,7 +866,7 @@ class ArrayVar(ImmutableVar):
|
||||
|
||||
@overload
|
||||
@classmethod
|
||||
def range(cls, stop: int | NumberVar, /) -> RangeOperation: ...
|
||||
def range(cls, stop: int | NumberVar, /) -> ArrayVar[List[int]]: ...
|
||||
|
||||
@overload
|
||||
@classmethod
|
||||
@ -766,7 +876,7 @@ class ArrayVar(ImmutableVar):
|
||||
end: int | NumberVar,
|
||||
step: int | NumberVar = 1,
|
||||
/,
|
||||
) -> RangeOperation: ...
|
||||
) -> ArrayVar[List[int]]: ...
|
||||
|
||||
@classmethod
|
||||
def range(
|
||||
@ -774,7 +884,7 @@ class ArrayVar(ImmutableVar):
|
||||
first_endpoint: int | NumberVar,
|
||||
second_endpoint: int | NumberVar | None = None,
|
||||
step: int | NumberVar | None = None,
|
||||
) -> RangeOperation:
|
||||
) -> ArrayVar[List[int]]:
|
||||
"""Create a range of numbers.
|
||||
|
||||
Args:
|
||||
@ -794,7 +904,7 @@ class ArrayVar(ImmutableVar):
|
||||
|
||||
return RangeOperation(start, end, step or 1)
|
||||
|
||||
def contains(self, other: Any) -> ArrayContainsOperation:
|
||||
def contains(self, other: Any) -> BooleanVar:
|
||||
"""Check if the array contains an element.
|
||||
|
||||
Args:
|
||||
@ -806,12 +916,21 @@ class ArrayVar(ImmutableVar):
|
||||
return ArrayContainsOperation(self, other)
|
||||
|
||||
|
||||
LIST_ELEMENT = TypeVar("LIST_ELEMENT")
|
||||
|
||||
ARRAY_VAR_OF_LIST_ELEMENT = Union[
|
||||
ArrayVar[List[LIST_ELEMENT]],
|
||||
ArrayVar[Set[LIST_ELEMENT]],
|
||||
ArrayVar[Tuple[LIST_ELEMENT, ...]],
|
||||
]
|
||||
|
||||
|
||||
@dataclasses.dataclass(
|
||||
eq=False,
|
||||
frozen=True,
|
||||
**{"slots": True} if sys.version_info >= (3, 10) else {},
|
||||
)
|
||||
class LiteralArrayVar(LiteralVar, ArrayVar):
|
||||
class LiteralArrayVar(LiteralVar, ArrayVar[ARRAY_VAR_TYPE]):
|
||||
"""Base class for immutable literal array vars."""
|
||||
|
||||
_var_value: Union[
|
||||
@ -819,9 +938,9 @@ class LiteralArrayVar(LiteralVar, ArrayVar):
|
||||
] = dataclasses.field(default_factory=list)
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
_var_value: list[Var | Any] | tuple[Var | Any, ...] | set[Var | Any],
|
||||
_var_type: type[list] | type[tuple] | type[set] | None = None,
|
||||
self: LiteralArrayVar[ARRAY_VAR_TYPE],
|
||||
_var_value: ARRAY_VAR_TYPE,
|
||||
_var_type: type[ARRAY_VAR_TYPE] | None = None,
|
||||
_var_data: VarData | None = None,
|
||||
):
|
||||
"""Initialize the array var.
|
||||
@ -834,11 +953,7 @@ class LiteralArrayVar(LiteralVar, ArrayVar):
|
||||
super(LiteralArrayVar, self).__init__(
|
||||
_var_name="",
|
||||
_var_data=ImmutableVarData.merge(_var_data),
|
||||
_var_type=(
|
||||
List[unionize(*map(type, _var_value))]
|
||||
if _var_type is None
|
||||
else _var_type
|
||||
),
|
||||
_var_type=(figure_out_type(_var_value) if _var_type is None else _var_type),
|
||||
)
|
||||
object.__setattr__(self, "_var_value", _var_value)
|
||||
object.__delattr__(self, "_var_name")
|
||||
@ -1261,23 +1376,6 @@ class ArrayLengthOperation(ArrayToNumberOperation):
|
||||
return f"{str(self.a)}.length"
|
||||
|
||||
|
||||
def unionize(*args: Type) -> Type:
|
||||
"""Unionize the types.
|
||||
|
||||
Args:
|
||||
args: The types to unionize.
|
||||
|
||||
Returns:
|
||||
The unionized types.
|
||||
"""
|
||||
if not args:
|
||||
return Any
|
||||
first, *rest = args
|
||||
if not rest:
|
||||
return first
|
||||
return Union[first, unionize(*rest)]
|
||||
|
||||
|
||||
def is_tuple_type(t: GenericType) -> bool:
|
||||
"""Check if a type is a tuple type.
|
||||
|
||||
|
@ -1042,7 +1042,7 @@ def test_object_operations():
|
||||
|
||||
def test_type_chains():
|
||||
object_var = LiteralObjectVar({"a": 1, "b": 2, "c": 3})
|
||||
assert object_var._var_type is Dict[str, int]
|
||||
assert (object_var._key_type(), object_var._value_type()) == (str, int)
|
||||
assert (object_var.keys()._var_type, object_var.values()._var_type) == (
|
||||
List[str],
|
||||
List[int],
|
||||
@ -1061,6 +1061,31 @@ def test_type_chains():
|
||||
)
|
||||
|
||||
|
||||
def test_nested_dict():
|
||||
arr = LiteralArrayVar([{"bar": ["foo", "bar"]}], List[Dict[str, List[str]]])
|
||||
|
||||
assert (
|
||||
str(arr[0]["bar"][0]) == '[({ ["bar"] : ["foo", "bar"] })].at(0)["bar"].at(0)'
|
||||
)
|
||||
|
||||
|
||||
def nested_base():
|
||||
class Boo(Base):
|
||||
foo: str
|
||||
bar: int
|
||||
|
||||
class Foo(Base):
|
||||
bar: Boo
|
||||
baz: int
|
||||
|
||||
parent_obj = LiteralVar.create(Foo(bar=Boo(foo="bar", bar=5), baz=5))
|
||||
|
||||
assert (
|
||||
str(parent_obj.bar.foo)
|
||||
== '({ ["bar"] : ({ ["foo"] : "bar", ["bar"] : 5 }), ["baz"] : 5 })["bar"]["foo"]'
|
||||
)
|
||||
|
||||
|
||||
def test_retrival():
|
||||
var_without_data = ImmutableVar.create("test")
|
||||
assert var_without_data is not None
|
||||
|
Loading…
Reference in New Issue
Block a user