From 7f798a96305dee2cfa813314d0d419201200dc36 Mon Sep 17 00:00:00 2001 From: Sieds Lykles <93992551+S-Lykles@users.noreply.github.com> Date: Tue, 21 Oct 2025 14:53:49 +0200 Subject: [PATCH] Cleanup const buffers (#12829) * split pm_cleanups * update test_schedule * shrink when we remove bufferize * dont do shrink if shape is empty * update tests * remove *1 from metadata * deal with the noop bufferize * only noop on cvar * cleanup * fix if * rename --- test/test_const_folding.py | 3 --- test/test_fusion_op.py | 2 +- test/test_image_dtype.py | 1 - test/test_schedule.py | 11 +++++++--- test/test_tensor.py | 9 ++++----- tinygrad/schedule/rangeify.py | 38 +++++++++++++++++++++-------------- 6 files changed, 36 insertions(+), 28 deletions(-) diff --git a/test/test_const_folding.py b/test/test_const_folding.py index c7bdda8cf5..f0dd3054cf 100644 --- a/test/test_const_folding.py +++ b/test/test_const_folding.py @@ -67,12 +67,9 @@ class TestBinaryOpsConstFolding(unittest.TestCase): def test_tensor_one_mul(self): _check_ast_count(0, Tensor.ones(4) * Tensor([1.0, 2, 3, 4])) - # TODO: these will be fixed with better folding - @unittest.expectedFailure def test_bool_tensor_mul_bool(self): _check_ast_count(0, Tensor([True, False]) * True) _check_ast_count(0, Tensor([True, False]) * False) - @unittest.expectedFailure def test_bool_mul_bool_tensor(self): _check_ast_count(0, True * Tensor([True, False])) _check_ast_count(0, False * Tensor([True, False])) diff --git a/test/test_fusion_op.py b/test/test_fusion_op.py index 255479cc48..6dd9040dc1 100644 --- a/test/test_fusion_op.py +++ b/test/test_fusion_op.py @@ -51,7 +51,7 @@ class TestFusionOp(unittest.TestCase): a = Tensor(val) for _ in range(24): a = Tensor.stack(a, a)[0] sched = a.schedule() - self.assertEqual(len(sched), 1) + self.assertEqual(len(sched), 0) self.assertLess(time.perf_counter()-st, 2.0) def test_recursive_reshape(self): diff --git a/test/test_image_dtype.py b/test/test_image_dtype.py index a45fd7e6a0..da1f3aeeea 100644 --- a/test/test_image_dtype.py +++ b/test/test_image_dtype.py @@ -52,7 +52,6 @@ class TestImageDType(unittest.TestCase): assert isinstance(it.uop.base.realized.dtype, ImageDType) np.testing.assert_equal(tst, it.numpy()) - @unittest.expectedFailure # this isn't supported anymore, CAST to ImageDType stays ImageDType def test_image_cast_and_back_collapses(self): data = Tensor.randn(9*27*4).realize() tst = data.numpy() diff --git a/test/test_schedule.py b/test/test_schedule.py index f6f12f182c..c346690c03 100644 --- a/test/test_schedule.py +++ b/test/test_schedule.py @@ -446,7 +446,7 @@ class TestSchedule(unittest.TestCase): @unittest.skipUnless(is_dtype_supported(dtypes.ulong), "Needs ulong") def test_fold_conv_batchnorm_optim(self): # this is too high - for optim, cnt in [(nn.optim.Adam, 30), (nn.optim.SGD, 13)]: + for optim, cnt in [(nn.optim.Adam, 21), (nn.optim.SGD, 8)]: with self.subTest(optim=optim.__name__): with Tensor.train(): img = Tensor.ones(1,3,4,4) @@ -1863,7 +1863,7 @@ class TestSchedule(unittest.TestCase): yt = Tensor.randn(BS, 10).realize() with Context(SPLIT_REDUCEOP=0): loss = yt.sparse_categorical_crossentropy(Y_train[samples]) - run_schedule(check_schedule(loss, 5)) + run_schedule(check_schedule(loss, 4)) loss_fused = loss.numpy() loss_ref = torch.nn.CrossEntropyLoss()(torch.tensor(yt.numpy()), torch.tensor(Y_train.numpy())[torch.tensor(samples.numpy())]) np.testing.assert_allclose(loss_fused, loss_ref.numpy(), atol=1e-6, rtol=1e-6) @@ -2076,6 +2076,11 @@ class TestCopyFolding(unittest.TestCase): check_schedule(b, 0, filter_sink=False) assert b.item() == 1 + def test_one_hot_with_copy(self): + y = Tensor([1, 2, 3]).to("CPU") + x = y.one_hot(10) + check_schedule(x, 3, filter_sink=False) + def test_const_copy_multi(self): x = Tensor.ones(1, device="CPU").to_(["CPU", "CPU:1"]) check_schedule(x, 0, filter_sink=False) @@ -2085,7 +2090,7 @@ class TestCopyFolding(unittest.TestCase): a = Tensor.arange(3).realize() zeros = Tensor.zeros(3).realize() b = (a*zeros).to("CPU") - run_schedule(check_schedule(b, 2, filter_sink=False)) # TODO: 0? + run_schedule(check_schedule(b, 0, filter_sink=False)) self.assertListEqual(b.tolist(), [0, 0, 0]) self.assertEqual(b.device, "CPU") diff --git a/test/test_tensor.py b/test/test_tensor.py index 207803a6e7..468633e560 100644 --- a/test/test_tensor.py +++ b/test/test_tensor.py @@ -839,12 +839,11 @@ class TestTensorMetadata(unittest.TestCase): self.assertEqual(y.grad.uop.metadata[0].name, "sigmoid") self.assertTrue(y.grad.uop.metadata[0].backward) si = Tensor.schedule(out, x.grad, y.grad)[-1] - self.assertEqual(len(si.metadata), 4, f"failed with {si.metadata}") - self.assertSetEqual(set(m.name for m in si.metadata), {"__mul__", "sigmoid", "relu"}) + self.assertEqual(len(si.metadata), 3, f"failed with {si.metadata}") + self.assertSetEqual(set(m.name for m in si.metadata), {"sigmoid", "relu"}) bw = [m for m in si.metadata if m.backward] - self.assertEqual(len(bw), 2) - self.assertEqual(bw[0].name, "__mul__") - self.assertEqual(bw[1].name, "sigmoid") + self.assertEqual(len(bw), 1) + self.assertEqual(bw[0].name, "sigmoid") class TestIdxUpcast(unittest.TestCase): def _find_op(self, ast: UOp, op: Ops): diff --git a/tinygrad/schedule/rangeify.py b/tinygrad/schedule/rangeify.py index 8f1add590d..ac768ca3a1 100644 --- a/tinygrad/schedule/rangeify.py +++ b/tinygrad/schedule/rangeify.py @@ -192,33 +192,42 @@ def remove_bufferize(src:UOp, buf:UOp, idx:UOp): # NOTE: if buf src is a const, we don't replace it return src.substitute({k:v for k,v in zip(buf.src[1:], idx.src[1:]) if k.op is not Ops.CONST}, extra_pm=pm_gate_substitute) -def pre_bufferize(b:UOp, x:UOp, copy:UOp): - nb = b.replace(src=(b.src[0].contiguous(),)+b.src[1:]) - return copy.replace(src=(x.replace(src=(nb,)+x.src[1:]), copy.src[1])) +def remove_noop_bufferize(idx,b2): + if idx.src[1:] != b2.src[1:] or idx.src[0].op is Ops.BUFFER_VIEW: return None + new_tag = (idx.src[0].tag or ()) + (b2.tag or ()) or None + return idx.src[0].rtag(new_tag).shrink(tuple((0, s) for s in b2.shape)) if b2.shape else idx.src[0].rtag(new_tag) -pm_cleanups = pm_mops+PatternMatcher([ +pm_const_buffer_folding = pm_mops+PatternMatcher([ (UPat(Ops.BUFFERIZE, name="b"), cleanup_dead_axes), (UPat(GroupOp.All-{Ops.BUFFERIZE, Ops.BUFFER}, name="x"), lambda x: x.replace(dtype=x.dtype.base) if isinstance(x.dtype, ImageDType) else None), (UPat((Ops.BUFFERIZE), name="x"), lambda x: x.replace(dtype=x.dtype.base) if isinstance(x.dtype, ImageDType) and (resolve(prod(x.dtype.shape)!=prod(x.shape)) or x.shape[-1]%4!=0) else None), # remove noop buffers. if we look at the next index we can remove even more of these - # NOTE: this is mostly the same case as below, but if there's no INDEX this gets more - (UPat(Ops.INDEX, name="idx").f(Ops.BUFFERIZE, allow_any_len=True, name="b2"), - lambda idx,b2: idx.src[0].replace(tag=nt if len(nt:=(idx.src[0].tag or ()) + (b2.tag or ())) else None) if idx.src[1:] == b2.src[1:] \ - and idx.src[0].op is not Ops.BUFFER_VIEW else None), - # remove reindexing with cost function - (UPat.var("src").f(Ops.BUFFERIZE, allow_any_len=True, name="buf").f(Ops.INDEX, allow_any_len=True, name="idx"), remove_bufferize), + (UPat(Ops.INDEX, name="idx").f(Ops.BUFFERIZE, allow_any_len=True, name="b2"), remove_noop_bufferize), # no buffers for const (UPat(Ops.CONST, name='c').f(Ops.BUFFERIZE, allow_any_len=True, name="b"), lambda c,b: b.const_like(c.arg).rtag(b.tag)), + # indexing a const is a const + (UPat(Ops.INDEX, src=(UPat(Ops.CONST, name="c"),),), lambda c: c), # copy on CONST is CONST (UPat(Ops.COPY, src=(UPat.cvar("x"), UPat()), name="copy"), lambda copy,x: copy.const_like(x.arg)), - (UPat(Ops.COPY, src=(UPat(GroupOp.All-{Ops.CONTIGUOUS, Ops.COPY}).f(Ops.BUFFERIZE, allow_any_len=True, name="b") - .f(Ops.INDEX, allow_any_len=True, name="x"), UPat()), name="copy"), pre_bufferize), + # hack if a noop turned to a const + (UPat.cvar("c").f(Ops.NOOP).f(Ops.BUFFERIZE, allow_any_len=True, name="buf"), lambda c,buf: buf.replace(src=(c,)+buf.src[1:])), # mstack on CONST is CONST (UPat(Ops.MSTACK, src=(UPat.var("s"),), allow_any_len=True).f(Ops.INDEX, allow_any_len=True), lambda s: UOp.const(c.dtype, c.arg) if (c:=s.base).op is Ops.CONST else None), ]) +def pre_bufferize(b:UOp, x:UOp, copy:UOp): + nb = b.replace(src=(b.src[0].contiguous(),)+b.src[1:]) + return copy.replace(src=(x.replace(src=(nb,)+x.src[1:]), copy.src[1])) +pm_remove_bufferize = PatternMatcher([ + # hack so remove_bufferize doesnt remove the buffer before a copy + (UPat(Ops.COPY, src=(UPat(GroupOp.All-{Ops.CONTIGUOUS, Ops.COPY}).f(Ops.BUFFERIZE, allow_any_len=True, name="b") + .f(Ops.INDEX, allow_any_len=True, name="x"), UPat()), name="copy"), pre_bufferize), + # remove reindexing with cost function + (UPat.var("src").f(Ops.BUFFERIZE, allow_any_len=True, name="buf").f(Ops.INDEX, allow_any_len=True, name="idx"), remove_bufferize), +]) + def late_buffer_view(t:UOp, b:UOp): if isinstance(b.device, str) and (b.device.startswith("DISK") or b.device.startswith("TINYFS")): rngs = b.src[1:] @@ -488,9 +497,8 @@ def get_rangeify_map(sink:UOp) -> dict[UOp, UOp]: # convert movement ops to ranges tsink, rctx = run_rangeify(tsink, DEBUG_RANGEIFY) - # NOTE: sym (vs symbolic_simple) breaks things here because ranges with len 1 aren't handled right - tsink = graph_rewrite(tsink, symbolic_flat+pm_reduce_unparented, name="symbolic") # this supports const folding - tsink = graph_rewrite(tsink, pm_cleanups, bottom_up=True, name="remove costly buffers") + tsink = graph_rewrite(tsink, symbolic_flat+pm_reduce_unparented+pm_const_buffer_folding, name="symbolic") # this supports const folding + tsink = graph_rewrite(tsink, pm_remove_bufferize, bottom_up=True, name="remove bufferize with cost function") tsink = graph_rewrite(tsink, pm_limit_bufs, ctx=rctx, name="limit buffers") # rebuild the sink with all the BUFFERIZEs with tags, this is what's ending up in the tensor graph