mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-07 03:00:26 -04:00
search: add a BEAM_COMPARE env to optionally not compare to hc/tc (#4107)
* search: add a BEAM_COMPARE env to optionally not compare to hc/tc setting BEAM_COMPARE=0 will prevent additional memory allocation needed to do the timing tests assuming the BEAM result is in the diskcache. * change to optionally use Buffer.allocate
This commit is contained in:
@@ -23,8 +23,8 @@ if __name__ == '__main__':
|
||||
with open(args.file, 'r') as file:
|
||||
ast_strs = file.readlines()
|
||||
|
||||
for ast_str in ast_strs:
|
||||
print(f"optimizing ast={ast_str}")
|
||||
for i, ast_str in enumerate(ast_strs):
|
||||
print(f"optimizing {i}/{len(ast_strs)}\nast={ast_str}")
|
||||
lin = ast_str_to_lin(ast_str, opts=device.compiler.compiler_opts)
|
||||
rawbufs = bufs_from_lin(lin)
|
||||
lin = beam_search(lin, rawbufs, getenv("BEAM", 8), bool(getenv("BEAM_ESTIMATE", 1)))
|
||||
|
||||
@@ -233,21 +233,24 @@ class Compiled:
|
||||
if not NOOPT:
|
||||
if not (used_tensor_cores:=k.apply_tensor_cores(getenv("TC", 1))): k.hand_coded_optimizations()
|
||||
if BEAM >= 1:
|
||||
lins = [(("tc" if used_tensor_cores else "hc"), k)]
|
||||
if used_tensor_cores:
|
||||
lins.append(("hc", Linearizer(*ast, opts=self.compiler.compiler_opts)))
|
||||
lins[-1][1].hand_coded_optimizations()
|
||||
kb = Linearizer(*ast, opts=self.compiler.compiler_opts)
|
||||
kb.required_optimizations()
|
||||
from tinygrad.features.search import beam_search, time_linearizer, bufs_from_lin
|
||||
test_rawbuffers = bufs_from_lin(kb) # allocate scratch buffers for optimization
|
||||
lins.append((f"beam{BEAM.value}", beam_search(kb, test_rawbuffers, BEAM.value, bool(getenv("BEAM_ESTIMATE", 1)))))
|
||||
timed = sorted([(nm, tk, time_linearizer(tk, test_rawbuffers, allow_test_size=False, clear_l2=True)) for nm, tk in lins], key=lambda x: x[2])
|
||||
if DEBUG >= 1: print(" < ".join(f"{nm:6s} : {lin.colored_shape(30, dense=True)} : {tm*1e6:8.2f} us" for nm, lin, tm in timed))
|
||||
k = timed[0][1]
|
||||
if logkern is not None and logkern_level > 1: logkern.writelines([f"{(lin.ast, lin.applied_opts)}\n" for (_,lin,_) in timed[1:]])
|
||||
kb, k_opt = Linearizer(*ast, opts=self.compiler.compiler_opts), k
|
||||
kb.required_optimizations()
|
||||
rawbufs = bufs_from_lin(kb, allocate=False)
|
||||
k = beam_search(kb, rawbufs, BEAM.value, bool(getenv("BEAM_ESTIMATE", 1)))
|
||||
if getenv("BEAM_COMPARE", 1):
|
||||
# TODO: move the HC/TC/BEAM compare to beam_search so it can be optionally cached which choice is better
|
||||
lins = [(f"beam{BEAM.value}", k), (("tc" if used_tensor_cores else "hc"), k_opt)]
|
||||
if used_tensor_cores:
|
||||
lins.append(("hc", Linearizer(*ast, opts=self.compiler.compiler_opts)))
|
||||
lins[-1][1].hand_coded_optimizations()
|
||||
timed = sorted([(nm, tk, time_linearizer(tk, rawbufs, allow_test_size=False, clear_l2=True)) for nm, tk in lins], key=lambda x: x[2])
|
||||
if DEBUG >= 1: print(" < ".join(f"{nm:6s} : {lin.colored_shape(30, dense=True)} : {tm*1e6:8.2f} us" for nm, lin, tm in timed))
|
||||
k = timed[0][1]
|
||||
if logkern is not None and logkern_level > 1: logkern.writelines([f"{(lin.ast, lin.applied_opts)}\n" for (_,lin,_) in timed[1:]])
|
||||
# TODO: check the correctness inline once compare_linearizer is in core
|
||||
if logkern is not None: logkern.writelines([f"{(k.ast, k.applied_opts)}\n"])
|
||||
if DEBUG >= 4: print((k.ast, k.applied_opts)) # print here to show final applied_opts for all kernels instead of just in beam_search
|
||||
return k
|
||||
|
||||
@functools.lru_cache(None) # pylint: disable=method-cache-max-size-none
|
||||
|
||||
@@ -64,17 +64,19 @@ def _try_compile_linearized_w_idx(x:Tuple[int,Linearizer], compiler:Compiler):
|
||||
# workers should ignore ctrl c
|
||||
def _init_worker(): signal.signal(signal.SIGINT, signal.SIG_IGN)
|
||||
|
||||
def _ensure_buffer_alloc(bufs:List[Buffer]) -> List[Buffer]: return [buf if hasattr(buf, "_buf") else buf.allocate() for buf in bufs]
|
||||
|
||||
# *** external API ***
|
||||
|
||||
# get (scrap) buffers for timing the linearizer
|
||||
def bufs_from_lin(lin:Linearizer) -> List[Buffer]:
|
||||
def bufs_from_lin(lin:Linearizer, allocate:bool=True) -> List[Buffer]:
|
||||
bufsts:DefaultDict[int, List[MemBuffer]] = defaultdict(list)
|
||||
for x in lin.membufs: bufsts[x.idx].append(x)
|
||||
rawbufs:List[Optional[Buffer]] = [None]*len(bufsts)
|
||||
for k,lx in bufsts.items():
|
||||
buf_size = prod(lx[0].dtype.shape) if isinstance(lx[0].dtype, ImageDType) else max(y.st.real_size() for y in lx)
|
||||
if buf_size == 0: buf_size = 1 # create a size 1 buffer if no cell is accessed in kernel. # TODO: remove from kernel input in this case.
|
||||
rawbufs[k] = Buffer(lin.opts.device, buf_size, lx[0].dtype).allocate()
|
||||
rawbufs[k] = Buffer(lin.opts.device, buf_size, lx[0].dtype).allocate() if allocate else Buffer(lin.opts.device, buf_size, lx[0].dtype)
|
||||
assert all(r is not None for r in rawbufs)
|
||||
return cast(List[Buffer], rawbufs)
|
||||
|
||||
@@ -97,7 +99,7 @@ def get_linearizer_actions(lin:Linearizer, include_0=True) -> Dict[int, Lineariz
|
||||
return acted_lins
|
||||
|
||||
beam_pool = None
|
||||
def beam_search(lin:Linearizer, rawbufs, amt:int, allow_test_size=True) -> Linearizer:
|
||||
def beam_search(lin:Linearizer, rawbufs:List[Buffer], amt:int, allow_test_size=True) -> Linearizer:
|
||||
global beam_pool
|
||||
key = {"ast": lin.ast[0].key, "amt": amt, "allow_test_size": allow_test_size, "device": lin.opts.device, "suffix": lin.opts.suffix}
|
||||
if (val:=diskcache_get("beam_search", key)) is not None and not getenv("IGNORE_BEAM_CACHE") and CACHELEVEL >= 1:
|
||||
@@ -113,6 +115,7 @@ def beam_search(lin:Linearizer, rawbufs, amt:int, allow_test_size=True) -> Linea
|
||||
beam_pool = multiprocessing.get_context("spawn").Pool(multiprocessing.cpu_count(), _init_worker, (), getenv("BEAM_MAX_TASKS_PER_CHILD", 16))
|
||||
|
||||
try:
|
||||
rawbufs = _ensure_buffer_alloc(rawbufs)
|
||||
var_vals = {k:(k.max+k.min)//2 for k in lin.ast[0].vars()}
|
||||
exiting, st = False, time.perf_counter()
|
||||
dev = Device[lin.opts.device]
|
||||
@@ -137,14 +140,13 @@ def beam_search(lin:Linearizer, rawbufs, amt:int, allow_test_size=True) -> Linea
|
||||
exiting = len(opts) == 0 or (len(beam) > 0 and ((beam[0][1]-opts[0][1])*1e6 < min_progress_micros))
|
||||
if not exiting: beam = opts[:amt]
|
||||
elif len(opts) > 0 and opts[0][1] < beam[0][1]: beam = opts[:1]
|
||||
assert len(beam) > 0, "no BEAM items succeeded?!?"
|
||||
assert len(beam) > 0, "no BEAM items succeeded?!?" # this asserts in unet3d multi-gpu, need to figure out why
|
||||
if DEBUG >= 2: print(f"\r{time.perf_counter() - st:7.2f}s:", colored(f"{beam[0][1]*1e6:12.2f} us", "green" if exiting else None), f"from {len(acted_lins):3d} -> {len(opts):3d} actions\033[K", beam[0][0].colored_shape()) # noqa: E501
|
||||
except KeyboardInterrupt as e:
|
||||
if beam_pool is not None: beam_pool.terminate()
|
||||
raise e
|
||||
|
||||
if CACHELEVEL >= 1: diskcache_put("beam_search", key, beam[0][0].applied_opts)
|
||||
if DEBUG >= 3: print(beam[0][0].applied_opts)
|
||||
return beam[0][0]
|
||||
|
||||
def optimize_local_size(clprg:Callable, global_size:List[int], rawbufs:List[Buffer]) -> List[int]:
|
||||
@@ -166,6 +168,7 @@ def time_linearizer(lin:Linearizer, rawbufs:List[Buffer], allow_test_size=True,
|
||||
dev = Device[lin.opts.device]
|
||||
assert dev.compiler is not None
|
||||
|
||||
rawbufs = _ensure_buffer_alloc(rawbufs)
|
||||
var_vals = {k:(k.max+k.min)//2 for k in lin.ast[0].vars()}
|
||||
lib, global_size, local_size, vars, outcount, _, _ = _compile_linearizer(dev.compiler, lin)
|
||||
tms = _time_program(vars, outcount, dev, lib, global_size, local_size, var_vals, rawbufs, max_global_size=max_global_size if allow_test_size else None, clear_l2=clear_l2, cnt=cnt, name=to_function_name(lin.name)) # noqa: E501
|
||||
|
||||
Reference in New Issue
Block a user