diff --git a/tinygrad/tensor.py b/tinygrad/tensor.py index fceb122c4a..9788bfc032 100644 --- a/tinygrad/tensor.py +++ b/tinygrad/tensor.py @@ -1,8 +1,8 @@ # inspired by https://github.com/karpathy/micrograd/blob/master/micrograd/engine.py from __future__ import annotations -import time, operator +import time from functools import partialmethod, reduce -from itertools import accumulate, filterfalse +from itertools import accumulate import numpy as np from typing import List, Tuple, Callable, Optional, ClassVar, Type, Union, Sequence, cast from math import ceil, pi, prod, sqrt, log, cos, copysign, isinf @@ -277,27 +277,23 @@ class Tensor: def normalize_int(e, i, dim_sz): if -dim_sz <= e < dim_sz: return e if e != -1 else dim_sz-1 raise IndexError(f"index {e} is out of bounds for dimension {i} with size {self.shape[i]}") - val = list(val) if isinstance(val, tuple) else [val] - if (num_slices := sum(isinstance(v, (slice, int, Tensor)) for v in val)) > len(self.shape): + orig_slices = list(val) if isinstance(val, tuple) else [val] + if (num_slices := sum(isinstance(v, (slice, int, Tensor)) for v in orig_slices)) > len(self.shape): raise IndexError(f"too many indices for tensor of dimension {len(self.shape)}") - orig_slices = list(val) - ellipses_found = [i for i, v in enumerate(val) if v is Ellipsis] - if len(ellipses_found) > 0: - if len(ellipses_found) != 1: - raise IndexError("an index can only have a single ellipsis ('...')") - ellipsis_idx = ellipses_found[0] - orig_slices[ellipsis_idx:ellipsis_idx+1] = [slice(None)] * (len(self.shape) - num_slices) - else: - orig_slices += [slice(None)] * (len(self.shape) - num_slices) + ellipses_found = [i for i, v in enumerate(orig_slices) if v is Ellipsis] + if len(ellipses_found) > 1: raise IndexError("an index can only have a single ellipsis ('...')") + ellipsis_idx = len(orig_slices) if len(ellipses_found) == 0 else ellipses_found[0] + orig_slices[ellipsis_idx:ellipsis_idx+1] = [slice(None)] * (len(self.shape) - num_slices) + tensor_found = [(i,v) for i, v in enumerate(orig_slices) if isinstance(v, Tensor)] - orig_slices = [slice(None, None, None) if isinstance(v, Tensor) else v for v in orig_slices] - valid_slices = list(filterfalse(lambda x: x is None, orig_slices)) + orig_slices = [slice(None) if isinstance(v, Tensor) else v for v in orig_slices] + valid_slices = [s for s in orig_slices if s is not None] valid_slices = [v if isinstance(v, slice) else slice(y := normalize_int(v, i, dim_sz), y+1) for i, (v, dim_sz) in enumerate(zip(valid_slices, self.shape))] start, stop, strides = zip(*y) if (y := [s.indices(dim_sz) for s, dim_sz in zip(valid_slices, self.shape)]) else ((), (), ()) new_slice = tuple((s, e) if st > 0 else (e+1, s+1) for s, e, st in zip(start, stop, strides)) - new_shape = tuple(e - s for s, e in new_slice) # Shrink sliced_tensor = self.shrink(new_slice) + new_shape = sliced_tensor.shape # Flip if (flip_axes := tuple(i for i, s in enumerate(strides) if s < 0)): sliced_tensor = sliced_tensor.flip(axis=flip_axes) @@ -309,15 +305,14 @@ class Tensor: paddings = tuple((0, num_zeros(s, dim_sz)) for s, dim_sz in zip(strides, sliced_tensor.shape)) padded_tensor = sliced_tensor.pad(paddings) # Reshape: [dim_sz_padded] -> [dim_sz_padded // s, s] - new_shape = reduce(operator.add, [[sh // s, s] for sh, s in zip(padded_tensor.shape, strides)], []) # type: ignore + new_shape = flatten([sh // s, s] for sh, s in zip(padded_tensor.shape, strides)) reshaped_tensor = padded_tensor.reshape(new_shape) # Shrink: do [:, 0] new_shape = new_shape[::2] - final_slice = reduce(operator.add, (((0, sh), (0, 1)) for sh in new_shape), ()) + final_slice = tuple(flatten(((0, sh), (0, 1)) for sh in new_shape)) sliced_tensor = reshaped_tensor.shrink(final_slice) - final_shape = [] + final_shape, it_shape = [], iter(new_shape) sub = [0] * len(tensor_found) - it_shape = iter(new_shape) for i,s in enumerate(orig_slices): if isinstance(s, (int, slice)): dim_shape = next(it_shape) @@ -332,14 +327,14 @@ class Tensor: for i,s in enumerate(sub): tensor_found[i] = (tensor_found[i][0]+s, tensor_found[i][1]) dim = [i[0] for i in tensor_found] idx = [i[1].sign().contiguous().__neg__().contiguous().relu() * ret.shape[i[0]] + i[1] for i in tensor_found] # TODO first contiguous fixes torch+cpu_only CI, but it causes llvm to fail. Second one fixes llvm - max_dim = max(idx, key=lambda i: i.ndim).ndim - idx = [i if i.ndim == max_dim else i.reshape(*[1]*(max_dim-i.ndim), *i.shape) for i in idx] - sum_dim = [d if n==0 else d+i.ndim-n for n,(d,i) in enumerate(zip(dim,idx))] - new_idx = idx[0].reshape(*[1]*sum_dim[0], 1, *idx[0].shape, *[1]*(ret.ndim-sum_dim[0]-1)) - arange = Tensor.arange(ret.shape[sum_dim[0]], dtype=dtypes.int32, requires_grad=False).reshape(*[1]*sum_dim[0], ret.shape[sum_dim[0]], *[1]*idx[0].ndim, *[1]*(ret.ndim-sum_dim[0]-1)) - ret = (ret.reshape(*ret.shape[:sum_dim[0]+1], *[1]*idx[0].ndim, *ret.shape[sum_dim[0]+1:]) * (arange == new_idx)).sum(sum_dim[0]) + max_dim = max(i.ndim for i in idx) + idx = [i.reshape(*[1]*(max_dim-i.ndim), *i.shape) for i in idx] + sum_dim = [d+max_dim-n for n,d in enumerate(dim)] + new_idx = idx[0].reshape(*[1]*dim[0], 1,*idx[0].shape, *[1]*(ret.ndim-dim[0]-1)) + arange = Tensor.arange(ret.shape[dim[0]], dtype=dtypes.int32, requires_grad=False).reshape(*[1]*dim[0], ret.shape[dim[0]], *[1]*idx[0].ndim, *[1]*(ret.ndim-dim[0]-1)) + ret = (ret.reshape(*ret.shape[:dim[0]+1], *[1]*idx[0].ndim, *ret.shape[dim[0]+1:]) * (arange == new_idx)).sum(dim[0]) for idx_,d in zip(idx[1:],sum_dim[1:]): - new_idx = idx_.reshape(*[1]*sum_dim[0], *idx_.shape, *[1]*(ret.ndim-sum_dim[0]-idx_.ndim)) + new_idx = idx_.reshape(*[1]*dim[0], *idx_.shape, *[1]*(ret.ndim-dim[0]-idx_.ndim)) arange = Tensor.arange(ret.shape[d], dtype=dtypes.int32, requires_grad=False).reshape(*[1]*(d), ret.shape[d], *[1]*(ret.ndim-d-1)) ret = ((new_idx == arange) * ret).sum(d) if dim[0] != 0 and dim != list(range(dim[0], dim[-1]+1)) and len(dim) != 1: # special permute case