tests from the indexing fusion branch (#5886)

This commit is contained in:
qazal
2024-08-03 16:56:48 +08:00
committed by GitHub
parent a77eab89ca
commit af59b2eea9

View File

@@ -1266,5 +1266,72 @@ class TestSchedule(unittest.TestCase):
out = x.argmax(1)
run_schedule(check_schedule(out, 3)) # TODO: push a reduceop through a reshape
class TestIndexing(unittest.TestCase):
def check_schedule(self, xt:Tensor, cnt:int):
s = xt.schedule()
kernel_cnt = len([si for si in s if si.ast.op is MetaOps.KERNEL])
run_schedule(s)
self.assertEqual(kernel_cnt, cnt)
def test_simple_indexing(self):
X = Tensor.randn(10, 10).realize()
idxs = Tensor([0, 2]).realize()
xt = X[idxs]
self.check_schedule(xt, 3)
np.testing.assert_equal(xt.numpy(), X.numpy()[idxs.numpy()])
def test_simple_indexing_alt(self):
X = Tensor.arange(16).reshape(4, 4)
xt = X[[1, 2], [1, 2]]
self.check_schedule(xt, 5)
np.testing.assert_equal(xt.numpy(), (np.arange(16).reshape(4, 4))[[1, 2], [1, 2]])
def test_advanced_indexing(self):
X = Tensor.arange(10)+1
xt = X[[0]]
self.check_schedule(xt, 3)
np.testing.assert_equal(xt.numpy(), (np.arange(10)+1)[[0]])
def test_advanced_indexing_alt(self):
X = Tensor.arange(6).reshape(3, 2)+1
xt = X[[Tensor([2]), Tensor([1])]]
self.check_schedule(xt, 6)
np.testing.assert_equal(xt.numpy(), 6)
def test_advanced_simple_indexing_combined(self):
X = Tensor.arange(16).reshape(4, 4)
xt = X[1:2, [1, 2]]
self.check_schedule(xt, 4)
np.testing.assert_equal(xt.numpy(), np.arange(16).reshape(4, 4)[1:2, [1, 2]])
def test_push_through_reshape(self):
Tensor.manual_seed(0)
x = Tensor.randn(10, 20).realize()
out = x.argmax(1)
self.check_schedule(out, 3)
np.testing.assert_allclose(out.numpy(), np.argmax(x.numpy(), 1))
def test_arange_push_through_expand(self):
Tensor.manual_seed(0)
a = Tensor.arange(4,)
b = Tensor.randn(4, 4).realize()
out = a+b
self.check_schedule(out, 2)
np.testing.assert_allclose(out.numpy(), np.arange(4)+b.numpy())
def test_argmin(self):
Tensor.manual_seed(0)
x = Tensor.randn(4, 32).realize()
out = x.argmin(-1)
self.check_schedule(out, 3)
np.testing.assert_equal(out.numpy(), x.numpy().argmin(axis=-1))
def test_argmax(self):
Tensor.manual_seed(0)
x = Tensor.randn(4, 32).realize()
out = x.argmax(-1)
self.check_schedule(out, 3)
np.testing.assert_equal(out.numpy(), x.numpy().argmax(axis=-1))
if __name__ == '__main__':
unittest.main(verbosity=2)