mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
fix merge_reduce_ends (#15659)
* fix merge_reduce_ends same range with different nesting should not merge, like cumsum twice should not merge * skip that
This commit is contained in:
@@ -24,6 +24,10 @@ class TestArange(unittest.TestCase):
|
||||
self.assertEqual(self._get_flops(Tensor.arange(256), np.arange(256)), 0)
|
||||
self.assertEqual(self._get_flops(Tensor.arange(2560), np.arange(2560)), 0)
|
||||
|
||||
@unittest.skipIf(Device.DEFAULT == "CL", "TODO: fails on CI CL")
|
||||
def test_arange_cumsum(self):
|
||||
np.testing.assert_equal(Tensor.arange(513).cumsum(0).numpy(), np.arange(513).cumsum())
|
||||
|
||||
def test_arange_cat(self):
|
||||
t = Tensor.arange(2, dtype=dtypes.int)+Tensor([3])
|
||||
self.assertEqual(t.cat(t).tolist(), [3, 4, 3, 4])
|
||||
|
||||
@@ -510,7 +510,7 @@ class TestSchedule(unittest.TestCase):
|
||||
np.testing.assert_allclose(out[1].numpy(), np.sqrt(np.square(y.numpy() - np_mu).sum(-1)/y.shape[-1]), atol=1e-4, rtol=1e-4)
|
||||
|
||||
def test_cumsum_parallel_reduce_fused(self):
|
||||
# two-stage cumsum + ops triggers parallel REDUCEs in one kernel that must share an END
|
||||
# two-stage cumsum + ops triggers parallel REDUCEs in one kernel that must share an END (same nesting context = should merge)
|
||||
step, num_steps = 513, 10
|
||||
t = Tensor.arange(step).float().realize()
|
||||
phase = t.cumsum()
|
||||
@@ -521,6 +521,12 @@ class TestSchedule(unittest.TestCase):
|
||||
expected = (expected * np.array([1,0,0,1,0,0,0,0,1,0]).reshape(num_steps, 1)).flatten()
|
||||
np.testing.assert_allclose(out.numpy(), expected, atol=1e-4, rtol=1e-4)
|
||||
|
||||
@unittest.skipIf(Device.DEFAULT == "CL", "TODO: fails on CI CL")
|
||||
def test_reduce_different_nesting_depth(self):
|
||||
# two REDUCEs sharing the same RANGE at different nesting depths must NOT merge
|
||||
x = Tensor.arange(768).reshape(3, 256).float()
|
||||
np.testing.assert_allclose((x.sum(axis=1) + x.sum(axis=1).sum()).numpy(), x.numpy().sum(axis=1) + x.numpy().sum(axis=1).sum())
|
||||
|
||||
def test_multimatmul_fusion(self):
|
||||
Tensor.manual_seed(0)
|
||||
a,b = Tensor.randn(4, 64).realize(), Tensor.rand(64,8).realize()
|
||||
|
||||
@@ -328,11 +328,23 @@ def reduce_to_acc(ctx:ReduceContext, red:UOp):
|
||||
return acc.after(end).index(UOp.const(dtypes.int, 0))
|
||||
|
||||
def merge_reduce_ends(ctx:ReduceContext, sink:UOp):
|
||||
# merge ENDs that share the same range (only those created by reduce_to_acc)
|
||||
# merge ENDs that share the same range and nesting context (only those created by reduce_to_acc)
|
||||
# ENDs at different nesting depths get cloned RANGEs so each RANGE maps to one END
|
||||
range_to_ends: dict[tuple[UOp, ...], list[UOp]] = {}
|
||||
for u in sink.backward_slice:
|
||||
if u.op is Ops.END and u.tag == "mergeable": range_to_ends.setdefault(u.src[1:], []).append(u)
|
||||
subs = {e: UOp.group(*(e.src[0] for e in ends)).end(*r) for r, ends in range_to_ends.items() if len(ends) > 1 for e in ends}
|
||||
subs: dict[UOp, UOp] = {}
|
||||
next_axis = max((u.arg[0] for u in sink.backward_slice if u.op is Ops.RANGE), default=-1) + 1
|
||||
for r, ends in range_to_ends.items():
|
||||
if len(ends) <= 1: continue
|
||||
by_ctx: dict[frozenset[UOp], list[UOp]] = {}
|
||||
for e in ends: by_ctx.setdefault(frozenset(e.ranges), []).append(e)
|
||||
for i, group in enumerate(by_ctx.values()):
|
||||
tr = r if i == 0 else tuple(rr.replace(arg=(next_axis + j, *rr.arg[1:])) for j, rr in enumerate(r))
|
||||
if i > 0: next_axis += len(r)
|
||||
mapped = [e.substitute(dict(zip(r, tr))) if i > 0 else e for e in group]
|
||||
merged = mapped[0] if len(mapped) == 1 else UOp.group(*(e.src[0] for e in mapped)).end(*tr)
|
||||
for e in group: subs[e] = merged
|
||||
return sink.substitute(subs) if subs else None
|
||||
|
||||
pm_reduce = PatternMatcher([
|
||||
|
||||
Reference in New Issue
Block a user