fuzz_linearizer: reduce debug verbosity and make easier for CI usage (#3942)

* fuzz_linearizer: reduce debug verbosity and make easier for CI usage

* rename FUZZ_BEAM to FUZZ_ALL_ACTIONS (not choosing a subset)
* skip simple ASTs (easier to use with LOGOPS output)
* don't fuzz a previously seen AST
* add options to allow non-zero --expected-failures

* clean up naming and use set
This commit is contained in:
Francis Lam
2024-03-26 13:25:24 -07:00
committed by GitHub
parent 8df6587c41
commit 5530b0cbed
3 changed files with 42 additions and 19 deletions

View File

@@ -315,7 +315,7 @@ jobs:
- name: Test Beam Search
run: PYTHONPATH="." METAL=1 IGNORE_BEAM_CACHE=1 python3 -m pytest extra/optimization/test_beam_search.py
- name: Fuzz Test linearizer
run: PYTHONPATH="." METAL=1 CACHELEVEL=0 FUZZ_BEAM=1 DEPTH=2 FUZZ_N=48 FUZZ_MAX_SIZE=10000000 python test/external/fuzz_linearizer.py
run: PYTHONPATH="." METAL=1 CACHELEVEL=0 FUZZ_ALL_ACTIONS=1 DEPTH=2 FUZZ_N=48 FUZZ_MAX_SIZE=10000000 python test/external/fuzz_linearizer.py
# testwebgl:

View File

@@ -10,7 +10,7 @@ from tinygrad.codegen.kernel import Opt
from tinygrad.features.search import get_linearizer_actions, bufs_from_lin
from tinygrad.features.graph import print_tree
from tinygrad.helpers import getenv, from_mv, prod, colored, Context, DEBUG
from tinygrad.ops import LazyOp
from tinygrad.ops import LazyOp, UnaryOps, BufferOps
def tuplize_uops(uops:List[UOp]) -> Tuple:
return tuple([(x.uop, x.dtype, tuple(uops.index(x) for x in x.vin), x.arg) for x in uops])
@@ -101,11 +101,12 @@ def compare_linearizer(lin: Linearizer, rawbufs=None, var_vals=None, ground_trut
except AssertionError as e:
if DEBUG >= 2:
print(f"COMPARE_ERROR details: {e}")
mismatch_indices = np.where(~np.isclose(result, ground_truth, rtol=rtol, atol=atol))
mismatched_result = result[mismatch_indices]
mismatched_ground_truth = ground_truth[mismatch_indices]
for i, idx in enumerate(mismatch_indices[0]):
print(f"mismatch at {idx=}: result={mismatched_result[i]} <> ground_truth={mismatched_ground_truth[i]}")
if getenv("DEBUG_VALUES") > 0:
mismatch_indices = np.where(~np.isclose(result, ground_truth, rtol=rtol, atol=atol))
mismatched_result = result[mismatch_indices]
mismatched_ground_truth = ground_truth[mismatch_indices]
for i, idx in enumerate(mismatch_indices[0]):
print(f"mismatch at {idx=}: result={mismatched_result[i]} <> ground_truth={mismatched_ground_truth[i]}")
return ("COMPARE_ERROR", rawbufs, var_vals, ground_truth,)
return ("PASS", rawbufs, var_vals, ground_truth,)
@@ -121,32 +122,36 @@ def fuzz_linearizer(lin: Linearizer):
failures:DefaultDict[str, List[Tuple[Tuple[LazyOp,...],List[Opt]]]] = defaultdict(list)
rawbufs, var_vals, ground_truth = None, None, None
FUZZ_BEAM = getenv("FUZZ_BEAM", 0)
FUZZ_ALL_ACTIONS = getenv("FUZZ_ALL_ACTIONS", 0)
FUZZ_MAX_SIZE = getenv("FUZZ_MAX_SIZE", 0)
FUZZ_IGNORE_SIMPLE_OPS = getenv("FUZZ_IGNORE_SIMPLE_OPS", 1)
if FUZZ_MAX_SIZE > 0 and prod(lin.full_shape) > FUZZ_MAX_SIZE:
print("skipping large kernel")
return failures
if FUZZ_IGNORE_SIMPLE_OPS and _is_simple(lin):
print("skipping simple kernel")
return failures
for depth in range(getenv("DEPTH", 1 if FUZZ_BEAM else 10)):
for depth in range(getenv("DEPTH", 1 if FUZZ_ALL_ACTIONS else 10)):
next_lins = []
for lin in last_lins:
actions = get_linearizer_actions(lin, include_0=False)
if FUZZ_BEAM: print(f"testing {lin.applied_opts=} with {len(actions)} actions")
if FUZZ_ALL_ACTIONS: print(f"testing {lin.applied_opts=} with {len(actions)} actions")
if not actions: continue
test_lins = list(actions.values())
if not FUZZ_BEAM: test_lins = [random.choice(test_lins)]
if not FUZZ_ALL_ACTIONS: test_lins = [random.choice(test_lins)]
for test_lin in test_lins:
if not FUZZ_BEAM and test_lin.applied_opts: print(f"applied opts: {test_lin.applied_opts}")
if not FUZZ_ALL_ACTIONS and test_lin.applied_opts: print(f"applied opts: {test_lin.applied_opts}")
# stop if kernel uops repeat
tuops = tuplize_uops(test_lin.linearize().uops.uops)
if tuops in seen_uops:
continue
if tuops in seen_uops: continue
seen_uops[tuops] = tuple(test_lin.applied_opts)
if not FUZZ_BEAM: print(test_lin.colored_shape())
if not FUZZ_ALL_ACTIONS: print(test_lin.colored_shape())
(msg, rawbufs, var_vals, ground_truth) = compare_linearizer(test_lin, rawbufs, var_vals, ground_truth)
if msg != "PASS":
@@ -159,13 +164,20 @@ def fuzz_linearizer(lin: Linearizer):
next_lins.append(test_lin)
last_lins = next_lins
if FUZZ_BEAM: print(f"depth={depth} total_lins={len(last_lins)} {failures=}")
if FUZZ_ALL_ACTIONS: print(f"depth={depth} total_lins={len(last_lins)} {failures=}")
return failures
def _is_simple(lin: Linearizer) -> bool:
if len(lin.ast) > 1: return False
ast:LazyOp = lin.ast[0]
if ast.src[0] and ast.src[0].op == UnaryOps.CAST and ast.src[0].src[0] and ast.src[0].src[0].op == BufferOps.LOAD: return True
return False
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Run a fuzz testing on one or more kernels", formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument("--ast", type=str, default=None, help="the ast for the kernel to be optimized")
parser.add_argument("--file", type=str, default=None, help="a file containing asts to be optimized, one per line")
parser.add_argument("--expected-failures", type=int, default=0, help="the number of expected failed kernels")
args = parser.parse_args()
if args.ast is not None:
@@ -183,9 +195,13 @@ if __name__ == "__main__":
tested = 0
failed_ids = []
failures = defaultdict(list)
seen_ast_strs = set()
for i, ast in enumerate(ast_strs[:getenv("FUZZ_N", len(ast_strs))]):
if (nth := getenv("FUZZ_NTH", -1)) != -1 and i != nth: continue
if "dtypes.image" in ast and Device.DEFAULT != "GPU": continue # IMAGE is only for GPU
if ast in seen_ast_strs: continue
seen_ast_strs.add(ast)
print(f"testing ast {i}")
tested += 1
lin = ast_str_to_lin(ast)
@@ -198,13 +214,16 @@ if __name__ == "__main__":
for msg, errors in failures.items():
for i, (ast, opts) in enumerate(errors):
print(f"{msg} {i} AST: {ast}")
print(f"{msg} {i} OPTS: {opts}\n")
print(f"{msg} {i} kernel: {(ast,opts)}") # easier to use with output with verify_kernel.py
print(f"{tested=}")
if failures:
print(f"{failed_ids=}")
for msg, errors in failures.items():
print(f"{msg}: {len(errors)}")
if len(failed_ids) == args.expected_failures:
print(colored(f"{len(failed_ids)} failed as expected", "yellow"))
if len(failed_ids) != args.expected_failures:
raise RuntimeError(f"failed on {len(failed_ids)} kernels, expected {args.expected_failures}")
else:
print(colored("all passed", "green"))

View File

@@ -19,6 +19,7 @@ if __name__ == "__main__":
parser.add_argument("--rtol", type=float, default=1e-2, help="relative tolerance for numerical comparison")
parser.add_argument("--atol", type=float, default=1e-2, help="absolute tolerance for numerical comparison")
parser.add_argument("--timing", action='store_true', help="show final timing for the kernel")
parser.add_argument("--expected-failures", type=int, default=0, help="the number of expected failed kernels")
args = parser.parse_args()
if args.kernel is not None:
@@ -58,6 +59,9 @@ if __name__ == "__main__":
print(f"{failed_ids=}")
for msg, errors in failures.items():
print(f"{msg}: {len(errors)}")
raise RuntimeError(f"failed on {len(failed_ids)} kernels")
if len(failed_ids) == args.expected_failures:
print(colored(f"{len(failed_ids)} failed as expected", "yellow"))
if len(failed_ids) != args.expected_failures:
raise RuntimeError(f"failed on {len(failed_ids)} kernels, expected {args.expected_failures}")
else:
print(colored("all passed", "green"))