mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
we don't support multi end yet
This commit is contained in:
@@ -14,7 +14,7 @@ from tinygrad.codegen.late.devectorizer import load_store_folding, load_store_in
|
||||
from tinygrad.codegen.opt.postrange import apply_opts
|
||||
from tinygrad.codegen.simplify import pm_simplify_ranges, pm_reduce_simplify, pm_flatten_range, pm_split_ranges
|
||||
from tinygrad.schedule.rangeify import pm_add_buffers, rangeify_codegen
|
||||
from tinygrad.codegen.late.control_flow import CFGContext, pm_add_ends, pm_add_control_flow, linearize, pm_merge_ends
|
||||
from tinygrad.codegen.late.control_flow import CFGContext, pm_add_ends, pm_add_control_flow, linearize
|
||||
|
||||
def full_rewrite_to_sink(sink:UOp, ren:Renderer|None=None, optimize:bool=True) -> UOp:
|
||||
if ren is None: ren = Renderer()
|
||||
@@ -79,8 +79,7 @@ def full_rewrite_to_sink(sink:UOp, ren:Renderer|None=None, optimize:bool=True) -
|
||||
sink = graph_rewrite(sink, pm_final_rewrite, ctx=ren.device, name="final rewrite")
|
||||
|
||||
# this was the linearizer
|
||||
sink = graph_rewrite(sink, pm_merge_ends, name="merge ends of ranges")
|
||||
sink = graph_rewrite(sink, pm_add_control_flow, ctx=CFGContext(sink), name="add control flow starts", bottom_up=True)
|
||||
sink = graph_rewrite(sink, pm_add_control_flow, ctx=CFGContext(sink), name="add control flow", bottom_up=True)
|
||||
|
||||
# return the rewritten sink
|
||||
return sink
|
||||
|
||||
@@ -96,24 +96,9 @@ class CFGContext:
|
||||
self.edges[y.src[0]] = x
|
||||
|
||||
pm_add_control_flow = PatternMatcher([
|
||||
(UPat((Ops.RANGE, Ops.IF), name="x"), lambda ctx,x: x.replace(src=x.src+(y,)) if (y:=ctx.edges.get(x)) is not None else None),
|
||||
(UPat(Ops.RANGE, name="x"), lambda ctx,x: x.replace(src=x.src+(y,)) if (y:=ctx.edges.get(x)) is not None else None),
|
||||
])
|
||||
|
||||
def do_merge_ends(s:UOp):
|
||||
# NOTE: this can fail
|
||||
stacked: dict[UOp, list[UOp]] = {}
|
||||
for x in s.toposort():
|
||||
if x.op is Ops.END:
|
||||
assert x.arg == 1, "ends must be single ends for linearizer"
|
||||
stacked.setdefault(x.src[0], []).append(x)
|
||||
replaces = {}
|
||||
for k,v in stacked.items():
|
||||
if len(v) == 1: continue
|
||||
rep = UOp(v[0].op, src=tuple([k] + [y for x in v for y in x.src[1:]]), arg=v[0].arg)
|
||||
for x in v: replaces[x] = rep
|
||||
if not len(replaces): return None
|
||||
return s.substitute(replaces)
|
||||
|
||||
pm_add_ends = PatternMatcher([
|
||||
# put the end on the store
|
||||
(UPat(Ops.STORE, name="s"), lambda s: s.replace(src=s.src[:2]).end(ends=s.src[2:]) if len(s.src) > 2 else None),
|
||||
@@ -122,7 +107,3 @@ pm_add_ends = PatternMatcher([
|
||||
# for renderering and linearizing, all ends must end one loop
|
||||
(UPat(Ops.END, name="e"), lambda e: e.replace(src=e.src[e.arg-1:], arg=1).end(ends=e.src[:e.arg-1]) if e.arg > 1 else None),
|
||||
])
|
||||
|
||||
pm_merge_ends = PatternMatcher([
|
||||
(UPat(Ops.SINK, name="s"), do_merge_ends),
|
||||
])
|
||||
Reference in New Issue
Block a user