diff --git a/tinygrad/tensor.py b/tinygrad/tensor.py index 762b64067e..02bed59589 100644 --- a/tinygrad/tensor.py +++ b/tinygrad/tensor.py @@ -335,11 +335,9 @@ class Tensor: # use Dict[type, List[dimension]] to track elements in indices type_dim: DefaultDict[Union[type, None], List[int]] = defaultdict(list) - # record None for dimension injection later + # record None for dimension injection later and filter None and record rest of indices type_dim[None] = [dim for dim, i in enumerate(indices) if i is None] - - # filter None and record rest of indices - indices_filtered = tuple(v for v in indices if v is not None) + indices_filtered = [v for v in indices if v is not None] for dim,i in enumerate(indices_filtered): type_dim[type(i)].append(dim) # validation! raise Errors @@ -348,34 +346,35 @@ class Tensor: if float in type_dim: raise IndexError("float type is not valid index") if any(isinstance(i, slice) and i.step == 0 for i in indices): raise ValueError('slice step cannot be 0') if num_slices > len(self.shape): raise IndexError(f"too many indices for tensor of dimension {len(self.shape)}") - for dim in type_dim[int]: - if indices_filtered[dim] >= self.shape[dim] or indices_filtered[dim] < -self.shape[dim]: - raise IndexError(f"index {indices_filtered[dim]} is out of bounds for dimension {dim} with size {self.shape[dim]}") - - # normalize! indices -> start, stop, strides - # TODO: this line is completely unreadable - start, stop, strides = zip(*y) if (y := [i.indices(sh) if isinstance(i, slice) else slice(normalized:= i if i != -1 else sh-1, normalized+1, 1).indices(sh) if isinstance(i, int) else (0, sh, 1) for i, sh in zip(indices_filtered, self.shape)]) else ((), (), ()) # type: ignore[arg-type] # noqa: E501 # 2. basic indexing (no copy) - # apply slices and flip where strides are negative - new_slice = tuple(((0, 0) if e < s else (s, e)) if st > 0 else ((0, 0) if e > s else (e+1, s+1)) for s, e, st in zip(start, stop, strides)) - sliced_tensor = self.shrink(new_slice).flip(axis=[i for i, s in enumerate(strides) if s < 0]) - new_shape = list(sliced_tensor.shape) + # currently indices_filtered: Tuple[Union[slice, int, Tensor], ...] + # turn indices in indices_filtered to Tuple[shrink_arg, strides] + for dim in type_dim[int]: + if (i := indices_filtered[dim]) >= (sh := self.shape[dim]) or i < -sh: + raise IndexError(f"index {i} is out of bounds for dimension {dim} with size {sh}") + indices_filtered[dim] = ((i, i+1), 1) if i >= 0 else ((sh+i, sh+i+1), 1) + for dim in type_dim[slice]: + s, e, st = indices_filtered[dim].indices(self.shape[dim]) + indices_filtered[dim] = ((0, 0) if (st > 0 and e < s) or (st <= 0 and e > s) else (s, e) if st > 0 else (e+1, s+1), st) + for dim in type_dim[Tensor]: indices_filtered[dim] = ((0, self.shape[dim]), 1) + new_slice, strides = ((),()) if not indices_filtered else zip(*indices_filtered) + ret = self.shrink(new_slice).flip(axis=[i for i, s in enumerate(strides) if s < 0]) # add strides by pad -> reshape -> shrink if any(abs(s) != 1 for s in strides): strides = tuple(abs(s) for s in strides) - padded_tensor = sliced_tensor.pad(tuple((0, s-(dim_sz % s) if dim_sz % s != 0 else 0) for s, dim_sz in zip(strides, sliced_tensor.shape))) - reshaped_tensor = padded_tensor.reshape(flatten([sh // s, s] for sh, s in zip(padded_tensor.shape, strides))) - new_shape = list(reshaped_tensor.shape[::2]) - sliced_tensor = reshaped_tensor.shrink(tuple(flatten(((0, sh), (0, 1)) for sh in new_shape))) + ret = ret.pad(tuple((0, round_up(sh, s) - sh) for s, sh in zip(strides, ret.shape))) + ret = ret.reshape(flatten([sh // s, s] for s, sh in zip(strides, ret.shape))) + ret = ret.shrink(tuple(flatten(((0, sh), (0, 1)) for sh in ret.shape[::2]))).reshape(ret.shape[::2]) - # inject dim=1 for None and collapse dim for int + # inject 1 for dim where it's None and collapse dim for int + new_shape = list(ret.shape) for dim in type_dim[None]: new_shape.insert(dim, 1) for dim in (dims_collapsed := [dim + sum(1 for d in type_dim[None] if dim >= d) for dim in reversed(type_dim[int])]): new_shape.pop(dim) for dim_sh in new_shape: assert isinstance(dim_sh, int), f"does not support symbolic shape {dim_sh}" - ret = sliced_tensor.reshape(tuple(new_shape)) + ret = ret.reshape(tuple(new_shape)) # 3. advanced indexing (copy) if type_dim[Tensor]: