opt transforms the ast into an optimized ast (#10900)

* opt transforms the ast into an optimized ast

* fix get_kernel order and to_function_name

* function_name property

* update docs

* copy from kernel.py

* improve docs

* ci didn't trigger?
This commit is contained in:
George Hotz
2025-06-22 09:41:26 -07:00
committed by GitHub
parent ffddf165f8
commit b09c47366f
8 changed files with 96 additions and 24 deletions

View File

@@ -1,6 +1,8 @@
# tinygrad directory layout # tinygrad directory layout
Listed in order of how they are processed This explains the flow of a big graph down to programs.
Directories are listed in order of how they are processed.
--- ---
@@ -20,7 +22,11 @@ Group UOps into kernels.
Transforms the ast into an optimized ast. This is where BEAM search and heuristics live. Transforms the ast into an optimized ast. This is where BEAM search and heuristics live.
When finished, this will just have a function that takes in the ast and returns the optimized ast. ::: tinygrad.opt.get_optimized_ast
options:
members: false
show_labels: false
show_source: false
--- ---
@@ -38,10 +44,23 @@ Transform the optimized ast into a linearized list of UOps.
## tinygrad/renderer ## tinygrad/renderer
Transform the linearized list of UOps into a program. Transform the linearized list of UOps into a program, represented as a string.
::: tinygrad.renderer.Renderer
options:
members:
- render
show_labels: false
show_source: false
--- ---
## tinygrad/engine ## tinygrad/engine
Abstracted high level interface to the runtimes. Abstracted high level interface to the runtimes.
::: tinygrad.engine.realize.get_program
options:
members: false
show_labels: false
show_source: false

View File

@@ -1,29 +1,51 @@
from typing import Optional, cast, Generator from typing import Optional, cast, Generator
import time, pprint import time, pprint
from dataclasses import dataclass, replace, field from dataclasses import dataclass, replace, field
from tinygrad.helpers import all_same, colored, getenv, DEBUG, GlobalCounters, ansilen, BEAM, NOOPT, all_int, CAPTURING, Metadata, TRACEMETA from tinygrad.helpers import all_same, colored, DEBUG, GlobalCounters, ansilen, BEAM, NOOPT, all_int, CAPTURING, Metadata, TRACEMETA
from tinygrad.helpers import DEVECTORIZE, time_to_str, VALIDATE_WITH_CPU from tinygrad.helpers import DEVECTORIZE, time_to_str, VALIDATE_WITH_CPU, getenv
from tinygrad.uop.ops import Ops, PatternMatcher, UOp, UPat, Variable, sym_infer from tinygrad.uop.ops import Ops, PatternMatcher, UOp, UPat, Variable, sym_infer, graph_rewrite, print_uops, track_rewrites
from tinygrad.device import Device, Buffer from tinygrad.device import Device, Buffer
from tinygrad.renderer import Renderer, ProgramSpec, Estimates from tinygrad.renderer import Renderer, ProgramSpec, Estimates
from tinygrad.opt.kernel import Kernel
from tinygrad.opt.heuristic import hand_coded_optimizations
from tinygrad.engine.schedule import ScheduleItem from tinygrad.engine.schedule import ScheduleItem
from tinygrad.opt import get_optimized_ast
from tinygrad.codegen import full_rewrite
from tinygrad.uop.spec import type_verify
# **************** Program Creation **************** # **************** Program Creation ****************
logkerns, logkerns_level = open(getenv("LOGKERNS", ""), "a") if getenv("LOGKERNS", "") else None, getenv("LOGKERNS_LEVEL", 1) @track_rewrites(name=lambda _ast,_renderer,ret:ret.name)
def get_program(ast:UOp, renderer:Renderer) -> ProgramSpec: def get_program(ast:UOp, renderer:Renderer) -> ProgramSpec:
k = Kernel(ast, opts=renderer) """
if not NOOPT: Transform an AST into a ProgramSpec. May trigger BEAM search.
if not k.apply_tensor_cores(getenv("TC", 1)): k.apply_opts(hand_coded_optimizations(k))
if BEAM >= 1: Args:
from tinygrad.opt.search import beam_search, bufs_from_lin ast: The Ops.SINK rooted AST
kb = Kernel(ast, opts=renderer) renderer: The renderer used to generate the code
rawbufs = bufs_from_lin(kb, allocate=False)
k = beam_search(kb, rawbufs, BEAM.value, bool(getenv("BEAM_ESTIMATE", 1))) Returns:
if logkerns is not None: logkerns.writelines([f"{(k.ast, k.applied_opts)}\n"]) The ProgramSpec of the program.
return k.to_program() """
if getenv("VIZ"): graph_rewrite(ast, PatternMatcher([]), name="View Base AST")
modified_ast = get_optimized_ast(ast, renderer) if ast.arg is None else ast
if __debug__: type_verify(list(modified_ast.toposort()))
# linearize
try:
uops = full_rewrite(modified_ast, renderer)
except RuntimeError:
print("***** LINEARIZE FAILURE *****")
print(f"ast = {ast}")
print(f"opts = {modified_ast.arg.applied_opts}")
raise
assert uops[-1].op is Ops.SINK, "last uop must be sink"
# print and render
if DEBUG >= 6: print_uops(uops)
src = renderer.render(uops)
return ProgramSpec(uops[-1].arg.name, src, renderer.device, ast, uops,
global_size=[1,1,1] if renderer.has_local else None, local_size=[1,1,1] if renderer.has_local else None)
# **************** Runners **************** # **************** Runners ****************

29
tinygrad/opt/__init__.py Normal file
View File

@@ -0,0 +1,29 @@
# opt opinionatedly transforms an ast into an optimized ast using either heuristics or beam search
from tinygrad.opt.kernel import Kernel
from tinygrad.opt.heuristic import hand_coded_optimizations
from tinygrad.uop.ops import UOp
from tinygrad.helpers import NOOPT, BEAM, getenv
from tinygrad.renderer import Renderer
def get_optimized_ast(ast:UOp, renderer:Renderer) -> UOp:
"""
Optimize an AST based on heuristics or BEAM search.
Args:
ast: The Ops.SINK rooted AST
renderer: The renderer used to generate the code
Returns:
The Ops.SINK rooted AST transformed to apply the opts and with a KernelInfo in the arg.
"""
k = Kernel(ast, opts=renderer)
if not NOOPT:
if not k.apply_tensor_cores(getenv("TC", 1)): k.apply_opts(hand_coded_optimizations(k))
if BEAM >= 1:
from tinygrad.opt.search import beam_search, bufs_from_lin
kb = Kernel(ast, opts=renderer)
rawbufs = bufs_from_lin(kb, allocate=False)
k = beam_search(kb, rawbufs, BEAM.value, bool(getenv("BEAM_ESTIMATE", 1)))
return k.get_optimized_ast()

View File

@@ -457,7 +457,7 @@ class Kernel:
# otherwise we just replace the VIEW source # otherwise we just replace the VIEW source
return ret.replace(src=(ret.src[0].replace(arg=st),)+ret.src[1:]) return ret.replace(src=(ret.src[0].replace(arg=st),)+ret.src[1:])
if op.op is Ops.SINK: if op.op is Ops.SINK:
return ret.replace(arg = KernelInfo(to_function_name(self.name) if name_override is None else name_override, return ret.replace(arg = KernelInfo(self.name if name_override is None else name_override,
self.local_dims, self.upcasted, self.dont_use_locals, tuple(self.applied_opts))) self.local_dims, self.upcasted, self.dont_use_locals, tuple(self.applied_opts)))
if op.op is Ops.REDUCE_AXIS: if op.op is Ops.REDUCE_AXIS:
reduce_idx = len(self.bufs) + self.reduceops.index(op) * 2 reduce_idx = len(self.bufs) + self.reduceops.index(op) * 2

View File

@@ -135,7 +135,7 @@ class CStyleLanguage(Renderer):
name = "test" name = "test"
for u in uops: for u in uops:
if u.op is Ops.SINK: if u.op is Ops.SINK:
if u.arg is not None: name = u.arg.name if u.arg is not None: name = u.arg.function_name
continue continue
if u.op in (Ops.DEFINE_GLOBAL, Ops.DEFINE_VAR): if u.op in (Ops.DEFINE_GLOBAL, Ops.DEFINE_VAR):
r[u] = f"data{u.arg}" if u.op is Ops.DEFINE_GLOBAL else u.arg[0] r[u] = f"data{u.arg}" if u.op is Ops.DEFINE_GLOBAL else u.arg[0]

View File

@@ -166,7 +166,7 @@ class LLVMRenderer(Renderer):
name = "test" name = "test"
for u in uops: for u in uops:
if u.op is Ops.SINK: if u.op is Ops.SINK:
if u.arg is not None: name = u.arg.name if u.arg is not None: name = u.arg.function_name
continue continue
if u.op in (Ops.DEFINE_GLOBAL, Ops.DEFINE_VAR): if u.op in (Ops.DEFINE_GLOBAL, Ops.DEFINE_VAR):
r[u] = f"%data{u.arg}" if u.op is Ops.DEFINE_GLOBAL else f"%{u.arg[0]}" r[u] = f"%data{u.arg}" if u.op is Ops.DEFINE_GLOBAL else f"%{u.arg[0]}"

View File

@@ -177,7 +177,7 @@ class PTXRenderer(Renderer):
name = "test" name = "test"
for u in uops: for u in uops:
if u.op is Ops.SINK: if u.op is Ops.SINK:
if u.arg is not None: name = u.arg.name if u.arg is not None: name = u.arg.function_name
continue continue
if u.op is Ops.VECTORIZE: if u.op is Ops.VECTORIZE:
r[u] = [cast(str,r[x]) for x in u.src] r[u] = [cast(str,r[x]) for x in u.src]

View File

@@ -6,7 +6,7 @@ from tinygrad.uop import Ops, GroupOp
from tinygrad.uop.mathtraits import MathTrait from tinygrad.uop.mathtraits import MathTrait
from tinygrad.dtype import ConstType, ImageDType, dtypes, DType, truncate from tinygrad.dtype import ConstType, ImageDType, dtypes, DType, truncate
from tinygrad.helpers import ContextVar, all_int, prod, getenv, all_same, Context, partition, temp, unwrap, T, argfix, Metadata, flatten from tinygrad.helpers import ContextVar, all_int, prod, getenv, all_same, Context, partition, temp, unwrap, T, argfix, Metadata, flatten
from tinygrad.helpers import PICKLE_BUFFERS, PROFILE, dedup, cdiv, cmod, diskcache_put from tinygrad.helpers import PICKLE_BUFFERS, PROFILE, dedup, cdiv, cmod, diskcache_put, to_function_name
if TYPE_CHECKING: if TYPE_CHECKING:
from tinygrad.shape.shapetracker import ShapeTracker from tinygrad.shape.shapetracker import ShapeTracker
from tinygrad.device import Buffer, MultiBuffer from tinygrad.device import Buffer, MultiBuffer
@@ -519,6 +519,8 @@ class KernelInfo:
upcasted: int = 0 # count that are upcasted (this is remapping RANGE to UNROLL) upcasted: int = 0 # count that are upcasted (this is remapping RANGE to UNROLL)
dont_use_locals: bool = False # don't use local indexing dont_use_locals: bool = False # don't use local indexing
applied_opts: tuple = tuple() applied_opts: tuple = tuple()
@property
def function_name(self): return to_function_name(self.name)
# ******** ops in python ******** # ******** ops in python ********