diff --git a/test/test_schedule.py b/test/test_schedule.py index 4cf16351c0..d38f48f1c5 100644 --- a/test/test_schedule.py +++ b/test/test_schedule.py @@ -1970,6 +1970,18 @@ class TestSchedule(unittest.TestCase): for X in range(1,N): root = root + bufs[X][vi] + bufs[X][vj] self.assertEqual(root.item(), N * 2) + def test_limit_bufs_kernelize(self): + N = 31 + with Context(TRACK_MATCH_STATS=0, DEBUG=0): + bufs = [Tensor(i).contiguous().realize() for i in range(N)] + x = bufs[0] + for y in bufs[1:]: x = x+y + x.kernelize() + kcount = len([s for s in x.uop.toposort() if s.op is Ops.KERNEL]) + z = x+Tensor.empty(1) # z only loads 2 buffers + sched = z.schedule() + self.assertEqual(len(sched), kcount+1) + def swizzle_cnt(u:UOp) -> int: return len([x for x in u.toposort() if x.op is Ops.VIEW and len(x.src) != 0 and x.src[0].op not in {Ops.BUFFER, Ops.DEFINE_GLOBAL, Ops.ASSIGN}]) diff --git a/tinygrad/schedule/rangeify.py b/tinygrad/schedule/rangeify.py index 2d401f4125..81baabc17a 100644 --- a/tinygrad/schedule/rangeify.py +++ b/tinygrad/schedule/rangeify.py @@ -496,7 +496,7 @@ def limit_bufs(ctx:RangeifyContext, root:UOp): bufs: set[UOp] = set() def gate_input(u:UOp): # TODO: add cache to fix n^2 - if is_load:=(u.op in {Ops.BUFFERIZE, Ops.BUFFER, Ops.MSELECT, Ops.MSTACK, Ops.DEFINE_VAR}): bufs.add(u) + if is_load:=(u.op in {Ops.BUFFERIZE, Ops.ASSIGN, Ops.BUFFER, Ops.MSELECT, Ops.MSTACK, Ops.DEFINE_VAR}): bufs.add(u) return not is_load root.toposort(gate=gate_input)