diff --git a/tinygrad/codegen/__init__.py b/tinygrad/codegen/__init__.py index 24456a2a4d..9ee02e682b 100644 --- a/tinygrad/codegen/__init__.py +++ b/tinygrad/codegen/__init__.py @@ -16,6 +16,7 @@ from tinygrad.codegen.devectorizer import load_store_folding, load_store_indexin ReduceContext, correct_load_store, pm_render from tinygrad.codegen.optional import get_late_rewrite_patterns from tinygrad.codegen.linearize import block_create, pm_blockend_merge, block_merge, pm_finalize, BlockContext +from tinygrad.opt import pm_optimize @dataclass class RewriteStep: @@ -42,6 +43,10 @@ def get_rewrites_for_renderer(opts:Renderer, linearizer:bool=True) -> list[Rewri def _get_rewrites_for_renderer(opts:Renderer, linearizer:bool, _QUANTIZE, _DEVECTORIZE, _TRANSCENDENTAL) -> list[RewriteStep]: # ** lowerer (rewrite_shapetracker_with_index) ** ret: list[RewriteStep] = [] + + # this is kernel.py + ret.append(RewriteStep(pm_optimize, ctx=lambda _: opts, name="optimize ast")) + if _QUANTIZE and opts.device in {"CPU", "DSP"}: ret.append(RewriteStep(pm_quant, name="quantize")) ret.append(RewriteStep(pm_lowerer, get_index, name="lowerer", bottom_up=True)) diff --git a/tinygrad/engine/realize.py b/tinygrad/engine/realize.py index f4cd06c5bf..388f7d33e8 100644 --- a/tinygrad/engine/realize.py +++ b/tinygrad/engine/realize.py @@ -7,9 +7,7 @@ from tinygrad.uop.ops import Ops, PatternMatcher, UOp, UPat, Variable, sym_infer from tinygrad.device import Device, Buffer from tinygrad.renderer import Renderer, ProgramSpec, Estimates from tinygrad.engine.schedule import ScheduleItem -from tinygrad.opt import get_optimized_ast from tinygrad.codegen import full_rewrite -from tinygrad.uop.spec import type_verify # **************** Program Creation **************** @@ -27,16 +25,13 @@ def get_program(ast:UOp, renderer:Renderer) -> ProgramSpec: """ if getenv("VIZ"): graph_rewrite(ast, PatternMatcher([]), name="View Base AST") - modified_ast = get_optimized_ast(ast, renderer) if ast.arg is None or ast.arg.opts_to_apply is not None else ast - if __debug__: type_verify(list(modified_ast.toposort())) # linearize try: - uops = full_rewrite(modified_ast, renderer) + uops = full_rewrite(ast, renderer) except RuntimeError: print("***** LINEARIZE FAILURE *****") print(f"ast = {ast}") - print(f"opts = {modified_ast.arg.applied_opts}") raise assert uops[-1].op is Ops.SINK, "last uop must be sink" diff --git a/tinygrad/opt/__init__.py b/tinygrad/opt/__init__.py index 934b9f0749..6128b21a5b 100644 --- a/tinygrad/opt/__init__.py +++ b/tinygrad/opt/__init__.py @@ -2,9 +2,10 @@ from tinygrad.opt.kernel import Kernel from tinygrad.opt.heuristic import hand_coded_optimizations -from tinygrad.uop.ops import UOp +from tinygrad.uop.ops import UOp, PatternMatcher, UPat, Ops from tinygrad.helpers import NOOPT, BEAM, USE_TC, getenv from tinygrad.renderer import Renderer +from tinygrad.uop.spec import type_verify def get_optimized_ast(ast:UOp, renderer:Renderer) -> UOp: """ @@ -27,4 +28,11 @@ def get_optimized_ast(ast:UOp, renderer:Renderer) -> UOp: kb = Kernel(ast, opts=renderer) rawbufs = bufs_from_lin(kb, allocate=False) k = beam_search(kb, rawbufs, BEAM.value, bool(getenv("BEAM_ESTIMATE", 1))) - return k.get_optimized_ast() + ret = k.get_optimized_ast() + if __debug__: type_verify(list(ret.toposort())) + return ret + +pm_optimize = PatternMatcher([ + (UPat(Ops.SINK, name="ast"), lambda ctx,ast: + get_optimized_ast(ast, ctx) if (ast.arg is None or ast.arg.opts_to_apply is not None) and ast.src[0].st is not None else None), +])