NumPy-like semantics for Tensor.__getitem__ (#506)

* Rewrote Tensor.__getitem__ to fix negative indices and add support for np.newaxis/None

* Fixed pad2d

* mypy doesn't know about mlops methods

* normal python behavior for out-of-bounds slicing

* type: ignore

* inlined idxfix

* added comment for __getitem__

* Better comments, better tests, and fixed bug in np.newaxis
This commit is contained in:
Mitchell Goff
2023-02-08 06:59:46 -08:00
committed by GitHub
parent 0ac3286af0
commit ae4f0aeb5f
2 changed files with 40 additions and 10 deletions

View File

@@ -1,6 +1,7 @@
import numpy as np
import torch
import unittest
import itertools
from tinygrad.tensor import Tensor, Device
from extra.gradcheck import numerical_jacobian, jacobian, gradcheck
@@ -20,6 +21,20 @@ class TestTinygrad(unittest.TestCase):
val2 = a.numpy()
np.testing.assert_allclose(val1, val2)
def test_slicing(self):
x = Tensor.randn(10,10)
slices = [0,1,9,-1,-10,None] + [slice(s,e) for s,e in itertools.combinations([0,1,-1,None], r=2)] + [slice(9,11), slice(-11,-9)]
fmt = lambda s: f'{s.start}:{s.stop}' if isinstance(s, slice) else str(s)
for s in list(itertools.product(slices, slices)) + [(None,0,None,0,None), (slice(0,2),None,None,slice(2,4),None,None)]:
np.testing.assert_equal(x.numpy()[s], x[s].numpy(), f'Test failed for slice x[{",".join(fmt(x) for x in s)}]')
for s in [-11,10]:
with self.assertRaises(IndexError):
x[s]
with self.assertRaises(AssertionError):
x[::2]
with self.assertRaises(AssertionError):
x[0,0,0]
def test_backward_pass(self):
def test_tinygrad():
x = Tensor(x_init, requires_grad=True)