mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 15:08:02 -05:00
* search: add BEAM_VERIFY option to validate search results refactor fuzz_linearizer comparison to allow it to be used in for BEAM_VERIFY in device.py * search: fix to verify the beam_search result and not the fastest * search: fix typing and clean up * device: remove imports from test and add LOGKERN options LOGKERN output can be used with test/external/verify_kernel.py to validate correctness * fix example in verify_kernel.py * cleanup fixes * fix to use f-strings
57 lines
2.6 KiB
Python
57 lines
2.6 KiB
Python
import argparse
|
|
from collections import defaultdict
|
|
from extra.optimization.helpers import kern_str_to_lin
|
|
from test.external.fuzz_linearizer import compare_linearizer
|
|
from tinygrad.helpers import colored
|
|
from tinygrad.features.graph import print_tree
|
|
|
|
# Use this with the LOGKERN options to verify that all executed kernels are valid and evaluate to the same ground truth results
|
|
|
|
# Example for GPT2:
|
|
# 1) Run the model to log all kernels: `PYTHONPATH=. LOGKERN=/tmp/gpt2_kerns.txt JIT=1 HALF=1 BEAM=2 CACHELEVEL=0 CAST_BEFORE_VIEW=0 python3 examples/gpt2.py --count 10 --temperature 0 --timing` # noqa: E501
|
|
# 2) Validate the kernel correctness: `PYTHONPATH=. python3 ./test/external/verify_kernel.py --file /tmp/gpt2_kerns.txt`
|
|
|
|
if __name__ == "__main__":
|
|
parser = argparse.ArgumentParser(description="Verify the correctness of one or more kernel", formatter_class=argparse.ArgumentDefaultsHelpFormatter) # noqa: E501
|
|
parser.add_argument("--kernel", type=str, default=None, help="a string of a tuple of (ast, applied_opts,)")
|
|
parser.add_argument("--file", type=str, default=None, help="a file containing a tuple of ast and applied_opts, one per line")
|
|
parser.add_argument("--rtol", type=float, default=1e-2, help="relative tolerance for numerical comparison")
|
|
parser.add_argument("--atol", type=float, default=1e-2, help="absolute tolerance for numerical comparison")
|
|
args = parser.parse_args()
|
|
|
|
if args.kernel is not None:
|
|
print("loading kernel from args")
|
|
kern_strs = [args.kernel]
|
|
elif args.file is not None:
|
|
print(f"loading kernel from file '{args.file}'")
|
|
with open(args.file, 'r') as file:
|
|
kern_strs = file.readlines()
|
|
else:
|
|
raise RuntimeError("no kernel specified; use --kernel or --file options")
|
|
|
|
print(f"verifying {len(kern_strs)} kernels")
|
|
|
|
failed_ids = []
|
|
failures = defaultdict(list)
|
|
for i, kern_str in enumerate(kern_strs):
|
|
print(f"testing kernel {i}")
|
|
test_lin = kern_str_to_lin(kern_str)
|
|
for op in test_lin.ast: print_tree(op)
|
|
print(test_lin.colored_shape())
|
|
if (msg:=compare_linearizer(test_lin, None, None, None, rtol=args.rtol, atol=args.atol)[0]) != "PASS":
|
|
failed_ids.append(i)
|
|
failures[msg].append((test_lin.ast, test_lin.applied_opts))
|
|
|
|
for msg, errors in failures.items():
|
|
for i, (ast, opts) in enumerate(errors):
|
|
print(f"{msg} {i} AST: {ast}")
|
|
print(f"{msg} {i} OPTS: {opts}\n")
|
|
|
|
print(f"tested {len(kern_strs)} kernels")
|
|
if failures:
|
|
print(f"{failed_ids=}")
|
|
for msg, errors in failures.items():
|
|
print(f"{msg}: {len(errors)}")
|
|
raise RuntimeError(f"failed on {len(failed_ids)} kernels")
|
|
else:
|
|
print(colored("all passed", "green")) |