mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-10 07:28:15 -05:00
explicitly check getitem indices can have at most one ellipsis (#5087)
* explicitly check getitem indices can have at most one ellipsis
previous error with multiple `...`:
```
if index_type not in [None, int, slice, Tensor]: raise IndexError(f"{index_type=} not supported")
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
IndexError: index_type=<class 'ellipsis'> not supported
```
this pr:
```
if len(ellipsis_idx) > 1: raise IndexError("an index can only have a single ellipsis ('...')")
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
IndexError: an index can only have a single ellipsis ('...')
```
* oh we have that already
* test that
* test these
This commit is contained in:
@@ -1039,10 +1039,10 @@ class TestOps(unittest.TestCase):
|
||||
def test_slice_errors(self):
|
||||
a = Tensor.ones(4, 3)
|
||||
b = Tensor(2)
|
||||
with self.assertRaises(IndexError): a[1, 77, 77, 77] # IndexError: (finds too many indices before the out of bounds)
|
||||
with self.assertRaises(IndexError): a[1, 3] # IndexError: (out of bounds).
|
||||
with self.assertRaises(IndexError): a[1, -4]
|
||||
with self.assertRaises(IndexError): a[..., ...] # IndexError: only single ellipsis
|
||||
with self.assertRaisesRegex(IndexError, "too many"): a[1, 77, 77, 77] # IndexError: (finds too many indices before the out of bounds)
|
||||
with self.assertRaisesRegex(IndexError, "out of bounds"): a[1, 3] # IndexError: (out of bounds).
|
||||
with self.assertRaisesRegex(IndexError, "out of bounds"): a[1, -4]
|
||||
with self.assertRaisesRegex(IndexError, "single ellipsis"): a[..., ...] # IndexError: only single ellipsis
|
||||
with self.assertRaises(ValueError): a[::0, 1] # no 0 strides
|
||||
with self.assertRaises(IndexError): b[:] # slice cannot be applied to a 0-dim tensor
|
||||
|
||||
|
||||
@@ -914,7 +914,7 @@ class Tensor:
|
||||
ellipsis_idx = [dim for dim, i in enumerate(indices) if i is Ellipsis]
|
||||
fill_idx = ellipsis_idx[0] if ellipsis_idx else len(indices)
|
||||
num_indices = len(indices) - len(ellipsis_idx) - sum(1 for i in indices if i is None)
|
||||
indices[fill_idx:fill_idx+1] = [slice(None)] * (len(self.shape) - num_indices)
|
||||
indices[fill_idx:fill_idx+1] = [slice(None)] * (self.ndim - num_indices)
|
||||
|
||||
# use Dict[type, List[dimension]] to track elements in indices
|
||||
type_dim: DefaultDict[Union[type, None], List[int]] = defaultdict(list)
|
||||
@@ -925,9 +925,9 @@ class Tensor:
|
||||
indices_filtered = [i for i in indices if i is not None]
|
||||
for dim,i in enumerate(indices_filtered): type_dim[type(i)].append(dim)
|
||||
|
||||
if len(ellipsis_idx) > 1: raise IndexError("indices can only have a single ellipsis ('...')")
|
||||
for index_type in type_dim:
|
||||
if index_type not in [None, int, slice, Tensor]: raise IndexError(f"{index_type=} not supported")
|
||||
if len(ellipsis_idx) > 1: raise IndexError("indices can only have a single ellipsis ('...')")
|
||||
if num_indices > self.ndim: raise IndexError(f"too many {num_indices=} for {self.ndim=}")
|
||||
|
||||
# 2. basic indexing, uses only movement ops (no copy)
|
||||
|
||||
Reference in New Issue
Block a user