mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-22 13:28:06 -05:00
optimize in rewrite, try 2 (#11518)
* changes * fix test uops * optimize in rewrite, try 2
This commit is contained in:
@@ -16,6 +16,7 @@ from tinygrad.codegen.devectorizer import load_store_folding, load_store_indexin
|
||||
ReduceContext, correct_load_store, pm_render
|
||||
from tinygrad.codegen.optional import get_late_rewrite_patterns
|
||||
from tinygrad.codegen.linearize import block_create, pm_blockend_merge, block_merge, pm_finalize, BlockContext
|
||||
from tinygrad.opt import pm_optimize
|
||||
|
||||
@dataclass
|
||||
class RewriteStep:
|
||||
@@ -42,6 +43,10 @@ def get_rewrites_for_renderer(opts:Renderer, linearizer:bool=True) -> list[Rewri
|
||||
def _get_rewrites_for_renderer(opts:Renderer, linearizer:bool, _QUANTIZE, _DEVECTORIZE, _TRANSCENDENTAL) -> list[RewriteStep]:
|
||||
# ** lowerer (rewrite_shapetracker_with_index) **
|
||||
ret: list[RewriteStep] = []
|
||||
|
||||
# this is kernel.py
|
||||
ret.append(RewriteStep(pm_optimize, ctx=lambda _: opts, name="optimize ast"))
|
||||
|
||||
if _QUANTIZE and opts.device in {"CPU", "DSP"}: ret.append(RewriteStep(pm_quant, name="quantize"))
|
||||
ret.append(RewriteStep(pm_lowerer, get_index, name="lowerer", bottom_up=True))
|
||||
|
||||
|
||||
@@ -7,9 +7,7 @@ from tinygrad.uop.ops import Ops, PatternMatcher, UOp, UPat, Variable, sym_infer
|
||||
from tinygrad.device import Device, Buffer
|
||||
from tinygrad.renderer import Renderer, ProgramSpec, Estimates
|
||||
from tinygrad.engine.schedule import ScheduleItem
|
||||
from tinygrad.opt import get_optimized_ast
|
||||
from tinygrad.codegen import full_rewrite
|
||||
from tinygrad.uop.spec import type_verify
|
||||
|
||||
# **************** Program Creation ****************
|
||||
|
||||
@@ -27,16 +25,13 @@ def get_program(ast:UOp, renderer:Renderer) -> ProgramSpec:
|
||||
"""
|
||||
|
||||
if getenv("VIZ"): graph_rewrite(ast, PatternMatcher([]), name="View Base AST")
|
||||
modified_ast = get_optimized_ast(ast, renderer) if ast.arg is None or ast.arg.opts_to_apply is not None else ast
|
||||
if __debug__: type_verify(list(modified_ast.toposort()))
|
||||
|
||||
# linearize
|
||||
try:
|
||||
uops = full_rewrite(modified_ast, renderer)
|
||||
uops = full_rewrite(ast, renderer)
|
||||
except RuntimeError:
|
||||
print("***** LINEARIZE FAILURE *****")
|
||||
print(f"ast = {ast}")
|
||||
print(f"opts = {modified_ast.arg.applied_opts}")
|
||||
raise
|
||||
assert uops[-1].op is Ops.SINK, "last uop must be sink"
|
||||
|
||||
|
||||
@@ -2,9 +2,10 @@
|
||||
|
||||
from tinygrad.opt.kernel import Kernel
|
||||
from tinygrad.opt.heuristic import hand_coded_optimizations
|
||||
from tinygrad.uop.ops import UOp
|
||||
from tinygrad.uop.ops import UOp, PatternMatcher, UPat, Ops
|
||||
from tinygrad.helpers import NOOPT, BEAM, USE_TC, getenv
|
||||
from tinygrad.renderer import Renderer
|
||||
from tinygrad.uop.spec import type_verify
|
||||
|
||||
def get_optimized_ast(ast:UOp, renderer:Renderer) -> UOp:
|
||||
"""
|
||||
@@ -27,4 +28,11 @@ def get_optimized_ast(ast:UOp, renderer:Renderer) -> UOp:
|
||||
kb = Kernel(ast, opts=renderer)
|
||||
rawbufs = bufs_from_lin(kb, allocate=False)
|
||||
k = beam_search(kb, rawbufs, BEAM.value, bool(getenv("BEAM_ESTIMATE", 1)))
|
||||
return k.get_optimized_ast()
|
||||
ret = k.get_optimized_ast()
|
||||
if __debug__: type_verify(list(ret.toposort()))
|
||||
return ret
|
||||
|
||||
pm_optimize = PatternMatcher([
|
||||
(UPat(Ops.SINK, name="ast"), lambda ctx,ast:
|
||||
get_optimized_ast(ast, ctx) if (ast.arg is None or ast.arg.opts_to_apply is not None) and ast.src[0].st is not None else None),
|
||||
])
|
||||
|
||||
Reference in New Issue
Block a user