Strict type checking for indexing with vars ()

This commit is contained in:
Elijah Ahianyo 2023-07-13 22:46:15 +00:00 committed by GitHub
parent 2c97c1e7ca
commit 40953d05ac
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 146 additions and 8 deletions

View File

@ -206,7 +206,8 @@ class Var(ABC):
):
if self.type_ == Any:
raise TypeError(
f"Could not index into var of type Any. (If you are trying to index into a state var, add the correct type annotation to the var.)"
f"Could not index into var of type Any. (If you are trying to index into a state var, "
f"add the correct type annotation to the var.)"
)
raise TypeError(
f"Var {self.name} of type {self.type_} does not support indexing."
@ -222,8 +223,12 @@ class Var(ABC):
# Handle list/tuple/str indexing.
if types._issubclass(self.type_, Union[List, Tuple, str]):
# List/Tuple/String indices must be ints, slices, or vars.
if not isinstance(i, types.get_args(Union[int, slice, Var])):
raise TypeError("Index must be an integer.")
if (
not isinstance(i, types.get_args(Union[int, slice, Var]))
or isinstance(i, Var)
and not i.type_ == int
):
raise TypeError("Index must be an integer or an integer var.")
# Handle slices first.
if isinstance(i, slice):
@ -253,6 +258,17 @@ class Var(ABC):
)
# Dictionary / dataframe indexing.
# Tuples are currently not supported as indexes.
if (
(types._issubclass(self.type_, Dict) or types.is_dataframe(self.type_))
and not isinstance(i, types.get_args(Union[int, str, float, Var]))
) or (
isinstance(i, Var)
and not types._issubclass(i.type_, types.get_args(Union[int, str, float]))
):
raise TypeError(
"Index must be one of the following types: int, str, int or str Var"
)
# Get the type of the indexed var.
if isinstance(i, str):
i = format.wrap(i, '"')

View File

@ -1,8 +1,9 @@
import typing
from typing import Dict, List, Tuple
from typing import Dict, List, Set, Tuple
import cloudpickle
import pytest
from pandas import DataFrame
from reflex.base import Base
from reflex.state import State
@ -293,11 +294,54 @@ def test_var_indexing_lists(var):
# Test negative indexing.
assert str(var[-1]) == f"{{{var.name}.at(-1)}}"
# Test non-integer indexing raises an error.
@pytest.mark.parametrize(
"var, index",
[
(BaseVar(name="lst", type_=List[int]), [1, 2]),
(BaseVar(name="lst", type_=List[int]), {"name": "dict"}),
(BaseVar(name="lst", type_=List[int]), {"set"}),
(
BaseVar(name="lst", type_=List[int]),
(
1,
2,
),
),
(BaseVar(name="lst", type_=List[int]), 1.5),
(BaseVar(name="lst", type_=List[int]), "str"),
(BaseVar(name="lst", type_=List[int]), BaseVar(name="string_var", type_=str)),
(BaseVar(name="lst", type_=List[int]), BaseVar(name="float_var", type_=float)),
(
BaseVar(name="lst", type_=List[int]),
BaseVar(name="list_var", type_=List[int]),
),
(BaseVar(name="lst", type_=List[int]), BaseVar(name="set_var", type_=Set[str])),
(
BaseVar(name="lst", type_=List[int]),
BaseVar(name="dict_var", type_=Dict[str, str]),
),
(BaseVar(name="str", type_=str), [1, 2]),
(BaseVar(name="lst", type_=str), {"name": "dict"}),
(BaseVar(name="lst", type_=str), {"set"}),
(BaseVar(name="lst", type_=str), BaseVar(name="string_var", type_=str)),
(BaseVar(name="lst", type_=str), BaseVar(name="float_var", type_=float)),
(BaseVar(name="str", type_=Tuple[str]), [1, 2]),
(BaseVar(name="lst", type_=Tuple[str]), {"name": "dict"}),
(BaseVar(name="lst", type_=Tuple[str]), {"set"}),
(BaseVar(name="lst", type_=Tuple[str]), BaseVar(name="string_var", type_=str)),
(BaseVar(name="lst", type_=Tuple[str]), BaseVar(name="float_var", type_=float)),
],
)
def test_var_unsupported_indexing_lists(var, index):
"""Test unsupported indexing throws a type error.
Args:
var: The base var.
index: The base var index.
"""
with pytest.raises(TypeError):
var["a"]
with pytest.raises(TypeError):
var[1.5]
var[index]
@pytest.mark.parametrize(
@ -328,6 +372,84 @@ def test_dict_indexing():
assert str(dct["asdf"]) == '{dct["asdf"]}'
@pytest.mark.parametrize(
"var, index",
[
(
BaseVar(name="dict", type_=Dict[str, str]),
[1, 2],
),
(
BaseVar(name="dict", type_=Dict[str, str]),
{"name": "dict"},
),
(
BaseVar(name="dict", type_=Dict[str, str]),
{"set"},
),
(
BaseVar(name="dict", type_=Dict[str, str]),
(
1,
2,
),
),
(
BaseVar(name="lst", type_=Dict[str, str]),
BaseVar(name="list_var", type_=List[int]),
),
(
BaseVar(name="lst", type_=Dict[str, str]),
BaseVar(name="set_var", type_=Set[str]),
),
(
BaseVar(name="lst", type_=Dict[str, str]),
BaseVar(name="dict_var", type_=Dict[str, str]),
),
(
BaseVar(name="df", type_=DataFrame),
[1, 2],
),
(
BaseVar(name="df", type_=DataFrame),
{"name": "dict"},
),
(
BaseVar(name="df", type_=DataFrame),
{"set"},
),
(
BaseVar(name="df", type_=DataFrame),
(
1,
2,
),
),
(
BaseVar(name="df", type_=DataFrame),
BaseVar(name="list_var", type_=List[int]),
),
(
BaseVar(name="df", type_=DataFrame),
BaseVar(name="set_var", type_=Set[str]),
),
(
BaseVar(name="df", type_=DataFrame),
BaseVar(name="dict_var", type_=Dict[str, str]),
),
],
)
def test_var_unsupported_indexing_dicts(var, index):
"""Test unsupported indexing throws a type error.
Args:
var: The base var.
index: The base var index.
"""
with pytest.raises(TypeError):
var[index]
@pytest.mark.parametrize(
"fixture,full_name",
[