From 2af2b4da5df00b1d3cee7845ca7380cbdb7ada3c Mon Sep 17 00:00:00 2001 From: George Hotz <72895+geohot@users.noreply.github.com> Date: Mon, 22 Dec 2025 19:21:33 -0500 Subject: [PATCH] Revert "rewrites for renderer and compiler (#13646)" (#13806) This reverts commit 339dadf056cd22e9352bd680207baffd9bf92473. --- test/opt/test_tensor_cores.py | 6 ++--- tinygrad/codegen/__init__.py | 25 ------------------- tinygrad/codegen/opt/search.py | 2 +- tinygrad/engine/realize.py | 45 +++++++++++++++++++--------------- tinygrad/renderer/__init__.py | 1 - tinygrad/uop/__init__.py | 4 --- tinygrad/uop/ops.py | 3 +-- tinygrad/uop/spec.py | 10 -------- 8 files changed, 30 insertions(+), 66 deletions(-) diff --git a/test/opt/test_tensor_cores.py b/test/opt/test_tensor_cores.py index 2a3343668d..0b6bad6e9f 100644 --- a/test/opt/test_tensor_cores.py +++ b/test/opt/test_tensor_cores.py @@ -26,14 +26,14 @@ def helper_tc_ensure_uops_and_opts_count(N: int, M:int, K:int, dtype_in:DType, d opts_to_apply = [Opt(OptOps.TC, axis, (tc_select, tc_opt, 1))] if ensure_triggered: - program = get_program(realized_ast, Device[Device.DEFAULT].renderer, Device.DEFAULT, opts=opts_to_apply) + program = get_program(realized_ast, Device[Device.DEFAULT].renderer, opts=opts_to_apply) wmmas = len([uop for uop in program.uops if uop.op is Ops.WMMA]) tcs = len([x for x in program.applied_opts if x.op is OptOps.TC]) assert wmmas > 0, "tensor core not triggered" assert tcs == 1, "tensor core opt not included" else: try: - program = get_program(realized_ast, Device[Device.DEFAULT].renderer, Device.DEFAULT, opts=opts_to_apply) + program = get_program(realized_ast, Device[Device.DEFAULT].renderer, opts=opts_to_apply) assert False, "OptOps.TC triggered, expected KernelOptError" except KernelOptError: pass @@ -44,7 +44,7 @@ def helper_tc_allclose(N:int, M:int, K:int, dtype_in:DType, dtype_out:DType, axi if dtype_in == dtypes.bfloat16: r = r.float() realized_ast, bufs = helper_realized_ast(r) opts = [Opt(op=OptOps.TC, axis=axis, arg=(tc_select, tc_opt, use_tensor_cores))] - prg = CompiledRunner(replace(get_program(realized_ast, Device[Device.DEFAULT].renderer, Device.DEFAULT, opts=opts), device=Device.DEFAULT)) + prg = CompiledRunner(replace(get_program(realized_ast, Device[Device.DEFAULT].renderer, opts=opts), device=Device.DEFAULT)) if use_tensor_cores == 1: assert len([uop for uop in prg.p.uops if uop.op is Ops.WMMA]) > 0, "wmma not triggered" assert len([x for x in prg.p.uops[-1].arg.applied_opts if x.op is OptOps.TC]) == 1, "tensor core opt not included" prg.exec(bufs) diff --git a/tinygrad/codegen/__init__.py b/tinygrad/codegen/__init__.py index 0066a7e678..572df13857 100644 --- a/tinygrad/codegen/__init__.py +++ b/tinygrad/codegen/__init__.py @@ -18,7 +18,6 @@ from tinygrad.codegen.opt.postrange import apply_opts from tinygrad.codegen.simplify import pm_simplify_ranges, pm_flatten_range, pm_split_ranges, pm_load_collapse, pm_split_store from tinygrad.schedule.rangeify import pm_add_buffers_local, rangeify_codegen, pm_mops from tinygrad.codegen.late.linearizer import CFGContext, pm_split_ends, pm_add_control_flow, linearize -from tinygrad.device import Compiler pm_syntactic_sugar = PatternMatcher([ # INDEX on ptr INDEX concats them @@ -124,30 +123,6 @@ def line_rewrite(lst:list[UOp], pm:PatternMatcher) -> list[UOp]: newlst.extend(ret[1]) return newlst -def do_linearize(prg:UOp, sink:UOp) -> UOp: - lst = line_rewrite(linearize(sink), pm_linearize_cleanups) - if SPEC: type_verify(lst, program_spec) - return prg.replace(src=prg.src + (UOp(Ops.LINEAR, src=tuple(lst)),)) - -def do_render(ctx:tuple[Renderer, Compiler], prg:UOp, lin:UOp) -> UOp: - src = ctx[0].render(list(lin.src)) - return prg.replace(src=prg.src + (UOp(Ops.SOURCE, arg=src),)) - -def do_compile(ctx:tuple[Renderer, Compiler], prg:UOp, src:UOp) -> UOp: - lib = ctx[1].compile_cached(src.arg) - return prg.replace(src=prg.src + (UOp(Ops.BINARY, arg=lib),)) - -pm_to_program = PatternMatcher([ - (UPat(Ops.PROGRAM, src=(UPat(Ops.SINK, name="sink"),), name="prg"), do_linearize), - (UPat(Ops.PROGRAM, src=(UPat(), UPat(Ops.LINEAR, name="lin")), name="prg"), do_render), - (UPat(Ops.PROGRAM, src=(UPat(), UPat(), UPat(Ops.SOURCE, name="src")), name="prg"), do_compile), -]) - -def full_rewrite_to_program(sink:UOp, ren:Renderer, compiler:Compiler) -> UOp: - full_sink = full_rewrite_to_sink(sink, ren, optimize=sink.tag is None) - sink = UOp(Ops.PROGRAM, src=(full_sink,)) - return graph_rewrite(sink, pm_to_program, ctx=(ren, compiler), name="linearize/render/compile") - def full_rewrite(sink:UOp, ren:Renderer|None=None) -> list[UOp]: """ Function to transform the Kernel UOp graph into a linearized program. diff --git a/tinygrad/codegen/opt/search.py b/tinygrad/codegen/opt/search.py index 546f344264..084352c486 100644 --- a/tinygrad/codegen/opt/search.py +++ b/tinygrad/codegen/opt/search.py @@ -40,7 +40,7 @@ def _time_program(p:ProgramSpec, lib:bytes, var_vals:dict[str, int], rawbufs:lis if allow_test_size and p.global_size is not None and max_global_size is not None: global_size, factor = get_test_global_size(p.global_size, max_global_size, var_vals) p = replace(p, global_size=global_size) - try: car = CompiledRunner(replace(p, lib=lib)) + try: car = CompiledRunner(p, precompiled=lib) except AssertionError: return [math.inf] * cnt tms = [] input_bufs = [rawbufs[i] for i in car.p.globals] diff --git a/tinygrad/engine/realize.py b/tinygrad/engine/realize.py index c402389558..d2451e1c8b 100644 --- a/tinygrad/engine/realize.py +++ b/tinygrad/engine/realize.py @@ -4,28 +4,26 @@ from dataclasses import dataclass, replace, field from tinygrad.helpers import all_same, colored, DEBUG, GlobalCounters, ansilen, BEAM, NOOPT, all_int, CAPTURING, Metadata, TRACEMETA, TracingKey from tinygrad.helpers import DEVECTORIZE, time_to_str, VALIDATE_WITH_CPU, getenv, cpu_profile, PROFILE, ProfilePointEvent, cpu_events, prod, Context from tinygrad.helpers import unwrap -from tinygrad.uop.ops import Ops, PatternMatcher, UOp, UPat, sym_infer, graph_rewrite, track_rewrites, KernelInfo, pyrender +from tinygrad.uop.ops import Ops, PatternMatcher, UOp, UPat, sym_infer, graph_rewrite, print_uops, track_rewrites, KernelInfo, pyrender from tinygrad.device import Device, Buffer from tinygrad.renderer import Renderer, ProgramSpec, Estimates -from tinygrad.codegen import full_rewrite_to_program +from tinygrad.codegen import full_rewrite from tinygrad.codegen.opt import Opt # **************** Program Creation **************** @track_rewrites(name=lambda *args,ret,**kwargs: TracingKey(ret.name, (ret.function_name, ret.ast), ret=ret), replay=True) -def get_program(ast:UOp, renderer:Renderer, device:str|None=None, opts:list[Opt]|None=None) -> ProgramSpec: +def get_program(ast:UOp, renderer:Renderer, opts:list[Opt]|None=None) -> ProgramSpec: """ Transform an AST into a ProgramSpec. May trigger BEAM search. Args: ast: The Ops.SINK rooted AST renderer: The renderer used to generate the code - device: The device to compile for (defaults to renderer.device) Returns: The ProgramSpec of the program. """ - device = device or renderer.device if getenv("VIZ"): graph_rewrite(ast, PatternMatcher([]), name="View Base AST") if DEBUG >= 5: print(pyrender(ast)) @@ -34,14 +32,20 @@ def get_program(ast:UOp, renderer:Renderer, device:str|None=None, opts:list[Opt] if opts is not None: assert ast.arg is None, "can't apply opts if sink has an arg" ast = ast.replace(arg=KernelInfo(opts_to_apply=tuple(opts))) - if ast.arg is None: ast = ast.replace(arg=KernelInfo()) + try: + uops = full_rewrite(ast, renderer) + except RuntimeError as e: + print("***** LINEARIZE FAILURE *****") + print(e) + print(pyrender(ast)) + raise + assert uops[-1].op is Ops.SINK, "last uop must be sink" - prg = full_rewrite_to_program(ast, renderer, Device[device].compiler) - # SINK/LINEAR/SOURCE/BINARY - sink, linear, source, binary = prg.src + # print and render + if DEBUG >= 6: print_uops(uops) + src = renderer.render(uops) - # legacy - return ProgramSpec(sink.arg.name, source.arg, device, sink, list(linear.src), binary.arg, + return ProgramSpec(uops[-1].arg.name if uops[-1].arg is not None else "test", src, renderer.device, ast, uops, global_size=[1,1,1] if renderer.has_local or renderer.has_threads else None, local_size=[1,1,1] if renderer.has_local else None) @@ -72,18 +76,19 @@ def optimize_local_size(_prg:Callable, global_size:list[int], rawbufs:list[Buffe return ret[1] class CompiledRunner(Runner): - def __init__(self, p:ProgramSpec, prg=None): + def __init__(self, p:ProgramSpec, precompiled:bytes|None=None, prg=None): if DEBUG >= 3: print(p.applied_opts) if DEBUG >= 4: print(p.src) - if p.lib is None: - with cpu_profile(TracingKey(f"compile {p.name}", (p.function_name,)), "TINY"): - p = replace(p, lib=Device[p.device].compiler.compile_cached(p.src)) self.p:ProgramSpec = p - if DEBUG >= 7: Device[p.device].compiler.disassemble(unwrap(p.lib)) - self._prg = Device[p.device].runtime(p.function_name, p.lib) if prg is None else prg + if precompiled is not None: self.lib = precompiled + else: + with cpu_profile(TracingKey(f"compile {p.name}", (p.function_name,)), "TINY"): + self.lib = Device[p.device].compiler.compile_cached(p.src) + if DEBUG >= 7: Device[p.device].compiler.disassemble(self.lib) + self._prg = Device[p.device].runtime(p.function_name, self.lib) if prg is None else prg super().__init__(p.name, p.device, p.estimates) - def __reduce__(self): return self.__class__, (self.p,) + def __reduce__(self): return self.__class__, (self.p, self.lib) def __call__(self, rawbufs:list[Buffer], var_vals:dict[str, int]|None=None, wait=False) -> float|None: if var_vals is None: var_vals = {} @@ -157,9 +162,9 @@ def get_runner(device:str, ast:UOp) -> CompiledRunner: if cret:=method_cache.get(ckey): return cret bkey = (device.split(":")[0], type(Device[device].compiler), ast.key, context, True) if bret:=method_cache.get(bkey): - method_cache[ckey] = ret = CompiledRunner(replace(bret.p, device=device)) + method_cache[ckey] = ret = CompiledRunner(replace(bret.p, device=device), bret.lib) else: - prg: ProgramSpec = get_program(ast, Device[device].renderer, device) + prg: ProgramSpec = get_program(ast, Device[device].renderer) method_cache[ckey] = method_cache[bkey] = ret = CompiledRunner(replace(prg, device=device)) return ret diff --git a/tinygrad/renderer/__init__.py b/tinygrad/renderer/__init__.py index 0e3bd4f33c..c63dbff3df 100644 --- a/tinygrad/renderer/__init__.py +++ b/tinygrad/renderer/__init__.py @@ -64,7 +64,6 @@ class ProgramSpec: device:str ast:UOp # save the base ast (this is method cache key) uops:list[UOp]|None=None - lib:bytes|None=None # compiled binary # filled in from uops (if we have uops) global_size:list[int]|None=None diff --git a/tinygrad/uop/__init__.py b/tinygrad/uop/__init__.py index 2fa1ddbabe..39a9427a87 100644 --- a/tinygrad/uop/__init__.py +++ b/tinygrad/uop/__init__.py @@ -27,10 +27,6 @@ class Ops(FastEnum): # uops that aren't rendered NOOP = auto(); REWRITE_ERROR = auto() - # renderer/compiler - # LINEAR is a list of UOps, SOURCE has a str arg that's human readable, BINARY has a bytes arg that's not - PROGRAM = auto(); LINEAR = auto(); SOURCE = auto(); BINARY = auto() - # AFTER passes src[0] through and promises in the toposort that any consumers of the AFTER run after src[1:] # GROUP is a NOOP that just merges things together SINK = auto(); AFTER = auto(); GROUP = auto() diff --git a/tinygrad/uop/ops.py b/tinygrad/uop/ops.py index ce4d0c2d3e..5830bde033 100644 --- a/tinygrad/uop/ops.py +++ b/tinygrad/uop/ops.py @@ -218,8 +218,7 @@ class UOp(OpMixin, metaclass=UOpMetaClass): match self.op: # late ops don't have shape case Ops.UNIQUE | Ops.LUNIQUE | Ops.DEVICE | Ops.RANGE | Ops.LOAD | Ops.IF | Ops.BARRIER | Ops.CUSTOM | Ops.CUSTOMI | \ - Ops.VECTORIZE | Ops.VCONST | Ops.GEP | Ops.SPECIAL | Ops.UNROLL | Ops.CONTRACT | Ops.CUSTOM_KERNEL | \ - Ops.LINEAR | Ops.PROGRAM | Ops.SOURCE | Ops.BINARY: + Ops.VECTORIZE | Ops.VCONST | Ops.GEP | Ops.SPECIAL | Ops.UNROLL | Ops.CONTRACT | Ops.CUSTOM_KERNEL: return None case Ops.INDEX: diff --git a/tinygrad/uop/spec.py b/tinygrad/uop/spec.py index 0d1e43fc42..b09eb6408b 100644 --- a/tinygrad/uop/spec.py +++ b/tinygrad/uop/spec.py @@ -249,16 +249,6 @@ full_spec = PatternMatcher([ # in progress MSTACK may lose device (UPat((Ops.MSELECT, Ops.MSTACK), name="x"), lambda x: True), - # codegen: PROGRAM with progressive sources through the pipeline - (UPat(Ops.PROGRAM, dtypes.void, src=(UPat(Ops.SINK),)), lambda: True), - (UPat(Ops.PROGRAM, dtypes.void, src=(UPat(Ops.SINK), UPat(Ops.LINEAR))), lambda: True), - (UPat(Ops.PROGRAM, dtypes.void, src=(UPat(Ops.SINK), UPat(Ops.LINEAR), UPat(Ops.SOURCE))), lambda: True), - (UPat(Ops.PROGRAM, dtypes.void, src=(UPat(Ops.SINK), UPat(Ops.LINEAR), UPat(Ops.SOURCE), UPat(Ops.BINARY))), lambda: True), - # codegen: standalone LINEAR/SOURCE/BINARY - (UPat(Ops.LINEAR, dtypes.void), lambda: True), - (UPat(Ops.SOURCE, dtypes.void, src=()), lambda: True), - (UPat(Ops.BINARY, dtypes.void, src=()), lambda: True), - # temp VECTORIZE/INDEX during rewrite have the wrong dtype (UPat(Ops.VECTORIZE), lambda: True), (UPat(Ops.INDEX), lambda: True),