diff --git a/reflex/vars.py b/reflex/vars.py index 35f9925d0..5898fca1a 100644 --- a/reflex/vars.py +++ b/reflex/vars.py @@ -206,7 +206,8 @@ class Var(ABC): ): if self.type_ == Any: raise TypeError( - f"Could not index into var of type Any. (If you are trying to index into a state var, add the correct type annotation to the var.)" + f"Could not index into var of type Any. (If you are trying to index into a state var, " + f"add the correct type annotation to the var.)" ) raise TypeError( f"Var {self.name} of type {self.type_} does not support indexing." @@ -222,8 +223,12 @@ class Var(ABC): # Handle list/tuple/str indexing. if types._issubclass(self.type_, Union[List, Tuple, str]): # List/Tuple/String indices must be ints, slices, or vars. - if not isinstance(i, types.get_args(Union[int, slice, Var])): - raise TypeError("Index must be an integer.") + if ( + not isinstance(i, types.get_args(Union[int, slice, Var])) + or isinstance(i, Var) + and not i.type_ == int + ): + raise TypeError("Index must be an integer or an integer var.") # Handle slices first. if isinstance(i, slice): @@ -253,6 +258,17 @@ class Var(ABC): ) # Dictionary / dataframe indexing. + # Tuples are currently not supported as indexes. + if ( + (types._issubclass(self.type_, Dict) or types.is_dataframe(self.type_)) + and not isinstance(i, types.get_args(Union[int, str, float, Var])) + ) or ( + isinstance(i, Var) + and not types._issubclass(i.type_, types.get_args(Union[int, str, float])) + ): + raise TypeError( + "Index must be one of the following types: int, str, int or str Var" + ) # Get the type of the indexed var. if isinstance(i, str): i = format.wrap(i, '"') diff --git a/tests/test_var.py b/tests/test_var.py index 752c460ed..9dd0784ef 100644 --- a/tests/test_var.py +++ b/tests/test_var.py @@ -1,8 +1,9 @@ import typing -from typing import Dict, List, Tuple +from typing import Dict, List, Set, Tuple import cloudpickle import pytest +from pandas import DataFrame from reflex.base import Base from reflex.state import State @@ -293,11 +294,54 @@ def test_var_indexing_lists(var): # Test negative indexing. assert str(var[-1]) == f"{{{var.name}.at(-1)}}" - # Test non-integer indexing raises an error. + +@pytest.mark.parametrize( + "var, index", + [ + (BaseVar(name="lst", type_=List[int]), [1, 2]), + (BaseVar(name="lst", type_=List[int]), {"name": "dict"}), + (BaseVar(name="lst", type_=List[int]), {"set"}), + ( + BaseVar(name="lst", type_=List[int]), + ( + 1, + 2, + ), + ), + (BaseVar(name="lst", type_=List[int]), 1.5), + (BaseVar(name="lst", type_=List[int]), "str"), + (BaseVar(name="lst", type_=List[int]), BaseVar(name="string_var", type_=str)), + (BaseVar(name="lst", type_=List[int]), BaseVar(name="float_var", type_=float)), + ( + BaseVar(name="lst", type_=List[int]), + BaseVar(name="list_var", type_=List[int]), + ), + (BaseVar(name="lst", type_=List[int]), BaseVar(name="set_var", type_=Set[str])), + ( + BaseVar(name="lst", type_=List[int]), + BaseVar(name="dict_var", type_=Dict[str, str]), + ), + (BaseVar(name="str", type_=str), [1, 2]), + (BaseVar(name="lst", type_=str), {"name": "dict"}), + (BaseVar(name="lst", type_=str), {"set"}), + (BaseVar(name="lst", type_=str), BaseVar(name="string_var", type_=str)), + (BaseVar(name="lst", type_=str), BaseVar(name="float_var", type_=float)), + (BaseVar(name="str", type_=Tuple[str]), [1, 2]), + (BaseVar(name="lst", type_=Tuple[str]), {"name": "dict"}), + (BaseVar(name="lst", type_=Tuple[str]), {"set"}), + (BaseVar(name="lst", type_=Tuple[str]), BaseVar(name="string_var", type_=str)), + (BaseVar(name="lst", type_=Tuple[str]), BaseVar(name="float_var", type_=float)), + ], +) +def test_var_unsupported_indexing_lists(var, index): + """Test unsupported indexing throws a type error. + + Args: + var: The base var. + index: The base var index. + """ with pytest.raises(TypeError): - var["a"] - with pytest.raises(TypeError): - var[1.5] + var[index] @pytest.mark.parametrize( @@ -328,6 +372,84 @@ def test_dict_indexing(): assert str(dct["asdf"]) == '{dct["asdf"]}' +@pytest.mark.parametrize( + "var, index", + [ + ( + BaseVar(name="dict", type_=Dict[str, str]), + [1, 2], + ), + ( + BaseVar(name="dict", type_=Dict[str, str]), + {"name": "dict"}, + ), + ( + BaseVar(name="dict", type_=Dict[str, str]), + {"set"}, + ), + ( + BaseVar(name="dict", type_=Dict[str, str]), + ( + 1, + 2, + ), + ), + ( + BaseVar(name="lst", type_=Dict[str, str]), + BaseVar(name="list_var", type_=List[int]), + ), + ( + BaseVar(name="lst", type_=Dict[str, str]), + BaseVar(name="set_var", type_=Set[str]), + ), + ( + BaseVar(name="lst", type_=Dict[str, str]), + BaseVar(name="dict_var", type_=Dict[str, str]), + ), + ( + BaseVar(name="df", type_=DataFrame), + [1, 2], + ), + ( + BaseVar(name="df", type_=DataFrame), + {"name": "dict"}, + ), + ( + BaseVar(name="df", type_=DataFrame), + {"set"}, + ), + ( + BaseVar(name="df", type_=DataFrame), + ( + 1, + 2, + ), + ), + ( + BaseVar(name="df", type_=DataFrame), + BaseVar(name="list_var", type_=List[int]), + ), + ( + BaseVar(name="df", type_=DataFrame), + BaseVar(name="set_var", type_=Set[str]), + ), + ( + BaseVar(name="df", type_=DataFrame), + BaseVar(name="dict_var", type_=Dict[str, str]), + ), + ], +) +def test_var_unsupported_indexing_dicts(var, index): + """Test unsupported indexing throws a type error. + + Args: + var: The base var. + index: The base var index. + """ + with pytest.raises(TypeError): + var[index] + + @pytest.mark.parametrize( "fixture,full_name", [