From 890e7c12bb5b736c8b0885601361d7a4347c4a50 Mon Sep 17 00:00:00 2001 From: Francis Lam Date: Tue, 4 Jun 2024 15:59:21 -0700 Subject: [PATCH] test/external/verify_kernel: add support for single pickled kernel (#4836) --- test/external/verify_kernel.py | 27 +++++++++++++++++++-------- 1 file changed, 19 insertions(+), 8 deletions(-) diff --git a/test/external/verify_kernel.py b/test/external/verify_kernel.py index a8cbb3fd4d..832b7c87d0 100644 --- a/test/external/verify_kernel.py +++ b/test/external/verify_kernel.py @@ -17,6 +17,7 @@ if __name__ == "__main__": parser = argparse.ArgumentParser(description="Verify the correctness of one or more kernel", formatter_class=argparse.ArgumentDefaultsHelpFormatter) # noqa: E501 parser.add_argument("--kernel", type=str, default=None, help="a string of a tuple of (ast, applied_opts,)") parser.add_argument("--file", type=str, default=None, help="a file containing a tuple of ast and applied_opts, one per line") + parser.add_argument("--pkl", type=str, default=None, help="a pickle file containing a single tuple of ast and applied_opts") parser.add_argument("--rtol", type=float, default=1e-2, help="relative tolerance for numerical comparison") parser.add_argument("--atol", type=float, default=1e-2, help="absolute tolerance for numerical comparison") parser.add_argument("--timing", action='store_true', help="show final timing for the kernel") @@ -25,21 +26,31 @@ if __name__ == "__main__": if args.kernel is not None: print("loading kernel from args") - kern_strs = [args.kernel] + test_lins = [kern_str_to_lin(args.kernel)] elif args.file is not None: print(f"loading kernel from file '{args.file}'") with open(args.file, 'r') as file: kern_strs = file.readlines() - else: - raise RuntimeError("no kernel specified; use --kernel or --file options") + test_lins = [kern_str_to_lin(kern_str) for kern_str in kern_strs] + elif args.pkl is not None: + print(f"loading kernel from pickle file '{args.file}'") + import pickle + with open(args.pkl, 'rb') as file: + (ast, applied_opts,) = pickle.load(file) + lin = Linearizer(*ast) + for opt in applied_opts: + lin.apply_opt(opt) + test_lins = [lin] - print(f"verifying {len(kern_strs)} kernels") + else: + raise RuntimeError("no kernel specified; use --kernel, --file, or --pkl options") + + print(f"verifying {len(test_lins)} kernels") failed_ids = [] failures = defaultdict(list) - for i, kern_str in enumerate(kern_strs): + for i, test_lin in enumerate(test_lins): print(f"testing kernel {i}") - test_lin = kern_str_to_lin(kern_str) for op in test_lin.ast: print_tree(op) print(op) @@ -60,7 +71,7 @@ if __name__ == "__main__": print(f"{msg} {i} AST: {ast}") print(f"{msg} {i} OPTS: {opts}\n") - print(f"tested {len(kern_strs)} kernels") + print(f"tested {len(test_lins)} kernels") if failures: print(f"{failed_ids=}") for msg, errors in failures.items(): @@ -70,4 +81,4 @@ if __name__ == "__main__": if len(failed_ids) != args.expected_failures: raise RuntimeError(f"failed on {len(failed_ids)} kernels, expected {args.expected_failures}") else: - print(colored("all passed", "green")) \ No newline at end of file + print(colored("all passed", "green"))