mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-08 22:48:25 -05:00
get_program refactor (#13816)
* get_program refactor * fix docs * cleanup
This commit is contained in:
@@ -28,7 +28,7 @@ Transforms the ast into an optimized ast. This is where BEAM search and heuristi
|
||||
|
||||
Transform the optimized ast into a linearized and rendered program.
|
||||
|
||||
::: tinygrad.codegen.full_rewrite_to_program
|
||||
::: tinygrad.codegen.get_program
|
||||
options:
|
||||
members: false
|
||||
show_labels: false
|
||||
|
||||
@@ -1,11 +1,12 @@
|
||||
from typing import cast
|
||||
import itertools
|
||||
from tinygrad.helpers import DEVECTORIZE, TRANSCENDENTAL, SPEC, DEBUG
|
||||
from tinygrad.uop.ops import PatternMatcher, graph_rewrite, UOp, pm_lower_index_dtype, Ops, UPat, print_uops
|
||||
from tinygrad.helpers import DEVECTORIZE, TRANSCENDENTAL, SPEC, DEBUG, getenv, TracingKey
|
||||
from tinygrad.uop.ops import PatternMatcher, graph_rewrite, UOp, pm_lower_index_dtype, Ops, UPat, track_rewrites, KernelInfo, pyrender
|
||||
from tinygrad.uop.spec import type_verify, program_spec, kernel_spec
|
||||
from tinygrad.renderer import Renderer
|
||||
from tinygrad.renderer import Renderer, ProgramSpec
|
||||
from tinygrad.dtype import dtypes, PtrDType
|
||||
from tinygrad.helpers import panic
|
||||
from tinygrad.codegen.opt import Opt
|
||||
|
||||
# import all pattern matchers here
|
||||
from tinygrad.codegen.gpudims import pm_add_gpudims
|
||||
@@ -28,6 +29,8 @@ pm_syntactic_sugar = PatternMatcher([
|
||||
def full_rewrite_to_sink(sink:UOp, ren:Renderer|None=None, optimize:bool=True) -> UOp:
|
||||
if ren is None: ren = Renderer()
|
||||
|
||||
if getenv("VIZ"): graph_rewrite(sink, PatternMatcher([]), name="View Base AST")
|
||||
if DEBUG >= 5: print(pyrender(sink))
|
||||
if SPEC: type_verify(sink, kernel_spec)
|
||||
|
||||
# preprocess
|
||||
@@ -137,11 +140,29 @@ pm_to_program = PatternMatcher([
|
||||
(UPat(Ops.PROGRAM, src=(UPat(), UPat(Ops.DEVICE), UPat(Ops.LINEAR, name="lin")), name="prg"), do_render),
|
||||
])
|
||||
|
||||
def full_rewrite_to_program(sink:UOp, ren:Renderer) -> UOp:
|
||||
from tinygrad.uop.ops import KernelInfo
|
||||
if sink.arg is None: sink = sink.replace(arg=KernelInfo())
|
||||
full_sink = full_rewrite_to_sink(sink, ren, optimize=sink.tag is None)
|
||||
sink = UOp(Ops.PROGRAM, src=(full_sink, UOp(Ops.DEVICE, arg=ren.device)))
|
||||
prg = graph_rewrite(sink, pm_to_program, ctx=ren, name="linearize/render")
|
||||
if DEBUG >= 6: print_uops(list(prg.src[2].src)) # LINEAR is src[2]
|
||||
return prg
|
||||
@track_rewrites(name=lambda *args,ret,**kwargs: TracingKey(ret.name, (ret.function_name, ret.ast), ret=ret), replay=True)
|
||||
def get_program(ast:UOp, renderer:Renderer, opts:list[Opt]|None=None) -> ProgramSpec:
|
||||
"""
|
||||
Transform an AST into a ProgramSpec. May trigger BEAM search.
|
||||
|
||||
Args:
|
||||
ast: The Ops.SINK rooted AST
|
||||
renderer: The renderer used to generate the code
|
||||
|
||||
Returns:
|
||||
The ProgramSpec of the program.
|
||||
"""
|
||||
|
||||
# fix up KernelInfo
|
||||
if opts is not None:
|
||||
assert ast.arg is None, "can't apply opts if sink has an arg"
|
||||
ast = ast.replace(arg=KernelInfo(opts_to_apply=tuple(opts)))
|
||||
if ast.arg is None: ast = ast.replace(arg=KernelInfo())
|
||||
|
||||
# rewrite to prg
|
||||
full_sink = full_rewrite_to_sink(ast, renderer, optimize=ast.tag is None)
|
||||
prg = UOp(Ops.PROGRAM, src=(full_sink, UOp(Ops.DEVICE, arg=renderer.device)))
|
||||
prg = graph_rewrite(prg, pm_to_program, ctx=renderer, name="linearize/render")
|
||||
|
||||
# create the ProgramSpec
|
||||
return ProgramSpec.from_uop(prg)
|
||||
|
||||
@@ -6,7 +6,8 @@ from tinygrad.helpers import prod, flatten, DEBUG, CACHELEVEL, diskcache_get, di
|
||||
from tinygrad.helpers import IGNORE_BEAM_CACHE
|
||||
from tinygrad.codegen.opt import Opt, OptOps, KernelOptError
|
||||
from tinygrad.tensor import Tensor
|
||||
from tinygrad.engine.realize import CompiledRunner, get_program
|
||||
from tinygrad.engine.realize import CompiledRunner
|
||||
from tinygrad.codegen import get_program
|
||||
from tinygrad.renderer import ProgramSpec
|
||||
from tinygrad.codegen.opt.postrange import Scheduler
|
||||
|
||||
|
||||
@@ -2,38 +2,12 @@ from typing import cast, Callable
|
||||
import time, pprint, random, itertools, math
|
||||
from dataclasses import dataclass, replace, field
|
||||
from tinygrad.helpers import all_same, colored, DEBUG, GlobalCounters, ansilen, BEAM, NOOPT, all_int, CAPTURING, Metadata, TRACEMETA, TracingKey
|
||||
from tinygrad.helpers import DEVECTORIZE, time_to_str, VALIDATE_WITH_CPU, getenv, cpu_profile, PROFILE, ProfilePointEvent, cpu_events, prod, Context
|
||||
from tinygrad.helpers import DEVECTORIZE, time_to_str, VALIDATE_WITH_CPU, cpu_profile, PROFILE, ProfilePointEvent, cpu_events, prod, Context
|
||||
from tinygrad.helpers import unwrap
|
||||
from tinygrad.uop.ops import Ops, PatternMatcher, UOp, UPat, sym_infer, graph_rewrite, track_rewrites, KernelInfo, pyrender
|
||||
from tinygrad.uop.ops import Ops, PatternMatcher, UOp, UPat, sym_infer
|
||||
from tinygrad.device import Device, Buffer
|
||||
from tinygrad.renderer import Renderer, ProgramSpec, Estimates
|
||||
from tinygrad.codegen import full_rewrite_to_program
|
||||
from tinygrad.codegen.opt import Opt
|
||||
|
||||
# **************** Program Creation ****************
|
||||
|
||||
@track_rewrites(name=lambda *args,ret,**kwargs: TracingKey(ret.name, (ret.function_name, ret.ast), ret=ret), replay=True)
|
||||
def get_program(ast:UOp, renderer:Renderer, opts:list[Opt]|None=None) -> ProgramSpec:
|
||||
"""
|
||||
Transform an AST into a ProgramSpec. May trigger BEAM search.
|
||||
|
||||
Args:
|
||||
ast: The Ops.SINK rooted AST
|
||||
renderer: The renderer used to generate the code
|
||||
|
||||
Returns:
|
||||
The ProgramSpec of the program.
|
||||
"""
|
||||
|
||||
# linearize
|
||||
if opts is not None:
|
||||
assert ast.arg is None, "can't apply opts if sink has an arg"
|
||||
ast = ast.replace(arg=KernelInfo(opts_to_apply=tuple(opts)))
|
||||
|
||||
if getenv("VIZ"): graph_rewrite(ast, PatternMatcher([]), name="View Base AST")
|
||||
if DEBUG >= 5: print(pyrender(ast))
|
||||
|
||||
return ProgramSpec.from_uop(full_rewrite_to_program(ast, renderer))
|
||||
from tinygrad.renderer import ProgramSpec, Estimates
|
||||
from tinygrad.codegen import get_program
|
||||
|
||||
# **************** Runners ****************
|
||||
|
||||
|
||||
@@ -2,8 +2,8 @@ from __future__ import annotations
|
||||
from typing import Callable, cast
|
||||
import functools
|
||||
from dataclasses import dataclass, field
|
||||
from tinygrad.helpers import to_function_name, dedup, prod
|
||||
from tinygrad.uop.ops import Ops, UOp, sym_infer, sint, Variable, ssimplify, GroupOp, PatternMatcher
|
||||
from tinygrad.helpers import to_function_name, dedup, prod, DEBUG
|
||||
from tinygrad.uop.ops import Ops, UOp, sym_infer, sint, Variable, ssimplify, GroupOp, PatternMatcher, print_uops
|
||||
from tinygrad.dtype import AddrSpace, PtrDType
|
||||
from tinygrad.codegen.opt.tc import TensorCore
|
||||
from tinygrad.codegen.opt import Opt
|
||||
@@ -65,35 +65,13 @@ class ProgramSpec:
|
||||
ast:UOp # save the base ast (this is method cache key)
|
||||
uops:list[UOp]|None=None
|
||||
|
||||
# filled in from uops (if we have uops)
|
||||
# filled in from uops (via from_uop)
|
||||
global_size:list[int]=field(default_factory=lambda: [1,1,1])
|
||||
local_size:list[int]|None=None
|
||||
vars:list[Variable]=field(default_factory=list)
|
||||
globals:list[int]=field(default_factory=list)
|
||||
outs:list[int]=field(default_factory=list)
|
||||
ins:list[int]=field(default_factory=list)
|
||||
_ran_post_init:bool=False # NOTE: this is needed if you call replace on the Program
|
||||
|
||||
def __post_init__(self):
|
||||
if not self._ran_post_init and self.uops is not None:
|
||||
# single pass through the uops
|
||||
for u in self.uops:
|
||||
if u.op is Ops.DEFINE_VAR: self.vars.append(u)
|
||||
if u.op is Ops.DEFINE_GLOBAL: self.globals.append(u.arg)
|
||||
if u.op in (Ops.STORE, Ops.LOAD):
|
||||
if (idx:=u.src[0]).op is Ops.INDEX or (u.src[0].op is Ops.CAST and (idx:=u.src[0].src[0]).op is Ops.INDEX):
|
||||
if (buf:=idx.src[0]).op is Ops.DEFINE_GLOBAL: (self.outs if u.op is Ops.STORE else self.ins).append(buf.arg)
|
||||
# TODO: can else happen?
|
||||
if u.op is Ops.SPECIAL:
|
||||
# NOTE: you have to set local_size and global_size to the base [1,1,1] outside this
|
||||
if u.arg[0] == 'i': self.local_size = None
|
||||
special_size = self.local_size if u.arg[0] == 'l' else self.global_size
|
||||
# TODO: this cast is wrong, u.src[0].ssimplify() can be sint
|
||||
if special_size is not None: special_size[int(u.arg[-1])] = cast(int, u.src[0].ssimplify())
|
||||
self.vars = sorted(self.vars, key=lambda v: v.arg)
|
||||
self.outs = sorted(dedup(self.outs))
|
||||
self.ins = sorted(dedup(self.ins))
|
||||
self._ran_post_init = True
|
||||
|
||||
@functools.cached_property
|
||||
def estimates(self) -> Estimates:
|
||||
@@ -119,7 +97,31 @@ class ProgramSpec:
|
||||
assert prg.op is Ops.PROGRAM, f"expected PROGRAM, got {prg.op}"
|
||||
# SINK/DEVICE/LINEAR/SOURCE
|
||||
sink, device, linear, source = prg.src
|
||||
return ProgramSpec(sink.arg.name, source.arg, device.arg, sink, list(linear.src), global_size=[1,1,1], local_size=[1,1,1])
|
||||
uops = list(linear.src)
|
||||
if DEBUG >= 6: print_uops(uops) # LINEAR is src[2]
|
||||
|
||||
# single pass through the uops to extract metadata
|
||||
_vars: list[Variable] = []
|
||||
_globals: list[int] = []
|
||||
outs: list[int] = []
|
||||
ins: list[int] = []
|
||||
global_size: list[int] = [1, 1, 1]
|
||||
local_size: list[int]|None = [1, 1, 1]
|
||||
for u in uops:
|
||||
if u.op is Ops.DEFINE_VAR: _vars.append(u)
|
||||
if u.op is Ops.DEFINE_GLOBAL: _globals.append(u.arg)
|
||||
if u.op in (Ops.STORE, Ops.LOAD):
|
||||
if (idx:=u.src[0]).op is Ops.INDEX or (u.src[0].op is Ops.CAST and (idx:=u.src[0].src[0]).op is Ops.INDEX):
|
||||
if (buf:=idx.src[0]).op is Ops.DEFINE_GLOBAL: (outs if u.op is Ops.STORE else ins).append(buf.arg)
|
||||
# TODO: can else happen?
|
||||
if u.op is Ops.SPECIAL:
|
||||
if u.arg[0] == 'i': local_size = None
|
||||
special_size = local_size if u.arg[0] == 'l' else global_size
|
||||
# TODO: this cast is wrong, u.src[0].ssimplify() can be sint
|
||||
if special_size is not None: special_size[int(u.arg[-1])] = cast(int, u.src[0].ssimplify())
|
||||
|
||||
return ProgramSpec(sink.arg.name, source.arg, device.arg, sink, uops, global_size, local_size,
|
||||
sorted(_vars, key=lambda v: v.arg), sorted(dedup(_globals)), sorted(dedup(outs)), sorted(dedup(ins)))
|
||||
|
||||
class Renderer:
|
||||
device: str = ""
|
||||
|
||||
Reference in New Issue
Block a user