From edce2303f425b028c75ed64f4f2d6af8c753b4ae Mon Sep 17 00:00:00 2001 From: George Hotz <72895+geohot@users.noreply.github.com> Date: Mon, 22 Dec 2025 20:03:33 -0500 Subject: [PATCH] rewrite to program (#13808) --- tinygrad/codegen/__init__.py | 19 +++++++++++++++++++ tinygrad/engine/realize.py | 22 +++++++++------------- tinygrad/uop/__init__.py | 4 ++++ tinygrad/uop/ops.py | 3 ++- tinygrad/uop/spec.py | 8 ++++++++ 5 files changed, 42 insertions(+), 14 deletions(-) diff --git a/tinygrad/codegen/__init__.py b/tinygrad/codegen/__init__.py index 572df13857..b359ee0c6e 100644 --- a/tinygrad/codegen/__init__.py +++ b/tinygrad/codegen/__init__.py @@ -123,6 +123,25 @@ def line_rewrite(lst:list[UOp], pm:PatternMatcher) -> list[UOp]: newlst.extend(ret[1]) return newlst +def do_linearize(prg:UOp, sink:UOp) -> UOp: + lst = line_rewrite(linearize(sink), pm_linearize_cleanups) + if SPEC: type_verify(lst, program_spec) + return prg.replace(src=prg.src + (UOp(Ops.LINEAR, src=tuple(lst)),)) + +def do_render(ctx:Renderer, prg:UOp, lin:UOp) -> UOp: + src = ctx.render(list(lin.src)) + return prg.replace(src=prg.src + (UOp(Ops.SOURCE, arg=src),)) + +pm_to_program = PatternMatcher([ + (UPat(Ops.PROGRAM, src=(UPat(Ops.SINK, name="sink"),), name="prg"), do_linearize), + (UPat(Ops.PROGRAM, src=(UPat(), UPat(Ops.LINEAR, name="lin")), name="prg"), do_render), +]) + +def full_rewrite_to_program(sink:UOp, ren:Renderer) -> UOp: + full_sink = full_rewrite_to_sink(sink, ren, optimize=sink.tag is None) + sink = UOp(Ops.PROGRAM, src=(full_sink,)) + return graph_rewrite(sink, pm_to_program, ctx=ren, name="linearize/render") + def full_rewrite(sink:UOp, ren:Renderer|None=None) -> list[UOp]: """ Function to transform the Kernel UOp graph into a linearized program. diff --git a/tinygrad/engine/realize.py b/tinygrad/engine/realize.py index d2451e1c8b..52898f88a8 100644 --- a/tinygrad/engine/realize.py +++ b/tinygrad/engine/realize.py @@ -7,7 +7,7 @@ from tinygrad.helpers import unwrap from tinygrad.uop.ops import Ops, PatternMatcher, UOp, UPat, sym_infer, graph_rewrite, print_uops, track_rewrites, KernelInfo, pyrender from tinygrad.device import Device, Buffer from tinygrad.renderer import Renderer, ProgramSpec, Estimates -from tinygrad.codegen import full_rewrite +from tinygrad.codegen import full_rewrite_to_program from tinygrad.codegen.opt import Opt # **************** Program Creation **************** @@ -32,20 +32,16 @@ def get_program(ast:UOp, renderer:Renderer, opts:list[Opt]|None=None) -> Program if opts is not None: assert ast.arg is None, "can't apply opts if sink has an arg" ast = ast.replace(arg=KernelInfo(opts_to_apply=tuple(opts))) - try: - uops = full_rewrite(ast, renderer) - except RuntimeError as e: - print("***** LINEARIZE FAILURE *****") - print(e) - print(pyrender(ast)) - raise - assert uops[-1].op is Ops.SINK, "last uop must be sink" + if ast.arg is None: ast = ast.replace(arg=KernelInfo()) - # print and render - if DEBUG >= 6: print_uops(uops) - src = renderer.render(uops) + prg = full_rewrite_to_program(ast, renderer) + # SINK/LINEAR/SOURCE + sink, linear, source = prg.src - return ProgramSpec(uops[-1].arg.name if uops[-1].arg is not None else "test", src, renderer.device, ast, uops, + # print + if DEBUG >= 6: print_uops(list(linear.src)) + + return ProgramSpec(sink.arg.name, source.arg, renderer.device, sink, list(linear.src), global_size=[1,1,1] if renderer.has_local or renderer.has_threads else None, local_size=[1,1,1] if renderer.has_local else None) diff --git a/tinygrad/uop/__init__.py b/tinygrad/uop/__init__.py index 39a9427a87..898c014986 100644 --- a/tinygrad/uop/__init__.py +++ b/tinygrad/uop/__init__.py @@ -27,6 +27,10 @@ class Ops(FastEnum): # uops that aren't rendered NOOP = auto(); REWRITE_ERROR = auto() + # renderer + # LINEAR is a list of UOps, SOURCE has a str arg that's human readable + PROGRAM = auto(); LINEAR = auto(); SOURCE = auto() + # AFTER passes src[0] through and promises in the toposort that any consumers of the AFTER run after src[1:] # GROUP is a NOOP that just merges things together SINK = auto(); AFTER = auto(); GROUP = auto() diff --git a/tinygrad/uop/ops.py b/tinygrad/uop/ops.py index 5830bde033..f016166e74 100644 --- a/tinygrad/uop/ops.py +++ b/tinygrad/uop/ops.py @@ -218,7 +218,8 @@ class UOp(OpMixin, metaclass=UOpMetaClass): match self.op: # late ops don't have shape case Ops.UNIQUE | Ops.LUNIQUE | Ops.DEVICE | Ops.RANGE | Ops.LOAD | Ops.IF | Ops.BARRIER | Ops.CUSTOM | Ops.CUSTOMI | \ - Ops.VECTORIZE | Ops.VCONST | Ops.GEP | Ops.SPECIAL | Ops.UNROLL | Ops.CONTRACT | Ops.CUSTOM_KERNEL: + Ops.VECTORIZE | Ops.VCONST | Ops.GEP | Ops.SPECIAL | Ops.UNROLL | Ops.CONTRACT | Ops.CUSTOM_KERNEL | \ + Ops.LINEAR | Ops.PROGRAM | Ops.SOURCE: return None case Ops.INDEX: diff --git a/tinygrad/uop/spec.py b/tinygrad/uop/spec.py index b09eb6408b..0c8eee5cbc 100644 --- a/tinygrad/uop/spec.py +++ b/tinygrad/uop/spec.py @@ -249,6 +249,14 @@ full_spec = PatternMatcher([ # in progress MSTACK may lose device (UPat((Ops.MSELECT, Ops.MSTACK), name="x"), lambda x: True), + # codegen: PROGRAM with progressive sources through the pipeline + (UPat(Ops.PROGRAM, dtypes.void, src=(UPat(Ops.SINK),)), lambda: True), + (UPat(Ops.PROGRAM, dtypes.void, src=(UPat(Ops.SINK), UPat(Ops.LINEAR))), lambda: True), + (UPat(Ops.PROGRAM, dtypes.void, src=(UPat(Ops.SINK), UPat(Ops.LINEAR), UPat(Ops.SOURCE))), lambda: True), + # codegen: standalone LINEAR/SOURCE + (UPat(Ops.LINEAR, dtypes.void), lambda: True), + (UPat(Ops.SOURCE, dtypes.void, src=()), lambda: True), + # temp VECTORIZE/INDEX during rewrite have the wrong dtype (UPat(Ops.VECTORIZE), lambda: True), (UPat(Ops.INDEX), lambda: True),