From 30fa03243eeb26cb354c759d20203e623417df95 Mon Sep 17 00:00:00 2001 From: chenyu Date: Thu, 21 Mar 2024 14:12:27 -0400 Subject: [PATCH] reuse fuzz_linearizer.compare_linearizer in test_linearizer_failures (#3861) --- test/external/fuzz_linearizer.py | 16 ++++++++++++++-- test/test_linearizer_failures.py | 19 ++++++------------- 2 files changed, 20 insertions(+), 15 deletions(-) diff --git a/test/external/fuzz_linearizer.py b/test/external/fuzz_linearizer.py index 1aef1a750f..a44ea2ae18 100644 --- a/test/external/fuzz_linearizer.py +++ b/test/external/fuzz_linearizer.py @@ -25,6 +25,7 @@ def get_fuzz_rawbufs(lin): rawbufs[0] = get_fuzz_rawbuf_like(rawbufs[0], zero=True, size=rawbufs[0].size+RED_AREA_SIZE) with Context(DEBUG=0): for rawbuf in rawbufs[1:]: + # TODO: better rand values based on dtype t = Tensor.uniform((rawbuf.size,), dtype=rawbuf.dtype) if isinstance(ld:=t.realize().lazydata, LazyBuffer) and ld.realized: rawbuf.copyin(ld.realized.as_buffer()) return rawbufs @@ -58,6 +59,7 @@ def run_linearizer(lin: Linearizer, rawbufs=None, var_vals=None): return "PASS" def compare_linearizer(lin: Linearizer, rawbufs=None, var_vals=None, ground_truth=None, rtol=1e-2, atol=1e-2): + # TODO: raise specific fuzzing errors instead of str, and propagate the error message try: if rawbufs is None: rawbufs = get_fuzz_rawbufs(lin) @@ -65,7 +67,11 @@ def compare_linearizer(lin: Linearizer, rawbufs=None, var_vals=None, ground_trut rawbufs[0] = get_fuzz_rawbuf_like(rawbufs[0], zero=True) # get a new output buffer except BaseException: return ("RAWBUFS_ERROR", rawbufs, var_vals, ground_truth,) - if var_vals is None: var_vals = {v: random.randint(v.min, v.max if isinstance(v.max, int) else v.min) for v in lin.ast[0].vars()} + + if var_vals is None: + # TODO: handle symbolic max case + var_vals = {v: random.randint(v.min, v.max if isinstance(v.max, int) else v.min) for v in lin.ast[0].vars()} + if ground_truth is None: unoptimized = Linearizer(*lin.ast) unoptimized.required_optimizations() @@ -77,7 +83,13 @@ def compare_linearizer(lin: Linearizer, rawbufs=None, var_vals=None, ground_trut if (run_msg := run_linearizer(lin, rawbufs, var_vals)) != "PASS": return (run_msg, rawbufs, var_vals, ground_truth,) result = np.frombuffer(rawbufs[0].as_buffer(), rawbufs[0].dtype.np) - return ("PASS" if np.allclose(result, ground_truth, rtol=rtol, atol=atol) else "COMPARE_ERROR", rawbufs, var_vals, ground_truth,) + + try: + np.testing.assert_allclose(result, ground_truth, rtol=rtol, atol=atol) + except AssertionError: + return ("COMPARE_ERROR", rawbufs, var_vals, ground_truth,) + + return ("PASS", rawbufs, var_vals, ground_truth,) def fuzz_linearizer(lin: Linearizer): SEED = getenv("SEED", 42) diff --git a/test/test_linearizer_failures.py b/test/test_linearizer_failures.py index 65e3f6301f..0f0079f570 100644 --- a/test/test_linearizer_failures.py +++ b/test/test_linearizer_failures.py @@ -6,7 +6,7 @@ from tinygrad.codegen.linearizer import Linearizer from tinygrad.features.search import Opt, OptOps from tinygrad import Device, dtypes, Tensor from tinygrad.helpers import CI -from test.external.fuzz_linearizer import run_linearizer, get_fuzz_rawbufs, get_fuzz_rawbuf_like +from test.external.fuzz_linearizer import compare_linearizer from test.helpers import is_dtype_supported from tinygrad.ops import LazyOp, BinaryOps, UnaryOps, ReduceOps, BufferOps, MemBuffer, ConstBuffer @@ -17,12 +17,6 @@ def helper_test_lin(lin: Linearizer, opts, failed_platforms): if any(b.dtype == dtypes.half for b in lin.membufs) and not is_dtype_supported(dtypes.half): return - rawbufs = get_fuzz_rawbufs(lin) - var_vals = {v: random.randint(v.min, v.max) for v in lin.ast[0].vars()} - - assert run_linearizer(lin, rawbufs, var_vals) == "PASS" or Device.DEFAULT in failed_platforms, "Failed running non-optimized ast" - ground_truth = np.frombuffer(rawbufs[0].as_buffer(), rawbufs[0].dtype.np) - for opt in opts: try: lin.apply_opt(opt) @@ -31,13 +25,12 @@ def helper_test_lin(lin: Linearizer, opts, failed_platforms): assert Device.DEFAULT not in failed_platforms return - rawbufs[0] = get_fuzz_rawbuf_like(rawbufs[0], zero=True) - linearizer_passed = (run_linearizer(lin, rawbufs, var_vals) == "PASS") - output_allclose = np.allclose(ground_truth, np.frombuffer(rawbufs[0].as_buffer(), rawbufs[0].dtype.np), rtol=1e-2, atol=1e-2) - if Device.DEFAULT not in failed_platforms: - assert linearizer_passed and output_allclose, f"{linearizer_passed=}, {output_allclose=}" + compare_result = compare_linearizer(lin) + if compare_result[0] in ["PASS", "KernelOptError"]: + # it's considered fixed if we invalidated the opts + assert Device.DEFAULT not in failed_platforms, f"unexpected success on {Device.DEFAULT}" else: - assert not linearizer_passed or not output_allclose, f"{linearizer_passed=}, {output_allclose=}" + assert Device.DEFAULT in failed_platforms, f"failed on {Device.DEFAULT} with {compare_result[0]}" @unittest.skipIf(CI and Device.DEFAULT=="CUDA", "failed on CUDA CI") class TestLinearizerFailures(unittest.TestCase):