diff --git a/extra/optimization/helpers.py b/extra/optimization/helpers.py index cefbba3eb6..fd34c6149a 100644 --- a/extra/optimization/helpers.py +++ b/extra/optimization/helpers.py @@ -1,5 +1,6 @@ # stuff needed to unpack a kernel from tinygrad.ops import LazyOp, TernaryOps, BinaryOps, UnaryOps, ReduceOps, BufferOps, MemBuffer, ConstBuffer +from tinygrad.codegen.kernel import Opt, OptOps from tinygrad.dtype import dtypes from tinygrad.shape.shapetracker import ShapeTracker from tinygrad.shape.view import View @@ -10,6 +11,12 @@ inf, nan = float('inf'), float('nan') from tinygrad.codegen.linearizer import Linearizer def ast_str_to_ast(ast_str:str) -> LazyOp: return eval(ast_str) def ast_str_to_lin(ast_str:str, opts=None): return Linearizer(ast_str_to_ast(ast_str), opts=opts) +def kern_str_to_lin(kern_str:str, opts=None): + (ast, applied_opts,) = eval(kern_str) + k = Linearizer(*ast, opts=opts) + for opt in applied_opts: + k.apply_opt(opt) + return k # load worlds, a dataset of about 12k kernels import gzip diff --git a/test/external/fuzz_linearizer.py b/test/external/fuzz_linearizer.py index 45a507c43a..920fa0fa64 100644 --- a/test/external/fuzz_linearizer.py +++ b/test/external/fuzz_linearizer.py @@ -1,15 +1,17 @@ import random, traceback, ctypes -from typing import List, Tuple +from typing import List, Tuple, DefaultDict import numpy as np from collections import defaultdict from extra.optimization.helpers import load_worlds, ast_str_to_lin -from tinygrad.codegen.linearizer import Linearizer +from tinygrad.codegen.linearizer import Linearizer, UOp +from tinygrad.codegen.kernel import Opt from tinygrad.features.search import get_linearizer_actions, bufs_from_lin from tinygrad.tensor import Tensor from tinygrad.features.graph import print_tree from tinygrad.helpers import getenv, from_mv, prod, colored, Context from tinygrad.device import Device, Compiled -from tinygrad.codegen.linearizer import UOp +from tinygrad.lazy import LazyBuffer +from tinygrad.ops import LazyOp def tuplize_uops(uops:List[UOp]) -> Tuple: return tuple([(x.uop, x.dtype, tuple(uops.index(x) for x in x.vin), x.arg) for x in uops]) @@ -24,7 +26,7 @@ def get_fuzz_rawbufs(lin): with Context(DEBUG=0): for rawbuf in rawbufs[1:]: t = Tensor.uniform((rawbuf.size,), dtype=rawbuf.dtype) - rawbuf.copyin(t.realize().lazydata.realized.as_buffer()) + if isinstance(ld:=t.realize().lazydata, LazyBuffer) and ld.realized: rawbuf.copyin(ld.realized.as_buffer()) return rawbufs def get_fuzz_rawbuf_like(rawbuf, zero=False, size=None): @@ -38,32 +40,43 @@ def get_fuzz_rawbuf_like(rawbuf, zero=False, size=None): def run_linearizer(lin: Linearizer, rawbufs=None, var_vals=None): if rawbufs is None: rawbufs = bufs_from_lin(lin) - if var_vals is None: var_vals = {v: v.min for v in lin.ast.vars()} + if var_vals is None: var_vals = {v: v.min for v in lin.ast[0].vars()} # TODO: images needs required_optimization try: - if isinstance(device, Compiled): - prg = device.to_program(lin) - else: - prg = device.get_runner(lin.ast) + prg = device.to_program(lin) except Exception: - print(lin.ast) - print(lin.applied_opts) traceback.print_exc() - print("COMPILE FAILED!!") return "COMPILE_ERROR" try: - prg.exec(rawbufs, var_vals) + prg(rawbufs, var_vals, wait=True, do_update_stats=False) except Exception: - print(lin.ast) - print(lin.applied_opts) traceback.print_exc() - print("EXEC FAILED!!") return "EXEC_ERROR" return "PASS" +def compare_linearizer(lin: Linearizer, rawbufs=None, var_vals=None, ground_truth=None, rtol=1e-2, atol=1e-2): + try: + if rawbufs is None: + rawbufs = get_fuzz_rawbufs(lin) + else: + rawbufs[0] = get_fuzz_rawbuf_like(rawbufs[0], zero=True) # get a new output buffer + except BaseException: + return ("RAWBUFS_ERROR", rawbufs, var_vals, ground_truth,) + if var_vals is None: var_vals = {v: random.randint(v.min, v.max if isinstance(v.max, int) else v.min) for v in lin.ast[0].vars()} + if ground_truth is None: + unoptimized = Linearizer(*lin.ast) + unoptimized.required_optimizations() + if run_linearizer(unoptimized, rawbufs, var_vals) != "PASS": + return ("BASELINE_ERROR", rawbufs, var_vals, ground_truth,) + ground_truth = np.frombuffer(rawbufs[0].as_buffer(), rawbufs[0].dtype.np).copy() + + if (run_msg := run_linearizer(lin, rawbufs, var_vals)) != "PASS": + return (run_msg, rawbufs, var_vals, ground_truth,) + result = np.frombuffer(rawbufs[0].as_buffer(), rawbufs[0].dtype.np) + return ("PASS" if np.allclose(result, ground_truth, rtol=rtol, atol=atol) else "COMPARE_ERROR", rawbufs, var_vals, ground_truth,) def fuzz_linearizer(lin: Linearizer): SEED = getenv("SEED", 42) @@ -73,7 +86,8 @@ def fuzz_linearizer(lin: Linearizer): print(lin.colored_shape()) seen_uops = {} last_lins = [lin] - failures = defaultdict(list) + failures:DefaultDict[str, List[Tuple[Tuple[LazyOp,...],List[Opt]]]] = defaultdict(list) + rawbufs, var_vals, ground_truth = None, None, None FUZZ_BEAM = getenv("FUZZ_BEAM", 0) FUZZ_MAX_SIZE = getenv("FUZZ_MAX_SIZE", 0) @@ -81,23 +95,6 @@ def fuzz_linearizer(lin: Linearizer): print("skipping large kernel") return failures - # get baseline unoptimized output - unoptimized = Linearizer(*lin.ast) - var_vals = {v: random.randint(v.min, v.max) for v in lin.ast[0].vars()} - - try: - rawbufs = get_fuzz_rawbufs(lin) - except Exception: - traceback.print_exc() - print("RAWBUFS FAILED!!") - failures["RAWBUFS_ERROR"].append((unoptimized.ast, unoptimized.applied_opts)) - return failures - - if run_linearizer(unoptimized, rawbufs, var_vals) != "PASS": - failures["BASELINE_ERROR"].append((unoptimized.ast, unoptimized.applied_opts)) - return failures - ground_truth = np.frombuffer(rawbufs[0].as_buffer(), rawbufs[0].dtype.np).copy() - for depth in range(getenv("DEPTH", 1 if FUZZ_BEAM else 10)): next_lins = [] for lin in last_lins: @@ -118,23 +115,15 @@ def fuzz_linearizer(lin: Linearizer): seen_uops[tuops] = tuple(test_lin.applied_opts) if not FUZZ_BEAM: print(test_lin.colored_shape()) - # get a new output buffer - rawbufs[0] = get_fuzz_rawbuf_like(rawbufs[0], zero=True) - if (msg := run_linearizer(test_lin, rawbufs, var_vals)) != "PASS": + + (msg, rawbufs, var_vals, ground_truth) = compare_linearizer(test_lin, rawbufs, var_vals, ground_truth) + if msg != "PASS": + print(test_lin.ast) + print(test_lin.applied_opts) + print(msg) failures[msg].append((test_lin.ast, test_lin.applied_opts)) continue - result = np.frombuffer(rawbufs[0].as_buffer(), rawbufs[0].dtype.np) - try: - # compare memoryviews directly - np.testing.assert_allclose(result, ground_truth, rtol=1e-2, atol=1e-2) - except AssertionError: - print(test_lin.ast) - print(test_lin.applied_opts) - traceback.print_exc() - print("COMPARE FAILED!!") - failures["COMPARE_ERROR"].append((test_lin.ast, test_lin.applied_opts)) - continue next_lins.append(test_lin) last_lins = next_lins diff --git a/test/external/verify_kernel.py b/test/external/verify_kernel.py new file mode 100644 index 0000000000..006ad7b77d --- /dev/null +++ b/test/external/verify_kernel.py @@ -0,0 +1,57 @@ +import argparse +from collections import defaultdict +from extra.optimization.helpers import kern_str_to_lin +from test.external.fuzz_linearizer import compare_linearizer +from tinygrad.helpers import colored +from tinygrad.features.graph import print_tree + +# Use this with the LOGKERN options to verify that all executed kernels are valid and evaluate to the same ground truth results + +# Example for GPT2: +# 1) Run the model to log all kernels: `PYTHONPATH=. LOGKERN=/tmp/gpt2_kerns.txt JIT=1 HALF=1 BEAM=2 CACHELEVEL=0 CAST_BEFORE_VIEW=0 python3 examples/gpt2.py --count 10 --temperature 0 --timing` # noqa: E501 +# 2) Validate the kernel correctness: `PYTHONPATH=. python3 ./test/external/verify_kernel.py --file /tmp/gpt2_kerns.txt` + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Verify the correctness of one or more kernel", formatter_class=argparse.ArgumentDefaultsHelpFormatter) # noqa: E501 + parser.add_argument("--kernel", type=str, default=None, help="a string of a tuple of (ast, applied_opts,)") + parser.add_argument("--file", type=str, default=None, help="a file containing a tuple of ast and applied_opts, one per line") + parser.add_argument("--rtol", type=float, default=1e-2, help="relative tolerance for numerical comparison") + parser.add_argument("--atol", type=float, default=1e-2, help="absolute tolerance for numerical comparison") + args = parser.parse_args() + + if args.kernel is not None: + print("loading kernel from args") + kern_strs = [args.kernel] + elif args.file is not None: + print(f"loading kernel from file '{args.file}'") + with open(args.file, 'r') as file: + kern_strs = file.readlines() + else: + raise RuntimeError("no kernel specified; use --kernel or --file options") + + print(f"verifying {len(kern_strs)} kernels") + + failed_ids = [] + failures = defaultdict(list) + for i, kern_str in enumerate(kern_strs): + print(f"testing kernel {i}") + test_lin = kern_str_to_lin(kern_str) + for op in test_lin.ast: print_tree(op) + print(test_lin.colored_shape()) + if (msg:=compare_linearizer(test_lin, None, None, None, rtol=args.rtol, atol=args.atol)[0]) != "PASS": + failed_ids.append(i) + failures[msg].append((test_lin.ast, test_lin.applied_opts)) + + for msg, errors in failures.items(): + for i, (ast, opts) in enumerate(errors): + print(f"{msg} {i} AST: {ast}") + print(f"{msg} {i} OPTS: {opts}\n") + + print(f"tested {len(kern_strs)} kernels") + if failures: + print(f"{failed_ids=}") + for msg, errors in failures.items(): + print(f"{msg}: {len(errors)}") + raise RuntimeError(f"failed on {len(failed_ids)} kernels") + else: + print(colored("all passed", "green")) \ No newline at end of file diff --git a/tinygrad/device.py b/tinygrad/device.py index 896e3bf6c7..496ec85f64 100644 --- a/tinygrad/device.py +++ b/tinygrad/device.py @@ -239,6 +239,7 @@ class MultiDeviceJITGraph(JITRunner): def __call__(self, rawbufs:List[Buffer], var_vals:Dict[Variable, int], wait=False, jit=False) -> Optional[float]: raise NotImplementedError("override this") +logkern, logkern_level = open(getenv("LOGKERN", ""), "a") if getenv("LOGKERN", "") else None, getenv("LOGKERN_LEVEL", 1) class Compiled: def __init__(self, device:str, allocator:Allocator, compiler:Optional[Compiler], runtime, graph=None): self.dname, self.allocator, self.compiler, self.runtime, self.graph = device, allocator, compiler, runtime, graph @@ -278,6 +279,9 @@ class Compiled: timed = sorted([(nm, tk, time_linearizer(tk, test_rawbuffers, allow_test_size=False, clear_l2=True)) for nm, tk in lins], key=lambda x: x[2]) if DEBUG >= 1: print(" < ".join(f"{nm:6s} : {lin.colored_shape(30, dense=True)} : {tm*1e6:8.2f} us" for nm, lin, tm in timed)) k = timed[0][1] + if logkern is not None and logkern_level > 1: logkern.writelines([f"{(lin.ast, lin.applied_opts)}\n" for (_,lin,_) in timed[1:]]) + # TODO: check the correctness inline once compare_linearizer is in core + if logkern is not None: logkern.writelines([f"{(k.ast, k.applied_opts)}\n"]) return k @functools.lru_cache(None) # pylint: disable=method-cache-max-size-none