mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 15:08:02 -05:00
debug: add optional detailed BEAM_LOG logging (#3883)
* debug: add optional detailed BEAM_LOG logging show uop count, compile and run times for each candidate in search also add --timing to verify_kernel.py to make it easier to explore hand-crafted applied opts * fix linter
This commit is contained in:
8
test/external/verify_kernel.py
vendored
8
test/external/verify_kernel.py
vendored
@@ -4,6 +4,7 @@ from extra.optimization.helpers import kern_str_to_lin
|
||||
from test.external.fuzz_linearizer import compare_linearizer
|
||||
from tinygrad.helpers import colored
|
||||
from tinygrad.features.graph import print_tree
|
||||
from tinygrad.features.search import time_linearizer
|
||||
|
||||
# Use this with the LOGKERN options to verify that all executed kernels are valid and evaluate to the same ground truth results
|
||||
|
||||
@@ -17,6 +18,7 @@ if __name__ == "__main__":
|
||||
parser.add_argument("--file", type=str, default=None, help="a file containing a tuple of ast and applied_opts, one per line")
|
||||
parser.add_argument("--rtol", type=float, default=1e-2, help="relative tolerance for numerical comparison")
|
||||
parser.add_argument("--atol", type=float, default=1e-2, help="absolute tolerance for numerical comparison")
|
||||
parser.add_argument("--timing", action='store_true', help="show final timing for the kernel")
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.kernel is not None:
|
||||
@@ -38,9 +40,13 @@ if __name__ == "__main__":
|
||||
test_lin = kern_str_to_lin(kern_str)
|
||||
for op in test_lin.ast: print_tree(op)
|
||||
print(test_lin.colored_shape())
|
||||
if (msg:=compare_linearizer(test_lin, None, None, None, rtol=args.rtol, atol=args.atol)[0]) != "PASS":
|
||||
(msg,rb,vv,gt) = compare_linearizer(test_lin, None, None, None, rtol=args.rtol, atol=args.atol)
|
||||
if msg != "PASS":
|
||||
failed_ids.append(i)
|
||||
failures[msg].append((test_lin.ast, test_lin.applied_opts))
|
||||
if args.timing:
|
||||
tm = time_linearizer(test_lin, rb, allow_test_size=False, cnt=10)
|
||||
print(f"final time {tm*1e6:9.0f} us")
|
||||
|
||||
for msg, errors in failures.items():
|
||||
for i, (ast, opts) in enumerate(errors):
|
||||
|
||||
@@ -47,11 +47,14 @@ def _time_program(variables:List[Variable], outcount:int, rdev:Compiled, lib:byt
|
||||
return tms
|
||||
|
||||
def _compile_linearizer(compiler:Compiler, lin:Linearizer, name:Optional[str]=None) -> Tuple[bytes, Optional[List[int]], Optional[List[int]],
|
||||
List[Variable], int]:
|
||||
List[Variable], int, float, int]:
|
||||
lin.linearize()
|
||||
src = compiler.render(name if name is not None else to_function_name(lin.name), lin.uops) # NOTE: these all have the same name for deduping
|
||||
if DEBUG >= 5: print(src)
|
||||
return compiler.compile(src), lin.global_size, lin.local_size, lin.uops.vars(), len(lin.outbufs)
|
||||
st = time.perf_counter()
|
||||
prog = compiler.compile(src)
|
||||
et = time.perf_counter() - st
|
||||
return prog, lin.global_size, lin.local_size, lin.uops.vars(), len(lin.outbufs), et, len(lin.uops.uops)
|
||||
|
||||
def _try_compile_linearized_w_idx(x, compiler:Compiler):
|
||||
try: return (x[0], _compile_linearizer(compiler, x[1], "test"))
|
||||
@@ -122,14 +125,15 @@ def beam_search(lin:Linearizer, rawbufs, amt:int, allow_test_size=True) -> Linea
|
||||
_compile_fn = functools.partial(_try_compile_linearized_w_idx, compiler=dev.compiler)
|
||||
for i,proc in (map(_compile_fn, enumerate(acted_lins)) if beam_pool is None else beam_pool.imap_unordered(_compile_fn, enumerate(acted_lins))):
|
||||
if proc is None: continue
|
||||
lib, global_size, local_size, vars, outcount = proc
|
||||
lib, global_size, local_size, vars, outcount, compile_et, num_uops = proc
|
||||
if lib in seen_libs: continue
|
||||
#print(acted_lins[i].colored_shape(), acted_lins[i].applied_opts) # for debugging BEAMs that segfault
|
||||
seen_libs.add(lib)
|
||||
try: tms = _time_program(vars, outcount, dev, lib, global_size, local_size, var_vals, rawbufs, early_stop=beam[0][1]*3 if len(beam) else 1.0)
|
||||
except RuntimeError: continue # for runtime issues
|
||||
timed_lins.append((acted_lins[i], min(tms)))
|
||||
if DEBUG >= 2: print(f"\r{time.perf_counter() - st:7.2f}s: {timed_lins[-1][1]*1e6:12.2f} us {len(timed_lins):4d}/{len(acted_lins):4d} {timed_lins[-1][0].colored_shape()}\033[K", end="") # noqa: E501
|
||||
if getenv("BEAM_LOG", 0) > 0: print(f"{time.perf_counter() - st:7.2f}s: {i:5d} {num_uops:5d} uops {compile_et*1e6:12.2f} us compile/{timed_lins[-1][1]*1e6:12.2f} us run {len(timed_lins):4d}/{len(acted_lins):4d} {timed_lins[-1][0].colored_shape()}") # noqa: E501
|
||||
elif DEBUG >= 2: print(f"\r{time.perf_counter() - st:7.2f}s: {timed_lins[-1][1]*1e6:12.2f} us {len(timed_lins):4d}/{len(acted_lins):4d} {timed_lins[-1][0].colored_shape()}\033[K", end="") # noqa: E501
|
||||
|
||||
# done
|
||||
opts = sorted(timed_lins, key=lambda x: x[1])
|
||||
@@ -166,7 +170,7 @@ def time_linearizer(lin:Linearizer, rawbufs:List[Buffer], allow_test_size=True,
|
||||
assert isinstance(dev, Compiled) and dev.compiler is not None
|
||||
|
||||
var_vals = {k:(k.max+k.min)//2 for k in lin.ast[0].vars()}
|
||||
lib, global_size, local_size, vars, outcount = _compile_linearizer(dev.compiler, lin)
|
||||
lib, global_size, local_size, vars, outcount, _, _ = _compile_linearizer(dev.compiler, lin)
|
||||
tms = _time_program(vars, outcount, dev, lib, global_size, local_size, var_vals, rawbufs, max_global_size=max_global_size if allow_test_size else None, clear_l2=clear_l2, cnt=cnt, name=to_function_name(lin.name)) # noqa: E501
|
||||
|
||||
if CACHELEVEL >= 2: diskcache_put("time_linearizer", key, tms)
|
||||
|
||||
Reference in New Issue
Block a user