diff --git a/test/test_tensor.py b/test/test_tensor.py index 65d579bb30..9d4c7801a0 100644 --- a/test/test_tensor.py +++ b/test/test_tensor.py @@ -1,6 +1,7 @@ import numpy as np import torch import unittest +import itertools from tinygrad.tensor import Tensor, Device from extra.gradcheck import numerical_jacobian, jacobian, gradcheck @@ -20,6 +21,20 @@ class TestTinygrad(unittest.TestCase): val2 = a.numpy() np.testing.assert_allclose(val1, val2) + def test_slicing(self): + x = Tensor.randn(10,10) + slices = [0,1,9,-1,-10,None] + [slice(s,e) for s,e in itertools.combinations([0,1,-1,None], r=2)] + [slice(9,11), slice(-11,-9)] + fmt = lambda s: f'{s.start}:{s.stop}' if isinstance(s, slice) else str(s) + for s in list(itertools.product(slices, slices)) + [(None,0,None,0,None), (slice(0,2),None,None,slice(2,4),None,None)]: + np.testing.assert_equal(x.numpy()[s], x[s].numpy(), f'Test failed for slice x[{",".join(fmt(x) for x in s)}]') + for s in [-11,10]: + with self.assertRaises(IndexError): + x[s] + with self.assertRaises(AssertionError): + x[::2] + with self.assertRaises(AssertionError): + x[0,0,0] + def test_backward_pass(self): def test_tinygrad(): x = Tensor(x_init, requires_grad=True) diff --git a/tinygrad/tensor.py b/tinygrad/tensor.py index 8f0e73d5f7..74344cba8f 100644 --- a/tinygrad/tensor.py +++ b/tinygrad/tensor.py @@ -153,16 +153,31 @@ class Tensor: # ***** non first class ops (hlops) ***** + # Tensors mostly follow the normal python indexing / slicing behavior for sequences + # - Negative indices are taken relative to the end of the sequence, so X[-2] returns the 2nd-to-last element + # - A slice i:j returns the elements with indices in [i, j) + # - If omitted, i and j will default to 0 and N, respectively, where N is the length of the sequence + # - Negative values for i and j are taken relative to the end of the sequence + # - Both i and j will be clamped to the range (-N, N], where N in the length of the sequence + # - Indexing with np.newaxis or None on a given axis will add a new dimension of size one before that axis + # - Empty slices are not allowed + # - Strides other than 1 are not allowedå def __getitem__(self, val): - arg, new_shape = [], [] - for i, rs in enumerate(val if isinstance(val, (list, tuple)) else [val]) if val is not None else []: - s = slice(rs, rs+1, None) if isinstance(rs, int) else rs - arg.append((s.start if s.start is not None else 0, (s.stop if s.stop>=0 else self.shape[i]+s.stop) if s.stop is not None else self.shape[i])) - assert s.step is None or s.step == 1 - if not isinstance(rs, int): # don't include in shape if it's an int - new_shape.append(arg[-1][1] - arg[-1][0]) - new_shape += [self.shape[i] for i in range(len(arg), len(self.shape))] - return self.slice(arg = arg + [(0,self.shape[i]) for i in range(len(arg), len(self.shape))]).reshape(new_shape if len(new_shape) else (1,)) + def slcfix(i, sz, default): return default if i is None else max(0, min(sz, sz+i if i < 0 else i)) # Fix negative idxs, clamp to [0,N] + new_slice, new_shape = [], [] + val = [val] if not isinstance(val, (list, tuple)) else val + assert sum(s is not None for s in val) <= len(self.shape) + assert all(s.step is None or s.step == 1 for s in val if isinstance(s, slice)) + for i,(sz,s) in enumerate(zip(self.shape, (v for v in val if v is not None))): # Slicing only depends on ints + slices + if isinstance(s, int) and not (-sz <= s < sz): + raise IndexError(f"index {s} is out of bounds for dimension {i} with size {sz}") + new_slice.append((s%sz, s%sz+1) if isinstance(s, int) else (slcfix(s.start, sz, 0), slcfix(s.stop, sz, sz))) + for s,sz in zip(val, (self.shape[i-1] for i in itertools.accumulate(s is not None for s in val))): # Shape depends on slices + positions of Nones + if not isinstance(s, int): + new_shape.append(1 if s is None else slcfix(s.stop, sz, sz) - slcfix(s.start, sz, 0)) + new_shape += [self.shape[i] for i in range(len(new_slice), len(self.shape))] + new_slice += [(0,self.shape[i]) for i in range(len(new_slice), len(self.shape))] + return self.slice(arg = new_slice).reshape(new_shape if len(new_shape) else (1,)) def cat(self, *args, dim=0): dim = (dim + len(self.shape)) if dim < 0 else dim @@ -204,7 +219,7 @@ class Tensor: dot = matmul # (padding_left, padding_right, padding_top, padding_bottom) - def pad2d(self, padding:Tuple[int, ...]): return self[:, :, -padding[2]:self.shape[2]+padding[3], -padding[0]:self.shape[3]+padding[1]] + def pad2d(self, padding:Tuple[int, ...]): return self.slice(arg = [(0,self.shape[0]), (0,self.shape[1]), (-padding[2],self.shape[2]+padding[3]), (-padding[0],self.shape[3]+padding[1])]) # type: ignore # TODO: this is totally not transpose def transpose(self, order=(1,0)): return self.permute(order=order) def flatten(self, start_dim=0): return self.reshape(shape=tuple(list(self.shape[0:start_dim]) + [-1]))