test/external/verify_kernel: add support for single pickled kernel (#4836)

This commit is contained in:
Francis Lam
2024-06-04 15:59:21 -07:00
committed by GitHub
parent e576aca044
commit 890e7c12bb

View File

@@ -17,6 +17,7 @@ if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Verify the correctness of one or more kernel", formatter_class=argparse.ArgumentDefaultsHelpFormatter) # noqa: E501
parser.add_argument("--kernel", type=str, default=None, help="a string of a tuple of (ast, applied_opts,)")
parser.add_argument("--file", type=str, default=None, help="a file containing a tuple of ast and applied_opts, one per line")
parser.add_argument("--pkl", type=str, default=None, help="a pickle file containing a single tuple of ast and applied_opts")
parser.add_argument("--rtol", type=float, default=1e-2, help="relative tolerance for numerical comparison")
parser.add_argument("--atol", type=float, default=1e-2, help="absolute tolerance for numerical comparison")
parser.add_argument("--timing", action='store_true', help="show final timing for the kernel")
@@ -25,21 +26,31 @@ if __name__ == "__main__":
if args.kernel is not None:
print("loading kernel from args")
kern_strs = [args.kernel]
test_lins = [kern_str_to_lin(args.kernel)]
elif args.file is not None:
print(f"loading kernel from file '{args.file}'")
with open(args.file, 'r') as file:
kern_strs = file.readlines()
else:
raise RuntimeError("no kernel specified; use --kernel or --file options")
test_lins = [kern_str_to_lin(kern_str) for kern_str in kern_strs]
elif args.pkl is not None:
print(f"loading kernel from pickle file '{args.file}'")
import pickle
with open(args.pkl, 'rb') as file:
(ast, applied_opts,) = pickle.load(file)
lin = Linearizer(*ast)
for opt in applied_opts:
lin.apply_opt(opt)
test_lins = [lin]
print(f"verifying {len(kern_strs)} kernels")
else:
raise RuntimeError("no kernel specified; use --kernel, --file, or --pkl options")
print(f"verifying {len(test_lins)} kernels")
failed_ids = []
failures = defaultdict(list)
for i, kern_str in enumerate(kern_strs):
for i, test_lin in enumerate(test_lins):
print(f"testing kernel {i}")
test_lin = kern_str_to_lin(kern_str)
for op in test_lin.ast:
print_tree(op)
print(op)
@@ -60,7 +71,7 @@ if __name__ == "__main__":
print(f"{msg} {i} AST: {ast}")
print(f"{msg} {i} OPTS: {opts}\n")
print(f"tested {len(kern_strs)} kernels")
print(f"tested {len(test_lins)} kernels")
if failures:
print(f"{failed_ids=}")
for msg, errors in failures.items():
@@ -70,4 +81,4 @@ if __name__ == "__main__":
if len(failed_ids) != args.expected_failures:
raise RuntimeError(f"failed on {len(failed_ids)} kernels, expected {args.expected_failures}")
else:
print(colored("all passed", "green"))
print(colored("all passed", "green"))