From 9851e2c3b9d62511692feb163b271447b35f76c0 Mon Sep 17 00:00:00 2001 From: Francis Lam Date: Tue, 19 Mar 2024 10:19:54 -0700 Subject: [PATCH] test_linearizer_failure: add failure 26 from a gpt2 kernel (#3821) found during a full fuzz test of all applied_opts combos to a depth of 3 on the gpt2 kernels --- test/test_linearizer_failures.py | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/test/test_linearizer_failures.py b/test/test_linearizer_failures.py index d77567974f..f77dec588f 100644 --- a/test/test_linearizer_failures.py +++ b/test/test_linearizer_failures.py @@ -192,5 +192,15 @@ class TestLinearizerFailures(unittest.TestCase): opts = [Opt(op=OptOps.GROUP, axis=0, amt=16), Opt(op=OptOps.UNROLL, axis=0, amt=4)] helper_test_lin(Linearizer(ast), opts, failed_platforms=[]) + # COMPARE_ERROR from GPT2 kernel - stems from uops.py self.simplify_phi_loops + def test_failure_26(self): + ast = LazyOp(op=BufferOps.STORE, src=(LazyOp(op=BinaryOps.ADD, src=(LazyOp(op=ReduceOps.SUM, src=(LazyOp(op=BufferOps.CONST, src=(), arg=ConstBuffer(val=1, dtype=dtypes.int, st=ShapeTracker(views=(View(shape=(129, 255), strides=(0, 0), offset=0, mask=((0, 129), (127, 255)), contiguous=False), View(shape=(128, 128), strides=(1, 256), offset=0, mask=None, contiguous=False))))),), arg=(1,)), LazyOp(op=BufferOps.CONST, src=(), arg=ConstBuffer(val=-1, dtype=dtypes.int, st=ShapeTracker(views=(View(shape=(128, 1), strides=(0, 0), offset=0, mask=None, contiguous=False),))))), arg=None),), arg=MemBuffer(idx=0, dtype=dtypes.int, st=ShapeTracker(views=(View(shape=(128, 1), strides=(1, 0), offset=0, mask=None, contiguous=True),)))) + all_failing_opts = [ + [Opt(op=OptOps.UPCAST, axis=0, amt=4), Opt(op=OptOps.GROUPTOP, axis=0, amt=32), Opt(op=OptOps.UNROLL, axis=0, amt=0)], + [Opt(op=OptOps.GROUPTOP, axis=0, amt=32), Opt(op=OptOps.UNROLL, axis=0, amt=0), Opt(op=OptOps.UPCAST, axis=0, amt=4)], + ] + for opts in all_failing_opts: + helper_test_lin(Linearizer(ast), opts, failed_platforms=["METAL", "GPU", "HSA", "CUDA"]) + if __name__ == '__main__': unittest.main()