mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 23:18:04 -05:00
delete the lowerer (#12526)
This commit is contained in:
@@ -1,40 +0,0 @@
|
||||
import time
|
||||
from extra.optimization.helpers import load_worlds, ast_str_to_ast
|
||||
from tinygrad import Device
|
||||
from tinygrad.codegen.lowerer import pm_lowerer, get_index
|
||||
from tinygrad.uop.ops import graph_rewrite
|
||||
from tinygrad.codegen.opt.kernel import Kernel
|
||||
from tinygrad.codegen.opt.postrange import Scheduler
|
||||
from tinygrad.codegen.opt.heuristic import hand_coded_optimizations
|
||||
from tinygrad.helpers import getenv
|
||||
|
||||
if __name__ == "__main__":
|
||||
renderer = Device.default.renderer
|
||||
ast_strs = load_worlds()
|
||||
if (n:=getenv("N", -1)) != -1: ast_strs = ast_strs[n:n+1]
|
||||
good = 0
|
||||
for i, ast_str in enumerate(ast_strs):
|
||||
ast = ast_str_to_ast(ast_str)
|
||||
|
||||
st = time.perf_counter()
|
||||
lin = Kernel(ast, renderer)
|
||||
opt1 = hand_coded_optimizations(lin)
|
||||
et_lin = time.perf_counter() - st
|
||||
|
||||
lowered = graph_rewrite(ast, pm_lowerer, ctx=get_index(ast), bottom_up=True)
|
||||
st = time.perf_counter()
|
||||
sch = Scheduler(lowered, renderer)
|
||||
sch.convert_loop_to_global()
|
||||
sch.simplify_merge_adjacent()
|
||||
opt2 = hand_coded_optimizations(sch)
|
||||
et_sch = time.perf_counter() - st
|
||||
|
||||
if opt1 != opt2:
|
||||
print(f"******* {i:6d}")
|
||||
print("Kernel: ", lin.colored_shape(), "->", lin.apply_opts(opt1).colored_shape())
|
||||
print("Scheduler: ", sch.colored_shape(), "->", sch.apply_opts(opt2).colored_shape())
|
||||
print(opt1)
|
||||
print(opt2)
|
||||
else:
|
||||
good += 1
|
||||
print(f"******* {i:6d} MATCH {good/(i+1)*100:.2f}% -- {et_lin/et_sch:4.2f}x speedup")
|
||||
@@ -7,7 +7,6 @@ from tinygrad.uop.spec import type_verify
|
||||
from tinygrad.renderer import Renderer
|
||||
|
||||
# import all pattern matchers here
|
||||
from tinygrad.codegen.lowerer import pm_lowerer, get_index
|
||||
from tinygrad.codegen.quantize import pm_quant
|
||||
from tinygrad.codegen.gpudims import pm_add_gpudims
|
||||
from tinygrad.uop.symbolic import sym, symbolic_simple, gep_pushing, symbolic
|
||||
@@ -51,7 +50,6 @@ def _get_rewrites_for_renderer(opts:Renderer, optimize:bool, linearizer:bool, _Q
|
||||
|
||||
# lowerer first
|
||||
if _QUANTIZE and opts.device in {"CPU", "DSP"}: ret.append(RewriteStep(pm_quant, name="quantize"))
|
||||
ret.append(RewriteStep(pm_lowerer, get_index, name="lowerer", bottom_up=True))
|
||||
|
||||
# split ranges
|
||||
if _RANGEIFY:
|
||||
|
||||
@@ -1,48 +0,0 @@
|
||||
# the job of the lowerer is to do indexing
|
||||
from dataclasses import dataclass
|
||||
from tinygrad.uop.ops import KernelInfo, UOp, Ops, PatternMatcher, UPat, sint_to_uop, AxisType, graph_rewrite
|
||||
|
||||
# ***** indexing *****
|
||||
|
||||
@dataclass
|
||||
class IndexContext:
|
||||
axis_types: tuple[AxisType, ...]
|
||||
idxs: list[UOp]
|
||||
start: int = 0
|
||||
|
||||
def shape_to_idx(s, axis_types, start=0):
|
||||
return [UOp.range(sint_to_uop(s), start+i, at) for i, (s, at) in enumerate(zip(s, axis_types))]
|
||||
|
||||
def get_index(ast:UOp) -> IndexContext:
|
||||
axis_types = ast.arg.axis_types if isinstance(ast.arg, KernelInfo) else ()
|
||||
#if len(ast.full_shape) != len(axis_types) and ast.st is not None:
|
||||
# axis_types = tuple([AxisType.REDUCE if resolve(s != fs) else AxisType.LOOP for s,fs in zip(ast.shape, ast.full_shape)])
|
||||
return IndexContext(axis_types, [], 0)
|
||||
|
||||
# ***** lowering (given index) *****
|
||||
|
||||
def subblock(ctx: IndexContext, full_new_idx: list[UOp], src: UOp):
|
||||
lc = IndexContext(ctx.axis_types, full_new_idx, ctx.start+1000)
|
||||
ctx.start = lc.start
|
||||
return graph_rewrite(src, pm_lowerer, lc, name="subblock", bottom_up=True)
|
||||
|
||||
def fixup_wmma(ctx:IndexContext, x:UOp):
|
||||
if x.tag is not None: return None
|
||||
new_idxs = shape_to_idx(x.src[0].shape, ctx.axis_types, ctx.start)
|
||||
full_new_idx = list(ctx.idxs)
|
||||
for a in x.arg[-1]: full_new_idx[a] = new_idxs[a]
|
||||
|
||||
srcs = subblock(ctx, full_new_idx, UOp.sink(*x.src)).src
|
||||
|
||||
# NOTE: this assumes these are expanded. which now shouldn't change anything
|
||||
new_x_arg_m2 = tuple([tuple([(full_new_idx[a].arg[0], sz) for a,sz in v]) for v in x.arg[-2]])
|
||||
new_x_arg_m1 = tuple([full_new_idx[a].arg[0] for a in x.arg[-1]])
|
||||
return x.replace(src=srcs, arg=x.arg[:-2]+(new_x_arg_m2, new_x_arg_m1), tag=1)
|
||||
|
||||
pm_lowerer = PatternMatcher([
|
||||
(UPat(Ops.WMMA, name="x"), fixup_wmma),
|
||||
|
||||
# axis fixups for WMMA
|
||||
(UPat((Ops.CONTRACT, Ops.UNROLL), name="x"),
|
||||
lambda ctx,x: x.replace(tag=1, arg=tuple([(ctx.idxs[a].arg[0], sz) for a,sz in x.arg])) if x.tag is None else None),
|
||||
])
|
||||
Reference in New Issue
Block a user