Computed vars tuple and str indexing support (#1322)

This commit is contained in:
Elijah Ahianyo 2023-07-12 22:26:34 +00:00 committed by GitHub
parent 5cbf7da952
commit 7f0fc86816
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 41 additions and 20 deletions

View File

@ -16,6 +16,7 @@ from typing import (
List, List,
Optional, Optional,
Set, Set,
Tuple,
Type, Type,
Union, Union,
_GenericAlias, # type: ignore _GenericAlias, # type: ignore
@ -198,9 +199,9 @@ class Var(ABC):
Raises: Raises:
TypeError: If the var is not indexable. TypeError: If the var is not indexable.
""" """
# Indexing is only supported for lists, dicts, and dataframes. # Indexing is only supported for strings, lists, tuples, dicts, and dataframes.
if not ( if not (
types._issubclass(self.type_, Union[List, Dict]) types._issubclass(self.type_, Union[List, Dict, Tuple, str])
or types.is_dataframe(self.type_) or types.is_dataframe(self.type_)
): ):
if self.type_ == Any: if self.type_ == Any:
@ -218,9 +219,9 @@ class Var(ABC):
if isinstance(i, Var): if isinstance(i, Var):
i = BaseVar(name=i.name, type_=i.type_, state=i.state, is_local=True) i = BaseVar(name=i.name, type_=i.type_, state=i.state, is_local=True)
# Handle list indexing. # Handle list/tuple/str indexing.
if types._issubclass(self.type_, List): if types._issubclass(self.type_, Union[List, Tuple, str]):
# List indices must be ints, slices, or vars. # List/Tuple/String indices must be ints, slices, or vars.
if not isinstance(i, types.get_args(Union[int, slice, Var])): if not isinstance(i, types.get_args(Union[int, slice, Var])):
raise TypeError("Index must be an integer.") raise TypeError("Index must be an integer.")

View File

@ -1,5 +1,5 @@
import typing import typing
from typing import Dict, List from typing import Dict, List, Tuple
import cloudpickle import cloudpickle
import pytest import pytest
@ -272,31 +272,51 @@ def test_basic_operations(TestObj):
assert str(v([1, 2, 3]).length()) == "{[1, 2, 3].length}" assert str(v([1, 2, 3]).length()) == "{[1, 2, 3].length}"
def test_var_indexing_lists(): @pytest.mark.parametrize(
"""Test that we can index into list vars.""" "var",
lst = BaseVar(name="lst", type_=List[int]) [
BaseVar(name="list", type_=List[int]),
BaseVar(name="tuple", type_=Tuple[int, int]),
BaseVar(name="str", type_=str),
],
)
def test_var_indexing_lists(var):
"""Test that we can index into str, list or tuple vars.
Args:
var : The str, list or tuple base var.
"""
# Test basic indexing. # Test basic indexing.
assert str(lst[0]) == "{lst.at(0)}" assert str(var[0]) == f"{{{var.name}.at(0)}}"
assert str(lst[1]) == "{lst.at(1)}" assert str(var[1]) == f"{{{var.name}.at(1)}}"
# Test negative indexing. # Test negative indexing.
assert str(lst[-1]) == "{lst.at(-1)}" assert str(var[-1]) == f"{{{var.name}.at(-1)}}"
# Test non-integer indexing raises an error. # Test non-integer indexing raises an error.
with pytest.raises(TypeError): with pytest.raises(TypeError):
lst["a"] var["a"]
with pytest.raises(TypeError): with pytest.raises(TypeError):
lst[1.5] var[1.5]
def test_var_list_slicing(): @pytest.mark.parametrize(
"""Test that we can slice into list vars.""" "var",
lst = BaseVar(name="lst", type_=List[int]) [
BaseVar(name="lst", type_=List[int]),
BaseVar(name="tuple", type_=Tuple[int, int]),
BaseVar(name="str", type_=str),
],
)
def test_var_list_slicing(var):
"""Test that we can slice into str, list or tuple vars.
assert str(lst[:1]) == "{lst.slice(0, 1)}" Args:
assert str(lst[:1]) == "{lst.slice(0, 1)}" var : The str, list or tuple base var.
assert str(lst[:]) == "{lst.slice(0, undefined)}" """
assert str(var[:1]) == f"{{{var.name}.slice(0, 1)}}"
assert str(var[:1]) == f"{{{var.name}.slice(0, 1)}}"
assert str(var[:]) == f"{{{var.name}.slice(0, undefined)}}"
def test_dict_indexing(): def test_dict_indexing():