Revert "rewrites for renderer and compiler (#13646)" (#13806)

This reverts commit 339dadf056.
This commit is contained in:
George Hotz
2025-12-22 19:21:33 -05:00
committed by GitHub
parent 339dadf056
commit 2af2b4da5d
8 changed files with 30 additions and 66 deletions

View File

@@ -26,14 +26,14 @@ def helper_tc_ensure_uops_and_opts_count(N: int, M:int, K:int, dtype_in:DType, d
opts_to_apply = [Opt(OptOps.TC, axis, (tc_select, tc_opt, 1))]
if ensure_triggered:
program = get_program(realized_ast, Device[Device.DEFAULT].renderer, Device.DEFAULT, opts=opts_to_apply)
program = get_program(realized_ast, Device[Device.DEFAULT].renderer, opts=opts_to_apply)
wmmas = len([uop for uop in program.uops if uop.op is Ops.WMMA])
tcs = len([x for x in program.applied_opts if x.op is OptOps.TC])
assert wmmas > 0, "tensor core not triggered"
assert tcs == 1, "tensor core opt not included"
else:
try:
program = get_program(realized_ast, Device[Device.DEFAULT].renderer, Device.DEFAULT, opts=opts_to_apply)
program = get_program(realized_ast, Device[Device.DEFAULT].renderer, opts=opts_to_apply)
assert False, "OptOps.TC triggered, expected KernelOptError"
except KernelOptError: pass
@@ -44,7 +44,7 @@ def helper_tc_allclose(N:int, M:int, K:int, dtype_in:DType, dtype_out:DType, axi
if dtype_in == dtypes.bfloat16: r = r.float()
realized_ast, bufs = helper_realized_ast(r)
opts = [Opt(op=OptOps.TC, axis=axis, arg=(tc_select, tc_opt, use_tensor_cores))]
prg = CompiledRunner(replace(get_program(realized_ast, Device[Device.DEFAULT].renderer, Device.DEFAULT, opts=opts), device=Device.DEFAULT))
prg = CompiledRunner(replace(get_program(realized_ast, Device[Device.DEFAULT].renderer, opts=opts), device=Device.DEFAULT))
if use_tensor_cores == 1: assert len([uop for uop in prg.p.uops if uop.op is Ops.WMMA]) > 0, "wmma not triggered"
assert len([x for x in prg.p.uops[-1].arg.applied_opts if x.op is OptOps.TC]) == 1, "tensor core opt not included"
prg.exec(bufs)

View File

@@ -18,7 +18,6 @@ from tinygrad.codegen.opt.postrange import apply_opts
from tinygrad.codegen.simplify import pm_simplify_ranges, pm_flatten_range, pm_split_ranges, pm_load_collapse, pm_split_store
from tinygrad.schedule.rangeify import pm_add_buffers_local, rangeify_codegen, pm_mops
from tinygrad.codegen.late.linearizer import CFGContext, pm_split_ends, pm_add_control_flow, linearize
from tinygrad.device import Compiler
pm_syntactic_sugar = PatternMatcher([
# INDEX on ptr INDEX concats them
@@ -124,30 +123,6 @@ def line_rewrite(lst:list[UOp], pm:PatternMatcher) -> list[UOp]:
newlst.extend(ret[1])
return newlst
def do_linearize(prg:UOp, sink:UOp) -> UOp:
lst = line_rewrite(linearize(sink), pm_linearize_cleanups)
if SPEC: type_verify(lst, program_spec)
return prg.replace(src=prg.src + (UOp(Ops.LINEAR, src=tuple(lst)),))
def do_render(ctx:tuple[Renderer, Compiler], prg:UOp, lin:UOp) -> UOp:
src = ctx[0].render(list(lin.src))
return prg.replace(src=prg.src + (UOp(Ops.SOURCE, arg=src),))
def do_compile(ctx:tuple[Renderer, Compiler], prg:UOp, src:UOp) -> UOp:
lib = ctx[1].compile_cached(src.arg)
return prg.replace(src=prg.src + (UOp(Ops.BINARY, arg=lib),))
pm_to_program = PatternMatcher([
(UPat(Ops.PROGRAM, src=(UPat(Ops.SINK, name="sink"),), name="prg"), do_linearize),
(UPat(Ops.PROGRAM, src=(UPat(), UPat(Ops.LINEAR, name="lin")), name="prg"), do_render),
(UPat(Ops.PROGRAM, src=(UPat(), UPat(), UPat(Ops.SOURCE, name="src")), name="prg"), do_compile),
])
def full_rewrite_to_program(sink:UOp, ren:Renderer, compiler:Compiler) -> UOp:
full_sink = full_rewrite_to_sink(sink, ren, optimize=sink.tag is None)
sink = UOp(Ops.PROGRAM, src=(full_sink,))
return graph_rewrite(sink, pm_to_program, ctx=(ren, compiler), name="linearize/render/compile")
def full_rewrite(sink:UOp, ren:Renderer|None=None) -> list[UOp]:
"""
Function to transform the Kernel UOp graph into a linearized program.

View File

@@ -40,7 +40,7 @@ def _time_program(p:ProgramSpec, lib:bytes, var_vals:dict[str, int], rawbufs:lis
if allow_test_size and p.global_size is not None and max_global_size is not None:
global_size, factor = get_test_global_size(p.global_size, max_global_size, var_vals)
p = replace(p, global_size=global_size)
try: car = CompiledRunner(replace(p, lib=lib))
try: car = CompiledRunner(p, precompiled=lib)
except AssertionError: return [math.inf] * cnt
tms = []
input_bufs = [rawbufs[i] for i in car.p.globals]

View File

@@ -4,28 +4,26 @@ from dataclasses import dataclass, replace, field
from tinygrad.helpers import all_same, colored, DEBUG, GlobalCounters, ansilen, BEAM, NOOPT, all_int, CAPTURING, Metadata, TRACEMETA, TracingKey
from tinygrad.helpers import DEVECTORIZE, time_to_str, VALIDATE_WITH_CPU, getenv, cpu_profile, PROFILE, ProfilePointEvent, cpu_events, prod, Context
from tinygrad.helpers import unwrap
from tinygrad.uop.ops import Ops, PatternMatcher, UOp, UPat, sym_infer, graph_rewrite, track_rewrites, KernelInfo, pyrender
from tinygrad.uop.ops import Ops, PatternMatcher, UOp, UPat, sym_infer, graph_rewrite, print_uops, track_rewrites, KernelInfo, pyrender
from tinygrad.device import Device, Buffer
from tinygrad.renderer import Renderer, ProgramSpec, Estimates
from tinygrad.codegen import full_rewrite_to_program
from tinygrad.codegen import full_rewrite
from tinygrad.codegen.opt import Opt
# **************** Program Creation ****************
@track_rewrites(name=lambda *args,ret,**kwargs: TracingKey(ret.name, (ret.function_name, ret.ast), ret=ret), replay=True)
def get_program(ast:UOp, renderer:Renderer, device:str|None=None, opts:list[Opt]|None=None) -> ProgramSpec:
def get_program(ast:UOp, renderer:Renderer, opts:list[Opt]|None=None) -> ProgramSpec:
"""
Transform an AST into a ProgramSpec. May trigger BEAM search.
Args:
ast: The Ops.SINK rooted AST
renderer: The renderer used to generate the code
device: The device to compile for (defaults to renderer.device)
Returns:
The ProgramSpec of the program.
"""
device = device or renderer.device
if getenv("VIZ"): graph_rewrite(ast, PatternMatcher([]), name="View Base AST")
if DEBUG >= 5: print(pyrender(ast))
@@ -34,14 +32,20 @@ def get_program(ast:UOp, renderer:Renderer, device:str|None=None, opts:list[Opt]
if opts is not None:
assert ast.arg is None, "can't apply opts if sink has an arg"
ast = ast.replace(arg=KernelInfo(opts_to_apply=tuple(opts)))
if ast.arg is None: ast = ast.replace(arg=KernelInfo())
try:
uops = full_rewrite(ast, renderer)
except RuntimeError as e:
print("***** LINEARIZE FAILURE *****")
print(e)
print(pyrender(ast))
raise
assert uops[-1].op is Ops.SINK, "last uop must be sink"
prg = full_rewrite_to_program(ast, renderer, Device[device].compiler)
# SINK/LINEAR/SOURCE/BINARY
sink, linear, source, binary = prg.src
# print and render
if DEBUG >= 6: print_uops(uops)
src = renderer.render(uops)
# legacy
return ProgramSpec(sink.arg.name, source.arg, device, sink, list(linear.src), binary.arg,
return ProgramSpec(uops[-1].arg.name if uops[-1].arg is not None else "test", src, renderer.device, ast, uops,
global_size=[1,1,1] if renderer.has_local or renderer.has_threads else None,
local_size=[1,1,1] if renderer.has_local else None)
@@ -72,18 +76,19 @@ def optimize_local_size(_prg:Callable, global_size:list[int], rawbufs:list[Buffe
return ret[1]
class CompiledRunner(Runner):
def __init__(self, p:ProgramSpec, prg=None):
def __init__(self, p:ProgramSpec, precompiled:bytes|None=None, prg=None):
if DEBUG >= 3: print(p.applied_opts)
if DEBUG >= 4: print(p.src)
if p.lib is None:
with cpu_profile(TracingKey(f"compile {p.name}", (p.function_name,)), "TINY"):
p = replace(p, lib=Device[p.device].compiler.compile_cached(p.src))
self.p:ProgramSpec = p
if DEBUG >= 7: Device[p.device].compiler.disassemble(unwrap(p.lib))
self._prg = Device[p.device].runtime(p.function_name, p.lib) if prg is None else prg
if precompiled is not None: self.lib = precompiled
else:
with cpu_profile(TracingKey(f"compile {p.name}", (p.function_name,)), "TINY"):
self.lib = Device[p.device].compiler.compile_cached(p.src)
if DEBUG >= 7: Device[p.device].compiler.disassemble(self.lib)
self._prg = Device[p.device].runtime(p.function_name, self.lib) if prg is None else prg
super().__init__(p.name, p.device, p.estimates)
def __reduce__(self): return self.__class__, (self.p,)
def __reduce__(self): return self.__class__, (self.p, self.lib)
def __call__(self, rawbufs:list[Buffer], var_vals:dict[str, int]|None=None, wait=False) -> float|None:
if var_vals is None: var_vals = {}
@@ -157,9 +162,9 @@ def get_runner(device:str, ast:UOp) -> CompiledRunner:
if cret:=method_cache.get(ckey): return cret
bkey = (device.split(":")[0], type(Device[device].compiler), ast.key, context, True)
if bret:=method_cache.get(bkey):
method_cache[ckey] = ret = CompiledRunner(replace(bret.p, device=device))
method_cache[ckey] = ret = CompiledRunner(replace(bret.p, device=device), bret.lib)
else:
prg: ProgramSpec = get_program(ast, Device[device].renderer, device)
prg: ProgramSpec = get_program(ast, Device[device].renderer)
method_cache[ckey] = method_cache[bkey] = ret = CompiledRunner(replace(prg, device=device))
return ret

View File

@@ -64,7 +64,6 @@ class ProgramSpec:
device:str
ast:UOp # save the base ast (this is method cache key)
uops:list[UOp]|None=None
lib:bytes|None=None # compiled binary
# filled in from uops (if we have uops)
global_size:list[int]|None=None

View File

@@ -27,10 +27,6 @@ class Ops(FastEnum):
# uops that aren't rendered
NOOP = auto(); REWRITE_ERROR = auto()
# renderer/compiler
# LINEAR is a list of UOps, SOURCE has a str arg that's human readable, BINARY has a bytes arg that's not
PROGRAM = auto(); LINEAR = auto(); SOURCE = auto(); BINARY = auto()
# AFTER passes src[0] through and promises in the toposort that any consumers of the AFTER run after src[1:]
# GROUP is a NOOP that just merges things together
SINK = auto(); AFTER = auto(); GROUP = auto()

View File

@@ -218,8 +218,7 @@ class UOp(OpMixin, metaclass=UOpMetaClass):
match self.op:
# late ops don't have shape
case Ops.UNIQUE | Ops.LUNIQUE | Ops.DEVICE | Ops.RANGE | Ops.LOAD | Ops.IF | Ops.BARRIER | Ops.CUSTOM | Ops.CUSTOMI | \
Ops.VECTORIZE | Ops.VCONST | Ops.GEP | Ops.SPECIAL | Ops.UNROLL | Ops.CONTRACT | Ops.CUSTOM_KERNEL | \
Ops.LINEAR | Ops.PROGRAM | Ops.SOURCE | Ops.BINARY:
Ops.VECTORIZE | Ops.VCONST | Ops.GEP | Ops.SPECIAL | Ops.UNROLL | Ops.CONTRACT | Ops.CUSTOM_KERNEL:
return None
case Ops.INDEX:

View File

@@ -249,16 +249,6 @@ full_spec = PatternMatcher([
# in progress MSTACK may lose device
(UPat((Ops.MSELECT, Ops.MSTACK), name="x"), lambda x: True),
# codegen: PROGRAM with progressive sources through the pipeline
(UPat(Ops.PROGRAM, dtypes.void, src=(UPat(Ops.SINK),)), lambda: True),
(UPat(Ops.PROGRAM, dtypes.void, src=(UPat(Ops.SINK), UPat(Ops.LINEAR))), lambda: True),
(UPat(Ops.PROGRAM, dtypes.void, src=(UPat(Ops.SINK), UPat(Ops.LINEAR), UPat(Ops.SOURCE))), lambda: True),
(UPat(Ops.PROGRAM, dtypes.void, src=(UPat(Ops.SINK), UPat(Ops.LINEAR), UPat(Ops.SOURCE), UPat(Ops.BINARY))), lambda: True),
# codegen: standalone LINEAR/SOURCE/BINARY
(UPat(Ops.LINEAR, dtypes.void), lambda: True),
(UPat(Ops.SOURCE, dtypes.void, src=()), lambda: True),
(UPat(Ops.BINARY, dtypes.void, src=()), lambda: True),
# temp VECTORIZE/INDEX during rewrite have the wrong dtype
(UPat(Ops.VECTORIZE), lambda: True),
(UPat(Ops.INDEX), lambda: True),