From 0914c2fec9a468bb95d339f07175dde3c5b080ce Mon Sep 17 00:00:00 2001 From: qazal <77887910+Qazalin@users.noreply.github.com> Date: Thu, 14 Nov 2024 06:00:33 +0200 Subject: [PATCH] add TestLinearizerFailures test_failure_56 and test_failure_57 (#7682) * add test_failure_56 and test_failure_57 * so it's only METAL=1 --- test/test_linearizer_failures.py | 92 ++++++++++++++++++++++++++++++++ 1 file changed, 92 insertions(+) diff --git a/test/test_linearizer_failures.py b/test/test_linearizer_failures.py index 05dbe3c826..26dd7b23af 100644 --- a/test/test_linearizer_failures.py +++ b/test/test_linearizer_failures.py @@ -1339,5 +1339,97 @@ class TestLinearizerFailures(unittest.TestCase): opts = [Opt(op=OptOps.SWAP, axis=1, amt=2)] helper_test_lin(Kernel(ast, opts=Device[Device.DEFAULT].renderer), opts=opts, failed_platforms=[]) + def test_failure_56(self): + ast = UOp(Ops.SINK, dtypes.void, arg=None, src=( + UOp(Ops.STORE, dtypes.void, arg=None, src=( + UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=0, src=()), + UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 16, 1, 1), strides=(0, 1, 0, 0), offset=0, mask=None, contiguous=True),)), src=()), + UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(Ops.ADD, (0, 2, 3)), src=( + UOp(Ops.MUL, dtypes.float, arg=None, src=( + UOp(Ops.CAST, dtypes.float, arg=None, src=( + UOp(Ops.CMPLT, dtypes.bool, arg=None, src=( + x7:=UOp(Ops.WHERE, dtypes.float, arg=None, src=( + x8:=UOp(Ops.VALID, dtypes.bool, arg=None, src=( + UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(128, 16, 11, 11), strides=(0, 0, 0, 0), offset=0, mask=None, contiguous=False),)), src=()),)), + x10:=UOp(Ops.CONST, dtypes.float, arg=0.0, src=()), + x10,)), + UOp(Ops.MAX, dtypes.float, arg=None, src=( + UOp(Ops.ADD, dtypes.float, arg=None, src=( + UOp(Ops.MUL, dtypes.float, arg=None, src=( + UOp(Ops.MUL, dtypes.float, arg=None, src=( + UOp(Ops.ADD, dtypes.float, arg=None, src=( + UOp(Ops.LOAD, dtypes.float, arg=None, src=( + UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=1, src=()), + UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(128, 16, 11, 11), strides=(1936, 121, 11, 1), offset=0, mask=None, contiguous=True),)), src=()),)), + UOp(Ops.MUL, dtypes.float, arg=None, src=( + UOp(Ops.LOAD, dtypes.float, arg=None, src=( + UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=2, src=()), + x22:=UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(128, 16, 11, 11), strides=(0, 1, 0, 0), offset=0, mask=None, contiguous=False),)), src=()),)), + UOp(Ops.WHERE, dtypes.float, arg=None, src=( + x8, + UOp(Ops.CONST, dtypes.float, arg=-1.0, src=()), + x10,)),)),)), + UOp(Ops.LOAD, dtypes.float, arg=None, src=( + UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=3, src=()), + x22,)),)), + UOp(Ops.LOAD, dtypes.float, arg=None, src=( + UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=4, src=()), + x22,)),)), + UOp(Ops.LOAD, dtypes.float, arg=None, src=( + UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=5, src=()), + x22,)),)), + x7,)),)),)), + UOp(Ops.LOAD, dtypes.float, arg=None, src=( + UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=6, src=()), + UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(128, 16, 5, 2, 5, 2), strides=(1600, 100, 20, 2, 4, 1), offset=0, mask=None, contiguous=False), View(shape=(128, 16, 11, 11), strides=(1600, 100, 10, 1), offset=0, mask=((0, 128), (0, 16), (0, 10), (0, 10)), contiguous=False))), src=()),)),)),)),)),)) + opts = [Opt(op=OptOps.UPCAST, axis=0, amt=0), Opt(op=OptOps.PADTO, axis=2, amt=32)] + helper_test_lin(Kernel(ast, opts=Device[Device.DEFAULT].renderer), opts=opts, failed_platforms=["METAL"]) + + def test_failure_57(self): + ast = UOp(Ops.SINK, dtypes.void, arg=None, src=( + UOp(Ops.STORE, dtypes.void, arg=None, src=( + UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=0, src=()), + UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 16, 1, 1), strides=(0, 1, 0, 0), offset=0, mask=None, contiguous=True),)), src=()), + UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(Ops.ADD, (0, 2, 3)), src=( + UOp(Ops.MUL, dtypes.float, arg=None, src=( + UOp(Ops.CAST, dtypes.float, arg=None, src=( + UOp(Ops.CMPLT, dtypes.bool, arg=None, src=( + x7:=UOp(Ops.WHERE, dtypes.float, arg=None, src=( + x8:=UOp(Ops.VALID, dtypes.bool, arg=None, src=( + UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(128, 16, 11, 11), strides=(0, 0, 0, 0), offset=0, mask=None, contiguous=False),)), src=()),)), + x10:=UOp(Ops.CONST, dtypes.float, arg=0.0, src=()), + x10,)), + UOp(Ops.MAX, dtypes.float, arg=None, src=( + UOp(Ops.ADD, dtypes.float, arg=None, src=( + UOp(Ops.MUL, dtypes.float, arg=None, src=( + UOp(Ops.MUL, dtypes.float, arg=None, src=( + UOp(Ops.ADD, dtypes.float, arg=None, src=( + UOp(Ops.LOAD, dtypes.float, arg=None, src=( + UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=1, src=()), + UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(128, 16, 11, 11), strides=(1936, 121, 11, 1), offset=0, mask=None, contiguous=True),)), src=()),)), + UOp(Ops.MUL, dtypes.float, arg=None, src=( + UOp(Ops.LOAD, dtypes.float, arg=None, src=( + UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=2, src=()), + x22:=UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(128, 16, 11, 11), strides=(0, 1, 0, 0), offset=0, mask=None, contiguous=False),)), src=()),)), + UOp(Ops.WHERE, dtypes.float, arg=None, src=( + x8, + UOp(Ops.CONST, dtypes.float, arg=-1.0, src=()), + x10,)),)),)), + UOp(Ops.LOAD, dtypes.float, arg=None, src=( + UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=3, src=()), + x22,)),)), + UOp(Ops.LOAD, dtypes.float, arg=None, src=( + UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=4, src=()), + x22,)),)), + UOp(Ops.LOAD, dtypes.float, arg=None, src=( + UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=5, src=()), + x22,)),)), + x7,)),)),)), + UOp(Ops.LOAD, dtypes.float, arg=None, src=( + UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=6, src=()), + UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(128, 16, 5, 2, 5, 2), strides=(1600, 100, 20, 2, 4, 1), offset=0, mask=None, contiguous=False), View(shape=(128, 16, 11, 11), strides=(1600, 100, 10, 1), offset=0, mask=((0, 128), (0, 16), (0, 10), (0, 10)), contiguous=False))), src=()),)),)),)),)),)) + opts = [Opt(op=OptOps.UPCAST, axis=0, amt=0), Opt(op=OptOps.PADTO, axis=1, amt=32)] + helper_test_lin(Kernel(ast, opts=Device[Device.DEFAULT].renderer), opts=opts, failed_platforms=["METAL"]) + if __name__ == '__main__': unittest.main()