fix in/outs calculation in ProgramSpec (#12937)

With the new linearizer the toposort is a problem, this matches the spec
now
This commit is contained in:
Sieds Lykles
2025-10-27 12:31:41 +01:00
committed by GitHub
parent e93c9bf6a7
commit 072f7c35c5

View File

@@ -81,8 +81,12 @@ class ProgramSpec:
for u in self.uops:
if u.op is Ops.DEFINE_VAR: self.vars.append(u)
if u.op is Ops.DEFINE_GLOBAL: self.globals.append(u.arg)
if u.op is Ops.STORE: self.outs.extend([x.arg for x in u.src[0].toposort() if x.op is Ops.DEFINE_GLOBAL])
if u.op is Ops.LOAD: self.ins.extend([x.arg for x in u.src[0].toposort() if x.op is Ops.DEFINE_GLOBAL])
if u.op is Ops.STORE and (u.src[0].op is Ops.INDEX or (u.src[0].op is Ops.CAST and u.src[0].src[0].op is Ops.INDEX)):
idx = u.src[0] if u.src[0].op is Ops.INDEX else u.src[0].src[0]
if (buf:=idx.src[0]).op is Ops.DEFINE_GLOBAL: self.outs.append(buf.arg)
if u.op is Ops.LOAD and (u.src[0].op is Ops.INDEX or (u.src[0].op is Ops.CAST and u.src[0].src[0].op is Ops.INDEX)):
idx = u.src[0] if u.src[0].op is Ops.INDEX else u.src[0].src[0]
if (buf:=idx.src[0]).op is Ops.DEFINE_GLOBAL: self.ins.append(buf.arg)
if u.op is Ops.SPECIAL:
# NOTE: you have to set local_size and global_size to the base [1,1,1] outside this
if u.arg[0] == 'i': self.local_size = None