gpu work now?

This commit is contained in:
ttomsa
2025-10-09 19:22:50 +01:00
parent 2aeca4fae1
commit 86782182a3

View File

@@ -54,21 +54,21 @@ def add_endrange(x:UOp):
return x.replace(src=tuple(src))
def add_endif(x:UOp):
# HACK: sometimes the if is in the load, either need a better way of inserting the endif or it should be added when the if is created
def _find_if(x:UOp):
if x.op is not Ops.STORE: return x
if len(x.src) > 2 and x.src[2].op is Ops.IF: return x.src[2]
ret = x.src[1]
while ret.src and ret.op is not Ops.LOAD: ret = ret.src[-1]
if len(ret.src) < 2 or ret.src[1].op is not Ops.IF: return x
return ret.src[1]
groups = {k: tuple(g) for k,g in groupby(x.src, key=lambda k: _find_if(k))}
groups = {k: tuple(g) for k,g in groupby(x.src, key=lambda k: k.src[2] if len(k.src) >= 3 and k.src[2].op is Ops.IF else k)}
if not any(k.op is Ops.IF for k in groups): return None
return x.replace(src=tuple(UOp(Ops.ENDIF, src=(k,) + g) if k.op is Ops.IF else k for k,g in groups.items()))
# some Ops.IF aren't closed by an Ops.STORE, in that case the Ops.SINK closes it
def close_ifs(x:UOp):
consumers = x.get_consumer_map()
if (y:=next((s for s in consumers if s.op is Ops.IF and all(n.op is not Ops.ENDIF for n in consumers[s])), None)) is not None:
return x.replace(src=(UOp(Ops.ENDIF, src=(y,) + x.src),))
return None
pm_control_flow_ends = PatternMatcher([
(UPat((Ops.SINK, Ops.NOOP, Ops.LOAD), name="x"), add_endrange),
(UPat((Ops.SINK, Ops.ENDRANGE), name="x"), add_endif),
(UPat(Ops.SINK, name="x"), close_ifs),
])
class CFGContext:
@@ -82,7 +82,7 @@ class CFGContext:
nesting: dict[UOp, UOp] = {}
for u in sink.toposort():
deps[u] = set().union(*(deps[s] for s in u.src))
if u.op is Ops.ENDRANGE:
if u.op in (Ops.ENDRANGE, Ops.ENDIF):
for n in [x for x in deps[u] if x.op in (Ops.ENDRANGE, Ops.ENDIF) and u.src[0] in deps[x] and x not in nesting]: nesting[n] = u
if u.op is Ops.SINK:
for n in [x for x in deps[u] if x.op in (Ops.ENDRANGE, Ops.ENDIF) and x not in nesting]: nesting[n] = u
@@ -94,11 +94,12 @@ class CFGContext:
for k,v in siblings.items():
# sibling ranges that have dependencies on other siblings need to run after them
endranges = sorted([x for x in v if x.op is Ops.ENDRANGE], key=lambda x: len([y for y in v if y in deps[x]]))
zipped = zip([k.src[0]] + endranges, endranges) if k.op is Ops.ENDRANGE else zip(endranges, endranges[1:])
zipped = zip(endranges, endranges[1:]) if k.op is Ops.SINK else zip([k.src[0]] + endranges, endranges)
for x,y in zipped: self.edges[y.src[0]] = x
endifs = [x for x in v if x.op is Ops.ENDIF]
for x,y in zip(endifs, endifs[1:]): self.edges[y.src[0]] = x
pm_control_flow_starts = PatternMatcher([
(UPat((Ops.RANGE, Ops.IF), src=(UPat(),), name="x"), lambda ctx,x: x.replace(src=(x.src[0], y)) if (y:=ctx.edges.get(x)) is not None else None),
(UPat((Ops.RANGE, Ops.IF), src=(UPat(),), name="x"), lambda ctx,x: x.replace(src=x.src + (y,)) if (y:=ctx.edges.get(x)) is not None else None),
(UPat(Ops.IF, src=(UPat(), UPat(Ops.BARRIER)), name="x"), lambda ctx,x: x.replace(src=x.src + (y,)) if (y:=ctx.edges.get(x)) is not None else None),
])