mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-22 13:28:06 -05:00
__getitem__ refactoring (#1586)
* dene * dene * form * form * form * form * lint * small change * preserve old * revert to explicit reshape
This commit is contained in:
@@ -1,8 +1,8 @@
|
||||
# inspired by https://github.com/karpathy/micrograd/blob/master/micrograd/engine.py
|
||||
from __future__ import annotations
|
||||
import time, operator
|
||||
import time
|
||||
from functools import partialmethod, reduce
|
||||
from itertools import accumulate, filterfalse
|
||||
from itertools import accumulate
|
||||
import numpy as np
|
||||
from typing import List, Tuple, Callable, Optional, ClassVar, Type, Union, Sequence, cast
|
||||
from math import ceil, pi, prod, sqrt, log, cos, copysign, isinf
|
||||
@@ -277,27 +277,23 @@ class Tensor:
|
||||
def normalize_int(e, i, dim_sz):
|
||||
if -dim_sz <= e < dim_sz: return e if e != -1 else dim_sz-1
|
||||
raise IndexError(f"index {e} is out of bounds for dimension {i} with size {self.shape[i]}")
|
||||
val = list(val) if isinstance(val, tuple) else [val]
|
||||
if (num_slices := sum(isinstance(v, (slice, int, Tensor)) for v in val)) > len(self.shape):
|
||||
orig_slices = list(val) if isinstance(val, tuple) else [val]
|
||||
if (num_slices := sum(isinstance(v, (slice, int, Tensor)) for v in orig_slices)) > len(self.shape):
|
||||
raise IndexError(f"too many indices for tensor of dimension {len(self.shape)}")
|
||||
orig_slices = list(val)
|
||||
ellipses_found = [i for i, v in enumerate(val) if v is Ellipsis]
|
||||
if len(ellipses_found) > 0:
|
||||
if len(ellipses_found) != 1:
|
||||
raise IndexError("an index can only have a single ellipsis ('...')")
|
||||
ellipsis_idx = ellipses_found[0]
|
||||
orig_slices[ellipsis_idx:ellipsis_idx+1] = [slice(None)] * (len(self.shape) - num_slices)
|
||||
else:
|
||||
orig_slices += [slice(None)] * (len(self.shape) - num_slices)
|
||||
ellipses_found = [i for i, v in enumerate(orig_slices) if v is Ellipsis]
|
||||
if len(ellipses_found) > 1: raise IndexError("an index can only have a single ellipsis ('...')")
|
||||
ellipsis_idx = len(orig_slices) if len(ellipses_found) == 0 else ellipses_found[0]
|
||||
orig_slices[ellipsis_idx:ellipsis_idx+1] = [slice(None)] * (len(self.shape) - num_slices)
|
||||
|
||||
tensor_found = [(i,v) for i, v in enumerate(orig_slices) if isinstance(v, Tensor)]
|
||||
orig_slices = [slice(None, None, None) if isinstance(v, Tensor) else v for v in orig_slices]
|
||||
valid_slices = list(filterfalse(lambda x: x is None, orig_slices))
|
||||
orig_slices = [slice(None) if isinstance(v, Tensor) else v for v in orig_slices]
|
||||
valid_slices = [s for s in orig_slices if s is not None]
|
||||
valid_slices = [v if isinstance(v, slice) else slice(y := normalize_int(v, i, dim_sz), y+1) for i, (v, dim_sz) in enumerate(zip(valid_slices, self.shape))]
|
||||
start, stop, strides = zip(*y) if (y := [s.indices(dim_sz) for s, dim_sz in zip(valid_slices, self.shape)]) else ((), (), ())
|
||||
new_slice = tuple((s, e) if st > 0 else (e+1, s+1) for s, e, st in zip(start, stop, strides))
|
||||
new_shape = tuple(e - s for s, e in new_slice)
|
||||
# Shrink
|
||||
sliced_tensor = self.shrink(new_slice)
|
||||
new_shape = sliced_tensor.shape
|
||||
# Flip
|
||||
if (flip_axes := tuple(i for i, s in enumerate(strides) if s < 0)):
|
||||
sliced_tensor = sliced_tensor.flip(axis=flip_axes)
|
||||
@@ -309,15 +305,14 @@ class Tensor:
|
||||
paddings = tuple((0, num_zeros(s, dim_sz)) for s, dim_sz in zip(strides, sliced_tensor.shape))
|
||||
padded_tensor = sliced_tensor.pad(paddings)
|
||||
# Reshape: [dim_sz_padded] -> [dim_sz_padded // s, s]
|
||||
new_shape = reduce(operator.add, [[sh // s, s] for sh, s in zip(padded_tensor.shape, strides)], []) # type: ignore
|
||||
new_shape = flatten([sh // s, s] for sh, s in zip(padded_tensor.shape, strides))
|
||||
reshaped_tensor = padded_tensor.reshape(new_shape)
|
||||
# Shrink: do [:, 0]
|
||||
new_shape = new_shape[::2]
|
||||
final_slice = reduce(operator.add, (((0, sh), (0, 1)) for sh in new_shape), ())
|
||||
final_slice = tuple(flatten(((0, sh), (0, 1)) for sh in new_shape))
|
||||
sliced_tensor = reshaped_tensor.shrink(final_slice)
|
||||
final_shape = []
|
||||
final_shape, it_shape = [], iter(new_shape)
|
||||
sub = [0] * len(tensor_found)
|
||||
it_shape = iter(new_shape)
|
||||
for i,s in enumerate(orig_slices):
|
||||
if isinstance(s, (int, slice)):
|
||||
dim_shape = next(it_shape)
|
||||
@@ -332,14 +327,14 @@ class Tensor:
|
||||
for i,s in enumerate(sub): tensor_found[i] = (tensor_found[i][0]+s, tensor_found[i][1])
|
||||
dim = [i[0] for i in tensor_found]
|
||||
idx = [i[1].sign().contiguous().__neg__().contiguous().relu() * ret.shape[i[0]] + i[1] for i in tensor_found] # TODO first contiguous fixes torch+cpu_only CI, but it causes llvm to fail. Second one fixes llvm
|
||||
max_dim = max(idx, key=lambda i: i.ndim).ndim
|
||||
idx = [i if i.ndim == max_dim else i.reshape(*[1]*(max_dim-i.ndim), *i.shape) for i in idx]
|
||||
sum_dim = [d if n==0 else d+i.ndim-n for n,(d,i) in enumerate(zip(dim,idx))]
|
||||
new_idx = idx[0].reshape(*[1]*sum_dim[0], 1, *idx[0].shape, *[1]*(ret.ndim-sum_dim[0]-1))
|
||||
arange = Tensor.arange(ret.shape[sum_dim[0]], dtype=dtypes.int32, requires_grad=False).reshape(*[1]*sum_dim[0], ret.shape[sum_dim[0]], *[1]*idx[0].ndim, *[1]*(ret.ndim-sum_dim[0]-1))
|
||||
ret = (ret.reshape(*ret.shape[:sum_dim[0]+1], *[1]*idx[0].ndim, *ret.shape[sum_dim[0]+1:]) * (arange == new_idx)).sum(sum_dim[0])
|
||||
max_dim = max(i.ndim for i in idx)
|
||||
idx = [i.reshape(*[1]*(max_dim-i.ndim), *i.shape) for i in idx]
|
||||
sum_dim = [d+max_dim-n for n,d in enumerate(dim)]
|
||||
new_idx = idx[0].reshape(*[1]*dim[0], 1,*idx[0].shape, *[1]*(ret.ndim-dim[0]-1))
|
||||
arange = Tensor.arange(ret.shape[dim[0]], dtype=dtypes.int32, requires_grad=False).reshape(*[1]*dim[0], ret.shape[dim[0]], *[1]*idx[0].ndim, *[1]*(ret.ndim-dim[0]-1))
|
||||
ret = (ret.reshape(*ret.shape[:dim[0]+1], *[1]*idx[0].ndim, *ret.shape[dim[0]+1:]) * (arange == new_idx)).sum(dim[0])
|
||||
for idx_,d in zip(idx[1:],sum_dim[1:]):
|
||||
new_idx = idx_.reshape(*[1]*sum_dim[0], *idx_.shape, *[1]*(ret.ndim-sum_dim[0]-idx_.ndim))
|
||||
new_idx = idx_.reshape(*[1]*dim[0], *idx_.shape, *[1]*(ret.ndim-dim[0]-idx_.ndim))
|
||||
arange = Tensor.arange(ret.shape[d], dtype=dtypes.int32, requires_grad=False).reshape(*[1]*(d), ret.shape[d], *[1]*(ret.ndim-d-1))
|
||||
ret = ((new_idx == arange) * ret).sum(d)
|
||||
if dim[0] != 0 and dim != list(range(dim[0], dim[-1]+1)) and len(dim) != 1: # special permute case
|
||||
|
||||
Reference in New Issue
Block a user