mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
terrible
This commit is contained in:
@@ -54,7 +54,15 @@ def add_endrange(x:UOp):
|
||||
return x.replace(src=tuple(src))
|
||||
|
||||
def add_endif(x:UOp):
|
||||
groups = {k: tuple(g) for k,g in groupby(x.src, key=lambda k: k.src[2] if len(k.src) >= 3 and k.src[2].op is Ops.IF else k)}
|
||||
# HACK: sometimes the if is in the load, either need a better way of inserting the endif or it should be added when the if is created
|
||||
def _find_if(x:UOp):
|
||||
if x.op is not Ops.STORE: return x
|
||||
if len(x.src) > 2 and x.src[2].op is Ops.IF: return x.src[2]
|
||||
ret = x.src[1]
|
||||
while ret.src and ret.op is not Ops.LOAD: ret = ret.src[-1]
|
||||
if len(ret.src) < 2 or ret.src[1].op is not Ops.IF: return x
|
||||
return ret.src[1]
|
||||
groups = {k: tuple(g) for k,g in groupby(x.src, key=lambda k: _find_if(k))}
|
||||
if not any(k.op is Ops.IF for k in groups): return None
|
||||
return x.replace(src=tuple(UOp(Ops.ENDIF, src=(k,) + g) if k.op is Ops.IF else k for k,g in groups.items()))
|
||||
|
||||
@@ -73,12 +81,12 @@ class CFGContext:
|
||||
deps: dict[UOp, set[UOp]] = {}
|
||||
nesting: dict[UOp, UOp] = {}
|
||||
for u in sink.toposort():
|
||||
deps[u] = {u} if u.op in (Ops.RANGE, Ops.IF) else set().union(*(deps[s] for s in u.src))
|
||||
deps[u] = set().union(*(deps[s] for s in u.src))
|
||||
if u.op is Ops.ENDRANGE:
|
||||
for n in [x for x in deps[u] if x.op in (Ops.ENDRANGE, Ops.ENDIF) and u.src[0] in deps[x] and x not in nesting]: nesting[n] = u
|
||||
if u.op is Ops.SINK:
|
||||
for n in [x for x in deps[u] if x.op in (Ops.ENDRANGE, Ops.ENDIF) and x not in nesting]: nesting[n] = u
|
||||
if u.op in (Ops.ENDRANGE, Ops.ENDIF): deps[u] |= {u}
|
||||
if u.op in (Ops.RANGE, Ops.ENDRANGE, Ops.IF, Ops.ENDIF): deps[u] |= {u}
|
||||
|
||||
self.edges: dict[UOp, UOp] = {}
|
||||
siblings: dict[UOp, list[UOp]] = {}
|
||||
|
||||
Reference in New Issue
Block a user