__getitem__ refactoring (#1586)

* dene

* dene

* form

* form

* form

* form

* lint

* small change

* preserve old

* revert to explicit reshape
This commit is contained in:
Umut Zengin
2023-08-21 04:42:30 +03:00
committed by GitHub
parent d627349af0
commit 3fc7e984f0

View File

@@ -1,8 +1,8 @@
# inspired by https://github.com/karpathy/micrograd/blob/master/micrograd/engine.py
from __future__ import annotations
import time, operator
import time
from functools import partialmethod, reduce
from itertools import accumulate, filterfalse
from itertools import accumulate
import numpy as np
from typing import List, Tuple, Callable, Optional, ClassVar, Type, Union, Sequence, cast
from math import ceil, pi, prod, sqrt, log, cos, copysign, isinf
@@ -277,27 +277,23 @@ class Tensor:
def normalize_int(e, i, dim_sz):
if -dim_sz <= e < dim_sz: return e if e != -1 else dim_sz-1
raise IndexError(f"index {e} is out of bounds for dimension {i} with size {self.shape[i]}")
val = list(val) if isinstance(val, tuple) else [val]
if (num_slices := sum(isinstance(v, (slice, int, Tensor)) for v in val)) > len(self.shape):
orig_slices = list(val) if isinstance(val, tuple) else [val]
if (num_slices := sum(isinstance(v, (slice, int, Tensor)) for v in orig_slices)) > len(self.shape):
raise IndexError(f"too many indices for tensor of dimension {len(self.shape)}")
orig_slices = list(val)
ellipses_found = [i for i, v in enumerate(val) if v is Ellipsis]
if len(ellipses_found) > 0:
if len(ellipses_found) != 1:
raise IndexError("an index can only have a single ellipsis ('...')")
ellipsis_idx = ellipses_found[0]
orig_slices[ellipsis_idx:ellipsis_idx+1] = [slice(None)] * (len(self.shape) - num_slices)
else:
orig_slices += [slice(None)] * (len(self.shape) - num_slices)
ellipses_found = [i for i, v in enumerate(orig_slices) if v is Ellipsis]
if len(ellipses_found) > 1: raise IndexError("an index can only have a single ellipsis ('...')")
ellipsis_idx = len(orig_slices) if len(ellipses_found) == 0 else ellipses_found[0]
orig_slices[ellipsis_idx:ellipsis_idx+1] = [slice(None)] * (len(self.shape) - num_slices)
tensor_found = [(i,v) for i, v in enumerate(orig_slices) if isinstance(v, Tensor)]
orig_slices = [slice(None, None, None) if isinstance(v, Tensor) else v for v in orig_slices]
valid_slices = list(filterfalse(lambda x: x is None, orig_slices))
orig_slices = [slice(None) if isinstance(v, Tensor) else v for v in orig_slices]
valid_slices = [s for s in orig_slices if s is not None]
valid_slices = [v if isinstance(v, slice) else slice(y := normalize_int(v, i, dim_sz), y+1) for i, (v, dim_sz) in enumerate(zip(valid_slices, self.shape))]
start, stop, strides = zip(*y) if (y := [s.indices(dim_sz) for s, dim_sz in zip(valid_slices, self.shape)]) else ((), (), ())
new_slice = tuple((s, e) if st > 0 else (e+1, s+1) for s, e, st in zip(start, stop, strides))
new_shape = tuple(e - s for s, e in new_slice)
# Shrink
sliced_tensor = self.shrink(new_slice)
new_shape = sliced_tensor.shape
# Flip
if (flip_axes := tuple(i for i, s in enumerate(strides) if s < 0)):
sliced_tensor = sliced_tensor.flip(axis=flip_axes)
@@ -309,15 +305,14 @@ class Tensor:
paddings = tuple((0, num_zeros(s, dim_sz)) for s, dim_sz in zip(strides, sliced_tensor.shape))
padded_tensor = sliced_tensor.pad(paddings)
# Reshape: [dim_sz_padded] -> [dim_sz_padded // s, s]
new_shape = reduce(operator.add, [[sh // s, s] for sh, s in zip(padded_tensor.shape, strides)], []) # type: ignore
new_shape = flatten([sh // s, s] for sh, s in zip(padded_tensor.shape, strides))
reshaped_tensor = padded_tensor.reshape(new_shape)
# Shrink: do [:, 0]
new_shape = new_shape[::2]
final_slice = reduce(operator.add, (((0, sh), (0, 1)) for sh in new_shape), ())
final_slice = tuple(flatten(((0, sh), (0, 1)) for sh in new_shape))
sliced_tensor = reshaped_tensor.shrink(final_slice)
final_shape = []
final_shape, it_shape = [], iter(new_shape)
sub = [0] * len(tensor_found)
it_shape = iter(new_shape)
for i,s in enumerate(orig_slices):
if isinstance(s, (int, slice)):
dim_shape = next(it_shape)
@@ -332,14 +327,14 @@ class Tensor:
for i,s in enumerate(sub): tensor_found[i] = (tensor_found[i][0]+s, tensor_found[i][1])
dim = [i[0] for i in tensor_found]
idx = [i[1].sign().contiguous().__neg__().contiguous().relu() * ret.shape[i[0]] + i[1] for i in tensor_found] # TODO first contiguous fixes torch+cpu_only CI, but it causes llvm to fail. Second one fixes llvm
max_dim = max(idx, key=lambda i: i.ndim).ndim
idx = [i if i.ndim == max_dim else i.reshape(*[1]*(max_dim-i.ndim), *i.shape) for i in idx]
sum_dim = [d if n==0 else d+i.ndim-n for n,(d,i) in enumerate(zip(dim,idx))]
new_idx = idx[0].reshape(*[1]*sum_dim[0], 1, *idx[0].shape, *[1]*(ret.ndim-sum_dim[0]-1))
arange = Tensor.arange(ret.shape[sum_dim[0]], dtype=dtypes.int32, requires_grad=False).reshape(*[1]*sum_dim[0], ret.shape[sum_dim[0]], *[1]*idx[0].ndim, *[1]*(ret.ndim-sum_dim[0]-1))
ret = (ret.reshape(*ret.shape[:sum_dim[0]+1], *[1]*idx[0].ndim, *ret.shape[sum_dim[0]+1:]) * (arange == new_idx)).sum(sum_dim[0])
max_dim = max(i.ndim for i in idx)
idx = [i.reshape(*[1]*(max_dim-i.ndim), *i.shape) for i in idx]
sum_dim = [d+max_dim-n for n,d in enumerate(dim)]
new_idx = idx[0].reshape(*[1]*dim[0], 1,*idx[0].shape, *[1]*(ret.ndim-dim[0]-1))
arange = Tensor.arange(ret.shape[dim[0]], dtype=dtypes.int32, requires_grad=False).reshape(*[1]*dim[0], ret.shape[dim[0]], *[1]*idx[0].ndim, *[1]*(ret.ndim-dim[0]-1))
ret = (ret.reshape(*ret.shape[:dim[0]+1], *[1]*idx[0].ndim, *ret.shape[dim[0]+1:]) * (arange == new_idx)).sum(dim[0])
for idx_,d in zip(idx[1:],sum_dim[1:]):
new_idx = idx_.reshape(*[1]*sum_dim[0], *idx_.shape, *[1]*(ret.ndim-sum_dim[0]-idx_.ndim))
new_idx = idx_.reshape(*[1]*dim[0], *idx_.shape, *[1]*(ret.ndim-dim[0]-idx_.ndim))
arange = Tensor.arange(ret.shape[d], dtype=dtypes.int32, requires_grad=False).reshape(*[1]*(d), ret.shape[d], *[1]*(ret.ndim-d-1))
ret = ((new_idx == arange) * ret).sum(d)
if dim[0] != 0 and dim != list(range(dim[0], dim[-1]+1)) and len(dim) != 1: # special permute case