mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 15:08:02 -05:00
NumPy-like semantics for Tensor.__getitem__ (#506)
* Rewrote Tensor.__getitem__ to fix negative indices and add support for np.newaxis/None * Fixed pad2d * mypy doesn't know about mlops methods * normal python behavior for out-of-bounds slicing * type: ignore * inlined idxfix * added comment for __getitem__ * Better comments, better tests, and fixed bug in np.newaxis
This commit is contained in:
@@ -1,6 +1,7 @@
|
||||
import numpy as np
|
||||
import torch
|
||||
import unittest
|
||||
import itertools
|
||||
from tinygrad.tensor import Tensor, Device
|
||||
from extra.gradcheck import numerical_jacobian, jacobian, gradcheck
|
||||
|
||||
@@ -20,6 +21,20 @@ class TestTinygrad(unittest.TestCase):
|
||||
val2 = a.numpy()
|
||||
np.testing.assert_allclose(val1, val2)
|
||||
|
||||
def test_slicing(self):
|
||||
x = Tensor.randn(10,10)
|
||||
slices = [0,1,9,-1,-10,None] + [slice(s,e) for s,e in itertools.combinations([0,1,-1,None], r=2)] + [slice(9,11), slice(-11,-9)]
|
||||
fmt = lambda s: f'{s.start}:{s.stop}' if isinstance(s, slice) else str(s)
|
||||
for s in list(itertools.product(slices, slices)) + [(None,0,None,0,None), (slice(0,2),None,None,slice(2,4),None,None)]:
|
||||
np.testing.assert_equal(x.numpy()[s], x[s].numpy(), f'Test failed for slice x[{",".join(fmt(x) for x in s)}]')
|
||||
for s in [-11,10]:
|
||||
with self.assertRaises(IndexError):
|
||||
x[s]
|
||||
with self.assertRaises(AssertionError):
|
||||
x[::2]
|
||||
with self.assertRaises(AssertionError):
|
||||
x[0,0,0]
|
||||
|
||||
def test_backward_pass(self):
|
||||
def test_tinygrad():
|
||||
x = Tensor(x_init, requires_grad=True)
|
||||
|
||||
Reference in New Issue
Block a user