opt transforms the ast into an optimized ast (#10900)

* opt transforms the ast into an optimized ast

* fix get_kernel order and to_function_name

* function_name property

* update docs

* copy from kernel.py

* improve docs

* ci didn't trigger?
This commit is contained in:
George Hotz
2025-06-22 09:41:26 -07:00
committed by GitHub
parent ffddf165f8
commit b09c47366f
8 changed files with 96 additions and 24 deletions

View File

@@ -1,6 +1,8 @@
# tinygrad directory layout
Listed in order of how they are processed
This explains the flow of a big graph down to programs.
Directories are listed in order of how they are processed.
---
@@ -20,7 +22,11 @@ Group UOps into kernels.
Transforms the ast into an optimized ast. This is where BEAM search and heuristics live.
When finished, this will just have a function that takes in the ast and returns the optimized ast.
::: tinygrad.opt.get_optimized_ast
options:
members: false
show_labels: false
show_source: false
---
@@ -38,10 +44,23 @@ Transform the optimized ast into a linearized list of UOps.
## tinygrad/renderer
Transform the linearized list of UOps into a program.
Transform the linearized list of UOps into a program, represented as a string.
::: tinygrad.renderer.Renderer
options:
members:
- render
show_labels: false
show_source: false
---
## tinygrad/engine
Abstracted high level interface to the runtimes.
::: tinygrad.engine.realize.get_program
options:
members: false
show_labels: false
show_source: false

View File

@@ -1,29 +1,51 @@
from typing import Optional, cast, Generator
import time, pprint
from dataclasses import dataclass, replace, field
from tinygrad.helpers import all_same, colored, getenv, DEBUG, GlobalCounters, ansilen, BEAM, NOOPT, all_int, CAPTURING, Metadata, TRACEMETA
from tinygrad.helpers import DEVECTORIZE, time_to_str, VALIDATE_WITH_CPU
from tinygrad.uop.ops import Ops, PatternMatcher, UOp, UPat, Variable, sym_infer
from tinygrad.helpers import all_same, colored, DEBUG, GlobalCounters, ansilen, BEAM, NOOPT, all_int, CAPTURING, Metadata, TRACEMETA
from tinygrad.helpers import DEVECTORIZE, time_to_str, VALIDATE_WITH_CPU, getenv
from tinygrad.uop.ops import Ops, PatternMatcher, UOp, UPat, Variable, sym_infer, graph_rewrite, print_uops, track_rewrites
from tinygrad.device import Device, Buffer
from tinygrad.renderer import Renderer, ProgramSpec, Estimates
from tinygrad.opt.kernel import Kernel
from tinygrad.opt.heuristic import hand_coded_optimizations
from tinygrad.engine.schedule import ScheduleItem
from tinygrad.opt import get_optimized_ast
from tinygrad.codegen import full_rewrite
from tinygrad.uop.spec import type_verify
# **************** Program Creation ****************
logkerns, logkerns_level = open(getenv("LOGKERNS", ""), "a") if getenv("LOGKERNS", "") else None, getenv("LOGKERNS_LEVEL", 1)
@track_rewrites(name=lambda _ast,_renderer,ret:ret.name)
def get_program(ast:UOp, renderer:Renderer) -> ProgramSpec:
k = Kernel(ast, opts=renderer)
if not NOOPT:
if not k.apply_tensor_cores(getenv("TC", 1)): k.apply_opts(hand_coded_optimizations(k))
if BEAM >= 1:
from tinygrad.opt.search import beam_search, bufs_from_lin
kb = Kernel(ast, opts=renderer)
rawbufs = bufs_from_lin(kb, allocate=False)
k = beam_search(kb, rawbufs, BEAM.value, bool(getenv("BEAM_ESTIMATE", 1)))
if logkerns is not None: logkerns.writelines([f"{(k.ast, k.applied_opts)}\n"])
return k.to_program()
"""
Transform an AST into a ProgramSpec. May trigger BEAM search.
Args:
ast: The Ops.SINK rooted AST
renderer: The renderer used to generate the code
Returns:
The ProgramSpec of the program.
"""
if getenv("VIZ"): graph_rewrite(ast, PatternMatcher([]), name="View Base AST")
modified_ast = get_optimized_ast(ast, renderer) if ast.arg is None else ast
if __debug__: type_verify(list(modified_ast.toposort()))
# linearize
try:
uops = full_rewrite(modified_ast, renderer)
except RuntimeError:
print("***** LINEARIZE FAILURE *****")
print(f"ast = {ast}")
print(f"opts = {modified_ast.arg.applied_opts}")
raise
assert uops[-1].op is Ops.SINK, "last uop must be sink"
# print and render
if DEBUG >= 6: print_uops(uops)
src = renderer.render(uops)
return ProgramSpec(uops[-1].arg.name, src, renderer.device, ast, uops,
global_size=[1,1,1] if renderer.has_local else None, local_size=[1,1,1] if renderer.has_local else None)
# **************** Runners ****************

29
tinygrad/opt/__init__.py Normal file
View File

@@ -0,0 +1,29 @@
# opt opinionatedly transforms an ast into an optimized ast using either heuristics or beam search
from tinygrad.opt.kernel import Kernel
from tinygrad.opt.heuristic import hand_coded_optimizations
from tinygrad.uop.ops import UOp
from tinygrad.helpers import NOOPT, BEAM, getenv
from tinygrad.renderer import Renderer
def get_optimized_ast(ast:UOp, renderer:Renderer) -> UOp:
"""
Optimize an AST based on heuristics or BEAM search.
Args:
ast: The Ops.SINK rooted AST
renderer: The renderer used to generate the code
Returns:
The Ops.SINK rooted AST transformed to apply the opts and with a KernelInfo in the arg.
"""
k = Kernel(ast, opts=renderer)
if not NOOPT:
if not k.apply_tensor_cores(getenv("TC", 1)): k.apply_opts(hand_coded_optimizations(k))
if BEAM >= 1:
from tinygrad.opt.search import beam_search, bufs_from_lin
kb = Kernel(ast, opts=renderer)
rawbufs = bufs_from_lin(kb, allocate=False)
k = beam_search(kb, rawbufs, BEAM.value, bool(getenv("BEAM_ESTIMATE", 1)))
return k.get_optimized_ast()

View File

@@ -457,7 +457,7 @@ class Kernel:
# otherwise we just replace the VIEW source
return ret.replace(src=(ret.src[0].replace(arg=st),)+ret.src[1:])
if op.op is Ops.SINK:
return ret.replace(arg = KernelInfo(to_function_name(self.name) if name_override is None else name_override,
return ret.replace(arg = KernelInfo(self.name if name_override is None else name_override,
self.local_dims, self.upcasted, self.dont_use_locals, tuple(self.applied_opts)))
if op.op is Ops.REDUCE_AXIS:
reduce_idx = len(self.bufs) + self.reduceops.index(op) * 2

View File

@@ -135,7 +135,7 @@ class CStyleLanguage(Renderer):
name = "test"
for u in uops:
if u.op is Ops.SINK:
if u.arg is not None: name = u.arg.name
if u.arg is not None: name = u.arg.function_name
continue
if u.op in (Ops.DEFINE_GLOBAL, Ops.DEFINE_VAR):
r[u] = f"data{u.arg}" if u.op is Ops.DEFINE_GLOBAL else u.arg[0]

View File

@@ -166,7 +166,7 @@ class LLVMRenderer(Renderer):
name = "test"
for u in uops:
if u.op is Ops.SINK:
if u.arg is not None: name = u.arg.name
if u.arg is not None: name = u.arg.function_name
continue
if u.op in (Ops.DEFINE_GLOBAL, Ops.DEFINE_VAR):
r[u] = f"%data{u.arg}" if u.op is Ops.DEFINE_GLOBAL else f"%{u.arg[0]}"

View File

@@ -177,7 +177,7 @@ class PTXRenderer(Renderer):
name = "test"
for u in uops:
if u.op is Ops.SINK:
if u.arg is not None: name = u.arg.name
if u.arg is not None: name = u.arg.function_name
continue
if u.op is Ops.VECTORIZE:
r[u] = [cast(str,r[x]) for x in u.src]

View File

@@ -6,7 +6,7 @@ from tinygrad.uop import Ops, GroupOp
from tinygrad.uop.mathtraits import MathTrait
from tinygrad.dtype import ConstType, ImageDType, dtypes, DType, truncate
from tinygrad.helpers import ContextVar, all_int, prod, getenv, all_same, Context, partition, temp, unwrap, T, argfix, Metadata, flatten
from tinygrad.helpers import PICKLE_BUFFERS, PROFILE, dedup, cdiv, cmod, diskcache_put
from tinygrad.helpers import PICKLE_BUFFERS, PROFILE, dedup, cdiv, cmod, diskcache_put, to_function_name
if TYPE_CHECKING:
from tinygrad.shape.shapetracker import ShapeTracker
from tinygrad.device import Buffer, MultiBuffer
@@ -519,6 +519,8 @@ class KernelInfo:
upcasted: int = 0 # count that are upcasted (this is remapping RANGE to UNROLL)
dont_use_locals: bool = False # don't use local indexing
applied_opts: tuple = tuple()
@property
def function_name(self): return to_function_name(self.name)
# ******** ops in python ********