diff --git a/test/external/fuzz_linearizer.py b/test/external/fuzz_linearizer.py index d237af50df..9e186310c2 100644 --- a/test/external/fuzz_linearizer.py +++ b/test/external/fuzz_linearizer.py @@ -12,6 +12,7 @@ from tinygrad.engine.graph import print_tree from tinygrad.engine.realize import CompiledRunner from tinygrad.helpers import getenv, from_mv, prod, colored, Context, DEBUG from tinygrad.ops import LazyOp, UnaryOps, BufferOps +from test.helpers import is_dtype_supported def tuplize_uops(uops:List[UOp]) -> Tuple: return tuple([(x.uop, x.dtype, tuple(uops.index(x) for x in x.vin), x.arg) for x in uops]) @@ -211,9 +212,13 @@ if __name__ == "__main__": if ast in seen_ast_strs: continue seen_ast_strs.add(ast) + lin = ast_str_to_lin(ast) + if not all(is_dtype_supported(buf.dtype) for buf in lin.bufs): + print("skipping kernel due to not supported dtype") + continue + print(f"testing ast {i}") tested += 1 - lin = ast_str_to_lin(ast) fuzz_failures = fuzz_linearizer(lin, rtol=args.rtol, atol=args.atol) if fuzz_failures: failed_ids.append(i)