mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-22 21:38:10 -05:00
tests from the indexing fusion branch (#5886)
This commit is contained in:
@@ -1266,5 +1266,72 @@ class TestSchedule(unittest.TestCase):
|
||||
out = x.argmax(1)
|
||||
run_schedule(check_schedule(out, 3)) # TODO: push a reduceop through a reshape
|
||||
|
||||
class TestIndexing(unittest.TestCase):
|
||||
def check_schedule(self, xt:Tensor, cnt:int):
|
||||
s = xt.schedule()
|
||||
kernel_cnt = len([si for si in s if si.ast.op is MetaOps.KERNEL])
|
||||
run_schedule(s)
|
||||
self.assertEqual(kernel_cnt, cnt)
|
||||
|
||||
def test_simple_indexing(self):
|
||||
X = Tensor.randn(10, 10).realize()
|
||||
idxs = Tensor([0, 2]).realize()
|
||||
xt = X[idxs]
|
||||
self.check_schedule(xt, 3)
|
||||
np.testing.assert_equal(xt.numpy(), X.numpy()[idxs.numpy()])
|
||||
|
||||
def test_simple_indexing_alt(self):
|
||||
X = Tensor.arange(16).reshape(4, 4)
|
||||
xt = X[[1, 2], [1, 2]]
|
||||
self.check_schedule(xt, 5)
|
||||
np.testing.assert_equal(xt.numpy(), (np.arange(16).reshape(4, 4))[[1, 2], [1, 2]])
|
||||
|
||||
def test_advanced_indexing(self):
|
||||
X = Tensor.arange(10)+1
|
||||
xt = X[[0]]
|
||||
self.check_schedule(xt, 3)
|
||||
np.testing.assert_equal(xt.numpy(), (np.arange(10)+1)[[0]])
|
||||
|
||||
def test_advanced_indexing_alt(self):
|
||||
X = Tensor.arange(6).reshape(3, 2)+1
|
||||
xt = X[[Tensor([2]), Tensor([1])]]
|
||||
self.check_schedule(xt, 6)
|
||||
np.testing.assert_equal(xt.numpy(), 6)
|
||||
|
||||
def test_advanced_simple_indexing_combined(self):
|
||||
X = Tensor.arange(16).reshape(4, 4)
|
||||
xt = X[1:2, [1, 2]]
|
||||
self.check_schedule(xt, 4)
|
||||
np.testing.assert_equal(xt.numpy(), np.arange(16).reshape(4, 4)[1:2, [1, 2]])
|
||||
|
||||
def test_push_through_reshape(self):
|
||||
Tensor.manual_seed(0)
|
||||
x = Tensor.randn(10, 20).realize()
|
||||
out = x.argmax(1)
|
||||
self.check_schedule(out, 3)
|
||||
np.testing.assert_allclose(out.numpy(), np.argmax(x.numpy(), 1))
|
||||
|
||||
def test_arange_push_through_expand(self):
|
||||
Tensor.manual_seed(0)
|
||||
a = Tensor.arange(4,)
|
||||
b = Tensor.randn(4, 4).realize()
|
||||
out = a+b
|
||||
self.check_schedule(out, 2)
|
||||
np.testing.assert_allclose(out.numpy(), np.arange(4)+b.numpy())
|
||||
|
||||
def test_argmin(self):
|
||||
Tensor.manual_seed(0)
|
||||
x = Tensor.randn(4, 32).realize()
|
||||
out = x.argmin(-1)
|
||||
self.check_schedule(out, 3)
|
||||
np.testing.assert_equal(out.numpy(), x.numpy().argmin(axis=-1))
|
||||
|
||||
def test_argmax(self):
|
||||
Tensor.manual_seed(0)
|
||||
x = Tensor.randn(4, 32).realize()
|
||||
out = x.argmax(-1)
|
||||
self.check_schedule(out, 3)
|
||||
np.testing.assert_equal(out.numpy(), x.numpy().argmax(axis=-1))
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main(verbosity=2)
|
||||
|
||||
Reference in New Issue
Block a user