fix merge_reduce_ends (#15659)

* fix merge_reduce_ends

same range with different nesting should not merge, like cumsum twice should not merge

* skip that
This commit is contained in:
chenyu
2026-04-08 17:20:01 -04:00
committed by GitHub
parent cb681da840
commit 4cf2759fc8
3 changed files with 25 additions and 3 deletions

View File

@@ -24,6 +24,10 @@ class TestArange(unittest.TestCase):
self.assertEqual(self._get_flops(Tensor.arange(256), np.arange(256)), 0)
self.assertEqual(self._get_flops(Tensor.arange(2560), np.arange(2560)), 0)
@unittest.skipIf(Device.DEFAULT == "CL", "TODO: fails on CI CL")
def test_arange_cumsum(self):
np.testing.assert_equal(Tensor.arange(513).cumsum(0).numpy(), np.arange(513).cumsum())
def test_arange_cat(self):
t = Tensor.arange(2, dtype=dtypes.int)+Tensor([3])
self.assertEqual(t.cat(t).tolist(), [3, 4, 3, 4])

View File

@@ -510,7 +510,7 @@ class TestSchedule(unittest.TestCase):
np.testing.assert_allclose(out[1].numpy(), np.sqrt(np.square(y.numpy() - np_mu).sum(-1)/y.shape[-1]), atol=1e-4, rtol=1e-4)
def test_cumsum_parallel_reduce_fused(self):
# two-stage cumsum + ops triggers parallel REDUCEs in one kernel that must share an END
# two-stage cumsum + ops triggers parallel REDUCEs in one kernel that must share an END (same nesting context = should merge)
step, num_steps = 513, 10
t = Tensor.arange(step).float().realize()
phase = t.cumsum()
@@ -521,6 +521,12 @@ class TestSchedule(unittest.TestCase):
expected = (expected * np.array([1,0,0,1,0,0,0,0,1,0]).reshape(num_steps, 1)).flatten()
np.testing.assert_allclose(out.numpy(), expected, atol=1e-4, rtol=1e-4)
@unittest.skipIf(Device.DEFAULT == "CL", "TODO: fails on CI CL")
def test_reduce_different_nesting_depth(self):
# two REDUCEs sharing the same RANGE at different nesting depths must NOT merge
x = Tensor.arange(768).reshape(3, 256).float()
np.testing.assert_allclose((x.sum(axis=1) + x.sum(axis=1).sum()).numpy(), x.numpy().sum(axis=1) + x.numpy().sum(axis=1).sum())
def test_multimatmul_fusion(self):
Tensor.manual_seed(0)
a,b = Tensor.randn(4, 64).realize(), Tensor.rand(64,8).realize()

View File

@@ -328,11 +328,23 @@ def reduce_to_acc(ctx:ReduceContext, red:UOp):
return acc.after(end).index(UOp.const(dtypes.int, 0))
def merge_reduce_ends(ctx:ReduceContext, sink:UOp):
# merge ENDs that share the same range (only those created by reduce_to_acc)
# merge ENDs that share the same range and nesting context (only those created by reduce_to_acc)
# ENDs at different nesting depths get cloned RANGEs so each RANGE maps to one END
range_to_ends: dict[tuple[UOp, ...], list[UOp]] = {}
for u in sink.backward_slice:
if u.op is Ops.END and u.tag == "mergeable": range_to_ends.setdefault(u.src[1:], []).append(u)
subs = {e: UOp.group(*(e.src[0] for e in ends)).end(*r) for r, ends in range_to_ends.items() if len(ends) > 1 for e in ends}
subs: dict[UOp, UOp] = {}
next_axis = max((u.arg[0] for u in sink.backward_slice if u.op is Ops.RANGE), default=-1) + 1
for r, ends in range_to_ends.items():
if len(ends) <= 1: continue
by_ctx: dict[frozenset[UOp], list[UOp]] = {}
for e in ends: by_ctx.setdefault(frozenset(e.ranges), []).append(e)
for i, group in enumerate(by_ctx.values()):
tr = r if i == 0 else tuple(rr.replace(arg=(next_axis + j, *rr.arg[1:])) for j, rr in enumerate(r))
if i > 0: next_axis += len(r)
mapped = [e.substitute(dict(zip(r, tr))) if i > 0 else e for e in group]
merged = mapped[0] if len(mapped) == 1 else UOp.group(*(e.src[0] for e in mapped)).end(*tr)
for e in group: subs[e] = merged
return sink.substitute(subs) if subs else None
pm_reduce = PatternMatcher([