mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
linearize API
This commit is contained in:
@@ -14,7 +14,7 @@ from tinygrad.uop.decompositions import get_late_rewrite_patterns
|
||||
from tinygrad.codegen.late.expander import migrate_indexing, expander, pm_pre_expander, pm_group_for_reduce
|
||||
from tinygrad.codegen.late.devectorizer import load_store_folding, load_store_indexing, devectorize, pm_reduce, \
|
||||
ReduceContext, correct_load_store, pm_render
|
||||
from tinygrad.codegen.late.control_flow import pm_control_flow_ends, pm_control_flow_starts, CFGContext, schedule
|
||||
from tinygrad.codegen.late.control_flow import pm_control_flow_ends, pm_control_flow_starts, CFGContext, linearize
|
||||
from tinygrad.codegen.opt.postrange import pm_postrange_opt
|
||||
from tinygrad.codegen.simplify import pm_simplify_ranges, pm_reduce_simplify, pm_flatten_range, pm_split_ranges
|
||||
from tinygrad.schedule.rangeify import pm_add_buffers, rangeify_codegen
|
||||
@@ -96,15 +96,14 @@ def _get_rewrites_for_renderer(opts:Renderer, optimize:bool, linearizer:bool, _Q
|
||||
ret.append(RewriteStep(pm_final_rewrite, lambda _: opts.device, name="final rewrite"))
|
||||
|
||||
# add control flow to the graph
|
||||
if linearizer:
|
||||
ret.append(RewriteStep(pm_control_flow_ends, name="add control flow ends"))
|
||||
ret.append(RewriteStep(pm_control_flow_starts, CFGContext, name="add control flow starts", bottom_up=True))
|
||||
ret.append(RewriteStep(pm_control_flow_ends, name="add control flow ends"))
|
||||
ret.append(RewriteStep(pm_control_flow_starts, CFGContext, name="add control flow starts", bottom_up=True))
|
||||
|
||||
# return the list
|
||||
return ret
|
||||
|
||||
def full_rewrite_to_sink(sink:UOp, opts:Renderer|None=None, optimize:bool=True, linearizer:bool=False) -> UOp:
|
||||
return apply_rewrites(sink, get_rewrites_for_renderer(opts if opts is not None else Renderer(), optimize, linearizer))
|
||||
def full_rewrite_to_sink(sink:UOp, opts:Renderer|None=None, optimize:bool=True) -> UOp:
|
||||
return apply_rewrites(sink, get_rewrites_for_renderer(opts if opts is not None else Renderer(), optimize))
|
||||
|
||||
def full_rewrite(sink:UOp, opts:Renderer|None=None) -> list[UOp]:
|
||||
"""
|
||||
@@ -117,6 +116,6 @@ def full_rewrite(sink:UOp, opts:Renderer|None=None) -> list[UOp]:
|
||||
Returns:
|
||||
Linear program in UOps.
|
||||
"""
|
||||
lst = schedule(list(full_rewrite_to_sink(sink, opts, optimize=sink.tag is None, linearizer=True).toposort()))
|
||||
lst = linearize(full_rewrite_to_sink(sink, opts, optimize=sink.tag is None))
|
||||
if __debug__: type_verify(lst)
|
||||
return lst
|
||||
|
||||
@@ -5,7 +5,8 @@ from itertools import groupby
|
||||
from functools import reduce
|
||||
import heapq
|
||||
|
||||
def schedule(lst:list[UOp]) -> list[UOp]:
|
||||
def linearize(u:UOp) -> list[UOp]:
|
||||
lst = list(u.toposort())
|
||||
in_this_block = set(lst)
|
||||
local_children: defaultdict[UOp, list[UOp]] = defaultdict(list)
|
||||
in_degree:dict[UOp, int] = {}
|
||||
@@ -99,6 +100,6 @@ class CFGContext:
|
||||
pm_control_flow_starts = PatternMatcher([
|
||||
(UPat((Ops.RANGE, Ops.IF), src=(UPat(),), name="x"), lambda ctx,x: x.replace(src=x.src+(y,)) if (y:=ctx.edges.get(x)) is not None else None),
|
||||
(UPat(Ops.IF, src=(UPat(), UPat(Ops.BARRIER)), name="x"), lambda ctx,x: x.replace(src=x.src+(y,)) if (y:=ctx.edges.get(x)) is not None else None),
|
||||
# remove ranges from STORE. keep NOOP since they determine ordering
|
||||
(UPat(Ops.STORE, name="s"), lambda s: s.replace(src=s.src[0:2]+tuple([x for x in s.src[2:] if x.op not in {Ops.RANGE, Ops.CONST}]))),
|
||||
# optional: remove ranges from STORE. keep NOOP since they determine ordering
|
||||
#(UPat(Ops.STORE, name="s"), lambda s: s.replace(src=s.src[0:2]+tuple([x for x in s.src[2:] if x.op not in {Ops.RANGE, Ops.CONST}]))),
|
||||
])
|
||||
|
||||
Reference in New Issue
Block a user