grouper tests from fuse_arange_default [pr] (#10394)

This commit is contained in:
qazal
2025-05-18 18:42:43 +03:00
committed by GitHub
parent 17f0f5e764
commit 04b23087d8

View File

@@ -86,6 +86,21 @@ class TestSchedule(unittest.TestCase):
run_schedule(check_schedule(b, 1))
np.testing.assert_allclose(b.numpy(), np.broadcast_to(a.numpy().astype(np.float16), (2, 4, 4))+2)
def test_indexing_scalars_simple(self):
X = Tensor.randn(2, 2).realize()
xt = X[Tensor(1)][Tensor(0)]
with Context(FUSE_ARANGE=1):
run_schedule(check_schedule(xt, 2))
np.testing.assert_equal(xt.numpy(), X.numpy()[1][0])
@unittest.expectedFailure # TODO: failing because of can_chase
def test_indexing_scalars_multiple_dims(self):
X = Tensor.randn(2, 3).realize()
xt = X[Tensor(0)][Tensor(1)]
with Context(FUSE_ARANGE=1):
run_schedule(check_schedule(xt, 2))
np.testing.assert_equal(xt.numpy(), X.numpy()[0][1])
def test_push_pads_elementwise(self):
x = Tensor.full((4,4), 2.).contiguous().realize()
y = Tensor.full((4,4), 4.).contiguous().realize()