mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 15:08:02 -05:00
test/external/verify_kernel: add support for single pickled kernel (#4836)
This commit is contained in:
27
test/external/verify_kernel.py
vendored
27
test/external/verify_kernel.py
vendored
@@ -17,6 +17,7 @@ if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description="Verify the correctness of one or more kernel", formatter_class=argparse.ArgumentDefaultsHelpFormatter) # noqa: E501
|
||||
parser.add_argument("--kernel", type=str, default=None, help="a string of a tuple of (ast, applied_opts,)")
|
||||
parser.add_argument("--file", type=str, default=None, help="a file containing a tuple of ast and applied_opts, one per line")
|
||||
parser.add_argument("--pkl", type=str, default=None, help="a pickle file containing a single tuple of ast and applied_opts")
|
||||
parser.add_argument("--rtol", type=float, default=1e-2, help="relative tolerance for numerical comparison")
|
||||
parser.add_argument("--atol", type=float, default=1e-2, help="absolute tolerance for numerical comparison")
|
||||
parser.add_argument("--timing", action='store_true', help="show final timing for the kernel")
|
||||
@@ -25,21 +26,31 @@ if __name__ == "__main__":
|
||||
|
||||
if args.kernel is not None:
|
||||
print("loading kernel from args")
|
||||
kern_strs = [args.kernel]
|
||||
test_lins = [kern_str_to_lin(args.kernel)]
|
||||
elif args.file is not None:
|
||||
print(f"loading kernel from file '{args.file}'")
|
||||
with open(args.file, 'r') as file:
|
||||
kern_strs = file.readlines()
|
||||
else:
|
||||
raise RuntimeError("no kernel specified; use --kernel or --file options")
|
||||
test_lins = [kern_str_to_lin(kern_str) for kern_str in kern_strs]
|
||||
elif args.pkl is not None:
|
||||
print(f"loading kernel from pickle file '{args.file}'")
|
||||
import pickle
|
||||
with open(args.pkl, 'rb') as file:
|
||||
(ast, applied_opts,) = pickle.load(file)
|
||||
lin = Linearizer(*ast)
|
||||
for opt in applied_opts:
|
||||
lin.apply_opt(opt)
|
||||
test_lins = [lin]
|
||||
|
||||
print(f"verifying {len(kern_strs)} kernels")
|
||||
else:
|
||||
raise RuntimeError("no kernel specified; use --kernel, --file, or --pkl options")
|
||||
|
||||
print(f"verifying {len(test_lins)} kernels")
|
||||
|
||||
failed_ids = []
|
||||
failures = defaultdict(list)
|
||||
for i, kern_str in enumerate(kern_strs):
|
||||
for i, test_lin in enumerate(test_lins):
|
||||
print(f"testing kernel {i}")
|
||||
test_lin = kern_str_to_lin(kern_str)
|
||||
for op in test_lin.ast:
|
||||
print_tree(op)
|
||||
print(op)
|
||||
@@ -60,7 +71,7 @@ if __name__ == "__main__":
|
||||
print(f"{msg} {i} AST: {ast}")
|
||||
print(f"{msg} {i} OPTS: {opts}\n")
|
||||
|
||||
print(f"tested {len(kern_strs)} kernels")
|
||||
print(f"tested {len(test_lins)} kernels")
|
||||
if failures:
|
||||
print(f"{failed_ids=}")
|
||||
for msg, errors in failures.items():
|
||||
@@ -70,4 +81,4 @@ if __name__ == "__main__":
|
||||
if len(failed_ids) != args.expected_failures:
|
||||
raise RuntimeError(f"failed on {len(failed_ids)} kernels, expected {args.expected_failures}")
|
||||
else:
|
||||
print(colored("all passed", "green"))
|
||||
print(colored("all passed", "green"))
|
||||
|
||||
Reference in New Issue
Block a user