Cleanup const buffers (#12829)

* split pm_cleanups

* update test_schedule

* shrink when we remove bufferize

* dont do shrink if shape is empty

* update tests

* remove *1 from metadata

* deal with the noop bufferize

* only noop on cvar

* cleanup

* fix if

* rename
This commit is contained in:
Sieds Lykles
2025-10-21 14:53:49 +02:00
committed by GitHub
parent 1ad6598963
commit 7f798a9630
6 changed files with 36 additions and 28 deletions

View File

@@ -67,12 +67,9 @@ class TestBinaryOpsConstFolding(unittest.TestCase):
def test_tensor_one_mul(self): def test_tensor_one_mul(self):
_check_ast_count(0, Tensor.ones(4) * Tensor([1.0, 2, 3, 4])) _check_ast_count(0, Tensor.ones(4) * Tensor([1.0, 2, 3, 4]))
# TODO: these will be fixed with better folding
@unittest.expectedFailure
def test_bool_tensor_mul_bool(self): def test_bool_tensor_mul_bool(self):
_check_ast_count(0, Tensor([True, False]) * True) _check_ast_count(0, Tensor([True, False]) * True)
_check_ast_count(0, Tensor([True, False]) * False) _check_ast_count(0, Tensor([True, False]) * False)
@unittest.expectedFailure
def test_bool_mul_bool_tensor(self): def test_bool_mul_bool_tensor(self):
_check_ast_count(0, True * Tensor([True, False])) _check_ast_count(0, True * Tensor([True, False]))
_check_ast_count(0, False * Tensor([True, False])) _check_ast_count(0, False * Tensor([True, False]))

View File

@@ -51,7 +51,7 @@ class TestFusionOp(unittest.TestCase):
a = Tensor(val) a = Tensor(val)
for _ in range(24): a = Tensor.stack(a, a)[0] for _ in range(24): a = Tensor.stack(a, a)[0]
sched = a.schedule() sched = a.schedule()
self.assertEqual(len(sched), 1) self.assertEqual(len(sched), 0)
self.assertLess(time.perf_counter()-st, 2.0) self.assertLess(time.perf_counter()-st, 2.0)
def test_recursive_reshape(self): def test_recursive_reshape(self):

View File

@@ -52,7 +52,6 @@ class TestImageDType(unittest.TestCase):
assert isinstance(it.uop.base.realized.dtype, ImageDType) assert isinstance(it.uop.base.realized.dtype, ImageDType)
np.testing.assert_equal(tst, it.numpy()) np.testing.assert_equal(tst, it.numpy())
@unittest.expectedFailure # this isn't supported anymore, CAST to ImageDType stays ImageDType
def test_image_cast_and_back_collapses(self): def test_image_cast_and_back_collapses(self):
data = Tensor.randn(9*27*4).realize() data = Tensor.randn(9*27*4).realize()
tst = data.numpy() tst = data.numpy()

View File

@@ -446,7 +446,7 @@ class TestSchedule(unittest.TestCase):
@unittest.skipUnless(is_dtype_supported(dtypes.ulong), "Needs ulong") @unittest.skipUnless(is_dtype_supported(dtypes.ulong), "Needs ulong")
def test_fold_conv_batchnorm_optim(self): def test_fold_conv_batchnorm_optim(self):
# this is too high # this is too high
for optim, cnt in [(nn.optim.Adam, 30), (nn.optim.SGD, 13)]: for optim, cnt in [(nn.optim.Adam, 21), (nn.optim.SGD, 8)]:
with self.subTest(optim=optim.__name__): with self.subTest(optim=optim.__name__):
with Tensor.train(): with Tensor.train():
img = Tensor.ones(1,3,4,4) img = Tensor.ones(1,3,4,4)
@@ -1863,7 +1863,7 @@ class TestSchedule(unittest.TestCase):
yt = Tensor.randn(BS, 10).realize() yt = Tensor.randn(BS, 10).realize()
with Context(SPLIT_REDUCEOP=0): with Context(SPLIT_REDUCEOP=0):
loss = yt.sparse_categorical_crossentropy(Y_train[samples]) loss = yt.sparse_categorical_crossentropy(Y_train[samples])
run_schedule(check_schedule(loss, 5)) run_schedule(check_schedule(loss, 4))
loss_fused = loss.numpy() loss_fused = loss.numpy()
loss_ref = torch.nn.CrossEntropyLoss()(torch.tensor(yt.numpy()), torch.tensor(Y_train.numpy())[torch.tensor(samples.numpy())]) loss_ref = torch.nn.CrossEntropyLoss()(torch.tensor(yt.numpy()), torch.tensor(Y_train.numpy())[torch.tensor(samples.numpy())])
np.testing.assert_allclose(loss_fused, loss_ref.numpy(), atol=1e-6, rtol=1e-6) np.testing.assert_allclose(loss_fused, loss_ref.numpy(), atol=1e-6, rtol=1e-6)
@@ -2076,6 +2076,11 @@ class TestCopyFolding(unittest.TestCase):
check_schedule(b, 0, filter_sink=False) check_schedule(b, 0, filter_sink=False)
assert b.item() == 1 assert b.item() == 1
def test_one_hot_with_copy(self):
y = Tensor([1, 2, 3]).to("CPU")
x = y.one_hot(10)
check_schedule(x, 3, filter_sink=False)
def test_const_copy_multi(self): def test_const_copy_multi(self):
x = Tensor.ones(1, device="CPU").to_(["CPU", "CPU:1"]) x = Tensor.ones(1, device="CPU").to_(["CPU", "CPU:1"])
check_schedule(x, 0, filter_sink=False) check_schedule(x, 0, filter_sink=False)
@@ -2085,7 +2090,7 @@ class TestCopyFolding(unittest.TestCase):
a = Tensor.arange(3).realize() a = Tensor.arange(3).realize()
zeros = Tensor.zeros(3).realize() zeros = Tensor.zeros(3).realize()
b = (a*zeros).to("CPU") b = (a*zeros).to("CPU")
run_schedule(check_schedule(b, 2, filter_sink=False)) # TODO: 0? run_schedule(check_schedule(b, 0, filter_sink=False))
self.assertListEqual(b.tolist(), [0, 0, 0]) self.assertListEqual(b.tolist(), [0, 0, 0])
self.assertEqual(b.device, "CPU") self.assertEqual(b.device, "CPU")

View File

@@ -839,12 +839,11 @@ class TestTensorMetadata(unittest.TestCase):
self.assertEqual(y.grad.uop.metadata[0].name, "sigmoid") self.assertEqual(y.grad.uop.metadata[0].name, "sigmoid")
self.assertTrue(y.grad.uop.metadata[0].backward) self.assertTrue(y.grad.uop.metadata[0].backward)
si = Tensor.schedule(out, x.grad, y.grad)[-1] si = Tensor.schedule(out, x.grad, y.grad)[-1]
self.assertEqual(len(si.metadata), 4, f"failed with {si.metadata}") self.assertEqual(len(si.metadata), 3, f"failed with {si.metadata}")
self.assertSetEqual(set(m.name for m in si.metadata), {"__mul__", "sigmoid", "relu"}) self.assertSetEqual(set(m.name for m in si.metadata), {"sigmoid", "relu"})
bw = [m for m in si.metadata if m.backward] bw = [m for m in si.metadata if m.backward]
self.assertEqual(len(bw), 2) self.assertEqual(len(bw), 1)
self.assertEqual(bw[0].name, "__mul__") self.assertEqual(bw[0].name, "sigmoid")
self.assertEqual(bw[1].name, "sigmoid")
class TestIdxUpcast(unittest.TestCase): class TestIdxUpcast(unittest.TestCase):
def _find_op(self, ast: UOp, op: Ops): def _find_op(self, ast: UOp, op: Ops):

View File

@@ -192,33 +192,42 @@ def remove_bufferize(src:UOp, buf:UOp, idx:UOp):
# NOTE: if buf src is a const, we don't replace it # NOTE: if buf src is a const, we don't replace it
return src.substitute({k:v for k,v in zip(buf.src[1:], idx.src[1:]) if k.op is not Ops.CONST}, extra_pm=pm_gate_substitute) return src.substitute({k:v for k,v in zip(buf.src[1:], idx.src[1:]) if k.op is not Ops.CONST}, extra_pm=pm_gate_substitute)
def pre_bufferize(b:UOp, x:UOp, copy:UOp): def remove_noop_bufferize(idx,b2):
nb = b.replace(src=(b.src[0].contiguous(),)+b.src[1:]) if idx.src[1:] != b2.src[1:] or idx.src[0].op is Ops.BUFFER_VIEW: return None
return copy.replace(src=(x.replace(src=(nb,)+x.src[1:]), copy.src[1])) new_tag = (idx.src[0].tag or ()) + (b2.tag or ()) or None
return idx.src[0].rtag(new_tag).shrink(tuple((0, s) for s in b2.shape)) if b2.shape else idx.src[0].rtag(new_tag)
pm_cleanups = pm_mops+PatternMatcher([ pm_const_buffer_folding = pm_mops+PatternMatcher([
(UPat(Ops.BUFFERIZE, name="b"), cleanup_dead_axes), (UPat(Ops.BUFFERIZE, name="b"), cleanup_dead_axes),
(UPat(GroupOp.All-{Ops.BUFFERIZE, Ops.BUFFER}, name="x"), lambda x: x.replace(dtype=x.dtype.base) if isinstance(x.dtype, ImageDType) else None), (UPat(GroupOp.All-{Ops.BUFFERIZE, Ops.BUFFER}, name="x"), lambda x: x.replace(dtype=x.dtype.base) if isinstance(x.dtype, ImageDType) else None),
(UPat((Ops.BUFFERIZE), name="x"), lambda x: x.replace(dtype=x.dtype.base) if isinstance(x.dtype, ImageDType) (UPat((Ops.BUFFERIZE), name="x"), lambda x: x.replace(dtype=x.dtype.base) if isinstance(x.dtype, ImageDType)
and (resolve(prod(x.dtype.shape)!=prod(x.shape)) or x.shape[-1]%4!=0) else None), and (resolve(prod(x.dtype.shape)!=prod(x.shape)) or x.shape[-1]%4!=0) else None),
# remove noop buffers. if we look at the next index we can remove even more of these # remove noop buffers. if we look at the next index we can remove even more of these
# NOTE: this is mostly the same case as below, but if there's no INDEX this gets more (UPat(Ops.INDEX, name="idx").f(Ops.BUFFERIZE, allow_any_len=True, name="b2"), remove_noop_bufferize),
(UPat(Ops.INDEX, name="idx").f(Ops.BUFFERIZE, allow_any_len=True, name="b2"),
lambda idx,b2: idx.src[0].replace(tag=nt if len(nt:=(idx.src[0].tag or ()) + (b2.tag or ())) else None) if idx.src[1:] == b2.src[1:] \
and idx.src[0].op is not Ops.BUFFER_VIEW else None),
# remove reindexing with cost function
(UPat.var("src").f(Ops.BUFFERIZE, allow_any_len=True, name="buf").f(Ops.INDEX, allow_any_len=True, name="idx"), remove_bufferize),
# no buffers for const # no buffers for const
(UPat(Ops.CONST, name='c').f(Ops.BUFFERIZE, allow_any_len=True, name="b"), lambda c,b: b.const_like(c.arg).rtag(b.tag)), (UPat(Ops.CONST, name='c').f(Ops.BUFFERIZE, allow_any_len=True, name="b"), lambda c,b: b.const_like(c.arg).rtag(b.tag)),
# indexing a const is a const
(UPat(Ops.INDEX, src=(UPat(Ops.CONST, name="c"),),), lambda c: c),
# copy on CONST is CONST # copy on CONST is CONST
(UPat(Ops.COPY, src=(UPat.cvar("x"), UPat()), name="copy"), lambda copy,x: copy.const_like(x.arg)), (UPat(Ops.COPY, src=(UPat.cvar("x"), UPat()), name="copy"), lambda copy,x: copy.const_like(x.arg)),
(UPat(Ops.COPY, src=(UPat(GroupOp.All-{Ops.CONTIGUOUS, Ops.COPY}).f(Ops.BUFFERIZE, allow_any_len=True, name="b") # hack if a noop turned to a const
.f(Ops.INDEX, allow_any_len=True, name="x"), UPat()), name="copy"), pre_bufferize), (UPat.cvar("c").f(Ops.NOOP).f(Ops.BUFFERIZE, allow_any_len=True, name="buf"), lambda c,buf: buf.replace(src=(c,)+buf.src[1:])),
# mstack on CONST is CONST # mstack on CONST is CONST
(UPat(Ops.MSTACK, src=(UPat.var("s"),), allow_any_len=True).f(Ops.INDEX, allow_any_len=True), (UPat(Ops.MSTACK, src=(UPat.var("s"),), allow_any_len=True).f(Ops.INDEX, allow_any_len=True),
lambda s: UOp.const(c.dtype, c.arg) if (c:=s.base).op is Ops.CONST else None), lambda s: UOp.const(c.dtype, c.arg) if (c:=s.base).op is Ops.CONST else None),
]) ])
def pre_bufferize(b:UOp, x:UOp, copy:UOp):
nb = b.replace(src=(b.src[0].contiguous(),)+b.src[1:])
return copy.replace(src=(x.replace(src=(nb,)+x.src[1:]), copy.src[1]))
pm_remove_bufferize = PatternMatcher([
# hack so remove_bufferize doesnt remove the buffer before a copy
(UPat(Ops.COPY, src=(UPat(GroupOp.All-{Ops.CONTIGUOUS, Ops.COPY}).f(Ops.BUFFERIZE, allow_any_len=True, name="b")
.f(Ops.INDEX, allow_any_len=True, name="x"), UPat()), name="copy"), pre_bufferize),
# remove reindexing with cost function
(UPat.var("src").f(Ops.BUFFERIZE, allow_any_len=True, name="buf").f(Ops.INDEX, allow_any_len=True, name="idx"), remove_bufferize),
])
def late_buffer_view(t:UOp, b:UOp): def late_buffer_view(t:UOp, b:UOp):
if isinstance(b.device, str) and (b.device.startswith("DISK") or b.device.startswith("TINYFS")): if isinstance(b.device, str) and (b.device.startswith("DISK") or b.device.startswith("TINYFS")):
rngs = b.src[1:] rngs = b.src[1:]
@@ -488,9 +497,8 @@ def get_rangeify_map(sink:UOp) -> dict[UOp, UOp]:
# convert movement ops to ranges # convert movement ops to ranges
tsink, rctx = run_rangeify(tsink, DEBUG_RANGEIFY) tsink, rctx = run_rangeify(tsink, DEBUG_RANGEIFY)
# NOTE: sym (vs symbolic_simple) breaks things here because ranges with len 1 aren't handled right tsink = graph_rewrite(tsink, symbolic_flat+pm_reduce_unparented+pm_const_buffer_folding, name="symbolic") # this supports const folding
tsink = graph_rewrite(tsink, symbolic_flat+pm_reduce_unparented, name="symbolic") # this supports const folding tsink = graph_rewrite(tsink, pm_remove_bufferize, bottom_up=True, name="remove bufferize with cost function")
tsink = graph_rewrite(tsink, pm_cleanups, bottom_up=True, name="remove costly buffers")
tsink = graph_rewrite(tsink, pm_limit_bufs, ctx=rctx, name="limit buffers") tsink = graph_rewrite(tsink, pm_limit_bufs, ctx=rctx, name="limit buffers")
# rebuild the sink with all the BUFFERIZEs with tags, this is what's ending up in the tensor graph # rebuild the sink with all the BUFFERIZEs with tags, this is what's ending up in the tensor graph