linearize API

This commit is contained in:
George Hotz
2025-10-10 16:52:46 +08:00
parent 3d760bba51
commit 5d66fa479b
2 changed files with 10 additions and 10 deletions

View File

@@ -14,7 +14,7 @@ from tinygrad.uop.decompositions import get_late_rewrite_patterns
from tinygrad.codegen.late.expander import migrate_indexing, expander, pm_pre_expander, pm_group_for_reduce
from tinygrad.codegen.late.devectorizer import load_store_folding, load_store_indexing, devectorize, pm_reduce, \
ReduceContext, correct_load_store, pm_render
from tinygrad.codegen.late.control_flow import pm_control_flow_ends, pm_control_flow_starts, CFGContext, schedule
from tinygrad.codegen.late.control_flow import pm_control_flow_ends, pm_control_flow_starts, CFGContext, linearize
from tinygrad.codegen.opt.postrange import pm_postrange_opt
from tinygrad.codegen.simplify import pm_simplify_ranges, pm_reduce_simplify, pm_flatten_range, pm_split_ranges
from tinygrad.schedule.rangeify import pm_add_buffers, rangeify_codegen
@@ -96,15 +96,14 @@ def _get_rewrites_for_renderer(opts:Renderer, optimize:bool, linearizer:bool, _Q
ret.append(RewriteStep(pm_final_rewrite, lambda _: opts.device, name="final rewrite"))
# add control flow to the graph
if linearizer:
ret.append(RewriteStep(pm_control_flow_ends, name="add control flow ends"))
ret.append(RewriteStep(pm_control_flow_starts, CFGContext, name="add control flow starts", bottom_up=True))
ret.append(RewriteStep(pm_control_flow_ends, name="add control flow ends"))
ret.append(RewriteStep(pm_control_flow_starts, CFGContext, name="add control flow starts", bottom_up=True))
# return the list
return ret
def full_rewrite_to_sink(sink:UOp, opts:Renderer|None=None, optimize:bool=True, linearizer:bool=False) -> UOp:
return apply_rewrites(sink, get_rewrites_for_renderer(opts if opts is not None else Renderer(), optimize, linearizer))
def full_rewrite_to_sink(sink:UOp, opts:Renderer|None=None, optimize:bool=True) -> UOp:
return apply_rewrites(sink, get_rewrites_for_renderer(opts if opts is not None else Renderer(), optimize))
def full_rewrite(sink:UOp, opts:Renderer|None=None) -> list[UOp]:
"""
@@ -117,6 +116,6 @@ def full_rewrite(sink:UOp, opts:Renderer|None=None) -> list[UOp]:
Returns:
Linear program in UOps.
"""
lst = schedule(list(full_rewrite_to_sink(sink, opts, optimize=sink.tag is None, linearizer=True).toposort()))
lst = linearize(full_rewrite_to_sink(sink, opts, optimize=sink.tag is None))
if __debug__: type_verify(lst)
return lst

View File

@@ -5,7 +5,8 @@ from itertools import groupby
from functools import reduce
import heapq
def schedule(lst:list[UOp]) -> list[UOp]:
def linearize(u:UOp) -> list[UOp]:
lst = list(u.toposort())
in_this_block = set(lst)
local_children: defaultdict[UOp, list[UOp]] = defaultdict(list)
in_degree:dict[UOp, int] = {}
@@ -99,6 +100,6 @@ class CFGContext:
pm_control_flow_starts = PatternMatcher([
(UPat((Ops.RANGE, Ops.IF), src=(UPat(),), name="x"), lambda ctx,x: x.replace(src=x.src+(y,)) if (y:=ctx.edges.get(x)) is not None else None),
(UPat(Ops.IF, src=(UPat(), UPat(Ops.BARRIER)), name="x"), lambda ctx,x: x.replace(src=x.src+(y,)) if (y:=ctx.edges.get(x)) is not None else None),
# remove ranges from STORE. keep NOOP since they determine ordering
(UPat(Ops.STORE, name="s"), lambda s: s.replace(src=s.src[0:2]+tuple([x for x in s.src[2:] if x.op not in {Ops.RANGE, Ops.CONST}]))),
# optional: remove ranges from STORE. keep NOOP since they determine ordering
#(UPat(Ops.STORE, name="s"), lambda s: s.replace(src=s.src[0:2]+tuple([x for x in s.src[2:] if x.op not in {Ops.RANGE, Ops.CONST}]))),
])