test/external/fuzz_linearizer: fix for new AST changes (#5519)

* test/external/fuzz_linearizer: fix for new AST changes

also add beautiful_mnist failures

* add CLANG and LLVM to test_failure_35 failed_platforms

* fix test_linearizer_failure names
This commit is contained in:
Francis Lam
2024-07-16 21:08:07 -07:00
committed by GitHub
parent 85d4ca7caa
commit 2d53abb04a
4 changed files with 51 additions and 6 deletions

View File

@@ -2,7 +2,7 @@ import random, traceback, ctypes, argparse
from typing import List, Tuple, DefaultDict
import numpy as np
from collections import defaultdict
from extra.optimization.helpers import load_worlds, ast_str_to_lin
from extra.optimization.helpers import load_worlds, ast_str_to_lin, kern_str_to_lin
from tinygrad import Tensor, Device, dtypes
from tinygrad.tensor import _to_np_dtype
@@ -157,7 +157,14 @@ def fuzz_linearizer(lin: Kernel, rtol=1e-2, atol=1e-2):
if not FUZZ_ALL_ACTIONS and test_lin.applied_opts: print(f"applied opts: {test_lin.applied_opts}")
# stop if kernel uops repeat
tuops = tuplize_uops(test_lin.linearize().uops.uops)
try: tuops = tuplize_uops(test_lin.linearize().uops.uops)
except BaseException as e:
print(test_lin.ast)
print(test_lin.applied_opts)
print(e)
failures["LINEARIZE_ERROR"].append((test_lin.ast, test_lin.applied_opts))
continue
if tuops in seen_uops: continue
seen_uops[tuops] = tuple(test_lin.applied_opts)
@@ -187,6 +194,7 @@ if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Run a fuzz testing on one or more kernels", formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument("--ast", type=str, default=None, help="the ast for the kernel to be optimized")
parser.add_argument("--file", type=str, default=None, help="a file containing asts to be optimized, one per line")
parser.add_argument("--logfile", type=str, default=None, help="a file containing a tuple of ast and applied_opts, one per line")
parser.add_argument("--expected-failures", type=int, default=0, help="the number of expected failed kernels")
parser.add_argument("--rtol", type=float, default=1e-2, help="relative tolerance for numerical comparison")
parser.add_argument("--atol", type=float, default=1e-2, help="absolute tolerance for numerical comparison")
@@ -199,6 +207,12 @@ if __name__ == "__main__":
print(f"loading ASTs from file '{args.file}'")
with open(args.file, 'r') as file:
ast_strs = file.readlines()
elif args.logfile is not None:
print(f"loading ASTs from LOGKERNS file '{args.file}'")
with open(args.logfile, 'r') as file:
kern_strs = file.readlines()
test_lins = [kern_str_to_lin(kern_str) for kern_str in kern_strs]
ast_strs = [f"{lin.ast}" for lin in test_lins]
else:
print("loading ASTs from world")
ast_strs = load_worlds(filter_reduce=False, filter_novariable=False)

View File

@@ -51,9 +51,8 @@ if __name__ == "__main__":
failures = defaultdict(list)
for i, test_lin in enumerate(test_lins):
print(f"testing kernel {i}")
for op in test_lin.ast:
print_tree(op)
print(op)
print_tree(test_lin.ast)
print(test_lin.ast)
print(test_lin.applied_opts)
unoptimized_lin = Kernel(test_lin.ast)
unoptimized_lin.required_optimizations()