diff --git a/extra/test_hcopt.py b/extra/test_hcopt.py deleted file mode 100644 index 36978bf831..0000000000 --- a/extra/test_hcopt.py +++ /dev/null @@ -1,40 +0,0 @@ -import time -from extra.optimization.helpers import load_worlds, ast_str_to_ast -from tinygrad import Device -from tinygrad.codegen.lowerer import pm_lowerer, get_index -from tinygrad.uop.ops import graph_rewrite -from tinygrad.codegen.opt.kernel import Kernel -from tinygrad.codegen.opt.postrange import Scheduler -from tinygrad.codegen.opt.heuristic import hand_coded_optimizations -from tinygrad.helpers import getenv - -if __name__ == "__main__": - renderer = Device.default.renderer - ast_strs = load_worlds() - if (n:=getenv("N", -1)) != -1: ast_strs = ast_strs[n:n+1] - good = 0 - for i, ast_str in enumerate(ast_strs): - ast = ast_str_to_ast(ast_str) - - st = time.perf_counter() - lin = Kernel(ast, renderer) - opt1 = hand_coded_optimizations(lin) - et_lin = time.perf_counter() - st - - lowered = graph_rewrite(ast, pm_lowerer, ctx=get_index(ast), bottom_up=True) - st = time.perf_counter() - sch = Scheduler(lowered, renderer) - sch.convert_loop_to_global() - sch.simplify_merge_adjacent() - opt2 = hand_coded_optimizations(sch) - et_sch = time.perf_counter() - st - - if opt1 != opt2: - print(f"******* {i:6d}") - print("Kernel: ", lin.colored_shape(), "->", lin.apply_opts(opt1).colored_shape()) - print("Scheduler: ", sch.colored_shape(), "->", sch.apply_opts(opt2).colored_shape()) - print(opt1) - print(opt2) - else: - good += 1 - print(f"******* {i:6d} MATCH {good/(i+1)*100:.2f}% -- {et_lin/et_sch:4.2f}x speedup") diff --git a/tinygrad/codegen/__init__.py b/tinygrad/codegen/__init__.py index 97957059f6..32f6ce9490 100644 --- a/tinygrad/codegen/__init__.py +++ b/tinygrad/codegen/__init__.py @@ -7,7 +7,6 @@ from tinygrad.uop.spec import type_verify from tinygrad.renderer import Renderer # import all pattern matchers here -from tinygrad.codegen.lowerer import pm_lowerer, get_index from tinygrad.codegen.quantize import pm_quant from tinygrad.codegen.gpudims import pm_add_gpudims from tinygrad.uop.symbolic import sym, symbolic_simple, gep_pushing, symbolic @@ -51,7 +50,6 @@ def _get_rewrites_for_renderer(opts:Renderer, optimize:bool, linearizer:bool, _Q # lowerer first if _QUANTIZE and opts.device in {"CPU", "DSP"}: ret.append(RewriteStep(pm_quant, name="quantize")) - ret.append(RewriteStep(pm_lowerer, get_index, name="lowerer", bottom_up=True)) # split ranges if _RANGEIFY: diff --git a/tinygrad/codegen/lowerer.py b/tinygrad/codegen/lowerer.py deleted file mode 100644 index 06794d88c9..0000000000 --- a/tinygrad/codegen/lowerer.py +++ /dev/null @@ -1,48 +0,0 @@ -# the job of the lowerer is to do indexing -from dataclasses import dataclass -from tinygrad.uop.ops import KernelInfo, UOp, Ops, PatternMatcher, UPat, sint_to_uop, AxisType, graph_rewrite - -# ***** indexing ***** - -@dataclass -class IndexContext: - axis_types: tuple[AxisType, ...] - idxs: list[UOp] - start: int = 0 - -def shape_to_idx(s, axis_types, start=0): - return [UOp.range(sint_to_uop(s), start+i, at) for i, (s, at) in enumerate(zip(s, axis_types))] - -def get_index(ast:UOp) -> IndexContext: - axis_types = ast.arg.axis_types if isinstance(ast.arg, KernelInfo) else () - #if len(ast.full_shape) != len(axis_types) and ast.st is not None: - # axis_types = tuple([AxisType.REDUCE if resolve(s != fs) else AxisType.LOOP for s,fs in zip(ast.shape, ast.full_shape)]) - return IndexContext(axis_types, [], 0) - -# ***** lowering (given index) ***** - -def subblock(ctx: IndexContext, full_new_idx: list[UOp], src: UOp): - lc = IndexContext(ctx.axis_types, full_new_idx, ctx.start+1000) - ctx.start = lc.start - return graph_rewrite(src, pm_lowerer, lc, name="subblock", bottom_up=True) - -def fixup_wmma(ctx:IndexContext, x:UOp): - if x.tag is not None: return None - new_idxs = shape_to_idx(x.src[0].shape, ctx.axis_types, ctx.start) - full_new_idx = list(ctx.idxs) - for a in x.arg[-1]: full_new_idx[a] = new_idxs[a] - - srcs = subblock(ctx, full_new_idx, UOp.sink(*x.src)).src - - # NOTE: this assumes these are expanded. which now shouldn't change anything - new_x_arg_m2 = tuple([tuple([(full_new_idx[a].arg[0], sz) for a,sz in v]) for v in x.arg[-2]]) - new_x_arg_m1 = tuple([full_new_idx[a].arg[0] for a in x.arg[-1]]) - return x.replace(src=srcs, arg=x.arg[:-2]+(new_x_arg_m2, new_x_arg_m1), tag=1) - -pm_lowerer = PatternMatcher([ - (UPat(Ops.WMMA, name="x"), fixup_wmma), - - # axis fixups for WMMA - (UPat((Ops.CONTRACT, Ops.UNROLL), name="x"), - lambda ctx,x: x.replace(tag=1, arg=tuple([(ctx.idxs[a].arg[0], sz) for a,sz in x.arg])) if x.tag is None else None), -])