COMMUTATIVE flipping is only for ints (#8996)

* COMMUTATIVE flipping is only for ints [pr]

* no pr

* comm fixes this
This commit is contained in:
George Hotz
2025-02-10 12:01:28 +08:00
committed by GitHub
parent 2983285315
commit e618efce22
2 changed files with 4 additions and 4 deletions

View File

@@ -1368,7 +1368,7 @@ class TestLinearizerFailures(unittest.TestCase):
UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=6, src=()),
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(128, 16, 5, 2, 5, 2), strides=(1600, 100, 20, 2, 4, 1), offset=0, mask=None, contiguous=False), View(shape=(128, 16, 11, 11), strides=(1600, 100, 10, 1), offset=0, mask=((0, 128), (0, 16), (0, 10), (0, 10)), contiguous=False))), src=()),)),)),)),)),))
opts = [Opt(op=OptOps.UPCAST, axis=0, arg=0), Opt(op=OptOps.PADTO, axis=2, arg=32)]
helper_test_lin(Kernel(ast, opts=Device[Device.DEFAULT].renderer), opts=opts, failed_platforms=["METAL"])
helper_test_lin(Kernel(ast, opts=Device[Device.DEFAULT].renderer), opts=opts, failed_platforms=[])
def test_failure_57(self):
ast = UOp(Ops.SINK, dtypes.void, arg=None, src=(
@@ -1409,7 +1409,7 @@ class TestLinearizerFailures(unittest.TestCase):
UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=6, src=()),
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(128, 16, 5, 2, 5, 2), strides=(1600, 100, 20, 2, 4, 1), offset=0, mask=None, contiguous=False), View(shape=(128, 16, 11, 11), strides=(1600, 100, 10, 1), offset=0, mask=((0, 128), (0, 16), (0, 10), (0, 10)), contiguous=False))), src=()),)),)),)),)),))
opts = [Opt(op=OptOps.UPCAST, axis=0, arg=0), Opt(op=OptOps.PADTO, axis=1, arg=32)]
helper_test_lin(Kernel(ast, opts=Device[Device.DEFAULT].renderer), opts=opts, failed_platforms=["METAL"])
helper_test_lin(Kernel(ast, opts=Device[Device.DEFAULT].renderer), opts=opts, failed_platforms=[])
if __name__ == '__main__':
unittest.main()