push permutes through fused reduces (#10628)

* fix pushing reshapes through reduceops

* reduceop_view_right should assert on ndims mismatch

* update that, view.reshape asserts it
This commit is contained in:
qazal
2025-06-05 16:14:04 +03:00
committed by GitHub
parent 8db0ba1161
commit 8c5ea00522
2 changed files with 13 additions and 8 deletions

View File

@@ -89,14 +89,20 @@ class TestSchedule(unittest.TestCase):
with self.assertRaises(RecursionError):
with Context(FUSE_ARANGE=1, NOOPT=0): self.test_arange_avgpool2d(kcount=1)
# grouper error
@unittest.expectedFailure
# when we're fusing a reduce, all ReduceOps must have the same N in the dimensions
# all permutes, reshapes, expands and shrinks push through the reduce
def test_arange_sum(self):
a = Tensor.arange(6).reshape(3, 2).sum(axis=1)
with Context(FUSE_ARANGE=1):
run_schedule(check_schedule(a, 1))
self.assertListEqual(a.tolist(), [1, 5, 9])
def test_permute_arange(self):
a = Tensor.arange(6).reshape(6, 1, 1).permute(2, 0, 1).sum(axis=1)
with Context(FUSE_ARANGE=1):
run_schedule(check_schedule(a, 1))
self.assertListEqual(a.tolist(), [[15]])
@unittest.skipIf(Device.DEFAULT == "CPU", "devices must mismatch")
def test_error_on_device_mismatch(self):
a = Tensor.empty(10)