delete the lowerer (#12526)

This commit is contained in:
George Hotz
2025-10-08 21:58:18 +08:00
committed by GitHub
parent 0774575442
commit 2653147cb7
3 changed files with 0 additions and 90 deletions

View File

@@ -1,40 +0,0 @@
import time
from extra.optimization.helpers import load_worlds, ast_str_to_ast
from tinygrad import Device
from tinygrad.codegen.lowerer import pm_lowerer, get_index
from tinygrad.uop.ops import graph_rewrite
from tinygrad.codegen.opt.kernel import Kernel
from tinygrad.codegen.opt.postrange import Scheduler
from tinygrad.codegen.opt.heuristic import hand_coded_optimizations
from tinygrad.helpers import getenv
if __name__ == "__main__":
renderer = Device.default.renderer
ast_strs = load_worlds()
if (n:=getenv("N", -1)) != -1: ast_strs = ast_strs[n:n+1]
good = 0
for i, ast_str in enumerate(ast_strs):
ast = ast_str_to_ast(ast_str)
st = time.perf_counter()
lin = Kernel(ast, renderer)
opt1 = hand_coded_optimizations(lin)
et_lin = time.perf_counter() - st
lowered = graph_rewrite(ast, pm_lowerer, ctx=get_index(ast), bottom_up=True)
st = time.perf_counter()
sch = Scheduler(lowered, renderer)
sch.convert_loop_to_global()
sch.simplify_merge_adjacent()
opt2 = hand_coded_optimizations(sch)
et_sch = time.perf_counter() - st
if opt1 != opt2:
print(f"******* {i:6d}")
print("Kernel: ", lin.colored_shape(), "->", lin.apply_opts(opt1).colored_shape())
print("Scheduler: ", sch.colored_shape(), "->", sch.apply_opts(opt2).colored_shape())
print(opt1)
print(opt2)
else:
good += 1
print(f"******* {i:6d} MATCH {good/(i+1)*100:.2f}% -- {et_lin/et_sch:4.2f}x speedup")

View File

@@ -7,7 +7,6 @@ from tinygrad.uop.spec import type_verify
from tinygrad.renderer import Renderer
# import all pattern matchers here
from tinygrad.codegen.lowerer import pm_lowerer, get_index
from tinygrad.codegen.quantize import pm_quant
from tinygrad.codegen.gpudims import pm_add_gpudims
from tinygrad.uop.symbolic import sym, symbolic_simple, gep_pushing, symbolic
@@ -51,7 +50,6 @@ def _get_rewrites_for_renderer(opts:Renderer, optimize:bool, linearizer:bool, _Q
# lowerer first
if _QUANTIZE and opts.device in {"CPU", "DSP"}: ret.append(RewriteStep(pm_quant, name="quantize"))
ret.append(RewriteStep(pm_lowerer, get_index, name="lowerer", bottom_up=True))
# split ranges
if _RANGEIFY:

View File

@@ -1,48 +0,0 @@
# the job of the lowerer is to do indexing
from dataclasses import dataclass
from tinygrad.uop.ops import KernelInfo, UOp, Ops, PatternMatcher, UPat, sint_to_uop, AxisType, graph_rewrite
# ***** indexing *****
@dataclass
class IndexContext:
axis_types: tuple[AxisType, ...]
idxs: list[UOp]
start: int = 0
def shape_to_idx(s, axis_types, start=0):
return [UOp.range(sint_to_uop(s), start+i, at) for i, (s, at) in enumerate(zip(s, axis_types))]
def get_index(ast:UOp) -> IndexContext:
axis_types = ast.arg.axis_types if isinstance(ast.arg, KernelInfo) else ()
#if len(ast.full_shape) != len(axis_types) and ast.st is not None:
# axis_types = tuple([AxisType.REDUCE if resolve(s != fs) else AxisType.LOOP for s,fs in zip(ast.shape, ast.full_shape)])
return IndexContext(axis_types, [], 0)
# ***** lowering (given index) *****
def subblock(ctx: IndexContext, full_new_idx: list[UOp], src: UOp):
lc = IndexContext(ctx.axis_types, full_new_idx, ctx.start+1000)
ctx.start = lc.start
return graph_rewrite(src, pm_lowerer, lc, name="subblock", bottom_up=True)
def fixup_wmma(ctx:IndexContext, x:UOp):
if x.tag is not None: return None
new_idxs = shape_to_idx(x.src[0].shape, ctx.axis_types, ctx.start)
full_new_idx = list(ctx.idxs)
for a in x.arg[-1]: full_new_idx[a] = new_idxs[a]
srcs = subblock(ctx, full_new_idx, UOp.sink(*x.src)).src
# NOTE: this assumes these are expanded. which now shouldn't change anything
new_x_arg_m2 = tuple([tuple([(full_new_idx[a].arg[0], sz) for a,sz in v]) for v in x.arg[-2]])
new_x_arg_m1 = tuple([full_new_idx[a].arg[0] for a in x.arg[-1]])
return x.replace(src=srcs, arg=x.arg[:-2]+(new_x_arg_m2, new_x_arg_m1), tag=1)
pm_lowerer = PatternMatcher([
(UPat(Ops.WMMA, name="x"), fixup_wmma),
# axis fixups for WMMA
(UPat((Ops.CONTRACT, Ops.UNROLL), name="x"),
lambda ctx,x: x.replace(tag=1, arg=tuple([(ctx.idxs[a].arg[0], sz) for a,sz in x.arg])) if x.tag is None else None),
])