diff --git a/test/test_schedule.py b/test/test_schedule.py index e04384c68c..45d1e7f61b 100644 --- a/test/test_schedule.py +++ b/test/test_schedule.py @@ -86,6 +86,21 @@ class TestSchedule(unittest.TestCase): run_schedule(check_schedule(b, 1)) np.testing.assert_allclose(b.numpy(), np.broadcast_to(a.numpy().astype(np.float16), (2, 4, 4))+2) + def test_indexing_scalars_simple(self): + X = Tensor.randn(2, 2).realize() + xt = X[Tensor(1)][Tensor(0)] + with Context(FUSE_ARANGE=1): + run_schedule(check_schedule(xt, 2)) + np.testing.assert_equal(xt.numpy(), X.numpy()[1][0]) + + @unittest.expectedFailure # TODO: failing because of can_chase + def test_indexing_scalars_multiple_dims(self): + X = Tensor.randn(2, 3).realize() + xt = X[Tensor(0)][Tensor(1)] + with Context(FUSE_ARANGE=1): + run_schedule(check_schedule(xt, 2)) + np.testing.assert_equal(xt.numpy(), X.numpy()[0][1]) + def test_push_pads_elementwise(self): x = Tensor.full((4,4), 2.).contiguous().realize() y = Tensor.full((4,4), 4.).contiguous().realize()