get_program refactor (#13816)

* get_program refactor

* fix docs

* cleanup
This commit is contained in:
George Hotz
2025-12-23 16:44:46 -05:00
committed by GitHub
parent 3d3c5b2fb9
commit 8eab6175ee
5 changed files with 67 additions and 69 deletions

View File

@@ -28,7 +28,7 @@ Transforms the ast into an optimized ast. This is where BEAM search and heuristi
Transform the optimized ast into a linearized and rendered program.
::: tinygrad.codegen.full_rewrite_to_program
::: tinygrad.codegen.get_program
options:
members: false
show_labels: false

View File

@@ -1,11 +1,12 @@
from typing import cast
import itertools
from tinygrad.helpers import DEVECTORIZE, TRANSCENDENTAL, SPEC, DEBUG
from tinygrad.uop.ops import PatternMatcher, graph_rewrite, UOp, pm_lower_index_dtype, Ops, UPat, print_uops
from tinygrad.helpers import DEVECTORIZE, TRANSCENDENTAL, SPEC, DEBUG, getenv, TracingKey
from tinygrad.uop.ops import PatternMatcher, graph_rewrite, UOp, pm_lower_index_dtype, Ops, UPat, track_rewrites, KernelInfo, pyrender
from tinygrad.uop.spec import type_verify, program_spec, kernel_spec
from tinygrad.renderer import Renderer
from tinygrad.renderer import Renderer, ProgramSpec
from tinygrad.dtype import dtypes, PtrDType
from tinygrad.helpers import panic
from tinygrad.codegen.opt import Opt
# import all pattern matchers here
from tinygrad.codegen.gpudims import pm_add_gpudims
@@ -28,6 +29,8 @@ pm_syntactic_sugar = PatternMatcher([
def full_rewrite_to_sink(sink:UOp, ren:Renderer|None=None, optimize:bool=True) -> UOp:
if ren is None: ren = Renderer()
if getenv("VIZ"): graph_rewrite(sink, PatternMatcher([]), name="View Base AST")
if DEBUG >= 5: print(pyrender(sink))
if SPEC: type_verify(sink, kernel_spec)
# preprocess
@@ -137,11 +140,29 @@ pm_to_program = PatternMatcher([
(UPat(Ops.PROGRAM, src=(UPat(), UPat(Ops.DEVICE), UPat(Ops.LINEAR, name="lin")), name="prg"), do_render),
])
def full_rewrite_to_program(sink:UOp, ren:Renderer) -> UOp:
from tinygrad.uop.ops import KernelInfo
if sink.arg is None: sink = sink.replace(arg=KernelInfo())
full_sink = full_rewrite_to_sink(sink, ren, optimize=sink.tag is None)
sink = UOp(Ops.PROGRAM, src=(full_sink, UOp(Ops.DEVICE, arg=ren.device)))
prg = graph_rewrite(sink, pm_to_program, ctx=ren, name="linearize/render")
if DEBUG >= 6: print_uops(list(prg.src[2].src)) # LINEAR is src[2]
return prg
@track_rewrites(name=lambda *args,ret,**kwargs: TracingKey(ret.name, (ret.function_name, ret.ast), ret=ret), replay=True)
def get_program(ast:UOp, renderer:Renderer, opts:list[Opt]|None=None) -> ProgramSpec:
"""
Transform an AST into a ProgramSpec. May trigger BEAM search.
Args:
ast: The Ops.SINK rooted AST
renderer: The renderer used to generate the code
Returns:
The ProgramSpec of the program.
"""
# fix up KernelInfo
if opts is not None:
assert ast.arg is None, "can't apply opts if sink has an arg"
ast = ast.replace(arg=KernelInfo(opts_to_apply=tuple(opts)))
if ast.arg is None: ast = ast.replace(arg=KernelInfo())
# rewrite to prg
full_sink = full_rewrite_to_sink(ast, renderer, optimize=ast.tag is None)
prg = UOp(Ops.PROGRAM, src=(full_sink, UOp(Ops.DEVICE, arg=renderer.device)))
prg = graph_rewrite(prg, pm_to_program, ctx=renderer, name="linearize/render")
# create the ProgramSpec
return ProgramSpec.from_uop(prg)

View File

@@ -6,7 +6,8 @@ from tinygrad.helpers import prod, flatten, DEBUG, CACHELEVEL, diskcache_get, di
from tinygrad.helpers import IGNORE_BEAM_CACHE
from tinygrad.codegen.opt import Opt, OptOps, KernelOptError
from tinygrad.tensor import Tensor
from tinygrad.engine.realize import CompiledRunner, get_program
from tinygrad.engine.realize import CompiledRunner
from tinygrad.codegen import get_program
from tinygrad.renderer import ProgramSpec
from tinygrad.codegen.opt.postrange import Scheduler

View File

@@ -2,38 +2,12 @@ from typing import cast, Callable
import time, pprint, random, itertools, math
from dataclasses import dataclass, replace, field
from tinygrad.helpers import all_same, colored, DEBUG, GlobalCounters, ansilen, BEAM, NOOPT, all_int, CAPTURING, Metadata, TRACEMETA, TracingKey
from tinygrad.helpers import DEVECTORIZE, time_to_str, VALIDATE_WITH_CPU, getenv, cpu_profile, PROFILE, ProfilePointEvent, cpu_events, prod, Context
from tinygrad.helpers import DEVECTORIZE, time_to_str, VALIDATE_WITH_CPU, cpu_profile, PROFILE, ProfilePointEvent, cpu_events, prod, Context
from tinygrad.helpers import unwrap
from tinygrad.uop.ops import Ops, PatternMatcher, UOp, UPat, sym_infer, graph_rewrite, track_rewrites, KernelInfo, pyrender
from tinygrad.uop.ops import Ops, PatternMatcher, UOp, UPat, sym_infer
from tinygrad.device import Device, Buffer
from tinygrad.renderer import Renderer, ProgramSpec, Estimates
from tinygrad.codegen import full_rewrite_to_program
from tinygrad.codegen.opt import Opt
# **************** Program Creation ****************
@track_rewrites(name=lambda *args,ret,**kwargs: TracingKey(ret.name, (ret.function_name, ret.ast), ret=ret), replay=True)
def get_program(ast:UOp, renderer:Renderer, opts:list[Opt]|None=None) -> ProgramSpec:
"""
Transform an AST into a ProgramSpec. May trigger BEAM search.
Args:
ast: The Ops.SINK rooted AST
renderer: The renderer used to generate the code
Returns:
The ProgramSpec of the program.
"""
# linearize
if opts is not None:
assert ast.arg is None, "can't apply opts if sink has an arg"
ast = ast.replace(arg=KernelInfo(opts_to_apply=tuple(opts)))
if getenv("VIZ"): graph_rewrite(ast, PatternMatcher([]), name="View Base AST")
if DEBUG >= 5: print(pyrender(ast))
return ProgramSpec.from_uop(full_rewrite_to_program(ast, renderer))
from tinygrad.renderer import ProgramSpec, Estimates
from tinygrad.codegen import get_program
# **************** Runners ****************

View File

@@ -2,8 +2,8 @@ from __future__ import annotations
from typing import Callable, cast
import functools
from dataclasses import dataclass, field
from tinygrad.helpers import to_function_name, dedup, prod
from tinygrad.uop.ops import Ops, UOp, sym_infer, sint, Variable, ssimplify, GroupOp, PatternMatcher
from tinygrad.helpers import to_function_name, dedup, prod, DEBUG
from tinygrad.uop.ops import Ops, UOp, sym_infer, sint, Variable, ssimplify, GroupOp, PatternMatcher, print_uops
from tinygrad.dtype import AddrSpace, PtrDType
from tinygrad.codegen.opt.tc import TensorCore
from tinygrad.codegen.opt import Opt
@@ -65,35 +65,13 @@ class ProgramSpec:
ast:UOp # save the base ast (this is method cache key)
uops:list[UOp]|None=None
# filled in from uops (if we have uops)
# filled in from uops (via from_uop)
global_size:list[int]=field(default_factory=lambda: [1,1,1])
local_size:list[int]|None=None
vars:list[Variable]=field(default_factory=list)
globals:list[int]=field(default_factory=list)
outs:list[int]=field(default_factory=list)
ins:list[int]=field(default_factory=list)
_ran_post_init:bool=False # NOTE: this is needed if you call replace on the Program
def __post_init__(self):
if not self._ran_post_init and self.uops is not None:
# single pass through the uops
for u in self.uops:
if u.op is Ops.DEFINE_VAR: self.vars.append(u)
if u.op is Ops.DEFINE_GLOBAL: self.globals.append(u.arg)
if u.op in (Ops.STORE, Ops.LOAD):
if (idx:=u.src[0]).op is Ops.INDEX or (u.src[0].op is Ops.CAST and (idx:=u.src[0].src[0]).op is Ops.INDEX):
if (buf:=idx.src[0]).op is Ops.DEFINE_GLOBAL: (self.outs if u.op is Ops.STORE else self.ins).append(buf.arg)
# TODO: can else happen?
if u.op is Ops.SPECIAL:
# NOTE: you have to set local_size and global_size to the base [1,1,1] outside this
if u.arg[0] == 'i': self.local_size = None
special_size = self.local_size if u.arg[0] == 'l' else self.global_size
# TODO: this cast is wrong, u.src[0].ssimplify() can be sint
if special_size is not None: special_size[int(u.arg[-1])] = cast(int, u.src[0].ssimplify())
self.vars = sorted(self.vars, key=lambda v: v.arg)
self.outs = sorted(dedup(self.outs))
self.ins = sorted(dedup(self.ins))
self._ran_post_init = True
@functools.cached_property
def estimates(self) -> Estimates:
@@ -119,7 +97,31 @@ class ProgramSpec:
assert prg.op is Ops.PROGRAM, f"expected PROGRAM, got {prg.op}"
# SINK/DEVICE/LINEAR/SOURCE
sink, device, linear, source = prg.src
return ProgramSpec(sink.arg.name, source.arg, device.arg, sink, list(linear.src), global_size=[1,1,1], local_size=[1,1,1])
uops = list(linear.src)
if DEBUG >= 6: print_uops(uops) # LINEAR is src[2]
# single pass through the uops to extract metadata
_vars: list[Variable] = []
_globals: list[int] = []
outs: list[int] = []
ins: list[int] = []
global_size: list[int] = [1, 1, 1]
local_size: list[int]|None = [1, 1, 1]
for u in uops:
if u.op is Ops.DEFINE_VAR: _vars.append(u)
if u.op is Ops.DEFINE_GLOBAL: _globals.append(u.arg)
if u.op in (Ops.STORE, Ops.LOAD):
if (idx:=u.src[0]).op is Ops.INDEX or (u.src[0].op is Ops.CAST and (idx:=u.src[0].src[0]).op is Ops.INDEX):
if (buf:=idx.src[0]).op is Ops.DEFINE_GLOBAL: (outs if u.op is Ops.STORE else ins).append(buf.arg)
# TODO: can else happen?
if u.op is Ops.SPECIAL:
if u.arg[0] == 'i': local_size = None
special_size = local_size if u.arg[0] == 'l' else global_size
# TODO: this cast is wrong, u.src[0].ssimplify() can be sint
if special_size is not None: special_size[int(u.arg[-1])] = cast(int, u.src[0].ssimplify())
return ProgramSpec(sink.arg.name, source.arg, device.arg, sink, uops, global_size, local_size,
sorted(_vars, key=lambda v: v.arg), sorted(dedup(_globals)), sorted(dedup(outs)), sorted(dedup(ins)))
class Renderer:
device: str = ""