add TestLinearizerFailures test_failure_56 and test_failure_57 (#7682)

* add test_failure_56 and test_failure_57

* so it's only METAL=1
This commit is contained in:
qazal
2024-11-14 06:00:33 +02:00
committed by GitHub
parent a87813f063
commit 0914c2fec9

View File

@@ -1339,5 +1339,97 @@ class TestLinearizerFailures(unittest.TestCase):
opts = [Opt(op=OptOps.SWAP, axis=1, amt=2)]
helper_test_lin(Kernel(ast, opts=Device[Device.DEFAULT].renderer), opts=opts, failed_platforms=[])
def test_failure_56(self):
ast = UOp(Ops.SINK, dtypes.void, arg=None, src=(
UOp(Ops.STORE, dtypes.void, arg=None, src=(
UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=0, src=()),
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 16, 1, 1), strides=(0, 1, 0, 0), offset=0, mask=None, contiguous=True),)), src=()),
UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(Ops.ADD, (0, 2, 3)), src=(
UOp(Ops.MUL, dtypes.float, arg=None, src=(
UOp(Ops.CAST, dtypes.float, arg=None, src=(
UOp(Ops.CMPLT, dtypes.bool, arg=None, src=(
x7:=UOp(Ops.WHERE, dtypes.float, arg=None, src=(
x8:=UOp(Ops.VALID, dtypes.bool, arg=None, src=(
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(128, 16, 11, 11), strides=(0, 0, 0, 0), offset=0, mask=None, contiguous=False),)), src=()),)),
x10:=UOp(Ops.CONST, dtypes.float, arg=0.0, src=()),
x10,)),
UOp(Ops.MAX, dtypes.float, arg=None, src=(
UOp(Ops.ADD, dtypes.float, arg=None, src=(
UOp(Ops.MUL, dtypes.float, arg=None, src=(
UOp(Ops.MUL, dtypes.float, arg=None, src=(
UOp(Ops.ADD, dtypes.float, arg=None, src=(
UOp(Ops.LOAD, dtypes.float, arg=None, src=(
UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=1, src=()),
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(128, 16, 11, 11), strides=(1936, 121, 11, 1), offset=0, mask=None, contiguous=True),)), src=()),)),
UOp(Ops.MUL, dtypes.float, arg=None, src=(
UOp(Ops.LOAD, dtypes.float, arg=None, src=(
UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=2, src=()),
x22:=UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(128, 16, 11, 11), strides=(0, 1, 0, 0), offset=0, mask=None, contiguous=False),)), src=()),)),
UOp(Ops.WHERE, dtypes.float, arg=None, src=(
x8,
UOp(Ops.CONST, dtypes.float, arg=-1.0, src=()),
x10,)),)),)),
UOp(Ops.LOAD, dtypes.float, arg=None, src=(
UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=3, src=()),
x22,)),)),
UOp(Ops.LOAD, dtypes.float, arg=None, src=(
UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=4, src=()),
x22,)),)),
UOp(Ops.LOAD, dtypes.float, arg=None, src=(
UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=5, src=()),
x22,)),)),
x7,)),)),)),
UOp(Ops.LOAD, dtypes.float, arg=None, src=(
UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=6, src=()),
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(128, 16, 5, 2, 5, 2), strides=(1600, 100, 20, 2, 4, 1), offset=0, mask=None, contiguous=False), View(shape=(128, 16, 11, 11), strides=(1600, 100, 10, 1), offset=0, mask=((0, 128), (0, 16), (0, 10), (0, 10)), contiguous=False))), src=()),)),)),)),)),))
opts = [Opt(op=OptOps.UPCAST, axis=0, amt=0), Opt(op=OptOps.PADTO, axis=2, amt=32)]
helper_test_lin(Kernel(ast, opts=Device[Device.DEFAULT].renderer), opts=opts, failed_platforms=["METAL"])
def test_failure_57(self):
ast = UOp(Ops.SINK, dtypes.void, arg=None, src=(
UOp(Ops.STORE, dtypes.void, arg=None, src=(
UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=0, src=()),
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 16, 1, 1), strides=(0, 1, 0, 0), offset=0, mask=None, contiguous=True),)), src=()),
UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(Ops.ADD, (0, 2, 3)), src=(
UOp(Ops.MUL, dtypes.float, arg=None, src=(
UOp(Ops.CAST, dtypes.float, arg=None, src=(
UOp(Ops.CMPLT, dtypes.bool, arg=None, src=(
x7:=UOp(Ops.WHERE, dtypes.float, arg=None, src=(
x8:=UOp(Ops.VALID, dtypes.bool, arg=None, src=(
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(128, 16, 11, 11), strides=(0, 0, 0, 0), offset=0, mask=None, contiguous=False),)), src=()),)),
x10:=UOp(Ops.CONST, dtypes.float, arg=0.0, src=()),
x10,)),
UOp(Ops.MAX, dtypes.float, arg=None, src=(
UOp(Ops.ADD, dtypes.float, arg=None, src=(
UOp(Ops.MUL, dtypes.float, arg=None, src=(
UOp(Ops.MUL, dtypes.float, arg=None, src=(
UOp(Ops.ADD, dtypes.float, arg=None, src=(
UOp(Ops.LOAD, dtypes.float, arg=None, src=(
UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=1, src=()),
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(128, 16, 11, 11), strides=(1936, 121, 11, 1), offset=0, mask=None, contiguous=True),)), src=()),)),
UOp(Ops.MUL, dtypes.float, arg=None, src=(
UOp(Ops.LOAD, dtypes.float, arg=None, src=(
UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=2, src=()),
x22:=UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(128, 16, 11, 11), strides=(0, 1, 0, 0), offset=0, mask=None, contiguous=False),)), src=()),)),
UOp(Ops.WHERE, dtypes.float, arg=None, src=(
x8,
UOp(Ops.CONST, dtypes.float, arg=-1.0, src=()),
x10,)),)),)),
UOp(Ops.LOAD, dtypes.float, arg=None, src=(
UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=3, src=()),
x22,)),)),
UOp(Ops.LOAD, dtypes.float, arg=None, src=(
UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=4, src=()),
x22,)),)),
UOp(Ops.LOAD, dtypes.float, arg=None, src=(
UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=5, src=()),
x22,)),)),
x7,)),)),)),
UOp(Ops.LOAD, dtypes.float, arg=None, src=(
UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=6, src=()),
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(128, 16, 5, 2, 5, 2), strides=(1600, 100, 20, 2, 4, 1), offset=0, mask=None, contiguous=False), View(shape=(128, 16, 11, 11), strides=(1600, 100, 10, 1), offset=0, mask=((0, 128), (0, 16), (0, 10), (0, 10)), contiguous=False))), src=()),)),)),)),)),))
opts = [Opt(op=OptOps.UPCAST, axis=0, amt=0), Opt(op=OptOps.PADTO, axis=1, amt=32)]
helper_test_lin(Kernel(ast, opts=Device[Device.DEFAULT].renderer), opts=opts, failed_platforms=["METAL"])
if __name__ == '__main__':
unittest.main()