mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
gpu work now?
This commit is contained in:
@@ -54,21 +54,21 @@ def add_endrange(x:UOp):
|
||||
return x.replace(src=tuple(src))
|
||||
|
||||
def add_endif(x:UOp):
|
||||
# HACK: sometimes the if is in the load, either need a better way of inserting the endif or it should be added when the if is created
|
||||
def _find_if(x:UOp):
|
||||
if x.op is not Ops.STORE: return x
|
||||
if len(x.src) > 2 and x.src[2].op is Ops.IF: return x.src[2]
|
||||
ret = x.src[1]
|
||||
while ret.src and ret.op is not Ops.LOAD: ret = ret.src[-1]
|
||||
if len(ret.src) < 2 or ret.src[1].op is not Ops.IF: return x
|
||||
return ret.src[1]
|
||||
groups = {k: tuple(g) for k,g in groupby(x.src, key=lambda k: _find_if(k))}
|
||||
groups = {k: tuple(g) for k,g in groupby(x.src, key=lambda k: k.src[2] if len(k.src) >= 3 and k.src[2].op is Ops.IF else k)}
|
||||
if not any(k.op is Ops.IF for k in groups): return None
|
||||
return x.replace(src=tuple(UOp(Ops.ENDIF, src=(k,) + g) if k.op is Ops.IF else k for k,g in groups.items()))
|
||||
|
||||
# some Ops.IF aren't closed by an Ops.STORE, in that case the Ops.SINK closes it
|
||||
def close_ifs(x:UOp):
|
||||
consumers = x.get_consumer_map()
|
||||
if (y:=next((s for s in consumers if s.op is Ops.IF and all(n.op is not Ops.ENDIF for n in consumers[s])), None)) is not None:
|
||||
return x.replace(src=(UOp(Ops.ENDIF, src=(y,) + x.src),))
|
||||
return None
|
||||
|
||||
pm_control_flow_ends = PatternMatcher([
|
||||
(UPat((Ops.SINK, Ops.NOOP, Ops.LOAD), name="x"), add_endrange),
|
||||
(UPat((Ops.SINK, Ops.ENDRANGE), name="x"), add_endif),
|
||||
(UPat(Ops.SINK, name="x"), close_ifs),
|
||||
])
|
||||
|
||||
class CFGContext:
|
||||
@@ -82,7 +82,7 @@ class CFGContext:
|
||||
nesting: dict[UOp, UOp] = {}
|
||||
for u in sink.toposort():
|
||||
deps[u] = set().union(*(deps[s] for s in u.src))
|
||||
if u.op is Ops.ENDRANGE:
|
||||
if u.op in (Ops.ENDRANGE, Ops.ENDIF):
|
||||
for n in [x for x in deps[u] if x.op in (Ops.ENDRANGE, Ops.ENDIF) and u.src[0] in deps[x] and x not in nesting]: nesting[n] = u
|
||||
if u.op is Ops.SINK:
|
||||
for n in [x for x in deps[u] if x.op in (Ops.ENDRANGE, Ops.ENDIF) and x not in nesting]: nesting[n] = u
|
||||
@@ -94,11 +94,12 @@ class CFGContext:
|
||||
for k,v in siblings.items():
|
||||
# sibling ranges that have dependencies on other siblings need to run after them
|
||||
endranges = sorted([x for x in v if x.op is Ops.ENDRANGE], key=lambda x: len([y for y in v if y in deps[x]]))
|
||||
zipped = zip([k.src[0]] + endranges, endranges) if k.op is Ops.ENDRANGE else zip(endranges, endranges[1:])
|
||||
zipped = zip(endranges, endranges[1:]) if k.op is Ops.SINK else zip([k.src[0]] + endranges, endranges)
|
||||
for x,y in zipped: self.edges[y.src[0]] = x
|
||||
endifs = [x for x in v if x.op is Ops.ENDIF]
|
||||
for x,y in zip(endifs, endifs[1:]): self.edges[y.src[0]] = x
|
||||
|
||||
pm_control_flow_starts = PatternMatcher([
|
||||
(UPat((Ops.RANGE, Ops.IF), src=(UPat(),), name="x"), lambda ctx,x: x.replace(src=(x.src[0], y)) if (y:=ctx.edges.get(x)) is not None else None),
|
||||
(UPat((Ops.RANGE, Ops.IF), src=(UPat(),), name="x"), lambda ctx,x: x.replace(src=x.src + (y,)) if (y:=ctx.edges.get(x)) is not None else None),
|
||||
(UPat(Ops.IF, src=(UPat(), UPat(Ops.BARRIER)), name="x"), lambda ctx,x: x.replace(src=x.src + (y,)) if (y:=ctx.edges.get(x)) is not None else None),
|
||||
])
|
||||
|
||||
Reference in New Issue
Block a user