optimize in rewrite, try 2 (#11518)

* changes

* fix test uops

* optimize in rewrite, try 2
This commit is contained in:
George Hotz
2025-08-05 15:52:53 -07:00
committed by GitHub
parent 07b0df0d86
commit b39f43c46a
3 changed files with 16 additions and 8 deletions

View File

@@ -16,6 +16,7 @@ from tinygrad.codegen.devectorizer import load_store_folding, load_store_indexin
ReduceContext, correct_load_store, pm_render
from tinygrad.codegen.optional import get_late_rewrite_patterns
from tinygrad.codegen.linearize import block_create, pm_blockend_merge, block_merge, pm_finalize, BlockContext
from tinygrad.opt import pm_optimize
@dataclass
class RewriteStep:
@@ -42,6 +43,10 @@ def get_rewrites_for_renderer(opts:Renderer, linearizer:bool=True) -> list[Rewri
def _get_rewrites_for_renderer(opts:Renderer, linearizer:bool, _QUANTIZE, _DEVECTORIZE, _TRANSCENDENTAL) -> list[RewriteStep]:
# ** lowerer (rewrite_shapetracker_with_index) **
ret: list[RewriteStep] = []
# this is kernel.py
ret.append(RewriteStep(pm_optimize, ctx=lambda _: opts, name="optimize ast"))
if _QUANTIZE and opts.device in {"CPU", "DSP"}: ret.append(RewriteStep(pm_quant, name="quantize"))
ret.append(RewriteStep(pm_lowerer, get_index, name="lowerer", bottom_up=True))

View File

@@ -7,9 +7,7 @@ from tinygrad.uop.ops import Ops, PatternMatcher, UOp, UPat, Variable, sym_infer
from tinygrad.device import Device, Buffer
from tinygrad.renderer import Renderer, ProgramSpec, Estimates
from tinygrad.engine.schedule import ScheduleItem
from tinygrad.opt import get_optimized_ast
from tinygrad.codegen import full_rewrite
from tinygrad.uop.spec import type_verify
# **************** Program Creation ****************
@@ -27,16 +25,13 @@ def get_program(ast:UOp, renderer:Renderer) -> ProgramSpec:
"""
if getenv("VIZ"): graph_rewrite(ast, PatternMatcher([]), name="View Base AST")
modified_ast = get_optimized_ast(ast, renderer) if ast.arg is None or ast.arg.opts_to_apply is not None else ast
if __debug__: type_verify(list(modified_ast.toposort()))
# linearize
try:
uops = full_rewrite(modified_ast, renderer)
uops = full_rewrite(ast, renderer)
except RuntimeError:
print("***** LINEARIZE FAILURE *****")
print(f"ast = {ast}")
print(f"opts = {modified_ast.arg.applied_opts}")
raise
assert uops[-1].op is Ops.SINK, "last uop must be sink"

View File

@@ -2,9 +2,10 @@
from tinygrad.opt.kernel import Kernel
from tinygrad.opt.heuristic import hand_coded_optimizations
from tinygrad.uop.ops import UOp
from tinygrad.uop.ops import UOp, PatternMatcher, UPat, Ops
from tinygrad.helpers import NOOPT, BEAM, USE_TC, getenv
from tinygrad.renderer import Renderer
from tinygrad.uop.spec import type_verify
def get_optimized_ast(ast:UOp, renderer:Renderer) -> UOp:
"""
@@ -27,4 +28,11 @@ def get_optimized_ast(ast:UOp, renderer:Renderer) -> UOp:
kb = Kernel(ast, opts=renderer)
rawbufs = bufs_from_lin(kb, allocate=False)
k = beam_search(kb, rawbufs, BEAM.value, bool(getenv("BEAM_ESTIMATE", 1)))
return k.get_optimized_ast()
ret = k.get_optimized_ast()
if __debug__: type_verify(list(ret.toposort()))
return ret
pm_optimize = PatternMatcher([
(UPat(Ops.SINK, name="ast"), lambda ctx,ast:
get_optimized_ast(ast, ctx) if (ast.arg is None or ast.arg.opts_to_apply is not None) and ast.src[0].st is not None else None),
])