This commit is contained in:
ttomsa
2025-10-08 00:35:53 +01:00
parent 4687d000a7
commit 332293cfa7

View File

@@ -54,7 +54,15 @@ def add_endrange(x:UOp):
return x.replace(src=tuple(src))
def add_endif(x:UOp):
groups = {k: tuple(g) for k,g in groupby(x.src, key=lambda k: k.src[2] if len(k.src) >= 3 and k.src[2].op is Ops.IF else k)}
# HACK: sometimes the if is in the load, either need a better way of inserting the endif or it should be added when the if is created
def _find_if(x:UOp):
if x.op is not Ops.STORE: return x
if len(x.src) > 2 and x.src[2].op is Ops.IF: return x.src[2]
ret = x.src[1]
while ret.src and ret.op is not Ops.LOAD: ret = ret.src[-1]
if len(ret.src) < 2 or ret.src[1].op is not Ops.IF: return x
return ret.src[1]
groups = {k: tuple(g) for k,g in groupby(x.src, key=lambda k: _find_if(k))}
if not any(k.op is Ops.IF for k in groups): return None
return x.replace(src=tuple(UOp(Ops.ENDIF, src=(k,) + g) if k.op is Ops.IF else k for k,g in groups.items()))
@@ -73,12 +81,12 @@ class CFGContext:
deps: dict[UOp, set[UOp]] = {}
nesting: dict[UOp, UOp] = {}
for u in sink.toposort():
deps[u] = {u} if u.op in (Ops.RANGE, Ops.IF) else set().union(*(deps[s] for s in u.src))
deps[u] = set().union(*(deps[s] for s in u.src))
if u.op is Ops.ENDRANGE:
for n in [x for x in deps[u] if x.op in (Ops.ENDRANGE, Ops.ENDIF) and u.src[0] in deps[x] and x not in nesting]: nesting[n] = u
if u.op is Ops.SINK:
for n in [x for x in deps[u] if x.op in (Ops.ENDRANGE, Ops.ENDIF) and x not in nesting]: nesting[n] = u
if u.op in (Ops.ENDRANGE, Ops.ENDIF): deps[u] |= {u}
if u.op in (Ops.RANGE, Ops.ENDRANGE, Ops.IF, Ops.ENDIF): deps[u] |= {u}
self.edges: dict[UOp, UOp] = {}
siblings: dict[UOp, list[UOp]] = {}