explicitly check getitem indices can have at most one ellipsis (#5087)

* explicitly check getitem indices can have at most one ellipsis

previous error with multiple `...`:
```
if index_type not in [None, int, slice, Tensor]: raise IndexError(f"{index_type=} not supported")
                                                     ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
IndexError: index_type=<class 'ellipsis'> not supported
```

this pr:
```
if len(ellipsis_idx) > 1: raise IndexError("an index can only have a single ellipsis ('...')")
                              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
IndexError: an index can only have a single ellipsis ('...')
```

* oh we have that already

* test that

* test these
This commit is contained in:
chenyu
2024-06-21 12:33:18 -04:00
committed by GitHub
parent f1e758bacb
commit 36b4a492a1
2 changed files with 6 additions and 6 deletions

View File

@@ -1039,10 +1039,10 @@ class TestOps(unittest.TestCase):
def test_slice_errors(self):
a = Tensor.ones(4, 3)
b = Tensor(2)
with self.assertRaises(IndexError): a[1, 77, 77, 77] # IndexError: (finds too many indices before the out of bounds)
with self.assertRaises(IndexError): a[1, 3] # IndexError: (out of bounds).
with self.assertRaises(IndexError): a[1, -4]
with self.assertRaises(IndexError): a[..., ...] # IndexError: only single ellipsis
with self.assertRaisesRegex(IndexError, "too many"): a[1, 77, 77, 77] # IndexError: (finds too many indices before the out of bounds)
with self.assertRaisesRegex(IndexError, "out of bounds"): a[1, 3] # IndexError: (out of bounds).
with self.assertRaisesRegex(IndexError, "out of bounds"): a[1, -4]
with self.assertRaisesRegex(IndexError, "single ellipsis"): a[..., ...] # IndexError: only single ellipsis
with self.assertRaises(ValueError): a[::0, 1] # no 0 strides
with self.assertRaises(IndexError): b[:] # slice cannot be applied to a 0-dim tensor

View File

@@ -914,7 +914,7 @@ class Tensor:
ellipsis_idx = [dim for dim, i in enumerate(indices) if i is Ellipsis]
fill_idx = ellipsis_idx[0] if ellipsis_idx else len(indices)
num_indices = len(indices) - len(ellipsis_idx) - sum(1 for i in indices if i is None)
indices[fill_idx:fill_idx+1] = [slice(None)] * (len(self.shape) - num_indices)
indices[fill_idx:fill_idx+1] = [slice(None)] * (self.ndim - num_indices)
# use Dict[type, List[dimension]] to track elements in indices
type_dim: DefaultDict[Union[type, None], List[int]] = defaultdict(list)
@@ -925,9 +925,9 @@ class Tensor:
indices_filtered = [i for i in indices if i is not None]
for dim,i in enumerate(indices_filtered): type_dim[type(i)].append(dim)
if len(ellipsis_idx) > 1: raise IndexError("indices can only have a single ellipsis ('...')")
for index_type in type_dim:
if index_type not in [None, int, slice, Tensor]: raise IndexError(f"{index_type=} not supported")
if len(ellipsis_idx) > 1: raise IndexError("indices can only have a single ellipsis ('...')")
if num_indices > self.ndim: raise IndexError(f"too many {num_indices=} for {self.ndim=}")
# 2. basic indexing, uses only movement ops (no copy)