Initial ellipsis support when slicing Tensors (#843)

* Initial ellipsis support when slicing Tensors

* Better comments in ellipsis slicing

* Formatting
This commit is contained in:
Filip Dimitrovski
2023-06-05 16:52:49 +02:00
committed by GitHub
parent 70f12fdb57
commit 78460034ff
2 changed files with 16 additions and 1 deletions

View File

@@ -262,7 +262,15 @@ class Tensor:
val = list(val) if isinstance(val, tuple) else [val]
if (num_slices := sum(isinstance(v, (slice, int)) for v in val)) > len(self.shape):
raise IndexError(f"too many indices for tensor of dimension {len(self.shape)}")
orig_slices = list(val) + [slice(None)] * (len(self.shape) - num_slices)
orig_slices = list(val)
ellipses_found = [i for i, v in enumerate(val) if v is Ellipsis]
if len(ellipses_found) > 0:
if len(ellipses_found) != 1:
raise IndexError("an index can only have a single ellipsis ('...')")
ellipsis_idx = ellipses_found[0]
orig_slices[ellipsis_idx:ellipsis_idx+1] = [slice(None)] * (len(self.shape) - num_slices)
else:
orig_slices += [slice(None)] * (len(self.shape) - num_slices)
valid_slices = list(itertools.filterfalse(lambda x: x is None, orig_slices))
valid_slices = [v if isinstance(v, slice) else slice(y := normalize_int(v, i, dim_sz), y+1) for i, (v, dim_sz) in enumerate(zip(valid_slices, self.shape))]
start, stop, strides = zip(*y) if (y := [s.indices(dim_sz) for s, dim_sz in zip(valid_slices, self.shape)]) else ((), (), ())