mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-07 22:23:55 -05:00
This reverts commit 339dadf056.
This commit is contained in:
@@ -26,14 +26,14 @@ def helper_tc_ensure_uops_and_opts_count(N: int, M:int, K:int, dtype_in:DType, d
|
||||
opts_to_apply = [Opt(OptOps.TC, axis, (tc_select, tc_opt, 1))]
|
||||
|
||||
if ensure_triggered:
|
||||
program = get_program(realized_ast, Device[Device.DEFAULT].renderer, Device.DEFAULT, opts=opts_to_apply)
|
||||
program = get_program(realized_ast, Device[Device.DEFAULT].renderer, opts=opts_to_apply)
|
||||
wmmas = len([uop for uop in program.uops if uop.op is Ops.WMMA])
|
||||
tcs = len([x for x in program.applied_opts if x.op is OptOps.TC])
|
||||
assert wmmas > 0, "tensor core not triggered"
|
||||
assert tcs == 1, "tensor core opt not included"
|
||||
else:
|
||||
try:
|
||||
program = get_program(realized_ast, Device[Device.DEFAULT].renderer, Device.DEFAULT, opts=opts_to_apply)
|
||||
program = get_program(realized_ast, Device[Device.DEFAULT].renderer, opts=opts_to_apply)
|
||||
assert False, "OptOps.TC triggered, expected KernelOptError"
|
||||
except KernelOptError: pass
|
||||
|
||||
@@ -44,7 +44,7 @@ def helper_tc_allclose(N:int, M:int, K:int, dtype_in:DType, dtype_out:DType, axi
|
||||
if dtype_in == dtypes.bfloat16: r = r.float()
|
||||
realized_ast, bufs = helper_realized_ast(r)
|
||||
opts = [Opt(op=OptOps.TC, axis=axis, arg=(tc_select, tc_opt, use_tensor_cores))]
|
||||
prg = CompiledRunner(replace(get_program(realized_ast, Device[Device.DEFAULT].renderer, Device.DEFAULT, opts=opts), device=Device.DEFAULT))
|
||||
prg = CompiledRunner(replace(get_program(realized_ast, Device[Device.DEFAULT].renderer, opts=opts), device=Device.DEFAULT))
|
||||
if use_tensor_cores == 1: assert len([uop for uop in prg.p.uops if uop.op is Ops.WMMA]) > 0, "wmma not triggered"
|
||||
assert len([x for x in prg.p.uops[-1].arg.applied_opts if x.op is OptOps.TC]) == 1, "tensor core opt not included"
|
||||
prg.exec(bufs)
|
||||
|
||||
@@ -18,7 +18,6 @@ from tinygrad.codegen.opt.postrange import apply_opts
|
||||
from tinygrad.codegen.simplify import pm_simplify_ranges, pm_flatten_range, pm_split_ranges, pm_load_collapse, pm_split_store
|
||||
from tinygrad.schedule.rangeify import pm_add_buffers_local, rangeify_codegen, pm_mops
|
||||
from tinygrad.codegen.late.linearizer import CFGContext, pm_split_ends, pm_add_control_flow, linearize
|
||||
from tinygrad.device import Compiler
|
||||
|
||||
pm_syntactic_sugar = PatternMatcher([
|
||||
# INDEX on ptr INDEX concats them
|
||||
@@ -124,30 +123,6 @@ def line_rewrite(lst:list[UOp], pm:PatternMatcher) -> list[UOp]:
|
||||
newlst.extend(ret[1])
|
||||
return newlst
|
||||
|
||||
def do_linearize(prg:UOp, sink:UOp) -> UOp:
|
||||
lst = line_rewrite(linearize(sink), pm_linearize_cleanups)
|
||||
if SPEC: type_verify(lst, program_spec)
|
||||
return prg.replace(src=prg.src + (UOp(Ops.LINEAR, src=tuple(lst)),))
|
||||
|
||||
def do_render(ctx:tuple[Renderer, Compiler], prg:UOp, lin:UOp) -> UOp:
|
||||
src = ctx[0].render(list(lin.src))
|
||||
return prg.replace(src=prg.src + (UOp(Ops.SOURCE, arg=src),))
|
||||
|
||||
def do_compile(ctx:tuple[Renderer, Compiler], prg:UOp, src:UOp) -> UOp:
|
||||
lib = ctx[1].compile_cached(src.arg)
|
||||
return prg.replace(src=prg.src + (UOp(Ops.BINARY, arg=lib),))
|
||||
|
||||
pm_to_program = PatternMatcher([
|
||||
(UPat(Ops.PROGRAM, src=(UPat(Ops.SINK, name="sink"),), name="prg"), do_linearize),
|
||||
(UPat(Ops.PROGRAM, src=(UPat(), UPat(Ops.LINEAR, name="lin")), name="prg"), do_render),
|
||||
(UPat(Ops.PROGRAM, src=(UPat(), UPat(), UPat(Ops.SOURCE, name="src")), name="prg"), do_compile),
|
||||
])
|
||||
|
||||
def full_rewrite_to_program(sink:UOp, ren:Renderer, compiler:Compiler) -> UOp:
|
||||
full_sink = full_rewrite_to_sink(sink, ren, optimize=sink.tag is None)
|
||||
sink = UOp(Ops.PROGRAM, src=(full_sink,))
|
||||
return graph_rewrite(sink, pm_to_program, ctx=(ren, compiler), name="linearize/render/compile")
|
||||
|
||||
def full_rewrite(sink:UOp, ren:Renderer|None=None) -> list[UOp]:
|
||||
"""
|
||||
Function to transform the Kernel UOp graph into a linearized program.
|
||||
|
||||
@@ -40,7 +40,7 @@ def _time_program(p:ProgramSpec, lib:bytes, var_vals:dict[str, int], rawbufs:lis
|
||||
if allow_test_size and p.global_size is not None and max_global_size is not None:
|
||||
global_size, factor = get_test_global_size(p.global_size, max_global_size, var_vals)
|
||||
p = replace(p, global_size=global_size)
|
||||
try: car = CompiledRunner(replace(p, lib=lib))
|
||||
try: car = CompiledRunner(p, precompiled=lib)
|
||||
except AssertionError: return [math.inf] * cnt
|
||||
tms = []
|
||||
input_bufs = [rawbufs[i] for i in car.p.globals]
|
||||
|
||||
@@ -4,28 +4,26 @@ from dataclasses import dataclass, replace, field
|
||||
from tinygrad.helpers import all_same, colored, DEBUG, GlobalCounters, ansilen, BEAM, NOOPT, all_int, CAPTURING, Metadata, TRACEMETA, TracingKey
|
||||
from tinygrad.helpers import DEVECTORIZE, time_to_str, VALIDATE_WITH_CPU, getenv, cpu_profile, PROFILE, ProfilePointEvent, cpu_events, prod, Context
|
||||
from tinygrad.helpers import unwrap
|
||||
from tinygrad.uop.ops import Ops, PatternMatcher, UOp, UPat, sym_infer, graph_rewrite, track_rewrites, KernelInfo, pyrender
|
||||
from tinygrad.uop.ops import Ops, PatternMatcher, UOp, UPat, sym_infer, graph_rewrite, print_uops, track_rewrites, KernelInfo, pyrender
|
||||
from tinygrad.device import Device, Buffer
|
||||
from tinygrad.renderer import Renderer, ProgramSpec, Estimates
|
||||
from tinygrad.codegen import full_rewrite_to_program
|
||||
from tinygrad.codegen import full_rewrite
|
||||
from tinygrad.codegen.opt import Opt
|
||||
|
||||
# **************** Program Creation ****************
|
||||
|
||||
@track_rewrites(name=lambda *args,ret,**kwargs: TracingKey(ret.name, (ret.function_name, ret.ast), ret=ret), replay=True)
|
||||
def get_program(ast:UOp, renderer:Renderer, device:str|None=None, opts:list[Opt]|None=None) -> ProgramSpec:
|
||||
def get_program(ast:UOp, renderer:Renderer, opts:list[Opt]|None=None) -> ProgramSpec:
|
||||
"""
|
||||
Transform an AST into a ProgramSpec. May trigger BEAM search.
|
||||
|
||||
Args:
|
||||
ast: The Ops.SINK rooted AST
|
||||
renderer: The renderer used to generate the code
|
||||
device: The device to compile for (defaults to renderer.device)
|
||||
|
||||
Returns:
|
||||
The ProgramSpec of the program.
|
||||
"""
|
||||
device = device or renderer.device
|
||||
|
||||
if getenv("VIZ"): graph_rewrite(ast, PatternMatcher([]), name="View Base AST")
|
||||
if DEBUG >= 5: print(pyrender(ast))
|
||||
@@ -34,14 +32,20 @@ def get_program(ast:UOp, renderer:Renderer, device:str|None=None, opts:list[Opt]
|
||||
if opts is not None:
|
||||
assert ast.arg is None, "can't apply opts if sink has an arg"
|
||||
ast = ast.replace(arg=KernelInfo(opts_to_apply=tuple(opts)))
|
||||
if ast.arg is None: ast = ast.replace(arg=KernelInfo())
|
||||
try:
|
||||
uops = full_rewrite(ast, renderer)
|
||||
except RuntimeError as e:
|
||||
print("***** LINEARIZE FAILURE *****")
|
||||
print(e)
|
||||
print(pyrender(ast))
|
||||
raise
|
||||
assert uops[-1].op is Ops.SINK, "last uop must be sink"
|
||||
|
||||
prg = full_rewrite_to_program(ast, renderer, Device[device].compiler)
|
||||
# SINK/LINEAR/SOURCE/BINARY
|
||||
sink, linear, source, binary = prg.src
|
||||
# print and render
|
||||
if DEBUG >= 6: print_uops(uops)
|
||||
src = renderer.render(uops)
|
||||
|
||||
# legacy
|
||||
return ProgramSpec(sink.arg.name, source.arg, device, sink, list(linear.src), binary.arg,
|
||||
return ProgramSpec(uops[-1].arg.name if uops[-1].arg is not None else "test", src, renderer.device, ast, uops,
|
||||
global_size=[1,1,1] if renderer.has_local or renderer.has_threads else None,
|
||||
local_size=[1,1,1] if renderer.has_local else None)
|
||||
|
||||
@@ -72,18 +76,19 @@ def optimize_local_size(_prg:Callable, global_size:list[int], rawbufs:list[Buffe
|
||||
return ret[1]
|
||||
|
||||
class CompiledRunner(Runner):
|
||||
def __init__(self, p:ProgramSpec, prg=None):
|
||||
def __init__(self, p:ProgramSpec, precompiled:bytes|None=None, prg=None):
|
||||
if DEBUG >= 3: print(p.applied_opts)
|
||||
if DEBUG >= 4: print(p.src)
|
||||
if p.lib is None:
|
||||
with cpu_profile(TracingKey(f"compile {p.name}", (p.function_name,)), "TINY"):
|
||||
p = replace(p, lib=Device[p.device].compiler.compile_cached(p.src))
|
||||
self.p:ProgramSpec = p
|
||||
if DEBUG >= 7: Device[p.device].compiler.disassemble(unwrap(p.lib))
|
||||
self._prg = Device[p.device].runtime(p.function_name, p.lib) if prg is None else prg
|
||||
if precompiled is not None: self.lib = precompiled
|
||||
else:
|
||||
with cpu_profile(TracingKey(f"compile {p.name}", (p.function_name,)), "TINY"):
|
||||
self.lib = Device[p.device].compiler.compile_cached(p.src)
|
||||
if DEBUG >= 7: Device[p.device].compiler.disassemble(self.lib)
|
||||
self._prg = Device[p.device].runtime(p.function_name, self.lib) if prg is None else prg
|
||||
super().__init__(p.name, p.device, p.estimates)
|
||||
|
||||
def __reduce__(self): return self.__class__, (self.p,)
|
||||
def __reduce__(self): return self.__class__, (self.p, self.lib)
|
||||
|
||||
def __call__(self, rawbufs:list[Buffer], var_vals:dict[str, int]|None=None, wait=False) -> float|None:
|
||||
if var_vals is None: var_vals = {}
|
||||
@@ -157,9 +162,9 @@ def get_runner(device:str, ast:UOp) -> CompiledRunner:
|
||||
if cret:=method_cache.get(ckey): return cret
|
||||
bkey = (device.split(":")[0], type(Device[device].compiler), ast.key, context, True)
|
||||
if bret:=method_cache.get(bkey):
|
||||
method_cache[ckey] = ret = CompiledRunner(replace(bret.p, device=device))
|
||||
method_cache[ckey] = ret = CompiledRunner(replace(bret.p, device=device), bret.lib)
|
||||
else:
|
||||
prg: ProgramSpec = get_program(ast, Device[device].renderer, device)
|
||||
prg: ProgramSpec = get_program(ast, Device[device].renderer)
|
||||
method_cache[ckey] = method_cache[bkey] = ret = CompiledRunner(replace(prg, device=device))
|
||||
return ret
|
||||
|
||||
|
||||
@@ -64,7 +64,6 @@ class ProgramSpec:
|
||||
device:str
|
||||
ast:UOp # save the base ast (this is method cache key)
|
||||
uops:list[UOp]|None=None
|
||||
lib:bytes|None=None # compiled binary
|
||||
|
||||
# filled in from uops (if we have uops)
|
||||
global_size:list[int]|None=None
|
||||
|
||||
@@ -27,10 +27,6 @@ class Ops(FastEnum):
|
||||
# uops that aren't rendered
|
||||
NOOP = auto(); REWRITE_ERROR = auto()
|
||||
|
||||
# renderer/compiler
|
||||
# LINEAR is a list of UOps, SOURCE has a str arg that's human readable, BINARY has a bytes arg that's not
|
||||
PROGRAM = auto(); LINEAR = auto(); SOURCE = auto(); BINARY = auto()
|
||||
|
||||
# AFTER passes src[0] through and promises in the toposort that any consumers of the AFTER run after src[1:]
|
||||
# GROUP is a NOOP that just merges things together
|
||||
SINK = auto(); AFTER = auto(); GROUP = auto()
|
||||
|
||||
@@ -218,8 +218,7 @@ class UOp(OpMixin, metaclass=UOpMetaClass):
|
||||
match self.op:
|
||||
# late ops don't have shape
|
||||
case Ops.UNIQUE | Ops.LUNIQUE | Ops.DEVICE | Ops.RANGE | Ops.LOAD | Ops.IF | Ops.BARRIER | Ops.CUSTOM | Ops.CUSTOMI | \
|
||||
Ops.VECTORIZE | Ops.VCONST | Ops.GEP | Ops.SPECIAL | Ops.UNROLL | Ops.CONTRACT | Ops.CUSTOM_KERNEL | \
|
||||
Ops.LINEAR | Ops.PROGRAM | Ops.SOURCE | Ops.BINARY:
|
||||
Ops.VECTORIZE | Ops.VCONST | Ops.GEP | Ops.SPECIAL | Ops.UNROLL | Ops.CONTRACT | Ops.CUSTOM_KERNEL:
|
||||
return None
|
||||
|
||||
case Ops.INDEX:
|
||||
|
||||
@@ -249,16 +249,6 @@ full_spec = PatternMatcher([
|
||||
# in progress MSTACK may lose device
|
||||
(UPat((Ops.MSELECT, Ops.MSTACK), name="x"), lambda x: True),
|
||||
|
||||
# codegen: PROGRAM with progressive sources through the pipeline
|
||||
(UPat(Ops.PROGRAM, dtypes.void, src=(UPat(Ops.SINK),)), lambda: True),
|
||||
(UPat(Ops.PROGRAM, dtypes.void, src=(UPat(Ops.SINK), UPat(Ops.LINEAR))), lambda: True),
|
||||
(UPat(Ops.PROGRAM, dtypes.void, src=(UPat(Ops.SINK), UPat(Ops.LINEAR), UPat(Ops.SOURCE))), lambda: True),
|
||||
(UPat(Ops.PROGRAM, dtypes.void, src=(UPat(Ops.SINK), UPat(Ops.LINEAR), UPat(Ops.SOURCE), UPat(Ops.BINARY))), lambda: True),
|
||||
# codegen: standalone LINEAR/SOURCE/BINARY
|
||||
(UPat(Ops.LINEAR, dtypes.void), lambda: True),
|
||||
(UPat(Ops.SOURCE, dtypes.void, src=()), lambda: True),
|
||||
(UPat(Ops.BINARY, dtypes.void, src=()), lambda: True),
|
||||
|
||||
# temp VECTORIZE/INDEX during rewrite have the wrong dtype
|
||||
(UPat(Ops.VECTORIZE), lambda: True),
|
||||
(UPat(Ops.INDEX), lambda: True),
|
||||
|
||||
Reference in New Issue
Block a user