Computed vars tuple and str indexing support (#1322)

This commit is contained in:
Elijah Ahianyo 2023-07-12 22:26:34 +00:00 committed by GitHub
parent 5cbf7da952
commit 7f0fc86816
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 41 additions and 20 deletions

View File

@ -16,6 +16,7 @@ from typing import (
List,
Optional,
Set,
Tuple,
Type,
Union,
_GenericAlias, # type: ignore
@ -198,9 +199,9 @@ class Var(ABC):
Raises:
TypeError: If the var is not indexable.
"""
# Indexing is only supported for lists, dicts, and dataframes.
# Indexing is only supported for strings, lists, tuples, dicts, and dataframes.
if not (
types._issubclass(self.type_, Union[List, Dict])
types._issubclass(self.type_, Union[List, Dict, Tuple, str])
or types.is_dataframe(self.type_)
):
if self.type_ == Any:
@ -218,9 +219,9 @@ class Var(ABC):
if isinstance(i, Var):
i = BaseVar(name=i.name, type_=i.type_, state=i.state, is_local=True)
# Handle list indexing.
if types._issubclass(self.type_, List):
# List indices must be ints, slices, or vars.
# Handle list/tuple/str indexing.
if types._issubclass(self.type_, Union[List, Tuple, str]):
# List/Tuple/String indices must be ints, slices, or vars.
if not isinstance(i, types.get_args(Union[int, slice, Var])):
raise TypeError("Index must be an integer.")

View File

@ -1,5 +1,5 @@
import typing
from typing import Dict, List
from typing import Dict, List, Tuple
import cloudpickle
import pytest
@ -272,31 +272,51 @@ def test_basic_operations(TestObj):
assert str(v([1, 2, 3]).length()) == "{[1, 2, 3].length}"
def test_var_indexing_lists():
"""Test that we can index into list vars."""
lst = BaseVar(name="lst", type_=List[int])
@pytest.mark.parametrize(
"var",
[
BaseVar(name="list", type_=List[int]),
BaseVar(name="tuple", type_=Tuple[int, int]),
BaseVar(name="str", type_=str),
],
)
def test_var_indexing_lists(var):
"""Test that we can index into str, list or tuple vars.
Args:
var : The str, list or tuple base var.
"""
# Test basic indexing.
assert str(lst[0]) == "{lst.at(0)}"
assert str(lst[1]) == "{lst.at(1)}"
assert str(var[0]) == f"{{{var.name}.at(0)}}"
assert str(var[1]) == f"{{{var.name}.at(1)}}"
# Test negative indexing.
assert str(lst[-1]) == "{lst.at(-1)}"
assert str(var[-1]) == f"{{{var.name}.at(-1)}}"
# Test non-integer indexing raises an error.
with pytest.raises(TypeError):
lst["a"]
var["a"]
with pytest.raises(TypeError):
lst[1.5]
var[1.5]
def test_var_list_slicing():
"""Test that we can slice into list vars."""
lst = BaseVar(name="lst", type_=List[int])
@pytest.mark.parametrize(
"var",
[
BaseVar(name="lst", type_=List[int]),
BaseVar(name="tuple", type_=Tuple[int, int]),
BaseVar(name="str", type_=str),
],
)
def test_var_list_slicing(var):
"""Test that we can slice into str, list or tuple vars.
assert str(lst[:1]) == "{lst.slice(0, 1)}"
assert str(lst[:1]) == "{lst.slice(0, 1)}"
assert str(lst[:]) == "{lst.slice(0, undefined)}"
Args:
var : The str, list or tuple base var.
"""
assert str(var[:1]) == f"{{{var.name}.slice(0, 1)}}"
assert str(var[:1]) == f"{{{var.name}.slice(0, 1)}}"
assert str(var[:]) == f"{{{var.name}.slice(0, undefined)}}"
def test_dict_indexing():