diff --git a/test/test_schedule.py b/test/test_schedule.py index 057debc297..cdc15d324d 100644 --- a/test/test_schedule.py +++ b/test/test_schedule.py @@ -1266,5 +1266,72 @@ class TestSchedule(unittest.TestCase): out = x.argmax(1) run_schedule(check_schedule(out, 3)) # TODO: push a reduceop through a reshape +class TestIndexing(unittest.TestCase): + def check_schedule(self, xt:Tensor, cnt:int): + s = xt.schedule() + kernel_cnt = len([si for si in s if si.ast.op is MetaOps.KERNEL]) + run_schedule(s) + self.assertEqual(kernel_cnt, cnt) + + def test_simple_indexing(self): + X = Tensor.randn(10, 10).realize() + idxs = Tensor([0, 2]).realize() + xt = X[idxs] + self.check_schedule(xt, 3) + np.testing.assert_equal(xt.numpy(), X.numpy()[idxs.numpy()]) + + def test_simple_indexing_alt(self): + X = Tensor.arange(16).reshape(4, 4) + xt = X[[1, 2], [1, 2]] + self.check_schedule(xt, 5) + np.testing.assert_equal(xt.numpy(), (np.arange(16).reshape(4, 4))[[1, 2], [1, 2]]) + + def test_advanced_indexing(self): + X = Tensor.arange(10)+1 + xt = X[[0]] + self.check_schedule(xt, 3) + np.testing.assert_equal(xt.numpy(), (np.arange(10)+1)[[0]]) + + def test_advanced_indexing_alt(self): + X = Tensor.arange(6).reshape(3, 2)+1 + xt = X[[Tensor([2]), Tensor([1])]] + self.check_schedule(xt, 6) + np.testing.assert_equal(xt.numpy(), 6) + + def test_advanced_simple_indexing_combined(self): + X = Tensor.arange(16).reshape(4, 4) + xt = X[1:2, [1, 2]] + self.check_schedule(xt, 4) + np.testing.assert_equal(xt.numpy(), np.arange(16).reshape(4, 4)[1:2, [1, 2]]) + + def test_push_through_reshape(self): + Tensor.manual_seed(0) + x = Tensor.randn(10, 20).realize() + out = x.argmax(1) + self.check_schedule(out, 3) + np.testing.assert_allclose(out.numpy(), np.argmax(x.numpy(), 1)) + + def test_arange_push_through_expand(self): + Tensor.manual_seed(0) + a = Tensor.arange(4,) + b = Tensor.randn(4, 4).realize() + out = a+b + self.check_schedule(out, 2) + np.testing.assert_allclose(out.numpy(), np.arange(4)+b.numpy()) + + def test_argmin(self): + Tensor.manual_seed(0) + x = Tensor.randn(4, 32).realize() + out = x.argmin(-1) + self.check_schedule(out, 3) + np.testing.assert_equal(out.numpy(), x.numpy().argmin(axis=-1)) + + def test_argmax(self): + Tensor.manual_seed(0) + x = Tensor.randn(4, 32).realize() + out = x.argmax(-1) + self.check_schedule(out, 3) + np.testing.assert_equal(out.numpy(), x.numpy().argmax(axis=-1)) + if __name__ == '__main__': unittest.main(verbosity=2)