we don't support multi end yet

This commit is contained in:
George Hotz
2025-10-22 19:59:53 +08:00
parent 174811fc0f
commit 93aa420f3b
2 changed files with 3 additions and 23 deletions

View File

@@ -14,7 +14,7 @@ from tinygrad.codegen.late.devectorizer import load_store_folding, load_store_in
from tinygrad.codegen.opt.postrange import apply_opts
from tinygrad.codegen.simplify import pm_simplify_ranges, pm_reduce_simplify, pm_flatten_range, pm_split_ranges
from tinygrad.schedule.rangeify import pm_add_buffers, rangeify_codegen
from tinygrad.codegen.late.control_flow import CFGContext, pm_add_ends, pm_add_control_flow, linearize, pm_merge_ends
from tinygrad.codegen.late.control_flow import CFGContext, pm_add_ends, pm_add_control_flow, linearize
def full_rewrite_to_sink(sink:UOp, ren:Renderer|None=None, optimize:bool=True) -> UOp:
if ren is None: ren = Renderer()
@@ -79,8 +79,7 @@ def full_rewrite_to_sink(sink:UOp, ren:Renderer|None=None, optimize:bool=True) -
sink = graph_rewrite(sink, pm_final_rewrite, ctx=ren.device, name="final rewrite")
# this was the linearizer
sink = graph_rewrite(sink, pm_merge_ends, name="merge ends of ranges")
sink = graph_rewrite(sink, pm_add_control_flow, ctx=CFGContext(sink), name="add control flow starts", bottom_up=True)
sink = graph_rewrite(sink, pm_add_control_flow, ctx=CFGContext(sink), name="add control flow", bottom_up=True)
# return the rewritten sink
return sink

View File

@@ -96,24 +96,9 @@ class CFGContext:
self.edges[y.src[0]] = x
pm_add_control_flow = PatternMatcher([
(UPat((Ops.RANGE, Ops.IF), name="x"), lambda ctx,x: x.replace(src=x.src+(y,)) if (y:=ctx.edges.get(x)) is not None else None),
(UPat(Ops.RANGE, name="x"), lambda ctx,x: x.replace(src=x.src+(y,)) if (y:=ctx.edges.get(x)) is not None else None),
])
def do_merge_ends(s:UOp):
# NOTE: this can fail
stacked: dict[UOp, list[UOp]] = {}
for x in s.toposort():
if x.op is Ops.END:
assert x.arg == 1, "ends must be single ends for linearizer"
stacked.setdefault(x.src[0], []).append(x)
replaces = {}
for k,v in stacked.items():
if len(v) == 1: continue
rep = UOp(v[0].op, src=tuple([k] + [y for x in v for y in x.src[1:]]), arg=v[0].arg)
for x in v: replaces[x] = rep
if not len(replaces): return None
return s.substitute(replaces)
pm_add_ends = PatternMatcher([
# put the end on the store
(UPat(Ops.STORE, name="s"), lambda s: s.replace(src=s.src[:2]).end(ends=s.src[2:]) if len(s.src) > 2 else None),
@@ -122,7 +107,3 @@ pm_add_ends = PatternMatcher([
# for renderering and linearizing, all ends must end one loop
(UPat(Ops.END, name="e"), lambda e: e.replace(src=e.src[e.arg-1:], arg=1).end(ends=e.src[:e.arg-1]) if e.arg > 1 else None),
])
pm_merge_ends = PatternMatcher([
(UPat(Ops.SINK, name="s"), do_merge_ends),
])