add type hinting to existing types (#3729)

* add type hinting to existing types

* dang it darglint

* i cannot
This commit is contained in:
Khaleel Al-Adhami 2024-07-31 12:01:17 -07:00 committed by GitHub
parent 129adc941a
commit ad14f38329
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 415 additions and 65 deletions

View File

@ -10,9 +10,15 @@ from typing import (
TYPE_CHECKING,
Any,
Callable,
Dict,
Generic,
List,
Optional,
Set,
Tuple,
Type,
TypeVar,
Union,
overload,
)
@ -42,13 +48,15 @@ if TYPE_CHECKING:
from .object import ObjectVar, ToObjectOperation
from .sequence import ArrayVar, StringVar, ToArrayOperation, ToStringOperation
VAR_TYPE = TypeVar("VAR_TYPE")
@dataclasses.dataclass(
eq=False,
frozen=True,
**{"slots": True} if sys.version_info >= (3, 10) else {},
)
class ImmutableVar(Var):
class ImmutableVar(Var, Generic[VAR_TYPE]):
"""Base class for immutable vars."""
# The name of the var.
@ -405,6 +413,8 @@ class ImmutableVar(Var):
return self.to(ArrayVar, var_type)
if issubclass(fixed_type, str):
return self.to(StringVar)
if issubclass(fixed_type, Base):
return self.to(ObjectVar, var_type)
return self
@ -531,3 +541,43 @@ def var_operation(*, output: Type[T]) -> Callable[[Callable[P, str]], Callable[P
return wrapper
return decorator
def unionize(*args: Type) -> Type:
"""Unionize the types.
Args:
args: The types to unionize.
Returns:
The unionized types.
"""
if not args:
return Any
first, *rest = args
if not rest:
return first
return Union[first, unionize(*rest)]
def figure_out_type(value: Any) -> Type:
"""Figure out the type of the value.
Args:
value: The value to figure out the type of.
Returns:
The type of the value.
"""
if isinstance(value, list):
return List[unionize(*(figure_out_type(v) for v in value))]
if isinstance(value, set):
return Set[unionize(*(figure_out_type(v) for v in value))]
if isinstance(value, tuple):
return Tuple[unionize(*(figure_out_type(v) for v in value)), ...]
if isinstance(value, dict):
return Dict[
unionize(*(figure_out_type(k) for k in value)),
unionize(*(figure_out_type(v) for v in value.values())),
]
return type(value)

View File

@ -11,7 +11,7 @@ from reflex.experimental.vars.base import ImmutableVar, LiteralVar
from reflex.vars import ImmutableVarData, Var, VarData
class FunctionVar(ImmutableVar):
class FunctionVar(ImmutableVar[Callable]):
"""Base class for immutable function vars."""
def __call__(self, *args: Var | Any) -> ArgsFunctionOperation:

View File

@ -15,7 +15,7 @@ from reflex.experimental.vars.base import (
from reflex.vars import ImmutableVarData, Var, VarData
class NumberVar(ImmutableVar):
class NumberVar(ImmutableVar[Union[int, float]]):
"""Base class for immutable number vars."""
def __add__(self, other: number_types | boolean_types) -> NumberAddOperation:
@ -693,7 +693,7 @@ class NumberTruncOperation(UnaryNumberOperation):
return f"Math.trunc({str(value)})"
class BooleanVar(ImmutableVar):
class BooleanVar(ImmutableVar[bool]):
"""Base class for immutable boolean vars."""
def __and__(self, other: bool) -> BooleanAndOperation:

View File

@ -6,23 +6,69 @@ import dataclasses
import sys
import typing
from functools import cached_property
from typing import Any, Dict, List, Tuple, Type, Union
from inspect import isclass
from typing import (
Any,
Dict,
List,
NoReturn,
Tuple,
Type,
TypeVar,
Union,
get_args,
overload,
)
from reflex.experimental.vars.base import ImmutableVar, LiteralVar
from reflex.experimental.vars.sequence import ArrayVar, unionize
from typing_extensions import get_origin
from reflex.experimental.vars.base import (
ImmutableVar,
LiteralVar,
figure_out_type,
)
from reflex.experimental.vars.number import NumberVar
from reflex.experimental.vars.sequence import ArrayVar, StringVar
from reflex.utils.exceptions import VarAttributeError
from reflex.utils.types import GenericType, get_attribute_access_type
from reflex.vars import ImmutableVarData, Var, VarData
OBJECT_TYPE = TypeVar("OBJECT_TYPE")
class ObjectVar(ImmutableVar):
KEY_TYPE = TypeVar("KEY_TYPE")
VALUE_TYPE = TypeVar("VALUE_TYPE")
ARRAY_INNER_TYPE = TypeVar("ARRAY_INNER_TYPE")
OTHER_KEY_TYPE = TypeVar("OTHER_KEY_TYPE")
class ObjectVar(ImmutableVar[OBJECT_TYPE]):
"""Base class for immutable object vars."""
@overload
def _key_type(self: ObjectVar[Dict[KEY_TYPE, VALUE_TYPE]]) -> KEY_TYPE: ...
@overload
def _key_type(self) -> Type: ...
def _key_type(self) -> Type:
"""Get the type of the keys of the object.
Returns:
The type of the keys of the object.
"""
return ImmutableVar
fixed_type = (
self._var_type if isclass(self._var_type) else get_origin(self._var_type)
)
args = get_args(self._var_type) if issubclass(fixed_type, dict) else ()
return args[0] if args else Any
@overload
def _value_type(self: ObjectVar[Dict[KEY_TYPE, VALUE_TYPE]]) -> VALUE_TYPE: ...
@overload
def _value_type(self) -> Type: ...
def _value_type(self) -> Type:
"""Get the type of the values of the object.
@ -30,9 +76,21 @@ class ObjectVar(ImmutableVar):
Returns:
The type of the values of the object.
"""
return ImmutableVar
fixed_type = (
self._var_type if isclass(self._var_type) else get_origin(self._var_type)
)
args = get_args(self._var_type) if issubclass(fixed_type, dict) else ()
return args[1] if args else Any
def keys(self) -> ObjectKeysOperation:
@overload
def keys(
self: ObjectVar[Dict[KEY_TYPE, VALUE_TYPE]],
) -> ArrayVar[List[KEY_TYPE]]: ...
@overload
def keys(self) -> ArrayVar: ...
def keys(self) -> ArrayVar:
"""Get the keys of the object.
Returns:
@ -40,7 +98,15 @@ class ObjectVar(ImmutableVar):
"""
return ObjectKeysOperation(self)
def values(self) -> ObjectValuesOperation:
@overload
def values(
self: ObjectVar[Dict[KEY_TYPE, VALUE_TYPE]],
) -> ArrayVar[List[VALUE_TYPE]]: ...
@overload
def values(self) -> ArrayVar: ...
def values(self) -> ArrayVar:
"""Get the values of the object.
Returns:
@ -48,7 +114,15 @@ class ObjectVar(ImmutableVar):
"""
return ObjectValuesOperation(self)
def entries(self) -> ObjectEntriesOperation:
@overload
def entries(
self: ObjectVar[Dict[KEY_TYPE, VALUE_TYPE]],
) -> ArrayVar[List[Tuple[KEY_TYPE, VALUE_TYPE]]]: ...
@overload
def entries(self) -> ArrayVar: ...
def entries(self) -> ArrayVar:
"""Get the entries of the object.
Returns:
@ -67,6 +141,53 @@ class ObjectVar(ImmutableVar):
"""
return ObjectMergeOperation(self, other)
# NoReturn is used here to catch when key value is Any
@overload
def __getitem__(
self: ObjectVar[Dict[KEY_TYPE, NoReturn]],
key: Var | Any,
) -> ImmutableVar: ...
@overload
def __getitem__(
self: (
ObjectVar[Dict[KEY_TYPE, int]]
| ObjectVar[Dict[KEY_TYPE, float]]
| ObjectVar[Dict[KEY_TYPE, int | float]]
),
key: Var | Any,
) -> NumberVar: ...
@overload
def __getitem__(
self: ObjectVar[Dict[KEY_TYPE, str]],
key: Var | Any,
) -> StringVar: ...
@overload
def __getitem__(
self: ObjectVar[Dict[KEY_TYPE, list[ARRAY_INNER_TYPE]]],
key: Var | Any,
) -> ArrayVar[list[ARRAY_INNER_TYPE]]: ...
@overload
def __getitem__(
self: ObjectVar[Dict[KEY_TYPE, set[ARRAY_INNER_TYPE]]],
key: Var | Any,
) -> ArrayVar[set[ARRAY_INNER_TYPE]]: ...
@overload
def __getitem__(
self: ObjectVar[Dict[KEY_TYPE, tuple[ARRAY_INNER_TYPE, ...]]],
key: Var | Any,
) -> ArrayVar[tuple[ARRAY_INNER_TYPE, ...]]: ...
@overload
def __getitem__(
self: ObjectVar[Dict[KEY_TYPE, dict[OTHER_KEY_TYPE, VALUE_TYPE]]],
key: Var | Any,
) -> ObjectVar[dict[OTHER_KEY_TYPE, VALUE_TYPE]]: ...
def __getitem__(self, key: Var | Any) -> ImmutableVar:
"""Get an item from the object.
@ -78,16 +199,78 @@ class ObjectVar(ImmutableVar):
"""
return ObjectItemOperation(self, key).guess_type()
def __getattr__(self, name) -> ObjectItemOperation:
# NoReturn is used here to catch when key value is Any
@overload
def __getattr__(
self: ObjectVar[Dict[KEY_TYPE, NoReturn]],
name: str,
) -> ImmutableVar: ...
@overload
def __getattr__(
self: (
ObjectVar[Dict[KEY_TYPE, int]]
| ObjectVar[Dict[KEY_TYPE, float]]
| ObjectVar[Dict[KEY_TYPE, int | float]]
),
name: str,
) -> NumberVar: ...
@overload
def __getattr__(
self: ObjectVar[Dict[KEY_TYPE, str]],
name: str,
) -> StringVar: ...
@overload
def __getattr__(
self: ObjectVar[Dict[KEY_TYPE, list[ARRAY_INNER_TYPE]]],
name: str,
) -> ArrayVar[list[ARRAY_INNER_TYPE]]: ...
@overload
def __getattr__(
self: ObjectVar[Dict[KEY_TYPE, set[ARRAY_INNER_TYPE]]],
name: str,
) -> ArrayVar[set[ARRAY_INNER_TYPE]]: ...
@overload
def __getattr__(
self: ObjectVar[Dict[KEY_TYPE, tuple[ARRAY_INNER_TYPE, ...]]],
name: str,
) -> ArrayVar[tuple[ARRAY_INNER_TYPE, ...]]: ...
@overload
def __getattr__(
self: ObjectVar[Dict[KEY_TYPE, dict[OTHER_KEY_TYPE, VALUE_TYPE]]],
name: str,
) -> ObjectVar[dict[OTHER_KEY_TYPE, VALUE_TYPE]]: ...
def __getattr__(self, name) -> ImmutableVar:
"""Get an attribute of the var.
Args:
name: The name of the attribute.
Raises:
VarAttributeError: The State var has no such attribute or may have been annotated wrongly.
Returns:
The attribute of the var.
"""
return ObjectItemOperation(self, name)
fixed_type = (
self._var_type if isclass(self._var_type) else get_origin(self._var_type)
)
if not issubclass(fixed_type, dict):
attribute_type = get_attribute_access_type(self._var_type, name)
if attribute_type is None:
raise VarAttributeError(
f"The State var `{self._var_name}` has no attribute '{name}' or may have been annotated "
f"wrongly."
)
return ObjectItemOperation(self, name, attribute_type).guess_type()
else:
return ObjectItemOperation(self, name).guess_type()
@dataclasses.dataclass(
@ -95,7 +278,7 @@ class ObjectVar(ImmutableVar):
frozen=True,
**{"slots": True} if sys.version_info >= (3, 10) else {},
)
class LiteralObjectVar(LiteralVar, ObjectVar):
class LiteralObjectVar(LiteralVar, ObjectVar[OBJECT_TYPE]):
"""Base class for immutable literal object vars."""
_var_value: Dict[Union[Var, Any], Union[Var, Any]] = dataclasses.field(
@ -103,9 +286,9 @@ class LiteralObjectVar(LiteralVar, ObjectVar):
)
def __init__(
self,
_var_value: dict[Var | Any, Var | Any],
_var_type: Type | None = None,
self: LiteralObjectVar[OBJECT_TYPE],
_var_value: OBJECT_TYPE,
_var_type: Type[OBJECT_TYPE] | None = None,
_var_data: VarData | None = None,
):
"""Initialize the object var.
@ -117,14 +300,7 @@ class LiteralObjectVar(LiteralVar, ObjectVar):
"""
super(LiteralObjectVar, self).__init__(
_var_name="",
_var_type=(
Dict[
unionize(*map(type, _var_value.keys())),
unionize(*map(type, _var_value.values())),
]
if _var_type is None
else _var_type
),
_var_type=(figure_out_type(_var_value) if _var_type is None else _var_type),
_var_data=ImmutableVarData.merge(_var_data),
)
object.__setattr__(
@ -489,6 +665,7 @@ class ObjectItemOperation(ImmutableVar):
self,
value: ObjectVar,
key: Var | Any,
_var_type: GenericType | None = None,
_var_data: VarData | None = None,
):
"""Initialize the object item operation.
@ -500,7 +677,7 @@ class ObjectItemOperation(ImmutableVar):
"""
super(ObjectItemOperation, self).__init__(
_var_name="",
_var_type=value._value_type(),
_var_type=value._value_type() if _var_type is None else _var_type,
_var_data=ImmutableVarData.merge(_var_data),
)
object.__setattr__(self, "value", value)

View File

@ -10,7 +10,18 @@ import re
import sys
import typing
from functools import cached_property
from typing import Any, List, Set, Tuple, Type, Union, overload
from typing import (
TYPE_CHECKING,
Any,
Dict,
List,
Literal,
Set,
Tuple,
TypeVar,
Union,
overload,
)
from typing_extensions import get_origin
@ -19,6 +30,8 @@ from reflex.constants.base import REFLEX_VAR_OPENING_TAG
from reflex.experimental.vars.base import (
ImmutableVar,
LiteralVar,
figure_out_type,
unionize,
)
from reflex.experimental.vars.number import (
BooleanVar,
@ -29,8 +42,11 @@ from reflex.experimental.vars.number import (
from reflex.utils.types import GenericType
from reflex.vars import ImmutableVarData, Var, VarData, _global_vars
if TYPE_CHECKING:
from .object import ObjectVar
class StringVar(ImmutableVar):
class StringVar(ImmutableVar[str]):
"""Base class for immutable string vars."""
def __add__(self, other: StringVar | str) -> ConcatVarOperation:
@ -699,7 +715,17 @@ class ConcatVarOperation(StringVar):
pass
class ArrayVar(ImmutableVar):
ARRAY_VAR_TYPE = TypeVar("ARRAY_VAR_TYPE", bound=Union[List, Tuple, Set])
OTHER_TUPLE = TypeVar("OTHER_TUPLE")
INNER_ARRAY_VAR = TypeVar("INNER_ARRAY_VAR")
KEY_TYPE = TypeVar("KEY_TYPE")
VALUE_TYPE = TypeVar("VALUE_TYPE")
class ArrayVar(ImmutableVar[ARRAY_VAR_TYPE]):
"""Base class for immutable array vars."""
from reflex.experimental.vars.sequence import StringVar
@ -717,7 +743,7 @@ class ArrayVar(ImmutableVar):
return ArrayJoinOperation(self, sep)
def reverse(self) -> ArrayReverseOperation:
def reverse(self) -> ArrayVar[ARRAY_VAR_TYPE]:
"""Reverse the array.
Returns:
@ -726,14 +752,98 @@ class ArrayVar(ImmutableVar):
return ArrayReverseOperation(self)
@overload
def __getitem__(self, i: slice) -> ArraySliceOperation: ...
def __getitem__(self, i: slice) -> ArrayVar[ARRAY_VAR_TYPE]: ...
@overload
def __getitem__(
self: (
ArrayVar[Tuple[int, OTHER_TUPLE]]
| ArrayVar[Tuple[float, OTHER_TUPLE]]
| ArrayVar[Tuple[int | float, OTHER_TUPLE]]
),
i: Literal[0, -2],
) -> NumberVar: ...
@overload
def __getitem__(
self: (
ArrayVar[Tuple[OTHER_TUPLE, int]]
| ArrayVar[Tuple[OTHER_TUPLE, float]]
| ArrayVar[Tuple[OTHER_TUPLE, int | float]]
),
i: Literal[1, -1],
) -> NumberVar: ...
@overload
def __getitem__(
self: ArrayVar[Tuple[str, OTHER_TUPLE]], i: Literal[0, -2]
) -> StringVar: ...
@overload
def __getitem__(
self: ArrayVar[Tuple[OTHER_TUPLE, str]], i: Literal[1, -1]
) -> StringVar: ...
@overload
def __getitem__(
self: ArrayVar[Tuple[bool, OTHER_TUPLE]], i: Literal[0, -2]
) -> BooleanVar: ...
@overload
def __getitem__(
self: ArrayVar[Tuple[OTHER_TUPLE, bool]], i: Literal[1, -1]
) -> BooleanVar: ...
@overload
def __getitem__(
self: (
ARRAY_VAR_OF_LIST_ELEMENT[int]
| ARRAY_VAR_OF_LIST_ELEMENT[float]
| ARRAY_VAR_OF_LIST_ELEMENT[int | float]
),
i: int | NumberVar,
) -> NumberVar: ...
@overload
def __getitem__(
self: ARRAY_VAR_OF_LIST_ELEMENT[str], i: int | NumberVar
) -> StringVar: ...
@overload
def __getitem__(
self: ARRAY_VAR_OF_LIST_ELEMENT[bool], i: int | NumberVar
) -> BooleanVar: ...
@overload
def __getitem__(
self: ARRAY_VAR_OF_LIST_ELEMENT[List[INNER_ARRAY_VAR]],
i: int | NumberVar,
) -> ArrayVar[List[INNER_ARRAY_VAR]]: ...
@overload
def __getitem__(
self: ARRAY_VAR_OF_LIST_ELEMENT[Set[INNER_ARRAY_VAR]],
i: int | NumberVar,
) -> ArrayVar[Set[INNER_ARRAY_VAR]]: ...
@overload
def __getitem__(
self: ARRAY_VAR_OF_LIST_ELEMENT[Tuple[INNER_ARRAY_VAR, ...]],
i: int | NumberVar,
) -> ArrayVar[Tuple[INNER_ARRAY_VAR, ...]]: ...
@overload
def __getitem__(
self: ARRAY_VAR_OF_LIST_ELEMENT[Dict[KEY_TYPE, VALUE_TYPE]],
i: int | NumberVar,
) -> ObjectVar[Dict[KEY_TYPE, VALUE_TYPE]]: ...
@overload
def __getitem__(self, i: int | NumberVar) -> ImmutableVar: ...
def __getitem__(
self, i: slice | int | NumberVar
) -> ArraySliceOperation | ImmutableVar:
) -> ArrayVar[ARRAY_VAR_TYPE] | ImmutableVar:
"""Get a slice of the array.
Args:
@ -756,7 +866,7 @@ class ArrayVar(ImmutableVar):
@overload
@classmethod
def range(cls, stop: int | NumberVar, /) -> RangeOperation: ...
def range(cls, stop: int | NumberVar, /) -> ArrayVar[List[int]]: ...
@overload
@classmethod
@ -766,7 +876,7 @@ class ArrayVar(ImmutableVar):
end: int | NumberVar,
step: int | NumberVar = 1,
/,
) -> RangeOperation: ...
) -> ArrayVar[List[int]]: ...
@classmethod
def range(
@ -774,7 +884,7 @@ class ArrayVar(ImmutableVar):
first_endpoint: int | NumberVar,
second_endpoint: int | NumberVar | None = None,
step: int | NumberVar | None = None,
) -> RangeOperation:
) -> ArrayVar[List[int]]:
"""Create a range of numbers.
Args:
@ -794,7 +904,7 @@ class ArrayVar(ImmutableVar):
return RangeOperation(start, end, step or 1)
def contains(self, other: Any) -> ArrayContainsOperation:
def contains(self, other: Any) -> BooleanVar:
"""Check if the array contains an element.
Args:
@ -806,12 +916,21 @@ class ArrayVar(ImmutableVar):
return ArrayContainsOperation(self, other)
LIST_ELEMENT = TypeVar("LIST_ELEMENT")
ARRAY_VAR_OF_LIST_ELEMENT = Union[
ArrayVar[List[LIST_ELEMENT]],
ArrayVar[Set[LIST_ELEMENT]],
ArrayVar[Tuple[LIST_ELEMENT, ...]],
]
@dataclasses.dataclass(
eq=False,
frozen=True,
**{"slots": True} if sys.version_info >= (3, 10) else {},
)
class LiteralArrayVar(LiteralVar, ArrayVar):
class LiteralArrayVar(LiteralVar, ArrayVar[ARRAY_VAR_TYPE]):
"""Base class for immutable literal array vars."""
_var_value: Union[
@ -819,9 +938,9 @@ class LiteralArrayVar(LiteralVar, ArrayVar):
] = dataclasses.field(default_factory=list)
def __init__(
self,
_var_value: list[Var | Any] | tuple[Var | Any, ...] | set[Var | Any],
_var_type: type[list] | type[tuple] | type[set] | None = None,
self: LiteralArrayVar[ARRAY_VAR_TYPE],
_var_value: ARRAY_VAR_TYPE,
_var_type: type[ARRAY_VAR_TYPE] | None = None,
_var_data: VarData | None = None,
):
"""Initialize the array var.
@ -834,11 +953,7 @@ class LiteralArrayVar(LiteralVar, ArrayVar):
super(LiteralArrayVar, self).__init__(
_var_name="",
_var_data=ImmutableVarData.merge(_var_data),
_var_type=(
List[unionize(*map(type, _var_value))]
if _var_type is None
else _var_type
),
_var_type=(figure_out_type(_var_value) if _var_type is None else _var_type),
)
object.__setattr__(self, "_var_value", _var_value)
object.__delattr__(self, "_var_name")
@ -1261,23 +1376,6 @@ class ArrayLengthOperation(ArrayToNumberOperation):
return f"{str(self.a)}.length"
def unionize(*args: Type) -> Type:
"""Unionize the types.
Args:
args: The types to unionize.
Returns:
The unionized types.
"""
if not args:
return Any
first, *rest = args
if not rest:
return first
return Union[first, unionize(*rest)]
def is_tuple_type(t: GenericType) -> bool:
"""Check if a type is a tuple type.

View File

@ -1042,7 +1042,7 @@ def test_object_operations():
def test_type_chains():
object_var = LiteralObjectVar({"a": 1, "b": 2, "c": 3})
assert object_var._var_type is Dict[str, int]
assert (object_var._key_type(), object_var._value_type()) == (str, int)
assert (object_var.keys()._var_type, object_var.values()._var_type) == (
List[str],
List[int],
@ -1061,6 +1061,31 @@ def test_type_chains():
)
def test_nested_dict():
arr = LiteralArrayVar([{"bar": ["foo", "bar"]}], List[Dict[str, List[str]]])
assert (
str(arr[0]["bar"][0]) == '[({ ["bar"] : ["foo", "bar"] })].at(0)["bar"].at(0)'
)
def nested_base():
class Boo(Base):
foo: str
bar: int
class Foo(Base):
bar: Boo
baz: int
parent_obj = LiteralVar.create(Foo(bar=Boo(foo="bar", bar=5), baz=5))
assert (
str(parent_obj.bar.foo)
== '({ ["bar"] : ({ ["foo"] : "bar", ["bar"] : 5 }), ["baz"] : 5 })["bar"]["foo"]'
)
def test_retrival():
var_without_data = ImmutableVar.create("test")
assert var_without_data is not None