rewrite to program (#13808)

This commit is contained in:
George Hotz
2025-12-22 20:03:33 -05:00
committed by GitHub
parent 2af2b4da5d
commit edce2303f4
5 changed files with 42 additions and 14 deletions

View File

@@ -123,6 +123,25 @@ def line_rewrite(lst:list[UOp], pm:PatternMatcher) -> list[UOp]:
newlst.extend(ret[1])
return newlst
def do_linearize(prg:UOp, sink:UOp) -> UOp:
lst = line_rewrite(linearize(sink), pm_linearize_cleanups)
if SPEC: type_verify(lst, program_spec)
return prg.replace(src=prg.src + (UOp(Ops.LINEAR, src=tuple(lst)),))
def do_render(ctx:Renderer, prg:UOp, lin:UOp) -> UOp:
src = ctx.render(list(lin.src))
return prg.replace(src=prg.src + (UOp(Ops.SOURCE, arg=src),))
pm_to_program = PatternMatcher([
(UPat(Ops.PROGRAM, src=(UPat(Ops.SINK, name="sink"),), name="prg"), do_linearize),
(UPat(Ops.PROGRAM, src=(UPat(), UPat(Ops.LINEAR, name="lin")), name="prg"), do_render),
])
def full_rewrite_to_program(sink:UOp, ren:Renderer) -> UOp:
full_sink = full_rewrite_to_sink(sink, ren, optimize=sink.tag is None)
sink = UOp(Ops.PROGRAM, src=(full_sink,))
return graph_rewrite(sink, pm_to_program, ctx=ren, name="linearize/render")
def full_rewrite(sink:UOp, ren:Renderer|None=None) -> list[UOp]:
"""
Function to transform the Kernel UOp graph into a linearized program.

View File

@@ -7,7 +7,7 @@ from tinygrad.helpers import unwrap
from tinygrad.uop.ops import Ops, PatternMatcher, UOp, UPat, sym_infer, graph_rewrite, print_uops, track_rewrites, KernelInfo, pyrender
from tinygrad.device import Device, Buffer
from tinygrad.renderer import Renderer, ProgramSpec, Estimates
from tinygrad.codegen import full_rewrite
from tinygrad.codegen import full_rewrite_to_program
from tinygrad.codegen.opt import Opt
# **************** Program Creation ****************
@@ -32,20 +32,16 @@ def get_program(ast:UOp, renderer:Renderer, opts:list[Opt]|None=None) -> Program
if opts is not None:
assert ast.arg is None, "can't apply opts if sink has an arg"
ast = ast.replace(arg=KernelInfo(opts_to_apply=tuple(opts)))
try:
uops = full_rewrite(ast, renderer)
except RuntimeError as e:
print("***** LINEARIZE FAILURE *****")
print(e)
print(pyrender(ast))
raise
assert uops[-1].op is Ops.SINK, "last uop must be sink"
if ast.arg is None: ast = ast.replace(arg=KernelInfo())
# print and render
if DEBUG >= 6: print_uops(uops)
src = renderer.render(uops)
prg = full_rewrite_to_program(ast, renderer)
# SINK/LINEAR/SOURCE
sink, linear, source = prg.src
return ProgramSpec(uops[-1].arg.name if uops[-1].arg is not None else "test", src, renderer.device, ast, uops,
# print
if DEBUG >= 6: print_uops(list(linear.src))
return ProgramSpec(sink.arg.name, source.arg, renderer.device, sink, list(linear.src),
global_size=[1,1,1] if renderer.has_local or renderer.has_threads else None,
local_size=[1,1,1] if renderer.has_local else None)

View File

@@ -27,6 +27,10 @@ class Ops(FastEnum):
# uops that aren't rendered
NOOP = auto(); REWRITE_ERROR = auto()
# renderer
# LINEAR is a list of UOps, SOURCE has a str arg that's human readable
PROGRAM = auto(); LINEAR = auto(); SOURCE = auto()
# AFTER passes src[0] through and promises in the toposort that any consumers of the AFTER run after src[1:]
# GROUP is a NOOP that just merges things together
SINK = auto(); AFTER = auto(); GROUP = auto()

View File

@@ -218,7 +218,8 @@ class UOp(OpMixin, metaclass=UOpMetaClass):
match self.op:
# late ops don't have shape
case Ops.UNIQUE | Ops.LUNIQUE | Ops.DEVICE | Ops.RANGE | Ops.LOAD | Ops.IF | Ops.BARRIER | Ops.CUSTOM | Ops.CUSTOMI | \
Ops.VECTORIZE | Ops.VCONST | Ops.GEP | Ops.SPECIAL | Ops.UNROLL | Ops.CONTRACT | Ops.CUSTOM_KERNEL:
Ops.VECTORIZE | Ops.VCONST | Ops.GEP | Ops.SPECIAL | Ops.UNROLL | Ops.CONTRACT | Ops.CUSTOM_KERNEL | \
Ops.LINEAR | Ops.PROGRAM | Ops.SOURCE:
return None
case Ops.INDEX:

View File

@@ -249,6 +249,14 @@ full_spec = PatternMatcher([
# in progress MSTACK may lose device
(UPat((Ops.MSELECT, Ops.MSTACK), name="x"), lambda x: True),
# codegen: PROGRAM with progressive sources through the pipeline
(UPat(Ops.PROGRAM, dtypes.void, src=(UPat(Ops.SINK),)), lambda: True),
(UPat(Ops.PROGRAM, dtypes.void, src=(UPat(Ops.SINK), UPat(Ops.LINEAR))), lambda: True),
(UPat(Ops.PROGRAM, dtypes.void, src=(UPat(Ops.SINK), UPat(Ops.LINEAR), UPat(Ops.SOURCE))), lambda: True),
# codegen: standalone LINEAR/SOURCE
(UPat(Ops.LINEAR, dtypes.void), lambda: True),
(UPat(Ops.SOURCE, dtypes.void, src=()), lambda: True),
# temp VECTORIZE/INDEX during rewrite have the wrong dtype
(UPat(Ops.VECTORIZE), lambda: True),
(UPat(Ops.INDEX), lambda: True),