mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-10 07:28:15 -05:00
fix limit_bufs with kernelize (#12415)
This commit is contained in:
@@ -1970,6 +1970,18 @@ class TestSchedule(unittest.TestCase):
|
||||
for X in range(1,N): root = root + bufs[X][vi] + bufs[X][vj]
|
||||
self.assertEqual(root.item(), N * 2)
|
||||
|
||||
def test_limit_bufs_kernelize(self):
|
||||
N = 31
|
||||
with Context(TRACK_MATCH_STATS=0, DEBUG=0):
|
||||
bufs = [Tensor(i).contiguous().realize() for i in range(N)]
|
||||
x = bufs[0]
|
||||
for y in bufs[1:]: x = x+y
|
||||
x.kernelize()
|
||||
kcount = len([s for s in x.uop.toposort() if s.op is Ops.KERNEL])
|
||||
z = x+Tensor.empty(1) # z only loads 2 buffers
|
||||
sched = z.schedule()
|
||||
self.assertEqual(len(sched), kcount+1)
|
||||
|
||||
def swizzle_cnt(u:UOp) -> int:
|
||||
return len([x for x in u.toposort() if x.op is Ops.VIEW and len(x.src) != 0 and x.src[0].op not in {Ops.BUFFER, Ops.DEFINE_GLOBAL, Ops.ASSIGN}])
|
||||
|
||||
|
||||
@@ -496,7 +496,7 @@ def limit_bufs(ctx:RangeifyContext, root:UOp):
|
||||
bufs: set[UOp] = set()
|
||||
def gate_input(u:UOp):
|
||||
# TODO: add cache to fix n^2
|
||||
if is_load:=(u.op in {Ops.BUFFERIZE, Ops.BUFFER, Ops.MSELECT, Ops.MSTACK, Ops.DEFINE_VAR}): bufs.add(u)
|
||||
if is_load:=(u.op in {Ops.BUFFERIZE, Ops.ASSIGN, Ops.BUFFER, Ops.MSELECT, Ops.MSTACK, Ops.DEFINE_VAR}): bufs.add(u)
|
||||
return not is_load
|
||||
root.toposort(gate=gate_input)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user