fix limit_bufs with kernelize (#12415)

This commit is contained in:
qazal
2025-10-02 07:49:11 +03:00
committed by GitHub
parent d1c868f990
commit 6fc6b51b59
2 changed files with 13 additions and 1 deletions

View File

@@ -1970,6 +1970,18 @@ class TestSchedule(unittest.TestCase):
for X in range(1,N): root = root + bufs[X][vi] + bufs[X][vj]
self.assertEqual(root.item(), N * 2)
def test_limit_bufs_kernelize(self):
N = 31
with Context(TRACK_MATCH_STATS=0, DEBUG=0):
bufs = [Tensor(i).contiguous().realize() for i in range(N)]
x = bufs[0]
for y in bufs[1:]: x = x+y
x.kernelize()
kcount = len([s for s in x.uop.toposort() if s.op is Ops.KERNEL])
z = x+Tensor.empty(1) # z only loads 2 buffers
sched = z.schedule()
self.assertEqual(len(sched), kcount+1)
def swizzle_cnt(u:UOp) -> int:
return len([x for x in u.toposort() if x.op is Ops.VIEW and len(x.src) != 0 and x.src[0].op not in {Ops.BUFFER, Ops.DEFINE_GLOBAL, Ops.ASSIGN}])

View File

@@ -496,7 +496,7 @@ def limit_bufs(ctx:RangeifyContext, root:UOp):
bufs: set[UOp] = set()
def gate_input(u:UOp):
# TODO: add cache to fix n^2
if is_load:=(u.op in {Ops.BUFFERIZE, Ops.BUFFER, Ops.MSELECT, Ops.MSTACK, Ops.DEFINE_VAR}): bufs.add(u)
if is_load:=(u.op in {Ops.BUFFERIZE, Ops.ASSIGN, Ops.BUFFER, Ops.MSELECT, Ops.MSTACK, Ops.DEFINE_VAR}): bufs.add(u)
return not is_load
root.toposort(gate=gate_input)