mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-10 07:28:15 -05:00
cleaning up search with Program (#4500)
* cleaning up search * fix tests * test fix * minor compiler cleanup
This commit is contained in:
@@ -59,8 +59,8 @@ lin = Device[DEVICE].get_linearizer(st_0).linearize()
|
||||
for u in lin.uops: print(u)
|
||||
|
||||
# compile a program (and print the source)
|
||||
fxn = Device[DEVICE].to_program(lin)
|
||||
print(fxn.p.prg)
|
||||
fxn = Device[DEVICE].to_runner(lin)
|
||||
print(fxn.p.src)
|
||||
# NOTE: fxn.clprg is the ClangProgram
|
||||
|
||||
# run the program
|
||||
|
||||
@@ -25,7 +25,7 @@ def compile_net(run:TinyJit, special_names:Dict[int,str]) -> Tuple[Dict[str,str]
|
||||
functions, bufs, bufs_to_save, statements, bufnum = {}, {}, {}, [], 0
|
||||
for ji in run.jit_cache:
|
||||
fxn: Program = ji.prg.p
|
||||
functions[fxn.function_name] = fxn.prg # NOTE: this assumes all with the same name are the same
|
||||
functions[fxn.function_name] = fxn.src # NOTE: this assumes all with the same name are the same
|
||||
cargs = []
|
||||
for i,arg in enumerate(ji.bufs):
|
||||
key = id(arg)
|
||||
|
||||
2
test/external/fuzz_linearizer.py
vendored
2
test/external/fuzz_linearizer.py
vendored
@@ -55,7 +55,7 @@ def run_linearizer(lin: Linearizer, rawbufs=None, var_vals=None):
|
||||
|
||||
# TODO: images needs required_optimization
|
||||
try:
|
||||
prg = device.to_program(lin)
|
||||
prg = device.to_runner(lin)
|
||||
except Exception:
|
||||
traceback.print_exc()
|
||||
return "COMPILE_ERROR"
|
||||
|
||||
2
test/external/speed_compare_cuda_ptx.py
vendored
2
test/external/speed_compare_cuda_ptx.py
vendored
@@ -28,7 +28,7 @@ if __name__ == "__main__":
|
||||
# cuda compile
|
||||
lin = ast_str_to_lin(ast, opts=dev.compiler.compiler_opts)
|
||||
lin.hand_coded_optimizations()
|
||||
cuda_prg = dev.to_program(lin)
|
||||
cuda_prg = dev.to_runner(lin)
|
||||
|
||||
bufs = bufs_from_lin(lin)
|
||||
|
||||
|
||||
@@ -29,7 +29,7 @@ class TestFusionOp(unittest.TestCase):
|
||||
sched = create_schedule([a.lazydata], None)
|
||||
ji = Device[Device.DEFAULT].get_runner(*sched[-1].ast)
|
||||
self.assertLess(time.perf_counter()-st, 1.0)
|
||||
assert len(ji.p.prg.splitlines()) < 250
|
||||
assert len(ji.p.src.splitlines()) < 250
|
||||
|
||||
def test_recursive_add_cmp(self):
|
||||
st = time.perf_counter()
|
||||
|
||||
@@ -269,7 +269,7 @@ class TestLinearizer(unittest.TestCase):
|
||||
assert len([uop for uop in k.uops if uop.uop is UOps.WMMA]) > 0, "tensor core not triggered"
|
||||
assert len([x for x in k.applied_opts if x.op is OptOps.TC]) == 1, "tensor core opt not included"
|
||||
|
||||
prg = Device[Device.DEFAULT].to_program(k)
|
||||
prg = Device[Device.DEFAULT].to_runner(k)
|
||||
real_bufs[0].copyin(np.zeros((real_bufs[0].size, ), dtype=real_bufs[0].dtype.np).data) # Zero to check that all values are filled
|
||||
prg.exec(real_bufs)
|
||||
result = np.frombuffer(real_bufs[0].as_buffer(), real_bufs[0].dtype.np)
|
||||
@@ -602,19 +602,19 @@ def helper_linearizer_opt(r:Tensor, opts=[], apply_tc=False, atol=1e-4, rtol=1e-
|
||||
|
||||
# Get baseline, which is not optimized at all.
|
||||
k = Linearizer(realized_ast)
|
||||
prg = Device[Device.DEFAULT].to_program(k)
|
||||
prg = Device[Device.DEFAULT].to_runner(k)
|
||||
prg.exec(real_bufs)
|
||||
wanna_output = np.frombuffer(real_bufs[0].as_buffer(), real_bufs[0].dtype.np).copy()
|
||||
|
||||
# Check correctness of handcoded optimiztions.
|
||||
k = Linearizer(realized_ast)
|
||||
k.hand_coded_optimizations()
|
||||
prg = Device[Device.DEFAULT].to_program(k)
|
||||
prg = Device[Device.DEFAULT].to_runner(k)
|
||||
real_bufs[0].copyin(np.zeros((real_bufs[0].size, ), dtype=real_bufs[0].dtype.np).data) # Zero to check that all values are filled
|
||||
prg.exec(real_bufs)
|
||||
np.testing.assert_allclose(wanna_output, np.frombuffer(real_bufs[0].as_buffer(), real_bufs[0].dtype.np), atol=atol, rtol=rtol)
|
||||
for i, x in enumerate(opts): # Check custom transformations if any.
|
||||
check_opt(x, lambda: Linearizer(realized_ast), Device[Device.DEFAULT].to_program, color_sizes[i] if i < len(color_sizes) else None)
|
||||
check_opt(x, lambda: Linearizer(realized_ast), Device[Device.DEFAULT].to_runner, color_sizes[i] if i < len(color_sizes) else None)
|
||||
|
||||
class TestKernelOpts(unittest.TestCase):
|
||||
def test_local_and_grouped_reduce(self):
|
||||
|
||||
@@ -41,7 +41,7 @@ class TestBEAM(unittest.TestCase):
|
||||
with Context(BEAM=0): Tensor.zeros(16).contiguous().realize()
|
||||
k_beam_0 = capturing[0].captured
|
||||
capturing.clear()
|
||||
assert k_beam_0[-1].prg.p.prg != k_beam_1[-1].prg.p.prg
|
||||
assert k_beam_0[-1].prg.p.src != k_beam_1[-1].prg.p.src
|
||||
|
||||
def test_get_linearizer_actions(self):
|
||||
from test.test_linearizer import helper_realized_ast
|
||||
|
||||
@@ -2,7 +2,7 @@ from __future__ import annotations
|
||||
import multiprocessing
|
||||
from collections import defaultdict
|
||||
from dataclasses import dataclass, replace
|
||||
from typing import TYPE_CHECKING, Any, List, Optional, Dict, Tuple, ClassVar, cast
|
||||
from typing import TYPE_CHECKING, Any, List, Optional, Dict, Tuple, ClassVar
|
||||
import importlib, inspect, functools, pathlib, time, ctypes, os
|
||||
from tinygrad.helpers import prod, getenv, colored, all_int, to_function_name, from_mv, flat_mv, diskcache_get, diskcache_put, DEBUG, BEAM, NOOPT
|
||||
from tinygrad.shape.symbolic import Variable, sym_infer, sint
|
||||
@@ -148,10 +148,19 @@ class Compiler:
|
||||
if self.cachekey is not None: diskcache_put(self.cachekey, src, lib)
|
||||
return lib
|
||||
|
||||
def to_program(self, k:Linearizer, override_device:Optional[str]=None) -> Program:
|
||||
k.linearize()
|
||||
info = get_lazyop_info(k.ast[0])
|
||||
ops, mem = k.uops.flops_mem()
|
||||
run_count = prod((k.global_size if k.global_size else []) + (k.local_size if k.local_size else []))
|
||||
# NOTE: we use min here to ignore the indexing FLOPS
|
||||
return Program(k.name, self.render(to_function_name(k.name), k.uops), override_device if override_device else self.compiler_opts.device,
|
||||
k.global_size, k.local_size, k.uops, min(info.flops, ops * run_count), min(info.mem_estimate, mem * run_count))
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class Program:
|
||||
name:str
|
||||
prg:str
|
||||
src:str
|
||||
dname:str
|
||||
global_size:Optional[List[int]]=None
|
||||
local_size:Optional[List[int]]=None
|
||||
@@ -171,10 +180,6 @@ class Program:
|
||||
@functools.cached_property
|
||||
def function_name(self) -> str: return to_function_name(self.name)
|
||||
|
||||
def compile(self, cached=True) -> bytes:
|
||||
compiler = cast(Compiler, Device[self.dname].compiler)
|
||||
return compiler.compile_cached(self.prg) if cached else compiler.compile(self.prg)
|
||||
|
||||
def launch_dims(self, var_vals:Dict[Variable, int]):
|
||||
global_size = [sym_infer(sz, var_vals) for sz in self.global_size] if self.global_size is not None else None
|
||||
local_size = [sym_infer(sz, var_vals) for sz in self.local_size] if self.local_size is not None else None
|
||||
@@ -182,9 +187,9 @@ class Program:
|
||||
|
||||
class CompiledRunner(Runner):
|
||||
def __init__(self, p:Program, precompiled:Optional[bytes]=None):
|
||||
if DEBUG >= 4: print(p.prg)
|
||||
if DEBUG >= 4: print(p.src)
|
||||
self.p:Program = p
|
||||
self.lib:bytes = precompiled if precompiled is not None else self.p.compile()
|
||||
self.lib:bytes = precompiled if precompiled is not None else Device[p.dname].compiler.compile_cached(p.src)
|
||||
self.clprg = Device[p.dname].runtime(p.function_name, self.lib)
|
||||
super().__init__(p.name, p.dname, p.op_estimate, p.mem_estimate)
|
||||
|
||||
@@ -210,21 +215,12 @@ method_cache: Dict[Tuple[str, Tuple[LazyOp, ...], int, bool], CompiledRunner] =
|
||||
logkerns, logkerns_level = open(getenv("LOGKERNS", ""), "a") if getenv("LOGKERNS", "") else None, getenv("LOGKERNS_LEVEL", 1)
|
||||
class Compiled:
|
||||
def __init__(self, device:str, allocator:Allocator, compiler:Optional[Compiler], runtime, graph=None):
|
||||
self.dname, self.allocator, self.compiler, self.runtime, self.graph = device, allocator, compiler, runtime, graph
|
||||
self.dname, self.allocator, self.compiler, self.runtime, self.graph = device, allocator, compiler if compiler else Compiler(), runtime, graph
|
||||
def synchronize(self): pass # override this in your device
|
||||
|
||||
def to_program(self, k:Linearizer) -> CompiledRunner:
|
||||
assert self.compiler is not None, "compiler is required to run AST"
|
||||
k.linearize()
|
||||
info = get_lazyop_info(k.ast[0])
|
||||
ops, mem = k.uops.flops_mem()
|
||||
run_count = prod((k.global_size if k.global_size else []) + (k.local_size if k.local_size else []))
|
||||
# NOTE: we use min here to ignore the indexing FLOPS
|
||||
return CompiledRunner(Program(k.name, self.compiler.render(to_function_name(k.name), k.uops), self.dname, k.global_size, k.local_size,
|
||||
k.uops, min(info.flops, ops * run_count), min(info.mem_estimate, mem * run_count)))
|
||||
def to_runner(self, k:Linearizer) -> CompiledRunner: return CompiledRunner(self.compiler.to_program(k, override_device=self.dname))
|
||||
|
||||
def get_linearizer(self, *ast:LazyOp) -> Linearizer:
|
||||
assert self.compiler is not None, "compiler is required to build AST"
|
||||
if DEBUG >= 3:
|
||||
from tinygrad.features.graph import print_tree
|
||||
for op in ast: print_tree(op)
|
||||
@@ -261,6 +257,6 @@ class Compiled:
|
||||
if bret:=method_cache.get(bkey):
|
||||
method_cache[ckey] = ret = CompiledRunner(replace(bret.p, dname=self.dname), bret.lib)
|
||||
else:
|
||||
method_cache[ckey] = method_cache[bkey] = ret = self.to_program(self.get_linearizer(*ast))
|
||||
method_cache[ckey] = method_cache[bkey] = ret = self.to_runner(self.get_linearizer(*ast))
|
||||
return ret
|
||||
|
||||
|
||||
@@ -1,7 +1,8 @@
|
||||
from typing import Dict, List, cast, DefaultDict, Optional, Tuple, Callable
|
||||
import itertools, functools, random, math, time, multiprocessing, traceback, signal
|
||||
from collections import defaultdict
|
||||
from tinygrad.device import Device, Compiled, Buffer, CompiledRunner, Compiler, Program
|
||||
from dataclasses import replace
|
||||
from tinygrad.device import Device, Buffer, CompiledRunner, Compiler, Program
|
||||
from tinygrad.ops import MemBuffer
|
||||
from tinygrad.helpers import prod, flatten, DEBUG, CACHELEVEL, diskcache_get, diskcache_put, getenv, Context, colored, to_function_name
|
||||
from tinygrad.dtype import ImageDType
|
||||
@@ -31,12 +32,12 @@ def _get_test_global_size(global_size, max_global_size, var_vals):
|
||||
break
|
||||
return test_global_size, factor
|
||||
|
||||
def _time_program(uops, rdev:Compiled, lib:bytes, global_size, local_size, var_vals, rawbufs,
|
||||
early_stop=None, max_global_size=65536, clear_l2=False, cnt=3, name="test"):
|
||||
def _time_program(p:Program, lib:bytes, var_vals, rawbufs, early_stop=None, max_global_size=65536, clear_l2=False, cnt=3, name="test"):
|
||||
factor = 1
|
||||
if global_size is not None and max_global_size is not None:
|
||||
global_size, factor = _get_test_global_size(global_size, max_global_size, var_vals)
|
||||
try: car = CompiledRunner(Program(name, "", rdev.dname, global_size, local_size, uops), precompiled=lib)
|
||||
if p.global_size is not None and max_global_size is not None:
|
||||
global_size, factor = _get_test_global_size(p.global_size, max_global_size, var_vals)
|
||||
p = replace(p, global_size=global_size)
|
||||
try: car = CompiledRunner(p, precompiled=lib)
|
||||
except AssertionError: return [math.inf] * cnt
|
||||
tms = []
|
||||
for _ in range(cnt):
|
||||
@@ -46,19 +47,15 @@ def _time_program(uops, rdev:Compiled, lib:bytes, global_size, local_size, var_v
|
||||
if early_stop is not None and early_stop < tms[-1]: break
|
||||
return tms
|
||||
|
||||
def _compile_linearizer(compiler:Compiler, lin:Linearizer, name:Optional[str]=None, enforce_max:bool=False) \
|
||||
-> Tuple[bytes, Optional[List[int]], Optional[List[int]], UOpGraph, float]:
|
||||
lin.linearize()
|
||||
if enforce_max and len(lin.uops.uops) >= getenv("BEAM_UOPS_MAX", 3000) > 0: raise RuntimeError("too many uops")
|
||||
src = compiler.render(name if name is not None else to_function_name(lin.name), lin.uops) # NOTE: these all have the same name for deduping
|
||||
if DEBUG >= 5: print(src)
|
||||
st = time.perf_counter()
|
||||
prog = compiler.compile(src)
|
||||
et = time.perf_counter() - st
|
||||
return prog, lin.global_size, lin.local_size, lin.uops, et
|
||||
|
||||
def _try_compile_linearized_w_idx(x:Tuple[int,Linearizer], compiler:Compiler):
|
||||
try: return x[0], _compile_linearizer(compiler, x[1], "test", enforce_max=True)
|
||||
def _try_compile_linearized_w_idx(x:Tuple[int,Linearizer], compiler:Compiler) -> Tuple[int, Optional[Tuple[Program, bytes, float]]]:
|
||||
try:
|
||||
x[1].linearize()
|
||||
if len(x[1].uops.uops) >= getenv("BEAM_UOPS_MAX", 3000) > 0: raise RuntimeError("too many uops")
|
||||
p = compiler.to_program(x[1])
|
||||
st = time.perf_counter()
|
||||
prog = compiler.compile(p.src)
|
||||
et = time.perf_counter() - st
|
||||
return x[0], (p, prog, et)
|
||||
except Exception:
|
||||
if DEBUG >= 4: traceback.print_exc()
|
||||
return x[0], None
|
||||
@@ -131,14 +128,14 @@ def beam_search(lin:Linearizer, rawbufs:List[Buffer], amt:int, allow_test_size=T
|
||||
_compile_fn = functools.partial(_try_compile_linearized_w_idx, compiler=dev.compiler)
|
||||
for i,proc in (map(_compile_fn, enumerate(acted_lins)) if beam_pool is None else beam_pool.imap_unordered(_compile_fn, enumerate(acted_lins))):
|
||||
if proc is None: continue
|
||||
lib, global_size, local_size, uops, compile_et = proc
|
||||
p, lib, compile_et = proc
|
||||
if lib in seen_libs: continue
|
||||
#print(acted_lins[i].colored_shape(), acted_lins[i].applied_opts) # for debugging BEAMs that segfault
|
||||
seen_libs.add(lib)
|
||||
try: tms = _time_program(uops, dev, lib, global_size, local_size, var_vals, rawbufs, early_stop=beam[0][1]*3 if len(beam) else 1.0)
|
||||
try: tms = _time_program(p, lib, var_vals, rawbufs, early_stop=beam[0][1]*3 if len(beam) else 1.0)
|
||||
except RuntimeError: continue # for runtime issues
|
||||
timed_lins.append((acted_lins[i], min(tms)))
|
||||
if BEAM_DEBUG > 1: print(f"{time.perf_counter() - st:7.2f}s: {i:5d} {len(uops.uops):5d} uops {compile_et*1e6:12.2f} us compile/{timed_lins[-1][1]*1e6:12.2f} us run {len(timed_lins):4d}/{len(acted_lins):4d} {timed_lins[-1][0].colored_shape()}") # noqa: E501
|
||||
if BEAM_DEBUG > 1: print(f"{time.perf_counter() - st:7.2f}s: {i:5d} {len(cast(UOpGraph, p.uops).uops):5d} uops {compile_et*1e6:12.2f} us compile/{timed_lins[-1][1]*1e6:12.2f} us run {len(timed_lins):4d}/{len(acted_lins):4d} {timed_lins[-1][0].colored_shape()}") # noqa: E501
|
||||
elif DEBUG >= 2: print(f"\r{time.perf_counter() - st:7.2f}s: {timed_lins[-1][1]*1e6:12.2f} us {len(timed_lins):4d}/{len(acted_lins):4d} {timed_lins[-1][0].colored_shape()}\033[K", end="") # noqa: E501
|
||||
|
||||
# done
|
||||
@@ -168,7 +165,8 @@ def optimize_local_size(clprg:Callable, global_size:List[int], rawbufs:List[Buff
|
||||
return ret[1]
|
||||
|
||||
def time_linearizer(lin:Linearizer, rawbufs:List[Buffer], allow_test_size=True, max_global_size=65536, cnt=3, disable_cache=False, clear_l2=False) -> float: # noqa: E501
|
||||
key = {"ast": lin.ast[0].key, "opts": str(lin.applied_opts), "allow_test_size": allow_test_size, "max_global_size": max_global_size, "clear_l2": clear_l2, "device": lin.opts.device, "suffix": lin.opts.suffix} # noqa: E501
|
||||
key = {"ast": lin.ast[0].key, "opts": str(lin.applied_opts), "allow_test_size": allow_test_size,
|
||||
"max_global_size": max_global_size, "clear_l2": clear_l2, "device": lin.opts.device, "suffix": lin.opts.suffix}
|
||||
if not disable_cache and CACHELEVEL >= 2 and (val:=diskcache_get("time_linearizer", key)) is not None: return min(val)
|
||||
|
||||
dev = Device[lin.opts.device]
|
||||
@@ -176,8 +174,9 @@ def time_linearizer(lin:Linearizer, rawbufs:List[Buffer], allow_test_size=True,
|
||||
|
||||
rawbufs = _ensure_buffer_alloc(rawbufs)
|
||||
var_vals = {k:(k.max+k.min)//2 for k in lin.ast[0].vars()}
|
||||
lib, global_size, local_size, uops, _ = _compile_linearizer(dev.compiler, lin)
|
||||
tms = _time_program(uops, dev, lib, global_size, local_size, var_vals, rawbufs, max_global_size=max_global_size if allow_test_size else None, clear_l2=clear_l2, cnt=cnt, name=to_function_name(lin.name)) # noqa: E501
|
||||
p = dev.compiler.to_program(lin)
|
||||
tms = _time_program(p, dev.compiler.compile(p.src), var_vals, rawbufs,
|
||||
max_global_size=max_global_size if allow_test_size else None, clear_l2=clear_l2, cnt=cnt, name=to_function_name(lin.name))
|
||||
|
||||
if CACHELEVEL >= 2: diskcache_put("time_linearizer", key, tms)
|
||||
return min(tms)
|
||||
|
||||
@@ -14,7 +14,7 @@ class ClangGraph(GraphRunner):
|
||||
super().__init__(jit_cache, input_rawbuffers, var_vals)
|
||||
if not all(isinstance(ji.prg, CompiledRunner) for ji in jit_cache): raise GraphException
|
||||
|
||||
prgs = '\n'.join(dedup([cast(CompiledRunner, ji.prg).p.prg for ji in jit_cache]))
|
||||
prgs = '\n'.join(dedup([cast(CompiledRunner, ji.prg).p.src for ji in jit_cache]))
|
||||
args = [f"{render_dtype(x.dtype)}* arg{i}" for i,x in enumerate(input_rawbuffers)]
|
||||
args += [f"int {v.expr}" for v in var_vals]
|
||||
code = ["void batched("+','.join(args)+") {"]
|
||||
|
||||
Reference in New Issue
Block a user