mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-02-19 02:44:40 -05:00
Initial ellipsis support when slicing Tensors (#843)
* Initial ellipsis support when slicing Tensors * Better comments in ellipsis slicing * Formatting
This commit is contained in:
committed by
GitHub
parent
70f12fdb57
commit
78460034ff
@@ -262,7 +262,15 @@ class Tensor:
|
||||
val = list(val) if isinstance(val, tuple) else [val]
|
||||
if (num_slices := sum(isinstance(v, (slice, int)) for v in val)) > len(self.shape):
|
||||
raise IndexError(f"too many indices for tensor of dimension {len(self.shape)}")
|
||||
orig_slices = list(val) + [slice(None)] * (len(self.shape) - num_slices)
|
||||
orig_slices = list(val)
|
||||
ellipses_found = [i for i, v in enumerate(val) if v is Ellipsis]
|
||||
if len(ellipses_found) > 0:
|
||||
if len(ellipses_found) != 1:
|
||||
raise IndexError("an index can only have a single ellipsis ('...')")
|
||||
ellipsis_idx = ellipses_found[0]
|
||||
orig_slices[ellipsis_idx:ellipsis_idx+1] = [slice(None)] * (len(self.shape) - num_slices)
|
||||
else:
|
||||
orig_slices += [slice(None)] * (len(self.shape) - num_slices)
|
||||
valid_slices = list(itertools.filterfalse(lambda x: x is None, orig_slices))
|
||||
valid_slices = [v if isinstance(v, slice) else slice(y := normalize_int(v, i, dim_sz), y+1) for i, (v, dim_sz) in enumerate(zip(valid_slices, self.shape))]
|
||||
start, stop, strides = zip(*y) if (y := [s.indices(dim_sz) for s, dim_sz in zip(valid_slices, self.shape)]) else ((), (), ())
|
||||
|
||||
Reference in New Issue
Block a user