cleaning up search with Program (#4500)

* cleaning up search

* fix tests

* test fix

* minor compiler cleanup
This commit is contained in:
George Hotz
2024-05-09 19:01:53 -07:00
committed by GitHub
parent d3dc332c2e
commit 1e843d495e
10 changed files with 52 additions and 57 deletions

View File

@@ -59,8 +59,8 @@ lin = Device[DEVICE].get_linearizer(st_0).linearize()
for u in lin.uops: print(u)
# compile a program (and print the source)
fxn = Device[DEVICE].to_program(lin)
print(fxn.p.prg)
fxn = Device[DEVICE].to_runner(lin)
print(fxn.p.src)
# NOTE: fxn.clprg is the ClangProgram
# run the program

View File

@@ -25,7 +25,7 @@ def compile_net(run:TinyJit, special_names:Dict[int,str]) -> Tuple[Dict[str,str]
functions, bufs, bufs_to_save, statements, bufnum = {}, {}, {}, [], 0
for ji in run.jit_cache:
fxn: Program = ji.prg.p
functions[fxn.function_name] = fxn.prg # NOTE: this assumes all with the same name are the same
functions[fxn.function_name] = fxn.src # NOTE: this assumes all with the same name are the same
cargs = []
for i,arg in enumerate(ji.bufs):
key = id(arg)

View File

@@ -55,7 +55,7 @@ def run_linearizer(lin: Linearizer, rawbufs=None, var_vals=None):
# TODO: images needs required_optimization
try:
prg = device.to_program(lin)
prg = device.to_runner(lin)
except Exception:
traceback.print_exc()
return "COMPILE_ERROR"

View File

@@ -28,7 +28,7 @@ if __name__ == "__main__":
# cuda compile
lin = ast_str_to_lin(ast, opts=dev.compiler.compiler_opts)
lin.hand_coded_optimizations()
cuda_prg = dev.to_program(lin)
cuda_prg = dev.to_runner(lin)
bufs = bufs_from_lin(lin)

View File

@@ -29,7 +29,7 @@ class TestFusionOp(unittest.TestCase):
sched = create_schedule([a.lazydata], None)
ji = Device[Device.DEFAULT].get_runner(*sched[-1].ast)
self.assertLess(time.perf_counter()-st, 1.0)
assert len(ji.p.prg.splitlines()) < 250
assert len(ji.p.src.splitlines()) < 250
def test_recursive_add_cmp(self):
st = time.perf_counter()

View File

@@ -269,7 +269,7 @@ class TestLinearizer(unittest.TestCase):
assert len([uop for uop in k.uops if uop.uop is UOps.WMMA]) > 0, "tensor core not triggered"
assert len([x for x in k.applied_opts if x.op is OptOps.TC]) == 1, "tensor core opt not included"
prg = Device[Device.DEFAULT].to_program(k)
prg = Device[Device.DEFAULT].to_runner(k)
real_bufs[0].copyin(np.zeros((real_bufs[0].size, ), dtype=real_bufs[0].dtype.np).data) # Zero to check that all values are filled
prg.exec(real_bufs)
result = np.frombuffer(real_bufs[0].as_buffer(), real_bufs[0].dtype.np)
@@ -602,19 +602,19 @@ def helper_linearizer_opt(r:Tensor, opts=[], apply_tc=False, atol=1e-4, rtol=1e-
# Get baseline, which is not optimized at all.
k = Linearizer(realized_ast)
prg = Device[Device.DEFAULT].to_program(k)
prg = Device[Device.DEFAULT].to_runner(k)
prg.exec(real_bufs)
wanna_output = np.frombuffer(real_bufs[0].as_buffer(), real_bufs[0].dtype.np).copy()
# Check correctness of handcoded optimiztions.
k = Linearizer(realized_ast)
k.hand_coded_optimizations()
prg = Device[Device.DEFAULT].to_program(k)
prg = Device[Device.DEFAULT].to_runner(k)
real_bufs[0].copyin(np.zeros((real_bufs[0].size, ), dtype=real_bufs[0].dtype.np).data) # Zero to check that all values are filled
prg.exec(real_bufs)
np.testing.assert_allclose(wanna_output, np.frombuffer(real_bufs[0].as_buffer(), real_bufs[0].dtype.np), atol=atol, rtol=rtol)
for i, x in enumerate(opts): # Check custom transformations if any.
check_opt(x, lambda: Linearizer(realized_ast), Device[Device.DEFAULT].to_program, color_sizes[i] if i < len(color_sizes) else None)
check_opt(x, lambda: Linearizer(realized_ast), Device[Device.DEFAULT].to_runner, color_sizes[i] if i < len(color_sizes) else None)
class TestKernelOpts(unittest.TestCase):
def test_local_and_grouped_reduce(self):

View File

@@ -41,7 +41,7 @@ class TestBEAM(unittest.TestCase):
with Context(BEAM=0): Tensor.zeros(16).contiguous().realize()
k_beam_0 = capturing[0].captured
capturing.clear()
assert k_beam_0[-1].prg.p.prg != k_beam_1[-1].prg.p.prg
assert k_beam_0[-1].prg.p.src != k_beam_1[-1].prg.p.src
def test_get_linearizer_actions(self):
from test.test_linearizer import helper_realized_ast

View File

@@ -2,7 +2,7 @@ from __future__ import annotations
import multiprocessing
from collections import defaultdict
from dataclasses import dataclass, replace
from typing import TYPE_CHECKING, Any, List, Optional, Dict, Tuple, ClassVar, cast
from typing import TYPE_CHECKING, Any, List, Optional, Dict, Tuple, ClassVar
import importlib, inspect, functools, pathlib, time, ctypes, os
from tinygrad.helpers import prod, getenv, colored, all_int, to_function_name, from_mv, flat_mv, diskcache_get, diskcache_put, DEBUG, BEAM, NOOPT
from tinygrad.shape.symbolic import Variable, sym_infer, sint
@@ -148,10 +148,19 @@ class Compiler:
if self.cachekey is not None: diskcache_put(self.cachekey, src, lib)
return lib
def to_program(self, k:Linearizer, override_device:Optional[str]=None) -> Program:
k.linearize()
info = get_lazyop_info(k.ast[0])
ops, mem = k.uops.flops_mem()
run_count = prod((k.global_size if k.global_size else []) + (k.local_size if k.local_size else []))
# NOTE: we use min here to ignore the indexing FLOPS
return Program(k.name, self.render(to_function_name(k.name), k.uops), override_device if override_device else self.compiler_opts.device,
k.global_size, k.local_size, k.uops, min(info.flops, ops * run_count), min(info.mem_estimate, mem * run_count))
@dataclass(frozen=True)
class Program:
name:str
prg:str
src:str
dname:str
global_size:Optional[List[int]]=None
local_size:Optional[List[int]]=None
@@ -171,10 +180,6 @@ class Program:
@functools.cached_property
def function_name(self) -> str: return to_function_name(self.name)
def compile(self, cached=True) -> bytes:
compiler = cast(Compiler, Device[self.dname].compiler)
return compiler.compile_cached(self.prg) if cached else compiler.compile(self.prg)
def launch_dims(self, var_vals:Dict[Variable, int]):
global_size = [sym_infer(sz, var_vals) for sz in self.global_size] if self.global_size is not None else None
local_size = [sym_infer(sz, var_vals) for sz in self.local_size] if self.local_size is not None else None
@@ -182,9 +187,9 @@ class Program:
class CompiledRunner(Runner):
def __init__(self, p:Program, precompiled:Optional[bytes]=None):
if DEBUG >= 4: print(p.prg)
if DEBUG >= 4: print(p.src)
self.p:Program = p
self.lib:bytes = precompiled if precompiled is not None else self.p.compile()
self.lib:bytes = precompiled if precompiled is not None else Device[p.dname].compiler.compile_cached(p.src)
self.clprg = Device[p.dname].runtime(p.function_name, self.lib)
super().__init__(p.name, p.dname, p.op_estimate, p.mem_estimate)
@@ -210,21 +215,12 @@ method_cache: Dict[Tuple[str, Tuple[LazyOp, ...], int, bool], CompiledRunner] =
logkerns, logkerns_level = open(getenv("LOGKERNS", ""), "a") if getenv("LOGKERNS", "") else None, getenv("LOGKERNS_LEVEL", 1)
class Compiled:
def __init__(self, device:str, allocator:Allocator, compiler:Optional[Compiler], runtime, graph=None):
self.dname, self.allocator, self.compiler, self.runtime, self.graph = device, allocator, compiler, runtime, graph
self.dname, self.allocator, self.compiler, self.runtime, self.graph = device, allocator, compiler if compiler else Compiler(), runtime, graph
def synchronize(self): pass # override this in your device
def to_program(self, k:Linearizer) -> CompiledRunner:
assert self.compiler is not None, "compiler is required to run AST"
k.linearize()
info = get_lazyop_info(k.ast[0])
ops, mem = k.uops.flops_mem()
run_count = prod((k.global_size if k.global_size else []) + (k.local_size if k.local_size else []))
# NOTE: we use min here to ignore the indexing FLOPS
return CompiledRunner(Program(k.name, self.compiler.render(to_function_name(k.name), k.uops), self.dname, k.global_size, k.local_size,
k.uops, min(info.flops, ops * run_count), min(info.mem_estimate, mem * run_count)))
def to_runner(self, k:Linearizer) -> CompiledRunner: return CompiledRunner(self.compiler.to_program(k, override_device=self.dname))
def get_linearizer(self, *ast:LazyOp) -> Linearizer:
assert self.compiler is not None, "compiler is required to build AST"
if DEBUG >= 3:
from tinygrad.features.graph import print_tree
for op in ast: print_tree(op)
@@ -261,6 +257,6 @@ class Compiled:
if bret:=method_cache.get(bkey):
method_cache[ckey] = ret = CompiledRunner(replace(bret.p, dname=self.dname), bret.lib)
else:
method_cache[ckey] = method_cache[bkey] = ret = self.to_program(self.get_linearizer(*ast))
method_cache[ckey] = method_cache[bkey] = ret = self.to_runner(self.get_linearizer(*ast))
return ret

View File

@@ -1,7 +1,8 @@
from typing import Dict, List, cast, DefaultDict, Optional, Tuple, Callable
import itertools, functools, random, math, time, multiprocessing, traceback, signal
from collections import defaultdict
from tinygrad.device import Device, Compiled, Buffer, CompiledRunner, Compiler, Program
from dataclasses import replace
from tinygrad.device import Device, Buffer, CompiledRunner, Compiler, Program
from tinygrad.ops import MemBuffer
from tinygrad.helpers import prod, flatten, DEBUG, CACHELEVEL, diskcache_get, diskcache_put, getenv, Context, colored, to_function_name
from tinygrad.dtype import ImageDType
@@ -31,12 +32,12 @@ def _get_test_global_size(global_size, max_global_size, var_vals):
break
return test_global_size, factor
def _time_program(uops, rdev:Compiled, lib:bytes, global_size, local_size, var_vals, rawbufs,
early_stop=None, max_global_size=65536, clear_l2=False, cnt=3, name="test"):
def _time_program(p:Program, lib:bytes, var_vals, rawbufs, early_stop=None, max_global_size=65536, clear_l2=False, cnt=3, name="test"):
factor = 1
if global_size is not None and max_global_size is not None:
global_size, factor = _get_test_global_size(global_size, max_global_size, var_vals)
try: car = CompiledRunner(Program(name, "", rdev.dname, global_size, local_size, uops), precompiled=lib)
if p.global_size is not None and max_global_size is not None:
global_size, factor = _get_test_global_size(p.global_size, max_global_size, var_vals)
p = replace(p, global_size=global_size)
try: car = CompiledRunner(p, precompiled=lib)
except AssertionError: return [math.inf] * cnt
tms = []
for _ in range(cnt):
@@ -46,19 +47,15 @@ def _time_program(uops, rdev:Compiled, lib:bytes, global_size, local_size, var_v
if early_stop is not None and early_stop < tms[-1]: break
return tms
def _compile_linearizer(compiler:Compiler, lin:Linearizer, name:Optional[str]=None, enforce_max:bool=False) \
-> Tuple[bytes, Optional[List[int]], Optional[List[int]], UOpGraph, float]:
lin.linearize()
if enforce_max and len(lin.uops.uops) >= getenv("BEAM_UOPS_MAX", 3000) > 0: raise RuntimeError("too many uops")
src = compiler.render(name if name is not None else to_function_name(lin.name), lin.uops) # NOTE: these all have the same name for deduping
if DEBUG >= 5: print(src)
st = time.perf_counter()
prog = compiler.compile(src)
et = time.perf_counter() - st
return prog, lin.global_size, lin.local_size, lin.uops, et
def _try_compile_linearized_w_idx(x:Tuple[int,Linearizer], compiler:Compiler):
try: return x[0], _compile_linearizer(compiler, x[1], "test", enforce_max=True)
def _try_compile_linearized_w_idx(x:Tuple[int,Linearizer], compiler:Compiler) -> Tuple[int, Optional[Tuple[Program, bytes, float]]]:
try:
x[1].linearize()
if len(x[1].uops.uops) >= getenv("BEAM_UOPS_MAX", 3000) > 0: raise RuntimeError("too many uops")
p = compiler.to_program(x[1])
st = time.perf_counter()
prog = compiler.compile(p.src)
et = time.perf_counter() - st
return x[0], (p, prog, et)
except Exception:
if DEBUG >= 4: traceback.print_exc()
return x[0], None
@@ -131,14 +128,14 @@ def beam_search(lin:Linearizer, rawbufs:List[Buffer], amt:int, allow_test_size=T
_compile_fn = functools.partial(_try_compile_linearized_w_idx, compiler=dev.compiler)
for i,proc in (map(_compile_fn, enumerate(acted_lins)) if beam_pool is None else beam_pool.imap_unordered(_compile_fn, enumerate(acted_lins))):
if proc is None: continue
lib, global_size, local_size, uops, compile_et = proc
p, lib, compile_et = proc
if lib in seen_libs: continue
#print(acted_lins[i].colored_shape(), acted_lins[i].applied_opts) # for debugging BEAMs that segfault
seen_libs.add(lib)
try: tms = _time_program(uops, dev, lib, global_size, local_size, var_vals, rawbufs, early_stop=beam[0][1]*3 if len(beam) else 1.0)
try: tms = _time_program(p, lib, var_vals, rawbufs, early_stop=beam[0][1]*3 if len(beam) else 1.0)
except RuntimeError: continue # for runtime issues
timed_lins.append((acted_lins[i], min(tms)))
if BEAM_DEBUG > 1: print(f"{time.perf_counter() - st:7.2f}s: {i:5d} {len(uops.uops):5d} uops {compile_et*1e6:12.2f} us compile/{timed_lins[-1][1]*1e6:12.2f} us run {len(timed_lins):4d}/{len(acted_lins):4d} {timed_lins[-1][0].colored_shape()}") # noqa: E501
if BEAM_DEBUG > 1: print(f"{time.perf_counter() - st:7.2f}s: {i:5d} {len(cast(UOpGraph, p.uops).uops):5d} uops {compile_et*1e6:12.2f} us compile/{timed_lins[-1][1]*1e6:12.2f} us run {len(timed_lins):4d}/{len(acted_lins):4d} {timed_lins[-1][0].colored_shape()}") # noqa: E501
elif DEBUG >= 2: print(f"\r{time.perf_counter() - st:7.2f}s: {timed_lins[-1][1]*1e6:12.2f} us {len(timed_lins):4d}/{len(acted_lins):4d} {timed_lins[-1][0].colored_shape()}\033[K", end="") # noqa: E501
# done
@@ -168,7 +165,8 @@ def optimize_local_size(clprg:Callable, global_size:List[int], rawbufs:List[Buff
return ret[1]
def time_linearizer(lin:Linearizer, rawbufs:List[Buffer], allow_test_size=True, max_global_size=65536, cnt=3, disable_cache=False, clear_l2=False) -> float: # noqa: E501
key = {"ast": lin.ast[0].key, "opts": str(lin.applied_opts), "allow_test_size": allow_test_size, "max_global_size": max_global_size, "clear_l2": clear_l2, "device": lin.opts.device, "suffix": lin.opts.suffix} # noqa: E501
key = {"ast": lin.ast[0].key, "opts": str(lin.applied_opts), "allow_test_size": allow_test_size,
"max_global_size": max_global_size, "clear_l2": clear_l2, "device": lin.opts.device, "suffix": lin.opts.suffix}
if not disable_cache and CACHELEVEL >= 2 and (val:=diskcache_get("time_linearizer", key)) is not None: return min(val)
dev = Device[lin.opts.device]
@@ -176,8 +174,9 @@ def time_linearizer(lin:Linearizer, rawbufs:List[Buffer], allow_test_size=True,
rawbufs = _ensure_buffer_alloc(rawbufs)
var_vals = {k:(k.max+k.min)//2 for k in lin.ast[0].vars()}
lib, global_size, local_size, uops, _ = _compile_linearizer(dev.compiler, lin)
tms = _time_program(uops, dev, lib, global_size, local_size, var_vals, rawbufs, max_global_size=max_global_size if allow_test_size else None, clear_l2=clear_l2, cnt=cnt, name=to_function_name(lin.name)) # noqa: E501
p = dev.compiler.to_program(lin)
tms = _time_program(p, dev.compiler.compile(p.src), var_vals, rawbufs,
max_global_size=max_global_size if allow_test_size else None, clear_l2=clear_l2, cnt=cnt, name=to_function_name(lin.name))
if CACHELEVEL >= 2: diskcache_put("time_linearizer", key, tms)
return min(tms)

View File

@@ -14,7 +14,7 @@ class ClangGraph(GraphRunner):
super().__init__(jit_cache, input_rawbuffers, var_vals)
if not all(isinstance(ji.prg, CompiledRunner) for ji in jit_cache): raise GraphException
prgs = '\n'.join(dedup([cast(CompiledRunner, ji.prg).p.prg for ji in jit_cache]))
prgs = '\n'.join(dedup([cast(CompiledRunner, ji.prg).p.src for ji in jit_cache]))
args = [f"{render_dtype(x.dtype)}* arg{i}" for i,x in enumerate(input_rawbuffers)]
args += [f"int {v.expr}" for v in var_vals]
code = ["void batched("+','.join(args)+") {"]