diff --git a/test/external/external_test_valid_remove.py b/test/external/external_test_valid_remove.py index e9e5ed066c..24f94c640a 100644 --- a/test/external/external_test_valid_remove.py +++ b/test/external/external_test_valid_remove.py @@ -2,7 +2,7 @@ import unittest from tinygrad import Device -from tinygrad.ops import UOp, Ops, BinaryOps, UnaryOps +from tinygrad.ops import UOp, Ops, BinaryOps from tinygrad.engine.search import Opt, OptOps from tinygrad.dtype import dtypes from tinygrad.shape.shapetracker import ShapeTracker @@ -17,12 +17,12 @@ class TestOpenpilotValidhack(unittest.TestCase): UOp(Ops.STORE, dtypes.void, arg=None, src=( UOp(Ops.DEFINE_GLOBAL, dtypes.imagef((64, 1024, 4)), arg=0, src=()), UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 64, 128, 1, 1, 8, 4, 1, 1, 1, 1), strides=(0, 4096, 32, 0, 0, 4, 1, 0, 0, 0, 0), offset=0, mask=None, contiguous=True),)), src=()), - UOp(Ops.ALU, dtypes.float, arg=BinaryOps.ADD, src=( - UOp(Ops.ALU, dtypes.float, arg=BinaryOps.MAX, src=( - x5:=UOp(Ops.ALU, dtypes.float, arg=BinaryOps.ADD, src=( + UOp(Ops.ADD, dtypes.float, arg=None, src=( + UOp(Ops.MAX, dtypes.float, arg=None, src=( + x5:=UOp(Ops.ADD, dtypes.float, arg=None, src=( UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (7, 8, 9, 10)), src=( UOp(Ops.CAST, dtypes.float, arg=None, src=( - UOp(Ops.ALU, dtypes.float, arg=BinaryOps.MUL, src=( + UOp(Ops.MUL, dtypes.float, arg=None, src=( UOp(Ops.LOAD, dtypes.float, arg=None, src=( UOp(Ops.DEFINE_GLOBAL, dtypes.imagef((128, 768, 4)), arg=1, src=()), UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 1, 1, 1, 1, 3, 1, 4, 4, 130, 4, 258), strides=(0, 0, 0, 0, 0, 4, 0, 1, 0, 3072, 0, 12), offset=-3084, mask=((0, 1), (0, 1), (0, 1), (0, 1), (0, 1), (0, 3), (0, 1), (0, 4), (0, 4), (1, 129), (0, 4), (1, 257)), contiguous=False), View(shape=(1, 64, 128, 1, 1, 8, 4, 3, 4, 3, 3), strides=(0, 2064, 2, 0, 0, 0, 0, 2146560, 536640, 135192, 259), offset=0, mask=None, contiguous=False))), src=()),)), @@ -35,14 +35,14 @@ class TestOpenpilotValidhack(unittest.TestCase): UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 64, 128, 1, 1, 8, 4, 1, 1, 1, 1), strides=(0, 0, 0, 0, 0, 4, 1, 0, 0, 0, 0), offset=0, mask=None, contiguous=False),)), src=()),)),)), x19:=UOp(Ops.CONST, dtypes.float, arg=0.0, src=( x20:=UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 64, 128, 1, 1, 8, 4, 1, 1, 1, 1), strides=(0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0), offset=0, mask=None, contiguous=False),)), src=()),)),)), - UOp(Ops.ALU, dtypes.float, arg=BinaryOps.MUL, src=( - UOp(Ops.ALU, dtypes.float, arg=BinaryOps.MAX, src=( - UOp(Ops.ALU, dtypes.float, arg=BinaryOps.ADD, src=( + UOp(Ops.MUL, dtypes.float, arg=None, src=( + UOp(Ops.MAX, dtypes.float, arg=None, src=( + UOp(Ops.ADD, dtypes.float, arg=None, src=( UOp(Ops.CONST, dtypes.float, arg=1.0, src=( x20,)), - UOp(Ops.ALU, dtypes.float, arg=BinaryOps.MUL, src=( - UOp(Ops.ALU, dtypes.float, arg=UnaryOps.EXP2, src=( - UOp(Ops.ALU, dtypes.float, arg=BinaryOps.MUL, src=( + UOp(Ops.MUL, dtypes.float, arg=None, src=( + UOp(Ops.EXP2, dtypes.float, arg=None, src=( + UOp(Ops.MUL, dtypes.float, arg=None, src=( x5, UOp(Ops.CONST, dtypes.float, arg=1.4426950408889634, src=( x20,)),)),)), @@ -67,19 +67,19 @@ class TestOpenpilotValidhack(unittest.TestCase): UOp(Ops.DEFINE_GLOBAL, dtypes.imagef((10, 128, 4)), arg=0, src=()), UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 10, 512), strides=(0, 512, 1), offset=0, mask=None, contiguous=True),)), src=()), UOp(Ops.CAST, dtypes.float, arg=None, src=( - UOp(Ops.ALU, dtypes.float, arg=BinaryOps.ADD, src=( - UOp(Ops.ALU, dtypes.float, arg=BinaryOps.ADD, src=( + UOp(Ops.ADD, dtypes.float, arg=None, src=( + UOp(Ops.ADD, dtypes.float, arg=None, src=( UOp(Ops.LOAD, dtypes.float, arg=None, src=( UOp(Ops.DEFINE_GLOBAL, dtypes.imagef((1, 128, 4)), arg=1, src=()), UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 10, 512), strides=(0, 0, 1), offset=0, mask=((0, 1), (0, 1), (0, 512)), contiguous=False),)), src=()),)), UOp(Ops.CAST, dtypes.float, arg=None, src=( - UOp(Ops.ALU, dtypes.float, arg=BinaryOps.ADD, src=( - UOp(Ops.ALU, dtypes.float, arg=BinaryOps.ADD, src=( - UOp(Ops.ALU, dtypes.float, arg=BinaryOps.ADD, src=( - UOp(Ops.ALU, dtypes.float, arg=BinaryOps.ADD, src=( - UOp(Ops.ALU, dtypes.float, arg=BinaryOps.ADD, src=( - UOp(Ops.ALU, dtypes.float, arg=BinaryOps.ADD, src=( - UOp(Ops.ALU, dtypes.float, arg=BinaryOps.ADD, src=( + UOp(Ops.ADD, dtypes.float, arg=None, src=( + UOp(Ops.ADD, dtypes.float, arg=None, src=( + UOp(Ops.ADD, dtypes.float, arg=None, src=( + UOp(Ops.ADD, dtypes.float, arg=None, src=( + UOp(Ops.ADD, dtypes.float, arg=None, src=( + UOp(Ops.ADD, dtypes.float, arg=None, src=( + UOp(Ops.ADD, dtypes.float, arg=None, src=( UOp(Ops.LOAD, dtypes.float, arg=None, src=( x18:=UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=2, src=()), UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 10, 512), strides=(0, 0, 1), offset=48128, mask=((0, 1), (1, 2), (0, 512)), contiguous=False),)), src=()),)), diff --git a/test/test_linearizer.py b/test/test_linearizer.py index 7e8f518cbd..ca2aaa49af 100644 --- a/test/test_linearizer.py +++ b/test/test_linearizer.py @@ -600,16 +600,16 @@ class TestLinearizer(unittest.TestCase): UOp(Ops.STORE, dtypes.void, arg=None, src=( UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(), arg=0, src=()), UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 20, 1), strides=(0, 1, 0), offset=0, mask=None, contiguous=True),)), src=()), # noqa E501 - UOp(Ops.ALU, dtypes.int, arg=BinaryOps.ADD, src=( - UOp(Ops.ALU, dtypes.int, arg=BinaryOps.ADD, src=( + UOp(Ops.ADD, dtypes.int, arg=None, src=( + UOp(Ops.ADD, dtypes.int, arg=None, src=( ast_const(dtypes.int, st=ShapeTracker(views=(View(shape=(1, 20, 1), strides=(0, 0, 0), offset=0, mask=None, contiguous=False),)), val=10), - UOp(Ops.ALU, dtypes.int, arg=BinaryOps.MUL, src=( + UOp(Ops.MUL, dtypes.int, arg=None, src=( ast_const(dtypes.int, -1, (1, 20, 1)), UOp(Ops.REDUCE_AXIS, dtypes.int, arg=(BinaryOps.MAX, (0,)), src=( - UOp(Ops.ALU, dtypes.int, arg=BinaryOps.MUL, src=( + UOp(Ops.MUL, dtypes.int, arg=None, src=( UOp(Ops.CAST, dtypes.int, arg=None, src=( - UOp(Ops.ALU, dtypes.bool, arg=BinaryOps.CMPNE, src=( - UOp(Ops.ALU, dtypes.bool, arg=BinaryOps.CMPNE, src=( + UOp(Ops.CMPNE, dtypes.bool, arg=None, src=( + UOp(Ops.CMPNE, dtypes.bool, arg=None, src=( UOp(Ops.LOAD, dtypes.float, arg=None, src=( UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=1, src=()), UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(10, 20, 1), strides=(20, 1, 0), offset=0, mask=None, contiguous=True),)), src=()),)), # noqa E501 @@ -617,7 +617,7 @@ class TestLinearizer(unittest.TestCase): UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=2, src=()), UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(10, 20, 1), strides=(0, 1, 0), offset=0, mask=None, contiguous=False),)), src=()),)),)), # noqa E501 ast_const(dtypes.bool, True, st=ShapeTracker(views=(View(shape=(10, 20, 1), strides=(0, 0, 0), offset=0, mask=None, contiguous=False),))),)),)), # noqa E501 - UOp(Ops.ALU, dtypes.int, arg=BinaryOps.ADD, src=( + UOp(Ops.ADD, dtypes.int, arg=None, src=( UOp(Ops.REDUCE_AXIS, dtypes.int, arg=(BinaryOps.ADD, (2,)), src=( ast_const(dtypes.int, -1, st=ShapeTracker(views=(View(shape=(11, 19), strides=(0, 0), offset=0, mask=((0, 11), (9, 19)), contiguous=False), View(shape=(10, 20, 10), strides=(1, 0, 20), offset=0, mask=None, contiguous=False)))),)), # noqa E501 ast_const(dtypes.int, 10, (10, 20, 1)))),)),)),)),)), @@ -632,16 +632,16 @@ class TestLinearizer(unittest.TestCase): UOp(Ops.STORE, dtypes.void, arg=None, src=( UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(), arg=0, src=()), UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 1), strides=(0, 0), offset=0, mask=None, contiguous=True),)), src=()), # noqa: E501 - UOp(Ops.ALU, dtypes.int, arg=BinaryOps.ADD, src=( - UOp(Ops.ALU, dtypes.int, arg=BinaryOps.ADD, src=( + UOp(Ops.ADD, dtypes.int, arg=None, src=( + UOp(Ops.ADD, dtypes.int, arg=None, src=( ast_const(dtypes.int, 200, (1, 1)), - UOp(Ops.ALU, dtypes.int, arg=BinaryOps.MUL, src=( + UOp(Ops.MUL, dtypes.int, arg=None, src=( ast_const(dtypes.int, -1, (1, 1)), UOp(Ops.REDUCE_AXIS, dtypes.int, arg=(BinaryOps.MAX, (0,)), src=( - UOp(Ops.ALU, dtypes.int, arg=BinaryOps.MUL, src=( + UOp(Ops.MUL, dtypes.int, arg=None, src=( UOp(Ops.CAST, dtypes.int, arg=None, src=( - UOp(Ops.ALU, dtypes.bool, arg=BinaryOps.CMPNE, src=( - UOp(Ops.ALU, dtypes.bool, arg=BinaryOps.CMPNE, src=( + UOp(Ops.CMPNE, dtypes.bool, arg=None, src=( + UOp(Ops.CMPNE, dtypes.bool, arg=None, src=( UOp(Ops.LOAD, dtypes.float, arg=None, src=( UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=1, src=()), UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(200, 1), strides=(1, 0), offset=0, mask=None, contiguous=True),)), src=()),)), # noqa: E501 @@ -649,7 +649,7 @@ class TestLinearizer(unittest.TestCase): UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=2, src=()), UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(200, 1), strides=(0, 0), offset=0, mask=None, contiguous=False),)), src=()),)),)), # noqa: E501 ast_const(dtypes.bool, True, (200, 1)),)),)), - UOp(Ops.ALU, dtypes.int, arg=BinaryOps.ADD, src=( + UOp(Ops.ADD, dtypes.int, arg=None, src=( UOp(Ops.REDUCE_AXIS, dtypes.int, arg=(BinaryOps.ADD, (1,)), src=( ast_const(dtypes.int, -1, st=ShapeTracker(views=(View(shape=(201, 399), strides=(0, 0), offset=0, mask=((0, 201), (199, 399)), contiguous=False), View(shape=(200, 200), strides=(1, 400), offset=0, mask=None, contiguous=False)))),)), # noqa: E501 ast_const(dtypes.int, 200, (200, 1)),)),)),)),)),)), @@ -732,16 +732,16 @@ class TestLinearizer(unittest.TestCase): UOp(Ops.STORE, src=( UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=0), UOp(Ops.VIEW, arg=ShapeTracker(views=(View(shape=(N, 1, 1), strides=(1, 0, 0), offset=0, mask=None, contiguous=True),))), - UOp(Ops.ALU, dtypes.float, arg=TernaryOps.WHERE, src=( - UOp(Ops.ALU, dtypes.bool, arg=BinaryOps.CMPLT, src=( + UOp(Ops.WHERE, dtypes.float, arg=None, src=( + UOp(Ops.CMPLT, dtypes.bool, arg=None, src=( ast_const(dtypes.float, 0.5*N, (N, 1, 1)), UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (1,)), src=( - UOp(Ops.ALU, dtypes.float, arg=BinaryOps.ADD, src=( + UOp(Ops.ADD, dtypes.float, arg=None, src=( UOp(Ops.LOAD, dtypes.float, src=( UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=1), ld1.to_uop(),)), - UOp(Ops.ALU, dtypes.float, arg=TernaryOps.WHERE, src=( - UOp(Ops.ALU, dtypes.bool, arg=BinaryOps.CMPLT, src=( + UOp(Ops.WHERE, dtypes.float, arg=None, src=( + UOp(Ops.CMPLT, dtypes.bool, arg=None, src=( ast_const(dtypes.float, 0.75*N, (N, N, 1)), UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (2,)), src=( UOp(Ops.LOAD, dtypes.float, src=( @@ -765,16 +765,16 @@ class TestLinearizer(unittest.TestCase): UOp(Ops.STORE, src=( UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=0, src=()), UOp(Ops.VIEW, arg=ShapeTracker(views=(View(shape=(1, 1, N), strides=(0, 0, 1), offset=0, mask=None, contiguous=True),)), src=()), - UOp(Ops.ALU, dtypes.float, arg=TernaryOps.WHERE, src=( - UOp(Ops.ALU, dtypes.bool, arg=BinaryOps.CMPLT, src=( + UOp(Ops.WHERE, dtypes.float, arg=None, src=( + UOp(Ops.CMPLT, dtypes.bool, arg=None, src=( ast_const(dtypes.float, 0.5*N, (1, 1, N)), UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (0,)), src=( - UOp(Ops.ALU, dtypes.float, arg=BinaryOps.ADD, src=( + UOp(Ops.ADD, dtypes.float, arg=None, src=( UOp(Ops.LOAD, dtypes.float, src=( UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=1, src=()), ld1.to_uop(),)), - UOp(Ops.ALU, dtypes.float, arg=TernaryOps.WHERE, src=( - UOp(Ops.ALU, dtypes.bool, arg=BinaryOps.CMPLT, src=( + UOp(Ops.WHERE, dtypes.float, arg=None, src=( + UOp(Ops.CMPLT, dtypes.bool, arg=None, src=( ast_const(dtypes.float, 0.75*N, (N, 1, N)), UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (1,)), src=( UOp(Ops.LOAD, dtypes.float, src=( @@ -801,16 +801,16 @@ class TestLinearizer(unittest.TestCase): UOp(Ops.STORE, src=( UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=0, src=()), UOp(Ops.VIEW, arg=ShapeTracker(views=(View(shape=(1, 1, 1, 1), strides=(0, 0, 0, 0), offset=0, mask=None, contiguous=True),))), - UOp(Ops.ALU, dtypes.float, arg=TernaryOps.WHERE, src=( - UOp(Ops.ALU, dtypes.bool, arg=BinaryOps.CMPLT, src=( + UOp(Ops.WHERE, dtypes.float, arg=None, src=( + UOp(Ops.CMPLT, dtypes.bool, arg=None, src=( ast_const(dtypes.float, 0.5*N, (1, 1, 1, 1)), UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (0, 1)), src=( - UOp(Ops.ALU, dtypes.float, arg=BinaryOps.ADD, src=( + UOp(Ops.ADD, dtypes.float, arg=None, src=( UOp(Ops.LOAD, dtypes.float, src=( UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=1), UOp(Ops.VIEW, arg=ShapeTracker(views=(View(shape=(N, N, 1, 1), strides=(N, 1, 0, 0), offset=0, mask=None, contiguous=True),))),)), # noqa: E501 - UOp(Ops.ALU, dtypes.float, arg=TernaryOps.WHERE, src=( - UOp(Ops.ALU, dtypes.bool, arg=BinaryOps.CMPLT, src=( + UOp(Ops.WHERE, dtypes.float, arg=None, src=( + UOp(Ops.CMPLT, dtypes.bool, arg=None, src=( ast_const(dtypes.float, 0.75*N, (N, N, 1, 1)), UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (2, 3)), src=( UOp(Ops.LOAD, dtypes.float, src=( @@ -1604,7 +1604,7 @@ class TestFloat4(unittest.TestCase): UOp(Ops.VIEW, arg=ShapeTracker(views=(View(shape=(1, 3, 32000, 1), strides=(0, 32000, 1, 0), offset=0, mask=None, contiguous=True),))), # noqa: E501 UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (3,)), src=( UOp(Ops.CAST, dtypes.float, src=( - UOp(Ops.ALU, dtypes.half, arg=BinaryOps.MUL, src=( + UOp(Ops.MUL, dtypes.half, arg=None, src=( UOp(Ops.LOAD, dtypes.half, src=( UOp(Ops.DEFINE_GLOBAL, dtypes.half.ptr(), arg=1), UOp(Ops.VIEW, arg=ShapeTracker(views=(View(shape=(1, 3, 32000, 1024), strides=(0, 4096, 0, 1), offset=0, mask=None, contiguous=False),))),)), # noqa: E501 @@ -1631,9 +1631,9 @@ class TestFloat4(unittest.TestCase): UOp(Ops.STORE, src=( UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=0), UOp(Ops.VIEW, arg=ShapeTracker(views=(View(shape=(1, 1, 128, 512, 512, 1, 1, 1), strides=(0, 0, 262144, 512, 1, 0, 0, 0), offset=0, mask=None, contiguous=True),))), # noqa: E501 - UOp(Ops.ALU, dtypes.float, arg=BinaryOps.ADD, src=( + UOp(Ops.ADD, dtypes.float, arg=None, src=( UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (5, 6, 7)), src=( - UOp(Ops.ALU, dtypes.float, arg=BinaryOps.MUL, src=( + UOp(Ops.MUL, dtypes.float, arg=None, src=( UOp(Ops.LOAD, dtypes.float, src=( UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=1), UOp(Ops.VIEW, arg=ShapeTracker(views=(View(shape=(1, 1, 1, 256, 4, 514, 4, 514), strides=(0, 0, 0, 262144, 0, 512, 0, 1), offset=-513, mask=((0, 1), (0, 1), (0, 1), (0, 256), (0, 4), (1, 513), (0, 4), (1, 513)), contiguous=False), View(shape=(1, 1, 128, 512, 512, 256, 3, 3), strides=(0, 0, 0, 2056, 1, 4227136, 1058840, 515), offset=0, mask=None, contiguous=False)))),)), # noqa: E501 @@ -1950,9 +1950,9 @@ class TestKernelOpts(unittest.TestCase): UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=0, src=()), UOp(Ops.VIEW, arg=ShapeTracker(views=(View(shape=(1, 256), strides=(0, 1), offset=0, mask=None, contiguous=True),))), UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (0,)), src=( - UOp(Ops.ALU, dtypes.float, arg=BinaryOps.MUL, src=( + UOp(Ops.MUL, dtypes.float, arg=None, src=( UOp(Ops.CAST, dtypes.float, src=( - UOp(Ops.ALU, dtypes.bool, arg=BinaryOps.CMPNE, src=( + UOp(Ops.CMPNE, dtypes.bool, arg=None, src=( UOp(Ops.LOAD, dtypes.int, src=( UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(), arg=1), UOp(Ops.VIEW, arg=ShapeTracker(views=(View(shape=(1243, 256), strides=(0, 1), offset=0, mask=None, contiguous=False),))),)), # noqa: E501 diff --git a/test/test_linearizer_dumb.py b/test/test_linearizer_dumb.py index eaf5df9818..ce45ee53f7 100644 --- a/test/test_linearizer_dumb.py +++ b/test/test_linearizer_dumb.py @@ -5,7 +5,7 @@ import unittest from test.helpers import ast_const from tinygrad import Device, dtypes -from tinygrad.ops import UOp, Ops, BinaryOps, TernaryOps +from tinygrad.ops import UOp, Ops, BinaryOps from tinygrad.helpers import getenv from tinygrad.shape.shapetracker import ShapeTracker, View from tinygrad.engine.search import Opt, OptOps @@ -18,12 +18,12 @@ class TestLinearizerDumb(unittest.TestCase): UOp(Ops.STORE, dtypes.void, arg=None, src=( UOp(Ops.DEFINE_GLOBAL, dtypes.half.ptr(), arg=0, src=()), UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(64, 1, 512, 7, 7, 1, 1, 1), strides=(25088, 0, 49, 7, 1, 0, 0, 0), offset=0, mask=None, contiguous=True),)), src=()), - UOp(Ops.ALU, dtypes.half, arg=BinaryOps.MAX, src=( - UOp(Ops.ALU, dtypes.half, arg=BinaryOps.MUL, src=( + UOp(Ops.MAX, dtypes.half, arg=None, src=( + UOp(Ops.MUL, dtypes.half, arg=None, src=( UOp(Ops.CAST, dtypes.half, arg=None, src=( UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (5, 6, 7)), src=( UOp(Ops.CAST, dtypes.float, arg=None, src=( - UOp(Ops.ALU, dtypes.half, arg=BinaryOps.MUL, src=( + UOp(Ops.MUL, dtypes.half, arg=None, src=( UOp(Ops.LOAD, dtypes.half, arg=None, src=( UOp(Ops.DEFINE_GLOBAL, dtypes.half.ptr(), arg=1, src=()), UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 64, 1, 512, 4, 9, 4, 9), strides=(0, 25088, 0, 49, 0, 7, 0, 1), offset=-8, mask=((0, 1), (0, 64), (0, 1), (0, 512), (0, 4), (1, 8), (0, 4), (1, 8)), contiguous=False), View(shape=(64, 1, 512, 7, 7, 512, 3, 3), strides=(663552, 0, 0, 36, 1, 1296, 360, 10), offset=0, mask=None, contiguous=False))), src=()),)), @@ -51,10 +51,10 @@ class TestLinearizerDumb(unittest.TestCase): UOp(Ops.STORE, dtypes.void, arg=None, src=( UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(), arg=0, src=()), UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1000, 1), strides=(1, 0), offset=0, mask=None, contiguous=True),)), src=()), - UOp(Ops.ALU, dtypes.int, arg=BinaryOps.MUL, src=( + UOp(Ops.MUL, dtypes.int, arg=None, src=( UOp(Ops.CAST, dtypes.int, arg=None, src=( - UOp(Ops.ALU, dtypes.bool, arg=BinaryOps.CMPNE, src=( - UOp(Ops.ALU, dtypes.bool, arg=BinaryOps.CMPNE, src=( + UOp(Ops.CMPNE, dtypes.bool, arg=None, src=( + UOp(Ops.CMPNE, dtypes.bool, arg=None, src=( UOp(Ops.LOAD, dtypes.float, arg=None, src=( UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=1, src=()), UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1000, 1), strides=(1, 0), offset=0, mask=None, contiguous=True),)), src=()),)), @@ -63,7 +63,7 @@ class TestLinearizerDumb(unittest.TestCase): UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1000, 1), strides=(0, 0), offset=0, mask=None, contiguous=False),)), src=()),)),)), ast_const(dtypes.bool, True, st_src=( UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1000, 1), strides=(0, 0), offset=0, mask=None, contiguous=False),)), src=()),)),)),)), - UOp(Ops.ALU, dtypes.int, arg=BinaryOps.ADD, src=( + UOp(Ops.ADD, dtypes.int, arg=None, src=( UOp(Ops.REDUCE_AXIS, dtypes.int, arg=(BinaryOps.ADD, (1,)), src=( ast_const(dtypes.int, -1, st_src=( UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1001, 1999), strides=(0, 0), offset=0, mask=((0, 1001), (999, 1999)), contiguous=False), View(shape=(1000, 1000), strides=(1, 2000), offset=0, mask=None, contiguous=False))), src=()),)),)), @@ -107,11 +107,11 @@ class TestLinearizerDumb(unittest.TestCase): UOp(Ops.CAST, dtypes.half, arg=None, src=( UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (1,)), src=( UOp(Ops.CAST, dtypes.float, arg=None, src=( - UOp(Ops.ALU, dtypes.half, arg=BinaryOps.MUL, src=( + UOp(Ops.MUL, dtypes.half, arg=None, src=( UOp(Ops.CAST, dtypes.half, arg=None, src=( - UOp(Ops.ALU, dtypes.bool, arg=BinaryOps.CMPNE, src=( - UOp(Ops.ALU, dtypes.bool, arg=BinaryOps.CMPNE, src=( - UOp(Ops.ALU, dtypes.int, arg=BinaryOps.ADD, src=( + UOp(Ops.CMPNE, dtypes.bool, arg=None, src=( + UOp(Ops.CMPNE, dtypes.bool, arg=None, src=( + UOp(Ops.ADD, dtypes.int, arg=None, src=( UOp(Ops.REDUCE_AXIS, dtypes.int, arg=(BinaryOps.ADD, (2,)), src=( ast_const(dtypes.int, 1, st_src=( UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(32001, 63999), strides=(0, 0), offset=0, mask=((0, 32001), (31999, 63999)), contiguous=False), View(shape=(4096, 32000, 32000), strides=(0, 1, 64000), offset=0, mask=None, contiguous=False))), src=()),)),)), @@ -137,10 +137,10 @@ class TestLinearizerDumb(unittest.TestCase): UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=0, src=()), UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(3, 1, 1), strides=(1, 0, 0), offset=0, mask=None, contiguous=True),)), src=()), UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (2,)), src=( - UOp(Ops.ALU, dtypes.float, arg=BinaryOps.MUL, src=( + UOp(Ops.MUL, dtypes.float, arg=None, src=( UOp(Ops.CAST, dtypes.float, arg=None, src=( - UOp(Ops.ALU, dtypes.bool, arg=BinaryOps.CMPNE, src=( - UOp(Ops.ALU, dtypes.bool, arg=BinaryOps.CMPNE, src=( + UOp(Ops.CMPNE, dtypes.bool, arg=None, src=( + UOp(Ops.CMPNE, dtypes.bool, arg=None, src=( UOp(Ops.LOAD, dtypes.long, arg=None, src=( UOp(Ops.DEFINE_GLOBAL, dtypes.long.ptr(), arg=1, src=()), UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(3, 1, 5), strides=(1, 0, 0), offset=0, mask=None, contiguous=False),)), src=()),)), @@ -169,9 +169,9 @@ class TestLinearizerDumb(unittest.TestCase): UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=0, src=()), UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 1), strides=(0, 0), offset=0, mask=None, contiguous=True),)), src=()), UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (0, 1)), src=( - UOp(Ops.ALU, dtypes.float, arg=TernaryOps.WHERE, src=( - UOp(Ops.ALU, dtypes.bool, arg=BinaryOps.CMPNE, src=( - UOp(Ops.ALU, dtypes.bool, arg=BinaryOps.CMPNE, src=( + UOp(Ops.WHERE, dtypes.float, arg=None, src=( + UOp(Ops.CMPNE, dtypes.bool, arg=None, src=( + UOp(Ops.CMPNE, dtypes.bool, arg=None, src=( UOp(Ops.LOAD, dtypes.long, arg=None, src=( UOp(Ops.DEFINE_GLOBAL, dtypes.long.ptr(), arg=1, src=()), UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(3, 6), strides=(6, 1), offset=0, mask=None, contiguous=True),)), src=()),)), @@ -201,7 +201,7 @@ class TestLinearizerDumb(unittest.TestCase): UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=0, src=()), UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(4, 5, 13, 1, 1, 1, 1, 1, 4, 3, 3), strides=(2340, 468, 36, 0, 0, 0, 0, 0, 9, 3, 1), offset=0, mask=None, contiguous=True),)), src=()), UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (6,)), src=( - UOp(Ops.ALU, dtypes.float, arg=BinaryOps.MUL, src=( + UOp(Ops.MUL, dtypes.float, arg=None, src=( UOp(Ops.LOAD, dtypes.float, arg=None, src=( UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=1, src=()), UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(4, 5, 13, 1, 1, 1, 4, 1, 4, 3, 3), strides=(0, 0, 0, 0, 0, 0, 1, 0, 4, 48, 16), offset=0, mask=None, contiguous=False),)), src=()),)), diff --git a/test/test_linearizer_failures.py b/test/test_linearizer_failures.py index b3754aa7ad..081994a904 100644 --- a/test/test_linearizer_failures.py +++ b/test/test_linearizer_failures.py @@ -2,7 +2,7 @@ import unittest, random import numpy as np from tinygrad.codegen.kernel import Kernel, KernelOptError -from tinygrad.ops import UOp, Ops, BinaryOps, UnaryOps, TernaryOps +from tinygrad.ops import UOp, Ops, BinaryOps from tinygrad.engine.search import Opt, OptOps from tinygrad import Device, dtypes, Tensor from tinygrad.helpers import CI @@ -44,8 +44,8 @@ class TestLinearizerFailures(unittest.TestCase): UOp(Ops.STORE, dtypes.void, arg=None, src=( UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=0, src=()), UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(32, 16, 1), strides=(16, 1, 0), offset=0, mask=None, contiguous=True),)), src=()), - UOp(Ops.ALU, dtypes.float, arg=BinaryOps.ADD, src=( - UOp(Ops.ALU, dtypes.float, arg=BinaryOps.ADD, src=( + UOp(Ops.ADD, dtypes.float, arg=None, src=( + UOp(Ops.ADD, dtypes.float, arg=None, src=( UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (2,)), src=( UOp(Ops.LOAD, dtypes.float, arg=None, src=( UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=1, src=()), @@ -89,9 +89,9 @@ class TestLinearizerFailures(unittest.TestCase): UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=0, src=()), UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 1, 1, 1, 1, 1, 1, 1), strides=(0, 0, 0, 0, 0, 0, 0, 0), offset=0, mask=None, contiguous=True),)), src=()), UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (0, 2, 4, 6)), src=( - UOp(Ops.ALU, dtypes.float, arg=BinaryOps.ADD, src=( - x5:=UOp(Ops.ALU, dtypes.float, arg=BinaryOps.MUL, src=( - UOp(Ops.ALU, dtypes.float, arg=BinaryOps.ADD, src=( + UOp(Ops.ADD, dtypes.float, arg=None, src=( + x5:=UOp(Ops.MUL, dtypes.float, arg=None, src=( + UOp(Ops.ADD, dtypes.float, arg=None, src=( ast_const(dtypes.float, 0.1464405059814453, st_src=( UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(2, 1, 4, 1, 3, 1, 4, 1), strides=(0, 0, 0, 0, 0, 0, 0, 0), offset=0, mask=None, contiguous=False),)), src=()),)), ast_const(dtypes.float, 1.0, st_src=( @@ -109,7 +109,7 @@ class TestLinearizerFailures(unittest.TestCase): UOp(Ops.STORE, dtypes.void, arg=None, src=( UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(), arg=0, src=()), UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(10, 1), strides=(1, 0), offset=0, mask=None, contiguous=True),)), src=()), - UOp(Ops.ALU, dtypes.int, arg=BinaryOps.ADD, src=( + UOp(Ops.ADD, dtypes.int, arg=None, src=( UOp(Ops.REDUCE_AXIS, dtypes.int, arg=(BinaryOps.ADD, (1,)), src=( ast_const(dtypes.int, -1, st_src=( UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(11, 19), strides=(0, 0), offset=0, mask=((0, 11), (9, 19)), contiguous=False), View(shape=(10, 10), strides=(1, 20), offset=0, mask=None, contiguous=False))), src=()),)),)), @@ -137,13 +137,13 @@ class TestLinearizerFailures(unittest.TestCase): UOp(Ops.STORE, dtypes.void, arg=None, src=( UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=0, src=()), UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 1, 1), strides=(0, 0, 0), offset=0, mask=None, contiguous=True),)), src=()), - UOp(Ops.ALU, dtypes.float, arg=UnaryOps.SQRT, src=( - UOp(Ops.ALU, dtypes.float, arg=UnaryOps.RECIP, src=( - UOp(Ops.ALU, dtypes.float, arg=BinaryOps.ADD, src=( - UOp(Ops.ALU, dtypes.float, arg=BinaryOps.MUL, src=( + UOp(Ops.SQRT, dtypes.float, arg=None, src=( + UOp(Ops.RECIP, dtypes.float, arg=None, src=( + UOp(Ops.ADD, dtypes.float, arg=None, src=( + UOp(Ops.MUL, dtypes.float, arg=None, src=( UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (2,)), src=( - UOp(Ops.ALU, dtypes.float, arg=BinaryOps.MUL, src=( - x9:=UOp(Ops.ALU, dtypes.float, arg=BinaryOps.ADD, src=( + UOp(Ops.MUL, dtypes.float, arg=None, src=( + x9:=UOp(Ops.ADD, dtypes.float, arg=None, src=( UOp(Ops.LOAD, dtypes.float, arg=None, src=( UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=1, src=()), UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 1, 4096), strides=(0, 0, 1), offset=0, mask=None, contiguous=True),)), src=()),)), @@ -166,7 +166,7 @@ class TestLinearizerFailures(unittest.TestCase): UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=0, src=()), UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 1, 1, 3, 1, 1, 1, 1, 5, 15, 5, 3, 4), strides=(0, 0, 0, 4500, 0, 0, 0, 0, 900, 60, 12, 4, 1), offset=0, mask=None, contiguous=True),)), src=()), UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (1,)), src=( - UOp(Ops.ALU, dtypes.float, arg=BinaryOps.MUL, src=( + UOp(Ops.MUL, dtypes.float, arg=None, src=( UOp(Ops.LOAD, dtypes.float, arg=None, src=( UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=1, src=()), UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 2, 1, 3, 1, 1, 1, 1, 5, 15, 5, 3, 4), strides=(0, 3, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0), offset=0, mask=None, contiguous=False),)), src=()),)), @@ -181,9 +181,9 @@ class TestLinearizerFailures(unittest.TestCase): UOp(Ops.STORE, dtypes.void, arg=None, src=( UOp(Ops.DEFINE_GLOBAL, dtypes.half.ptr(), arg=0, src=()), UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 1, 1024, 1), strides=(0, 0, 1, 0), offset=0, mask=None, contiguous=True),)), src=()), - UOp(Ops.ALU, dtypes.half, arg=BinaryOps.ADD, src=( + UOp(Ops.ADD, dtypes.half, arg=None, src=( UOp(Ops.REDUCE_AXIS, dtypes.half, arg=(BinaryOps.ADD, (3,)), src=( - UOp(Ops.ALU, dtypes.half, arg=BinaryOps.MUL, src=( + UOp(Ops.MUL, dtypes.half, arg=None, src=( UOp(Ops.LOAD, dtypes.half, arg=None, src=( UOp(Ops.DEFINE_GLOBAL, dtypes.half.ptr(), arg=1, src=()), UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 1, 1024, 50257), strides=(0, 0, 0, 1), offset=0, mask=None, contiguous=False),)), src=()),)), @@ -200,13 +200,13 @@ class TestLinearizerFailures(unittest.TestCase): UOp(Ops.STORE, dtypes.void, arg=None, src=( UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=0, src=()), UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 64, 1, 1), strides=(0, 1, 0, 0), offset=0, mask=None, contiguous=True),)), src=()), - UOp(Ops.ALU, dtypes.float, arg=UnaryOps.RECIP, src=( + UOp(Ops.RECIP, dtypes.float, arg=None, src=( UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (0, 2, 3)), src=( - UOp(Ops.ALU, dtypes.float, arg=BinaryOps.MUL, src=( - UOp(Ops.ALU, dtypes.float, arg=BinaryOps.MUL, src=( - UOp(Ops.ALU, dtypes.float, arg=BinaryOps.ADD, src=( - UOp(Ops.ALU, dtypes.float, arg=BinaryOps.MAX, src=( - UOp(Ops.ALU, dtypes.float, arg=BinaryOps.ADD, src=( + UOp(Ops.MUL, dtypes.float, arg=None, src=( + UOp(Ops.MUL, dtypes.float, arg=None, src=( + UOp(Ops.ADD, dtypes.float, arg=None, src=( + UOp(Ops.MAX, dtypes.float, arg=None, src=( + UOp(Ops.ADD, dtypes.float, arg=None, src=( UOp(Ops.LOAD, dtypes.float, arg=None, src=( UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=1, src=()), UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(512, 64, 6, 6), strides=(2304, 36, 6, 1), offset=0, mask=None, contiguous=True),)), src=()),)), @@ -220,19 +220,19 @@ class TestLinearizerFailures(unittest.TestCase): UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(512, 64, 6, 6), strides=(0, 1, 0, 0), offset=0, mask=None, contiguous=False),)), src=()),)),)), ast_const(dtypes.float, 1.0, st_src=( UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(512, 64, 6, 6), strides=(0, 0, 0, 0), offset=0, mask=None, contiguous=False),)), src=()),)),)), - UOp(Ops.ALU, dtypes.float, arg=BinaryOps.MUL, src=( - UOp(Ops.ALU, dtypes.float, arg=BinaryOps.MUL, src=( - UOp(Ops.ALU, dtypes.float, arg=BinaryOps.ADD, src=( + UOp(Ops.MUL, dtypes.float, arg=None, src=( + UOp(Ops.MUL, dtypes.float, arg=None, src=( + UOp(Ops.ADD, dtypes.float, arg=None, src=( ast_const(dtypes.float, 1.0, st_src=( UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(512, 64, 3, 3, 2, 2), strides=(0, 0, 0, 0, 0, 0), offset=0, mask=None, contiguous=False), View(shape=(512, 64, 3, 2, 3, 2), strides=(2304, 36, 12, 2, 4, 1), offset=0, mask=None, contiguous=False), View(shape=(512, 64, 6, 6), strides=(2304, 36, 6, 1), offset=0, mask=None, contiguous=True))), src=()),)), UOp(Ops.CAST, dtypes.float, arg=None, src=( - UOp(Ops.ALU, dtypes.bool, arg=BinaryOps.CMPLT, src=( - UOp(Ops.ALU, dtypes.float, arg=BinaryOps.ADD, src=( - UOp(Ops.ALU, dtypes.float, arg=BinaryOps.MUL, src=( - UOp(Ops.ALU, dtypes.float, arg=BinaryOps.MUL, src=( - UOp(Ops.ALU, dtypes.float, arg=BinaryOps.ADD, src=( - UOp(Ops.ALU, dtypes.float, arg=BinaryOps.MAX, src=( - UOp(Ops.ALU, dtypes.float, arg=BinaryOps.ADD, src=( + UOp(Ops.CMPLT, dtypes.bool, arg=None, src=( + UOp(Ops.ADD, dtypes.float, arg=None, src=( + UOp(Ops.MUL, dtypes.float, arg=None, src=( + UOp(Ops.MUL, dtypes.float, arg=None, src=( + UOp(Ops.ADD, dtypes.float, arg=None, src=( + UOp(Ops.MAX, dtypes.float, arg=None, src=( + UOp(Ops.ADD, dtypes.float, arg=None, src=( UOp(Ops.LOAD, dtypes.float, arg=None, src=( UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=1, src=()), UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(512, 64, 6, 6), strides=(2304, 36, 6, 1), offset=0, mask=None, contiguous=True), View(shape=(512, 64, 3, 3, 2, 2), strides=(2304, 36, 12, 2, 6, 1), offset=0, mask=None, contiguous=False), View(shape=(512, 64, 3, 2, 3, 2), strides=(2304, 36, 12, 2, 4, 1), offset=0, mask=None, contiguous=False), View(shape=(512, 64, 6, 6), strides=(2304, 36, 6, 1), offset=0, mask=None, contiguous=True))), src=()),)), @@ -246,11 +246,11 @@ class TestLinearizerFailures(unittest.TestCase): UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(512, 64, 6, 6), strides=(0, 1, 0, 0), offset=0, mask=None, contiguous=False), View(shape=(512, 64, 3, 3, 2, 2), strides=(2304, 36, 12, 2, 6, 1), offset=0, mask=None, contiguous=False), View(shape=(512, 64, 3, 2, 3, 2), strides=(2304, 36, 12, 2, 4, 1), offset=0, mask=None, contiguous=False), View(shape=(512, 64, 6, 6), strides=(2304, 36, 6, 1), offset=0, mask=None, contiguous=True))), src=()),)),)), ast_const(dtypes.float, 1.0, st_src=( UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(512, 64, 6, 6), strides=(0, 0, 0, 0), offset=0, mask=None, contiguous=False), View(shape=(512, 64, 3, 3, 2, 2), strides=(2304, 36, 12, 2, 6, 1), offset=0, mask=None, contiguous=False), View(shape=(512, 64, 3, 2, 3, 2), strides=(2304, 36, 12, 2, 4, 1), offset=0, mask=None, contiguous=False), View(shape=(512, 64, 6, 6), strides=(2304, 36, 6, 1), offset=0, mask=None, contiguous=True))), src=()),)),)), - UOp(Ops.ALU, dtypes.float, arg=UnaryOps.SQRT, src=( + UOp(Ops.SQRT, dtypes.float, arg=None, src=( UOp(Ops.CAST, dtypes.float, arg=None, src=( - UOp(Ops.ALU, dtypes.float, arg=UnaryOps.RECIP, src=( - UOp(Ops.ALU, dtypes.float, arg=BinaryOps.ADD, src=( - UOp(Ops.ALU, dtypes.float, arg=BinaryOps.MUL, src=( + UOp(Ops.RECIP, dtypes.float, arg=None, src=( + UOp(Ops.ADD, dtypes.float, arg=None, src=( + UOp(Ops.MUL, dtypes.float, arg=None, src=( UOp(Ops.LOAD, dtypes.float, arg=None, src=( UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=4, src=()), UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(64,), strides=(1,), offset=0, mask=None, contiguous=True), View(shape=(512, 64, 6, 6), strides=(0, 1, 0, 0), offset=0, mask=None, contiguous=False), View(shape=(512, 64, 3, 3, 2, 2), strides=(2304, 36, 12, 2, 6, 1), offset=0, mask=None, contiguous=False), View(shape=(512, 64, 3, 2, 3, 2), strides=(2304, 36, 12, 2, 4, 1), offset=0, mask=None, contiguous=False), View(shape=(512, 64, 6, 6), strides=(2304, 36, 6, 1), offset=0, mask=None, contiguous=True))), src=()),)), @@ -262,7 +262,7 @@ class TestLinearizerFailures(unittest.TestCase): UOp(Ops.LOAD, dtypes.float, arg=None, src=( UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=5, src=()), UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(512, 64, 3, 3, 2, 2), strides=(576, 9, 3, 1, 0, 0), offset=0, mask=None, contiguous=False), View(shape=(512, 64, 3, 2, 3, 2), strides=(2304, 36, 12, 2, 4, 1), offset=0, mask=None, contiguous=False), View(shape=(512, 64, 6, 6), strides=(2304, 36, 6, 1), offset=0, mask=None, contiguous=True))), src=()),)),)),)),)), - UOp(Ops.ALU, dtypes.float, arg=UnaryOps.RECIP, src=( + UOp(Ops.RECIP, dtypes.float, arg=None, src=( UOp(Ops.LOAD, dtypes.float, arg=None, src=( UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=6, src=()), UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(512, 64, 3, 3, 2, 2), strides=(576, 9, 3, 1, 0, 0), offset=0, mask=None, contiguous=False), View(shape=(512, 64, 3, 2, 3, 2), strides=(2304, 36, 12, 2, 4, 1), offset=0, mask=None, contiguous=False), View(shape=(512, 64, 6, 6), strides=(2304, 36, 6, 1), offset=0, mask=None, contiguous=True))), src=()),)),)),)), @@ -277,9 +277,9 @@ class TestLinearizerFailures(unittest.TestCase): UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=0, src=()), UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 1, 1, 1, 1, 4, 1, 6, 1, 3), strides=(0, 0, 0, 0, 0, 18, 0, 3, 0, 1), offset=0, mask=None, contiguous=True),)), src=()), UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (0, 2, 4, 6)), src=( - UOp(Ops.ALU, dtypes.float, arg=BinaryOps.ADD, src=( - x5:=UOp(Ops.ALU, dtypes.float, arg=BinaryOps.MUL, src=( - UOp(Ops.ALU, dtypes.float, arg=BinaryOps.ADD, src=( + UOp(Ops.ADD, dtypes.float, arg=None, src=( + x5:=UOp(Ops.MUL, dtypes.float, arg=None, src=( + UOp(Ops.ADD, dtypes.float, arg=None, src=( UOp(Ops.LOAD, dtypes.float, arg=None, src=( UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=1, src=()), UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(2, 1, 4, 1, 3, 4, 2, 6, 1, 3), strides=(0, 0, 0, 0, 0, 18, 0, 3, 0, 1), offset=0, mask=None, contiguous=False),)), src=()),)), @@ -299,10 +299,10 @@ class TestLinearizerFailures(unittest.TestCase): UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=0, src=()), UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 1, 1, 1, 1, 4, 1, 6, 1, 3), strides=(0, 0, 0, 0, 0, 18, 0, 3, 0, 1), offset=0, mask=None, contiguous=True),)), src=()), UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (0, 2, 4, 8)), src=( - UOp(Ops.ALU, dtypes.float, arg=BinaryOps.ADD, src=( - x5:=UOp(Ops.ALU, dtypes.float, arg=BinaryOps.ADD, src=( - x6:=UOp(Ops.ALU, dtypes.float, arg=BinaryOps.MUL, src=( - UOp(Ops.ALU, dtypes.float, arg=BinaryOps.ADD, src=( + UOp(Ops.ADD, dtypes.float, arg=None, src=( + x5:=UOp(Ops.ADD, dtypes.float, arg=None, src=( + x6:=UOp(Ops.MUL, dtypes.float, arg=None, src=( + UOp(Ops.ADD, dtypes.float, arg=None, src=( UOp(Ops.LOAD, dtypes.float, arg=None, src=( UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=1, src=()), UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(2, 1, 4, 1, 3, 4, 2, 6, 1, 3), strides=(0, 0, 0, 0, 0, 18, 0, 3, 0, 1), offset=0, mask=None, contiguous=False),)), src=()),)), @@ -323,9 +323,9 @@ class TestLinearizerFailures(unittest.TestCase): UOp(Ops.STORE, dtypes.void, arg=None, src=( UOp(Ops.DEFINE_GLOBAL, dtypes.half.ptr(), arg=0, src=()), UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(2, 1, 384, 1), strides=(384, 0, 1, 0), offset=0, mask=None, contiguous=True),)), src=()), - UOp(Ops.ALU, dtypes.half, arg=BinaryOps.ADD, src=( + UOp(Ops.ADD, dtypes.half, arg=None, src=( UOp(Ops.REDUCE_AXIS, dtypes.half, arg=(BinaryOps.ADD, (3,)), src=( - UOp(Ops.ALU, dtypes.half, arg=BinaryOps.MUL, src=( + UOp(Ops.MUL, dtypes.half, arg=None, src=( UOp(Ops.LOAD, dtypes.half, arg=None, src=( UOp(Ops.DEFINE_GLOBAL, dtypes.half.ptr(), arg=1, src=()), UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(2, 1, 384, 51864), strides=(51864, 0, 0, 1), offset=0, mask=None, contiguous=False),)), src=()),)), @@ -344,9 +344,9 @@ class TestLinearizerFailures(unittest.TestCase): UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=0, src=()), UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 1, 1, 1, 1, 4, 1, 6, 1, 3), strides=(0, 0, 0, 0, 0, 18, 0, 3, 0, 1), offset=0, mask=None, contiguous=True),)), src=()), UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (0, 2, 4, 6)), src=( - UOp(Ops.ALU, dtypes.float, arg=BinaryOps.ADD, src=( - x5:=UOp(Ops.ALU, dtypes.float, arg=BinaryOps.MUL, src=( - UOp(Ops.ALU, dtypes.float, arg=BinaryOps.ADD, src=( + UOp(Ops.ADD, dtypes.float, arg=None, src=( + x5:=UOp(Ops.MUL, dtypes.float, arg=None, src=( + UOp(Ops.ADD, dtypes.float, arg=None, src=( UOp(Ops.LOAD, dtypes.float, arg=None, src=( UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=1, src=()), UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(2, 1, 4, 1, 3, 4, 2, 6, 1, 3), strides=(0, 0, 0, 0, 0, 18, 0, 3, 0, 1), offset=0, mask=None, contiguous=False),)), src=()),)), @@ -365,12 +365,12 @@ class TestLinearizerFailures(unittest.TestCase): UOp(Ops.STORE, dtypes.void, arg=None, src=( UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=0, src=()), UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 1, 112, 14, 14, 1, 1, 1), strides=(0, 0, 196, 14, 1, 0, 0, 0), offset=0, mask=None, contiguous=True),)), src=()), - UOp(Ops.ALU, dtypes.float, arg=BinaryOps.ADD, src=( - UOp(Ops.ALU, dtypes.float, arg=BinaryOps.MUL, src=( - UOp(Ops.ALU, dtypes.float, arg=BinaryOps.MUL, src=( - UOp(Ops.ALU, dtypes.float, arg=BinaryOps.ADD, src=( + UOp(Ops.ADD, dtypes.float, arg=None, src=( + UOp(Ops.MUL, dtypes.float, arg=None, src=( + UOp(Ops.MUL, dtypes.float, arg=None, src=( + UOp(Ops.ADD, dtypes.float, arg=None, src=( UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (5,)), src=( - UOp(Ops.ALU, dtypes.float, arg=BinaryOps.MUL, src=( + UOp(Ops.MUL, dtypes.float, arg=None, src=( UOp(Ops.LOAD, dtypes.float, arg=None, src=( UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=1, src=()), UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 1, 112, 14, 14, 480, 1, 1), strides=(0, 0, 0, 14, 1, 196, 0, 0), offset=0, mask=None, contiguous=False),)), src=()),)), @@ -383,9 +383,9 @@ class TestLinearizerFailures(unittest.TestCase): UOp(Ops.LOAD, dtypes.float, arg=None, src=( UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=4, src=()), UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 1, 112, 14, 14, 1, 1, 1), strides=(0, 0, 1, 0, 0, 0, 0, 0), offset=0, mask=None, contiguous=False),)), src=()),)),)), - UOp(Ops.ALU, dtypes.float, arg=UnaryOps.SQRT, src=( - UOp(Ops.ALU, dtypes.float, arg=UnaryOps.RECIP, src=( - UOp(Ops.ALU, dtypes.float, arg=BinaryOps.ADD, src=( + UOp(Ops.SQRT, dtypes.float, arg=None, src=( + UOp(Ops.RECIP, dtypes.float, arg=None, src=( + UOp(Ops.ADD, dtypes.float, arg=None, src=( UOp(Ops.LOAD, dtypes.float, arg=None, src=( UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=5, src=()), UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 1, 112, 14, 14, 1, 1, 1), strides=(0, 0, 1, 0, 0, 0, 0, 0), offset=0, mask=None, contiguous=False),)), src=()),)), @@ -403,7 +403,7 @@ class TestLinearizerFailures(unittest.TestCase): UOp(Ops.STORE, dtypes.void, arg=None, src=( UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=0, src=()), UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 13, 1), strides=(0, 1, 0), offset=0, mask=None, contiguous=True),)), src=()), - UOp(Ops.ALU, dtypes.float, arg=BinaryOps.MUL, src=( + UOp(Ops.MUL, dtypes.float, arg=None, src=( UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (2,)), src=( UOp(Ops.LOAD, dtypes.float, arg=None, src=( UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=1, src=()), @@ -420,7 +420,7 @@ class TestLinearizerFailures(unittest.TestCase): UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=0, src=()), UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(2, 1, 40, 1, 28, 28, 1, 1), strides=(31360, 0, 784, 0, 28, 1, 0, 0), offset=0, mask=None, contiguous=True),)), src=()), UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (3,)), src=( - UOp(Ops.ALU, dtypes.float, arg=BinaryOps.MUL, src=( + UOp(Ops.MUL, dtypes.float, arg=None, src=( UOp(Ops.LOAD, dtypes.float, arg=None, src=( UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=1, src=()), UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(2, 1, 40, 240, 28, 28, 1, 1), strides=(0, 0, 1, 40, 0, 0, 0, 0), offset=0, mask=None, contiguous=False),)), src=()),)), @@ -436,13 +436,13 @@ class TestLinearizerFailures(unittest.TestCase): UOp(Ops.STORE, dtypes.void, arg=None, src=( UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=0, src=()), UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(2, 1, 384, 1), strides=(384, 0, 1, 0), offset=0, mask=None, contiguous=True),)), src=()), - UOp(Ops.ALU, dtypes.float, arg=BinaryOps.ADD, src=( + UOp(Ops.ADD, dtypes.float, arg=None, src=( UOp(Ops.LOAD, dtypes.float, arg=None, src=( UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=1, src=()), UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(2, 1, 384, 1), strides=(384, 0, 1, 0), offset=0, mask=None, contiguous=True),)), src=()),)), - UOp(Ops.ALU, dtypes.float, arg=BinaryOps.ADD, src=( + UOp(Ops.ADD, dtypes.float, arg=None, src=( UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (3,)), src=( - UOp(Ops.ALU, dtypes.float, arg=BinaryOps.MUL, src=( + UOp(Ops.MUL, dtypes.float, arg=None, src=( UOp(Ops.LOAD, dtypes.float, arg=None, src=( UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=2, src=()), UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(2, 1, 384, 1536), strides=(1536, 0, 0, 1), offset=0, mask=None, contiguous=False),)), src=()),)), @@ -462,7 +462,7 @@ class TestLinearizerFailures(unittest.TestCase): UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=0, src=()), UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(2, 1, 4, 1, 9, 7, 3, 3), strides=(2268, 0, 567, 0, 63, 9, 3, 1), offset=0, mask=None, contiguous=True),)), src=()), UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (3,)), src=( - UOp(Ops.ALU, dtypes.float, arg=BinaryOps.MUL, src=( + UOp(Ops.MUL, dtypes.float, arg=None, src=( UOp(Ops.LOAD, dtypes.float, arg=None, src=( UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=1, src=()), UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(2, 1, 4, 4, 9, 7, 3, 3), strides=(0, 0, 36, 9, 0, 0, -3, -1), offset=8, mask=None, contiguous=False),)), src=()),)), @@ -478,7 +478,7 @@ class TestLinearizerFailures(unittest.TestCase): UOp(Ops.STORE, dtypes.void, arg=None, src=( UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=0, src=()), UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(4, 4), strides=(4, 1), offset=0, mask=None, contiguous=True),)), src=()), - UOp(Ops.ALU, dtypes.float, arg=BinaryOps.MUL, src=( + UOp(Ops.MUL, dtypes.float, arg=None, src=( UOp(Ops.LOAD, dtypes.float, arg=None, src=( UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=1, src=()), UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(4, 4), strides=(0, 1), offset=0, mask=None, contiguous=False),)), src=()),)), @@ -504,20 +504,20 @@ class TestLinearizerFailures(unittest.TestCase): UOp(Ops.STORE, dtypes.void, arg=None, src=( UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=0, src=()), UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 96, 1, 1), strides=(0, 1, 0, 0), offset=0, mask=None, contiguous=True),)), src=()), - UOp(Ops.ALU, dtypes.float, arg=BinaryOps.MUL, src=( + UOp(Ops.MUL, dtypes.float, arg=None, src=( x4:=ast_const(dtypes.float, 0.000244140625, st_src=( UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 96, 1, 1), strides=(0, 0, 0, 0), offset=0, mask=None, contiguous=False),)), src=()),)), - UOp(Ops.ALU, dtypes.float, arg=BinaryOps.MUL, src=( - UOp(Ops.ALU, dtypes.float, arg=BinaryOps.MUL, src=( - UOp(Ops.ALU, dtypes.float, arg=BinaryOps.MUL, src=( + UOp(Ops.MUL, dtypes.float, arg=None, src=( + UOp(Ops.MUL, dtypes.float, arg=None, src=( + UOp(Ops.MUL, dtypes.float, arg=None, src=( UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (0, 2, 3)), src=( - UOp(Ops.ALU, dtypes.float, arg=BinaryOps.MUL, src=( - UOp(Ops.ALU, dtypes.float, arg=BinaryOps.MUL, src=( - UOp(Ops.ALU, dtypes.float, arg=BinaryOps.ADD, src=( - UOp(Ops.ALU, dtypes.float, arg=BinaryOps.ADD, src=( - UOp(Ops.ALU, dtypes.float, arg=BinaryOps.MUL, src=( - UOp(Ops.ALU, dtypes.float, arg=BinaryOps.MUL, src=( - UOp(Ops.ALU, dtypes.float, arg=BinaryOps.ADD, src=( + UOp(Ops.MUL, dtypes.float, arg=None, src=( + UOp(Ops.MUL, dtypes.float, arg=None, src=( + UOp(Ops.ADD, dtypes.float, arg=None, src=( + UOp(Ops.ADD, dtypes.float, arg=None, src=( + UOp(Ops.MUL, dtypes.float, arg=None, src=( + UOp(Ops.MUL, dtypes.float, arg=None, src=( + UOp(Ops.ADD, dtypes.float, arg=None, src=( UOp(Ops.LOAD, dtypes.float, arg=None, src=( UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=1, src=()), UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(32, 96, 8, 16), strides=(12288, 128, 16, 1), offset=0, mask=None, contiguous=True),)), src=()),)), @@ -539,51 +539,51 @@ class TestLinearizerFailures(unittest.TestCase): UOp(Ops.LOAD, dtypes.float, arg=None, src=( UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=7, src=()), UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(32, 96, 8, 16), strides=(0, 1, 0, 0), offset=0, mask=None, contiguous=False),)), src=()),)),)), - UOp(Ops.ALU, dtypes.float, arg=BinaryOps.ADD, src=( + UOp(Ops.ADD, dtypes.float, arg=None, src=( UOp(Ops.LOAD, dtypes.float, arg=None, src=( UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=8, src=()), UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(2, 32, 48, 8, 16), strides=(0, 8640, 180, 18, 1), offset=19, mask=((1, 2), (0, 32), (0, 48), (0, 8), (0, 16)), contiguous=False), View(shape=(1536, 2, 128), strides=(128, 196608, 1), offset=0, mask=None, contiguous=False), View(shape=(32, 96, 8, 16), strides=(12288, 128, 16, 1), offset=0, mask=None, contiguous=True))), src=()),)), - UOp(Ops.ALU, dtypes.float, arg=BinaryOps.ADD, src=( + UOp(Ops.ADD, dtypes.float, arg=None, src=( UOp(Ops.LOAD, dtypes.float, arg=None, src=( UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=9, src=()), UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(2, 32, 48, 8, 16), strides=(0, 8640, 180, 18, 1), offset=19, mask=((1, 2), (0, 32), (0, 48), (0, 8), (0, 16)), contiguous=False), View(shape=(1536, 2, 128), strides=(128, 196608, 1), offset=0, mask=None, contiguous=False), View(shape=(2, 32, 48, 8, 16), strides=(0, 12288, 128, 16, 1), offset=0, mask=((0, 1), (0, 32), (0, 48), (0, 8), (0, 16)), contiguous=False), View(shape=(1536, 2, 128), strides=(128, 196608, 1), offset=0, mask=None, contiguous=False), View(shape=(32, 96, 8, 16), strides=(12288, 128, 16, 1), offset=0, mask=None, contiguous=True))), src=()),)), - UOp(Ops.ALU, dtypes.float, arg=BinaryOps.ADD, src=( + UOp(Ops.ADD, dtypes.float, arg=None, src=( UOp(Ops.LOAD, dtypes.float, arg=None, src=( UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=10, src=()), UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(2, 32, 48, 8, 16), strides=(0, 8640, 180, 18, 1), offset=19, mask=((1, 2), (0, 32), (0, 48), (0, 8), (0, 16)), contiguous=False), View(shape=(1536, 2, 128), strides=(128, 196608, 1), offset=0, mask=None, contiguous=False), View(shape=(2, 32, 48, 8, 16), strides=(0, 12288, 128, 16, 1), offset=0, mask=((0, 1), (0, 32), (0, 48), (0, 8), (0, 16)), contiguous=False), View(shape=(1536, 2, 128), strides=(128, 196608, 1), offset=0, mask=None, contiguous=False), View(shape=(2, 32, 48, 8, 16), strides=(0, 12288, 128, 16, 1), offset=0, mask=((0, 1), (0, 32), (0, 48), (0, 8), (0, 16)), contiguous=False), View(shape=(1536, 2, 128), strides=(128, 196608, 1), offset=0, mask=None, contiguous=False), View(shape=(32, 96, 8, 16), strides=(12288, 128, 16, 1), offset=0, mask=None, contiguous=True))), src=()),)), - UOp(Ops.ALU, dtypes.float, arg=BinaryOps.ADD, src=( + UOp(Ops.ADD, dtypes.float, arg=None, src=( UOp(Ops.LOAD, dtypes.float, arg=None, src=( UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=11, src=()), UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(2, 32, 48, 8, 16), strides=(0, 8640, 180, 18, 1), offset=19, mask=((1, 2), (0, 32), (0, 48), (0, 8), (0, 16)), contiguous=False), View(shape=(1536, 2, 128), strides=(128, 196608, 1), offset=0, mask=None, contiguous=False), View(shape=(2, 32, 48, 8, 16), strides=(0, 12288, 128, 16, 1), offset=0, mask=((0, 1), (0, 32), (0, 48), (0, 8), (0, 16)), contiguous=False), View(shape=(1536, 2, 128), strides=(128, 196608, 1), offset=0, mask=None, contiguous=False), View(shape=(2, 32, 48, 8, 16), strides=(0, 12288, 128, 16, 1), offset=0, mask=((0, 1), (0, 32), (0, 48), (0, 8), (0, 16)), contiguous=False), View(shape=(1536, 2, 128), strides=(128, 196608, 1), offset=0, mask=None, contiguous=False), View(shape=(2, 32, 48, 8, 16), strides=(0, 12288, 128, 16, 1), offset=0, mask=((0, 1), (0, 32), (0, 48), (0, 8), (0, 16)), contiguous=False), View(shape=(1536, 2, 128), strides=(128, 196608, 1), offset=0, mask=None, contiguous=False), View(shape=(32, 96, 8, 16), strides=(12288, 128, 16, 1), offset=0, mask=None, contiguous=True))), src=()),)), - UOp(Ops.ALU, dtypes.float, arg=BinaryOps.ADD, src=( + UOp(Ops.ADD, dtypes.float, arg=None, src=( UOp(Ops.LOAD, dtypes.float, arg=None, src=( UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=12, src=()), UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(2, 32, 48, 8, 16), strides=(0, 8640, 180, 18, 1), offset=19, mask=((1, 2), (0, 32), (0, 48), (0, 8), (0, 16)), contiguous=False), View(shape=(1536, 2, 128), strides=(128, 196608, 1), offset=0, mask=None, contiguous=False), View(shape=(2, 32, 48, 8, 16), strides=(0, 12288, 128, 16, 1), offset=0, mask=((0, 1), (0, 32), (0, 48), (0, 8), (0, 16)), contiguous=False), View(shape=(1536, 2, 128), strides=(128, 196608, 1), offset=0, mask=None, contiguous=False), View(shape=(2, 32, 48, 8, 16), strides=(0, 12288, 128, 16, 1), offset=0, mask=((0, 1), (0, 32), (0, 48), (0, 8), (0, 16)), contiguous=False), View(shape=(1536, 2, 128), strides=(128, 196608, 1), offset=0, mask=None, contiguous=False), View(shape=(2, 32, 48, 8, 16), strides=(0, 12288, 128, 16, 1), offset=0, mask=((0, 1), (0, 32), (0, 48), (0, 8), (0, 16)), contiguous=False), View(shape=(1536, 2, 128), strides=(128, 196608, 1), offset=0, mask=None, contiguous=False), View(shape=(2, 32, 48, 8, 16), strides=(0, 12288, 128, 16, 1), offset=0, mask=((0, 1), (0, 32), (0, 48), (0, 8), (0, 16)), contiguous=False), View(shape=(1536, 2, 128), strides=(128, 196608, 1), offset=0, mask=None, contiguous=False), View(shape=(32, 96, 8, 16), strides=(12288, 128, 16, 1), offset=0, mask=None, contiguous=True))), src=()),)), - UOp(Ops.ALU, dtypes.float, arg=BinaryOps.ADD, src=( + UOp(Ops.ADD, dtypes.float, arg=None, src=( UOp(Ops.LOAD, dtypes.float, arg=None, src=( UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=13, src=()), UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(2, 32, 48, 8, 16), strides=(0, 8640, 180, 18, 1), offset=19, mask=((1, 2), (0, 32), (0, 48), (0, 8), (0, 16)), contiguous=False), View(shape=(1536, 2, 128), strides=(128, 196608, 1), offset=0, mask=None, contiguous=False), View(shape=(2, 32, 48, 8, 16), strides=(0, 12288, 128, 16, 1), offset=0, mask=((0, 1), (0, 32), (0, 48), (0, 8), (0, 16)), contiguous=False), View(shape=(1536, 2, 128), strides=(128, 196608, 1), offset=0, mask=None, contiguous=False), View(shape=(2, 32, 48, 8, 16), strides=(0, 12288, 128, 16, 1), offset=0, mask=((0, 1), (0, 32), (0, 48), (0, 8), (0, 16)), contiguous=False), View(shape=(1536, 2, 128), strides=(128, 196608, 1), offset=0, mask=None, contiguous=False), View(shape=(2, 32, 48, 8, 16), strides=(0, 12288, 128, 16, 1), offset=0, mask=((0, 1), (0, 32), (0, 48), (0, 8), (0, 16)), contiguous=False), View(shape=(1536, 2, 128), strides=(128, 196608, 1), offset=0, mask=None, contiguous=False), View(shape=(2, 32, 48, 8, 16), strides=(0, 12288, 128, 16, 1), offset=0, mask=((0, 1), (0, 32), (0, 48), (0, 8), (0, 16)), contiguous=False), View(shape=(1536, 2, 128), strides=(128, 196608, 1), offset=0, mask=None, contiguous=False), View(shape=(2, 32, 48, 8, 16), strides=(0, 12288, 128, 16, 1), offset=0, mask=((0, 1), (0, 32), (0, 48), (0, 8), (0, 16)), contiguous=False), View(shape=(1536, 2, 128), strides=(128, 196608, 1), offset=0, mask=None, contiguous=False), View(shape=(32, 96, 8, 16), strides=(12288, 128, 16, 1), offset=0, mask=None, contiguous=True))), src=()),)), - UOp(Ops.ALU, dtypes.float, arg=BinaryOps.ADD, src=( + UOp(Ops.ADD, dtypes.float, arg=None, src=( UOp(Ops.LOAD, dtypes.float, arg=None, src=( UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=14, src=()), UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(2, 32, 48, 8, 16), strides=(0, 8640, 180, 18, 1), offset=19, mask=((1, 2), (0, 32), (0, 48), (0, 8), (0, 16)), contiguous=False), View(shape=(1536, 2, 128), strides=(128, 196608, 1), offset=0, mask=None, contiguous=False), View(shape=(2, 32, 48, 8, 16), strides=(0, 12288, 128, 16, 1), offset=0, mask=((0, 1), (0, 32), (0, 48), (0, 8), (0, 16)), contiguous=False), View(shape=(1536, 2, 128), strides=(128, 196608, 1), offset=0, mask=None, contiguous=False), View(shape=(2, 32, 48, 8, 16), strides=(0, 12288, 128, 16, 1), offset=0, mask=((0, 1), (0, 32), (0, 48), (0, 8), (0, 16)), contiguous=False), View(shape=(1536, 2, 128), strides=(128, 196608, 1), offset=0, mask=None, contiguous=False), View(shape=(2, 32, 48, 8, 16), strides=(0, 12288, 128, 16, 1), offset=0, mask=((0, 1), (0, 32), (0, 48), (0, 8), (0, 16)), contiguous=False), View(shape=(1536, 2, 128), strides=(128, 196608, 1), offset=0, mask=None, contiguous=False), View(shape=(2, 32, 48, 8, 16), strides=(0, 12288, 128, 16, 1), offset=0, mask=((0, 1), (0, 32), (0, 48), (0, 8), (0, 16)), contiguous=False), View(shape=(1536, 2, 128), strides=(128, 196608, 1), offset=0, mask=None, contiguous=False), View(shape=(2, 32, 48, 8, 16), strides=(0, 12288, 128, 16, 1), offset=0, mask=((0, 1), (0, 32), (0, 48), (0, 8), (0, 16)), contiguous=False), View(shape=(1536, 2, 128), strides=(128, 196608, 1), offset=0, mask=None, contiguous=False), View(shape=(2, 32, 48, 8, 16), strides=(0, 12288, 128, 16, 1), offset=0, mask=((0, 1), (0, 32), (0, 48), (0, 8), (0, 16)), contiguous=False), View(shape=(1536, 2, 128), strides=(128, 196608, 1), offset=0, mask=None, contiguous=False), View(shape=(32, 96, 8, 16), strides=(12288, 128, 16, 1), offset=0, mask=None, contiguous=True))), src=()),)), - UOp(Ops.ALU, dtypes.float, arg=BinaryOps.ADD, src=( + UOp(Ops.ADD, dtypes.float, arg=None, src=( UOp(Ops.LOAD, dtypes.float, arg=None, src=( UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=15, src=()), UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(2, 32, 48, 8, 16), strides=(0, 8640, 180, 18, 1), offset=19, mask=((1, 2), (0, 32), (0, 48), (0, 8), (0, 16)), contiguous=False), View(shape=(1536, 2, 128), strides=(128, 196608, 1), offset=0, mask=None, contiguous=False), View(shape=(2, 32, 48, 8, 16), strides=(0, 12288, 128, 16, 1), offset=0, mask=((0, 1), (0, 32), (0, 48), (0, 8), (0, 16)), contiguous=False), View(shape=(1536, 2, 128), strides=(128, 196608, 1), offset=0, mask=None, contiguous=False), View(shape=(2, 32, 48, 8, 16), strides=(0, 12288, 128, 16, 1), offset=0, mask=((0, 1), (0, 32), (0, 48), (0, 8), (0, 16)), contiguous=False), View(shape=(1536, 2, 128), strides=(128, 196608, 1), offset=0, mask=None, contiguous=False), View(shape=(2, 32, 48, 8, 16), strides=(0, 12288, 128, 16, 1), offset=0, mask=((0, 1), (0, 32), (0, 48), (0, 8), (0, 16)), contiguous=False), View(shape=(1536, 2, 128), strides=(128, 196608, 1), offset=0, mask=None, contiguous=False), View(shape=(2, 32, 48, 8, 16), strides=(0, 12288, 128, 16, 1), offset=0, mask=((0, 1), (0, 32), (0, 48), (0, 8), (0, 16)), contiguous=False), View(shape=(1536, 2, 128), strides=(128, 196608, 1), offset=0, mask=None, contiguous=False), View(shape=(2, 32, 48, 8, 16), strides=(0, 12288, 128, 16, 1), offset=0, mask=((0, 1), (0, 32), (0, 48), (0, 8), (0, 16)), contiguous=False), View(shape=(1536, 2, 128), strides=(128, 196608, 1), offset=0, mask=None, contiguous=False), View(shape=(2, 32, 48, 8, 16), strides=(0, 12288, 128, 16, 1), offset=0, mask=((0, 1), (0, 32), (0, 48), (0, 8), (0, 16)), contiguous=False), View(shape=(1536, 2, 128), strides=(128, 196608, 1), offset=0, mask=None, contiguous=False), View(shape=(2, 32, 48, 8, 16), strides=(0, 12288, 128, 16, 1), offset=0, mask=((0, 1), (0, 32), (0, 48), (0, 8), (0, 16)), contiguous=False), View(shape=(1536, 2, 128), strides=(128, 196608, 1), offset=0, mask=None, contiguous=False), View(shape=(32, 96, 8, 16), strides=(12288, 128, 16, 1), offset=0, mask=None, contiguous=True))), src=()),)), UOp(Ops.LOAD, dtypes.float, arg=None, src=( UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=16, src=()), UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(2, 32, 48, 8, 16), strides=(0, 17280, 180, 18, 1), offset=19, mask=((0, 1), (0, 32), (0, 48), (0, 8), (0, 16)), contiguous=False), View(shape=(1536, 2, 128), strides=(128, 196608, 1), offset=0, mask=None, contiguous=False), View(shape=(2, 32, 48, 8, 16), strides=(0, 12288, 128, 16, 1), offset=0, mask=((0, 1), (0, 32), (0, 48), (0, 8), (0, 16)), contiguous=False), View(shape=(1536, 2, 128), strides=(128, 196608, 1), offset=0, mask=None, contiguous=False), View(shape=(2, 32, 48, 8, 16), strides=(0, 12288, 128, 16, 1), offset=0, mask=((0, 1), (0, 32), (0, 48), (0, 8), (0, 16)), contiguous=False), View(shape=(1536, 2, 128), strides=(128, 196608, 1), offset=0, mask=None, contiguous=False), View(shape=(2, 32, 48, 8, 16), strides=(0, 12288, 128, 16, 1), offset=0, mask=((0, 1), (0, 32), (0, 48), (0, 8), (0, 16)), contiguous=False), View(shape=(1536, 2, 128), strides=(128, 196608, 1), offset=0, mask=None, contiguous=False), View(shape=(2, 32, 48, 8, 16), strides=(0, 12288, 128, 16, 1), offset=0, mask=((0, 1), (0, 32), (0, 48), (0, 8), (0, 16)), contiguous=False), View(shape=(1536, 2, 128), strides=(128, 196608, 1), offset=0, mask=None, contiguous=False), View(shape=(2, 32, 48, 8, 16), strides=(0, 12288, 128, 16, 1), offset=0, mask=((0, 1), (0, 32), (0, 48), (0, 8), (0, 16)), contiguous=False), View(shape=(1536, 2, 128), strides=(128, 196608, 1), offset=0, mask=None, contiguous=False), View(shape=(2, 32, 48, 8, 16), strides=(0, 12288, 128, 16, 1), offset=0, mask=((0, 1), (0, 32), (0, 48), (0, 8), (0, 16)), contiguous=False), View(shape=(1536, 2, 128), strides=(128, 196608, 1), offset=0, mask=None, contiguous=False), View(shape=(2, 32, 48, 8, 16), strides=(0, 12288, 128, 16, 1), offset=0, mask=((0, 1), (0, 32), (0, 48), (0, 8), (0, 16)), contiguous=False), View(shape=(1536, 2, 128), strides=(128, 196608, 1), offset=0, mask=None, contiguous=False), View(shape=(32, 96, 8, 16), strides=(12288, 128, 16, 1), offset=0, mask=None, contiguous=True))), src=()),)),)),)),)),)),)),)),)),)),)),)), - UOp(Ops.ALU, dtypes.float, arg=UnaryOps.RECIP, src=( - UOp(Ops.ALU, dtypes.float, arg=BinaryOps.MUL, src=( + UOp(Ops.RECIP, dtypes.float, arg=None, src=( + UOp(Ops.MUL, dtypes.float, arg=None, src=( UOp(Ops.LOAD, dtypes.float, arg=None, src=( UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=17, src=()), UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 96, 1, 1), strides=(0, 1, 0, 0), offset=0, mask=None, contiguous=True),)), src=()),)), ast_const(dtypes.float, 2.0, st_src=( UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 96, 1, 1), strides=(0, 0, 0, 0), offset=0, mask=None, contiguous=False),)), src=()),)),)),)),)), - x80:=UOp(Ops.ALU, dtypes.float, arg=UnaryOps.RECIP, src=( - UOp(Ops.ALU, dtypes.float, arg=BinaryOps.ADD, src=( - UOp(Ops.ALU, dtypes.float, arg=BinaryOps.MUL, src=( + x80:=UOp(Ops.RECIP, dtypes.float, arg=None, src=( + UOp(Ops.ADD, dtypes.float, arg=None, src=( + UOp(Ops.MUL, dtypes.float, arg=None, src=( UOp(Ops.LOAD, dtypes.float, arg=None, src=( UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=18, src=()), UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 96, 1, 1), strides=(0, 1, 0, 0), offset=0, mask=None, contiguous=True),)), src=()),)), @@ -622,7 +622,7 @@ class TestLinearizerFailures(unittest.TestCase): UOp(Ops.STORE, dtypes.void, arg=None, src=( UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(), arg=0, src=()), UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1024, 1), strides=(1, 0), offset=0, mask=None, contiguous=True),)), src=()), - UOp(Ops.ALU, dtypes.int, arg=BinaryOps.ADD, src=( + UOp(Ops.ADD, dtypes.int, arg=None, src=( UOp(Ops.REDUCE_AXIS, dtypes.int, arg=(BinaryOps.ADD, (1,)), src=( ast_const(dtypes.int, 1, st_src=( UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1025, 2047), strides=(0, 0), offset=0, mask=((0, 1025), (1023, 2047)), contiguous=False), View(shape=(1024, 1024), strides=(1, 2048), offset=0, mask=None, contiguous=False))), src=()),)),)), @@ -637,7 +637,7 @@ class TestLinearizerFailures(unittest.TestCase): UOp(Ops.STORE, dtypes.void, arg=None, src=( UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(), arg=0, src=()), UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(128, 1), strides=(1, 0), offset=0, mask=None, contiguous=True),)), src=()), - UOp(Ops.ALU, dtypes.int, arg=BinaryOps.ADD, src=( + UOp(Ops.ADD, dtypes.int, arg=None, src=( UOp(Ops.REDUCE_AXIS, dtypes.int, arg=(BinaryOps.ADD, (1,)), src=( ast_const(dtypes.int, 1, st_src=( UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(129, 255), strides=(0, 0), offset=0, mask=((0, 129), (127, 255)), contiguous=False), View(shape=(128, 128), strides=(1, 256), offset=0, mask=None, contiguous=False))), src=()),)),)), @@ -692,17 +692,17 @@ class TestLinearizerFailures(unittest.TestCase): UOp(Ops.STORE, dtypes.void, arg=None, src=( UOp(Ops.DEFINE_GLOBAL, dtypes.bfloat16.ptr(), arg=0, src=()), UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1,), strides=(0,), offset=0, mask=None, contiguous=True),)), src=()), - UOp(Ops.ALU, dtypes.bfloat16, arg=TernaryOps.WHERE, src=( - UOp(Ops.ALU, dtypes.bool, arg=BinaryOps.CMPLT, src=( + UOp(Ops.WHERE, dtypes.bfloat16, arg=None, src=( + UOp(Ops.CMPLT, dtypes.bool, arg=None, src=( x5:=UOp(Ops.CAST, dtypes.bfloat16, arg=None, src=( UOp(Ops.LOAD, dtypes.int, arg=None, src=( UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(), arg=1, src=()), UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1,), strides=(0,), offset=0, mask=None, contiguous=True),)), src=()),)),)), x9:=ast_const(dtypes.bfloat16, 230.0, st_src=( UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1,), strides=(0,), offset=0, mask=None, contiguous=True),)), src=()),)),)), - UOp(Ops.ALU, dtypes.bfloat16, arg=BinaryOps.ADD, src=( - UOp(Ops.ALU, dtypes.bfloat16, arg=BinaryOps.MUL, src=( - UOp(Ops.ALU, dtypes.bfloat16, arg=BinaryOps.MUL, src=( + UOp(Ops.ADD, dtypes.bfloat16, arg=None, src=( + UOp(Ops.MUL, dtypes.bfloat16, arg=None, src=( + UOp(Ops.MUL, dtypes.bfloat16, arg=None, src=( x5, ast_const(dtypes.bfloat16, 0.004347826086956522, st_src=( UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1,), strides=(0,), offset=0, mask=None, contiguous=True),)), src=()),)),)), @@ -710,10 +710,10 @@ class TestLinearizerFailures(unittest.TestCase): UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1,), strides=(0,), offset=0, mask=None, contiguous=True),)), src=()),)),)), ast_const(dtypes.bfloat16, 1.99375e-07, st_src=( UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1,), strides=(0,), offset=0, mask=None, contiguous=True),)), src=()),)),)), - UOp(Ops.ALU, dtypes.bfloat16, arg=BinaryOps.ADD, src=( - UOp(Ops.ALU, dtypes.bfloat16, arg=BinaryOps.MUL, src=( - UOp(Ops.ALU, dtypes.bfloat16, arg=BinaryOps.MUL, src=( - UOp(Ops.ALU, dtypes.bfloat16, arg=BinaryOps.ADD, src=( + UOp(Ops.ADD, dtypes.bfloat16, arg=None, src=( + UOp(Ops.MUL, dtypes.bfloat16, arg=None, src=( + UOp(Ops.MUL, dtypes.bfloat16, arg=None, src=( + UOp(Ops.ADD, dtypes.bfloat16, arg=None, src=( x5, x9,)), ast_const(dtypes.bfloat16, 0.0012987012987012987, st_src=( @@ -732,7 +732,7 @@ class TestLinearizerFailures(unittest.TestCase): UOp(Ops.CAST, dtypes.half, arg=None, src=( UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (7, 6, 5)), src=( UOp(Ops.CAST, dtypes.float, arg=None, src=( - UOp(Ops.ALU, dtypes.half, arg=BinaryOps.MUL, src=( + UOp(Ops.MUL, dtypes.half, arg=None, src=( UOp(Ops.LOAD, dtypes.half, arg=None, src=( UOp(Ops.DEFINE_GLOBAL, dtypes.half.ptr(), arg=1, src=()), UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 128, 1, 64, 4, 58, 4, 58), strides=(0, 200704, 0, 3136, 0, 56, 0, 1), offset=-57, mask=((0, 1), (0, 128), (0, 1), (0, 64), (0, 4), (1, 57), (0, 4), (1, 57)), contiguous=False), View(shape=(128, 1, 64, 56, 56, 64, 3, 3), strides=(3444736, 0, 0, 232, 1, 53824, 13688, 59), offset=0, mask=None, contiguous=False))), src=()),)), @@ -750,7 +750,7 @@ class TestLinearizerFailures(unittest.TestCase): UOp(Ops.CAST, dtypes.half, arg=None, src=( UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (7, 6, 5)), src=( UOp(Ops.CAST, dtypes.float, arg=None, src=( - UOp(Ops.ALU, dtypes.half, arg=BinaryOps.MUL, src=( + UOp(Ops.MUL, dtypes.half, arg=None, src=( UOp(Ops.LOAD, dtypes.half, arg=None, src=( UOp(Ops.DEFINE_GLOBAL, dtypes.half.ptr(), arg=1, src=()), UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(256, 1, 12, 31, 31, 3, 2, 2), strides=(3072, 0, 0, 32, 1, 1024, 32, 1), offset=0, mask=None, contiguous=False),)), src=()),)), @@ -767,9 +767,9 @@ class TestLinearizerFailures(unittest.TestCase): UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=0, src=()), UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 16, 13, 1), strides=(0, 13, 1, 0), offset=0, mask=None, contiguous=True),)), src=()), UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (3,)), src=( - UOp(Ops.ALU, dtypes.float, arg=UnaryOps.EXP2, src=( - UOp(Ops.ALU, dtypes.float, arg=BinaryOps.MUL, src=( - UOp(Ops.ALU, dtypes.float, arg=BinaryOps.ADD, src=( + UOp(Ops.EXP2, dtypes.float, arg=None, src=( + UOp(Ops.MUL, dtypes.float, arg=None, src=( + UOp(Ops.ADD, dtypes.float, arg=None, src=( UOp(Ops.LOAD, dtypes.float, arg=None, src=( UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=1, src=()), UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 16, 13, 13), strides=(0, 169, 13, 1), offset=0, mask=None, contiguous=True),)), src=()),)), @@ -792,7 +792,7 @@ class TestLinearizerFailures(unittest.TestCase): UOp(Ops.CAST, dtypes.half, arg=None, src=( UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (7, 6, 5)), src=( UOp(Ops.CAST, dtypes.float, arg=None, src=( - UOp(Ops.ALU, dtypes.half, arg=BinaryOps.MUL, src=( + UOp(Ops.MUL, dtypes.half, arg=None, src=( UOp(Ops.LOAD, dtypes.half, arg=None, src=( UOp(Ops.DEFINE_GLOBAL, dtypes.half.ptr(), arg=1, src=()), UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 256, 1, 256, 4, 16, 4, 16), strides=(0, 50176, 0, 196, 0, 14, 0, 1), offset=-15, mask=((0, 1), (0, 256), (0, 1), (0, 256), (0, 4), (1, 15), (0, 4), (1, 15)), contiguous=False), View(shape=(256, 1, 256, 14, 14, 256, 3, 3), strides=(1048576, 0, 0, 64, 1, 4096, 1088, 17), offset=0, mask=None, contiguous=False))), src=()),)), @@ -809,26 +809,26 @@ class TestLinearizerFailures(unittest.TestCase): UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=0, src=()), UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1,), strides=(0,), offset=0, mask=None, contiguous=True),)), src=()), UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (0,)), src=( - UOp(Ops.ALU, dtypes.float, arg=BinaryOps.MUL, src=( + UOp(Ops.MUL, dtypes.float, arg=None, src=( x5:=UOp(Ops.LOAD, dtypes.float, arg=None, src=( UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=1, src=()), UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(32640,), strides=(1,), offset=0, mask=((0, 26040),), contiguous=False),)), src=()),)), - UOp(Ops.ALU, dtypes.float, arg=TernaryOps.WHERE, src=( - UOp(Ops.ALU, dtypes.bool, arg=BinaryOps.CMPNE, src=( + UOp(Ops.WHERE, dtypes.float, arg=None, src=( + UOp(Ops.CMPNE, dtypes.bool, arg=None, src=( x5, x10:=ast_const(dtypes.float, 0.0, st_src=( UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(32640,), strides=(0,), offset=0, mask=None, contiguous=False),)), src=()),)),)), - UOp(Ops.ALU, dtypes.float, arg=TernaryOps.WHERE, src=( - UOp(Ops.ALU, dtypes.bool, arg=BinaryOps.CMPLT, src=( - UOp(Ops.ALU, dtypes.float, arg=BinaryOps.ADD, src=( - UOp(Ops.ALU, dtypes.float, arg=BinaryOps.ADD, src=( - UOp(Ops.ALU, dtypes.float, arg=BinaryOps.MUL, src=( + UOp(Ops.WHERE, dtypes.float, arg=None, src=( + UOp(Ops.CMPLT, dtypes.bool, arg=None, src=( + UOp(Ops.ADD, dtypes.float, arg=None, src=( + UOp(Ops.ADD, dtypes.float, arg=None, src=( + UOp(Ops.MUL, dtypes.float, arg=None, src=( ast_const(dtypes.float, 0.06788442333021306, st_src=( UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(32640,), strides=(0,), offset=0, mask=((0, 26040),), contiguous=False),)), src=()),)), x5,)), ast_const(dtypes.float, -0.03394221166510653, st_src=( UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(32640,), strides=(0,), offset=0, mask=((0, 26040),), contiguous=False),)), src=()),)),)), - UOp(Ops.ALU, dtypes.float, arg=BinaryOps.ADD, src=( + UOp(Ops.ADD, dtypes.float, arg=None, src=( UOp(Ops.LOAD, dtypes.float, arg=None, src=( UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=2, src=()), UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(32640,), strides=(1,), offset=-26040, mask=((26040, 32640),), contiguous=False),)), src=()),)), @@ -849,9 +849,9 @@ class TestLinearizerFailures(unittest.TestCase): UOp(Ops.STORE, dtypes.void, arg=None, src=( UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=0, src=()), UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(4, 1, 6, 10, 3, 1, 1, 1), strides=(180, 0, 30, 3, 1, 0, 0, 0), offset=0, mask=None, contiguous=True),)), src=()), - UOp(Ops.ALU, dtypes.float, arg=BinaryOps.MAX, src=( + UOp(Ops.MAX, dtypes.float, arg=None, src=( UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (6, 7)), src=( - UOp(Ops.ALU, dtypes.float, arg=BinaryOps.MUL, src=( + UOp(Ops.MUL, dtypes.float, arg=None, src=( UOp(Ops.LOAD, dtypes.float, arg=None, src=( UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=1, src=()), UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(4, 1, 6, 10, 3, 1, 2, 5), strides=(77, 0, 0, 7, 1, 0, 7, 1), offset=0, mask=None, contiguous=False),)), src=()),)), @@ -873,7 +873,7 @@ class TestLinearizerFailures(unittest.TestCase): UOp(Ops.DEFINE_GLOBAL, dtypes.uchar.ptr(), arg=0, src=()), UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(5, 1), strides=(1, 0), offset=0, mask=None, contiguous=True),)), src=()), UOp(Ops.CAST, dtypes.uchar, arg=None, src=( - UOp(Ops.ALU, dtypes.uint, arg=BinaryOps.ADD, src=( + UOp(Ops.ADD, dtypes.uint, arg=None, src=( UOp(Ops.REDUCE_AXIS, dtypes.uint, arg=(BinaryOps.ADD, (1,)), src=( UOp(Ops.CAST, dtypes.uint, arg=None, src=( ast_const(dtypes.uchar, 1, st_src=( @@ -892,10 +892,10 @@ class TestLinearizerFailures(unittest.TestCase): UOp(Ops.STORE, dtypes.void, arg=None, src=( UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=0, src=()), UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(512, 1, 32, 24, 24, 1, 1, 1), strides=(18432, 0, 576, 24, 1, 0, 0, 0), offset=0, mask=None, contiguous=True),)), src=()), - UOp(Ops.ALU, dtypes.float, arg=BinaryOps.MAX, src=( - UOp(Ops.ALU, dtypes.float, arg=BinaryOps.ADD, src=( + UOp(Ops.MAX, dtypes.float, arg=None, src=( + UOp(Ops.ADD, dtypes.float, arg=None, src=( UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (6, 7)), src=( - UOp(Ops.ALU, dtypes.float, arg=BinaryOps.MUL, src=( + UOp(Ops.MUL, dtypes.float, arg=None, src=( UOp(Ops.CAST, dtypes.float, arg=None, src=( UOp(Ops.LOAD, dtypes.uchar, arg=None, src=( UOp(Ops.DEFINE_GLOBAL, dtypes.uchar.ptr(), arg=1, src=()), @@ -920,7 +920,7 @@ class TestLinearizerFailures(unittest.TestCase): UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=0, src=()), UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 1, 32, 1, 1, 1, 5, 5, 256), strides=(0, 0, 6400, 0, 0, 0, 1280, 256, 1), offset=0, mask=None, contiguous=True),)), src=()), UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (0, 3, 4)), src=( - UOp(Ops.ALU, dtypes.float, arg=BinaryOps.MUL, src=( + UOp(Ops.MUL, dtypes.float, arg=None, src=( UOp(Ops.CAST, dtypes.float, arg=None, src=( UOp(Ops.LOAD, dtypes.uchar, arg=None, src=( UOp(Ops.DEFINE_GLOBAL, dtypes.uchar.ptr(), arg=1, src=()), @@ -940,10 +940,10 @@ class TestLinearizerFailures(unittest.TestCase): UOp(Ops.STORE, dtypes.void, arg=None, src=( UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=0, src=()), UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(10000, 1, 32, 24, 24, 1, 1, 1), strides=(18432, 0, 576, 24, 1, 0, 0, 0), offset=0, mask=None, contiguous=True),)), src=()), - UOp(Ops.ALU, dtypes.float, arg=BinaryOps.MAX, src=( - UOp(Ops.ALU, dtypes.float, arg=BinaryOps.ADD, src=( + UOp(Ops.MAX, dtypes.float, arg=None, src=( + UOp(Ops.ADD, dtypes.float, arg=None, src=( UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (6, 7)), src=( - UOp(Ops.ALU, dtypes.float, arg=BinaryOps.MUL, src=( + UOp(Ops.MUL, dtypes.float, arg=None, src=( UOp(Ops.CAST, dtypes.float, arg=None, src=( UOp(Ops.LOAD, dtypes.uchar, arg=None, src=( UOp(Ops.DEFINE_GLOBAL, dtypes.uchar.ptr(), arg=1, src=()), @@ -967,7 +967,7 @@ class TestLinearizerFailures(unittest.TestCase): UOp(Ops.STORE, dtypes.void, arg=None, src=( UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(), arg=0, src=()), UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(60000, 1), strides=(1, 0), offset=0, mask=None, contiguous=True),)), src=()), - UOp(Ops.ALU, dtypes.int, arg=BinaryOps.ADD, src=( + UOp(Ops.ADD, dtypes.int, arg=None, src=( UOp(Ops.REDUCE_AXIS, dtypes.int, arg=(BinaryOps.ADD, (1,)), src=( ast_const(dtypes.int, 1, st_src=( UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(60001, 119999), strides=(0, 0), offset=0, mask=((0, 60001), (59999, 119999)), contiguous=False), View(shape=(60000, 60000), strides=(1, 120000), offset=0, mask=None, contiguous=False))), src=()),)),)), @@ -988,7 +988,7 @@ class TestLinearizerFailures(unittest.TestCase): UOp(Ops.CAST, dtypes.half, arg=None, src=( UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (5, 6, 7)), src=( UOp(Ops.CAST, dtypes.float, arg=None, src=( - UOp(Ops.ALU, dtypes.half, arg=BinaryOps.MUL, src=( + UOp(Ops.MUL, dtypes.half, arg=None, src=( UOp(Ops.LOAD, dtypes.half, arg=None, src=( UOp(Ops.DEFINE_GLOBAL, dtypes.half.ptr(), arg=1, src=()), UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 256, 1, 128, 4, 58, 4, 58), strides=(0, 401408, 0, 3136, 0, 56, 0, 1), offset=-57, mask=((0, 1), (0, 256), (0, 1), (0, 128), (0, 4), (1, 57), (0, 4), (1, 57)), contiguous=False), View(shape=(256, 1, 128, 28, 28, 128, 3, 3), strides=(6889472, 0, 0, 464, 2, 53824, 13688, 59), offset=0, mask=None, contiguous=False))), src=()),)), @@ -1052,18 +1052,18 @@ class TestLinearizerFailures(unittest.TestCase): UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=0, src=()), UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(2, 3, 1, 1, 1), strides=(3, 1, 0, 0, 0), offset=0, mask=None, contiguous=True),)), src=()), UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (2, 3)), src=( - UOp(Ops.ALU, dtypes.float, arg=BinaryOps.MUL, src=( + UOp(Ops.MUL, dtypes.float, arg=None, src=( UOp(Ops.LOAD, dtypes.float, arg=None, src=( UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=1, src=()), UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(2, 3, 2, 3, 1), strides=(0, 0, 3, 1, 0), offset=0, mask=None, contiguous=False),)), src=()),)), UOp(Ops.CAST, dtypes.float, arg=None, src=( - UOp(Ops.ALU, dtypes.bool, arg=BinaryOps.MUL, src=( - UOp(Ops.ALU, dtypes.bool, arg=BinaryOps.CMPNE, src=( - UOp(Ops.ALU, dtypes.bool, arg=BinaryOps.CMPNE, src=( + UOp(Ops.MUL, dtypes.bool, arg=None, src=( + UOp(Ops.CMPNE, dtypes.bool, arg=None, src=( + UOp(Ops.CMPNE, dtypes.bool, arg=None, src=( UOp(Ops.LOAD, dtypes.int, arg=None, src=( UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(), arg=2, src=()), UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(2, 3, 2, 3, 1), strides=(0, 0, 0, 0, 0), offset=0, mask=None, contiguous=False),)), src=()),)), - UOp(Ops.ALU, dtypes.int, arg=BinaryOps.ADD, src=( + UOp(Ops.ADD, dtypes.int, arg=None, src=( UOp(Ops.REDUCE_AXIS, dtypes.int, arg=(BinaryOps.ADD, (4,)), src=( ast_const(dtypes.int, 1, st_src=( UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(3, 3), strides=(0, 0), offset=0, mask=((0, 3), (1, 3)), contiguous=False), View(shape=(2, 3, 2, 3, 3), strides=(0, 0, 1, 0, 4), offset=0, mask=((0, 2), (0, 3), (0, 2), (0, 3), (0, 2)), contiguous=False))), src=()),)),)), @@ -1071,12 +1071,12 @@ class TestLinearizerFailures(unittest.TestCase): UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(2, 3, 2, 3, 1), strides=(0, 0, 0, 0, 0), offset=0, mask=None, contiguous=False),)), src=()),)),)),)), x21:=ast_const(dtypes.bool, True, st_src=( UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(2, 3, 2, 3, 1), strides=(0, 0, 0, 0, 0), offset=0, mask=None, contiguous=False),)), src=()),)),)), - UOp(Ops.ALU, dtypes.bool, arg=BinaryOps.CMPNE, src=( - UOp(Ops.ALU, dtypes.bool, arg=BinaryOps.CMPNE, src=( + UOp(Ops.CMPNE, dtypes.bool, arg=None, src=( + UOp(Ops.CMPNE, dtypes.bool, arg=None, src=( UOp(Ops.LOAD, dtypes.int, arg=None, src=( UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(), arg=3, src=()), UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(2, 3, 2, 3, 1), strides=(3, 1, 0, 0, 0), offset=0, mask=None, contiguous=False),)), src=()),)), - UOp(Ops.ALU, dtypes.int, arg=BinaryOps.ADD, src=( + UOp(Ops.ADD, dtypes.int, arg=None, src=( UOp(Ops.REDUCE_AXIS, dtypes.int, arg=(BinaryOps.ADD, (4,)), src=( ast_const(dtypes.int, 1, st_src=( UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(4, 5), strides=(0, 0), offset=0, mask=((0, 4), (2, 5)), contiguous=False), View(shape=(2, 3, 2, 3, 3), strides=(0, 0, 0, 1, 6), offset=0, mask=None, contiguous=False))), src=()),)),)), @@ -1091,13 +1091,13 @@ class TestLinearizerFailures(unittest.TestCase): UOp(Ops.STORE, dtypes.void, arg=None, src=( UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=0, src=()), UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(512, 1), strides=(1, 0), offset=0, mask=None, contiguous=True),)), src=()), - UOp(Ops.ALU, dtypes.float, arg=BinaryOps.MUL, src=( + UOp(Ops.MUL, dtypes.float, arg=None, src=( UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (1,)), src=( - UOp(Ops.ALU, dtypes.float, arg=BinaryOps.MUL, src=( + UOp(Ops.MUL, dtypes.float, arg=None, src=( UOp(Ops.CAST, dtypes.float, arg=None, src=( - UOp(Ops.ALU, dtypes.bool, arg=BinaryOps.MUL, src=( - UOp(Ops.ALU, dtypes.bool, arg=BinaryOps.CMPNE, src=( - UOp(Ops.ALU, dtypes.bool, arg=BinaryOps.CMPNE, src=( + UOp(Ops.MUL, dtypes.bool, arg=None, src=( + UOp(Ops.CMPNE, dtypes.bool, arg=None, src=( + UOp(Ops.CMPNE, dtypes.bool, arg=None, src=( UOp(Ops.LOAD, dtypes.int, arg=None, src=( UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(), arg=1, src=()), UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(512, 10), strides=(0, 1), offset=0, mask=None, contiguous=False),)), src=()),)), @@ -1112,7 +1112,7 @@ class TestLinearizerFailures(unittest.TestCase): UOp(Ops.LOAD, dtypes.float, arg=None, src=( UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=4, src=()), UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(512, 10), strides=(0, 0), offset=0, mask=None, contiguous=False),)), src=()),)),)),)), - UOp(Ops.ALU, dtypes.float, arg=UnaryOps.RECIP, src=( + UOp(Ops.RECIP, dtypes.float, arg=None, src=( UOp(Ops.LOAD, dtypes.float, arg=None, src=( UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=5, src=()), UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(512, 1), strides=(1, 0), offset=0, mask=None, contiguous=True),)), src=()),)),)),)),)),)) @@ -1125,7 +1125,7 @@ class TestLinearizerFailures(unittest.TestCase): UOp(Ops.STORE, dtypes.void, arg=None, src=( UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(), arg=0, src=()), UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(60000, 1), strides=(1, 0), offset=0, mask=None, contiguous=True),)), src=()), - UOp(Ops.ALU, dtypes.int, arg=BinaryOps.ADD, src=( + UOp(Ops.ADD, dtypes.int, arg=None, src=( UOp(Ops.REDUCE_AXIS, dtypes.int, arg=(BinaryOps.ADD, (1,)), src=( ast_const(dtypes.int, 1, st_src=( UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(60001, 119999), strides=(0, 0), offset=0, mask=((0, 60001), (59999, 119999)), contiguous=False), View(shape=(60000, 60000), strides=(1, 120000), offset=0, mask=None, contiguous=False))), src=()),)),)), @@ -1143,7 +1143,7 @@ class TestLinearizerFailures(unittest.TestCase): UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 1, 64, 1, 1, 256, 1, 1, 256), strides=(0, 0, 65536, 0, 0, 256, 0, 0, 1), offset=0, mask=None, contiguous=True),)), src=()), UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (3, 4)), src=( UOp(Ops.CAST, dtypes.float, arg=None, src=( - UOp(Ops.ALU, dtypes.half, arg=BinaryOps.MUL, src=( + UOp(Ops.MUL, dtypes.half, arg=None, src=( UOp(Ops.LOAD, dtypes.half, arg=None, src=( UOp(Ops.DEFINE_GLOBAL, dtypes.half.ptr(), arg=1, src=()), UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 1, 64, 56, 56, 256, 1, 1, 256), strides=(0, 0, 0, 56, 1, 3136, 0, 0, 802816), offset=0, mask=None, contiguous=False),)), src=()),)), @@ -1160,7 +1160,7 @@ class TestLinearizerFailures(unittest.TestCase): UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=0, src=()), UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(10, 6, 1), strides=(6, 1, 0), offset=0, mask=None, contiguous=True),)), src=()), UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (2,)), src=( - UOp(Ops.ALU, dtypes.float, arg=BinaryOps.MUL, src=( + UOp(Ops.MUL, dtypes.float, arg=None, src=( UOp(Ops.LOAD, dtypes.float, arg=None, src=( UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=1, src=()), UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(10, 6, 10), strides=(10, 0, 1), offset=0, mask=None, contiguous=False),)), src=()),)), @@ -1176,14 +1176,14 @@ class TestLinearizerFailures(unittest.TestCase): UOp(Ops.STORE, dtypes.void, arg=None, src=( UOp(Ops.DEFINE_GLOBAL, dtypes.bool.ptr(), arg=0, src=()), UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 1, 20, 1, 20), strides=(0, 0, 20, 0, 1), offset=0, mask=None, contiguous=True),)), src=()), - UOp(Ops.ALU, dtypes.bool, arg=BinaryOps.CMPNE, src=( + UOp(Ops.CMPNE, dtypes.bool, arg=None, src=( UOp(Ops.REDUCE_AXIS, dtypes.bool, arg=(BinaryOps.ADD, (3,)), src=( - UOp(Ops.ALU, dtypes.bool, arg=BinaryOps.MUL, src=( + UOp(Ops.MUL, dtypes.bool, arg=None, src=( UOp(Ops.LOAD, dtypes.bool, arg=None, src=( UOp(Ops.DEFINE_GLOBAL, dtypes.bool.ptr(), arg=1, src=()), UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 1, 20, 20, 20), strides=(0, 0, 0, 20, 1), offset=0, mask=None, contiguous=False),)), src=()),)), - UOp(Ops.ALU, dtypes.bool, arg=BinaryOps.CMPNE, src=( - UOp(Ops.ALU, dtypes.bool, arg=BinaryOps.CMPNE, src=( + UOp(Ops.CMPNE, dtypes.bool, arg=None, src=( + UOp(Ops.CMPNE, dtypes.bool, arg=None, src=( UOp(Ops.LOAD, dtypes.int, arg=None, src=( UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(), arg=2, src=()), UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 1, 20, 20, 20), strides=(0, 0, 1, 0, 0), offset=0, mask=None, contiguous=False),)), src=()),)), @@ -1203,25 +1203,25 @@ class TestLinearizerFailures(unittest.TestCase): UOp(Ops.STORE, dtypes.void, arg=None, src=( UOp(Ops.DEFINE_GLOBAL, dtypes.half.ptr(), arg=0, src=()), UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(12, 1024, 1), strides=(1024, 1, 0), offset=0, mask=None, contiguous=True),)), src=()), - UOp(Ops.ALU, dtypes.half, arg=UnaryOps.RECIP, src=( - UOp(Ops.ALU, dtypes.half, arg=BinaryOps.ADD, src=( - UOp(Ops.ALU, dtypes.half, arg=TernaryOps.WHERE, src=( + UOp(Ops.RECIP, dtypes.half, arg=None, src=( + UOp(Ops.ADD, dtypes.half, arg=None, src=( + UOp(Ops.WHERE, dtypes.half, arg=None, src=( x6:=UOp(Ops.VALID, dtypes.bool, arg=None, src=( UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(12, 1024, 1), strides=(0, 0, 0), offset=0, mask=None, contiguous=False),)), src=()),)), UOp(Ops.CONST, dtypes.half, arg=1.0, src=()), x9:=UOp(Ops.CONST, dtypes.half, arg=0.0, src=()),)), - UOp(Ops.ALU, dtypes.half, arg=UnaryOps.EXP2, src=( - UOp(Ops.ALU, dtypes.half, arg=BinaryOps.MUL, src=( - UOp(Ops.ALU, dtypes.half, arg=BinaryOps.MUL, src=( - UOp(Ops.ALU, dtypes.half, arg=TernaryOps.WHERE, src=( + UOp(Ops.EXP2, dtypes.half, arg=None, src=( + UOp(Ops.MUL, dtypes.half, arg=None, src=( + UOp(Ops.MUL, dtypes.half, arg=None, src=( + UOp(Ops.WHERE, dtypes.half, arg=None, src=( x6, UOp(Ops.CONST, dtypes.half, arg=2.0, src=()), x9,)), - UOp(Ops.ALU, dtypes.half, arg=BinaryOps.ADD, src=( + UOp(Ops.ADD, dtypes.half, arg=None, src=( UOp(Ops.CAST, dtypes.half, arg=None, src=( UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (2,)), src=( UOp(Ops.CAST, dtypes.float, arg=None, src=( - UOp(Ops.ALU, dtypes.half, arg=BinaryOps.MUL, src=( + UOp(Ops.MUL, dtypes.half, arg=None, src=( UOp(Ops.LOAD, dtypes.half, arg=None, src=( UOp(Ops.DEFINE_GLOBAL, dtypes.half.ptr(), arg=1, src=()), UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(12, 1024, 1024), strides=(524288, 0, 1), offset=0, mask=None, contiguous=False),)), src=()),)), @@ -1231,7 +1231,7 @@ class TestLinearizerFailures(unittest.TestCase): UOp(Ops.LOAD, dtypes.half, arg=None, src=( UOp(Ops.DEFINE_GLOBAL, dtypes.half.ptr(), arg=3, src=()), UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(12, 1024, 1), strides=(0, 1, 0), offset=0, mask=None, contiguous=False),)), src=()),)),)),)), - UOp(Ops.ALU, dtypes.half, arg=TernaryOps.WHERE, src=( + UOp(Ops.WHERE, dtypes.half, arg=None, src=( x6, UOp(Ops.CONST, dtypes.half, arg=-1.4426950408889634, src=()), x9,)),)),)),)),)),)),)) @@ -1250,7 +1250,7 @@ class TestLinearizerFailures(unittest.TestCase): UOp(Ops.CAST, dtypes.half, arg=None, src=( UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (5, 6, 7)), src=( UOp(Ops.CAST, dtypes.float, arg=None, src=( - UOp(Ops.ALU, dtypes.half, arg=BinaryOps.MUL, src=( + UOp(Ops.MUL, dtypes.half, arg=None, src=( UOp(Ops.LOAD, dtypes.half, arg=None, src=( UOp(Ops.DEFINE_GLOBAL, dtypes.half.ptr(), arg=1, src=()), UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 256, 1, 3, 8, 230, 8, 230), strides=(0, 150528, 0, 50176, 0, 224, 0, 1), offset=-675, mask=((0, 1), (0, 256), (0, 1), (0, 3), (0, 8), (3, 227), (0, 8), (3, 227)), contiguous=False), View(shape=(256, 1, 64, 112, 112, 3, 7, 7), strides=(10156800, 0, 0, 3680, 2, 3385600, 425040, 231), offset=0, mask=None, contiguous=False))), src=()),)), @@ -1267,29 +1267,29 @@ class TestLinearizerFailures(unittest.TestCase): UOp(Ops.DEFINE_GLOBAL, dtypes.uchar.ptr(), arg=0, src=()), UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1024, 1, 1), strides=(1, 0, 0), offset=0, mask=None, contiguous=True),)), src=()), UOp(Ops.REDUCE_AXIS, dtypes.uchar, arg=(BinaryOps.ADD, (1,)), src=( - UOp(Ops.ALU, dtypes.uchar, arg=BinaryOps.MUL, src=( + UOp(Ops.MUL, dtypes.uchar, arg=None, src=( UOp(Ops.LOAD, dtypes.uchar, arg=None, src=( UOp(Ops.DEFINE_GLOBAL, dtypes.uchar.ptr(), arg=1, src=()), UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1024, 50000, 1), strides=(0, 1, 0), offset=0, mask=None, contiguous=False),)), src=()),)), UOp(Ops.CAST, dtypes.uchar, arg=None, src=( - UOp(Ops.ALU, dtypes.bool, arg=BinaryOps.CMPNE, src=( - UOp(Ops.ALU, dtypes.bool, arg=BinaryOps.CMPNE, src=( + UOp(Ops.CMPNE, dtypes.bool, arg=None, src=( + UOp(Ops.CMPNE, dtypes.bool, arg=None, src=( UOp(Ops.LOAD, dtypes.int, arg=None, src=( UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(), arg=2, src=()), UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1024, 50000, 1), strides=(1, 0, 0), offset=0, mask=None, contiguous=False),)), src=()),)), - UOp(Ops.ALU, dtypes.int, arg=BinaryOps.ADD, src=( + UOp(Ops.ADD, dtypes.int, arg=None, src=( UOp(Ops.REDUCE_AXIS, dtypes.int, arg=(BinaryOps.ADD, (2,)), src=( - UOp(Ops.ALU, dtypes.int, arg=TernaryOps.WHERE, src=( + UOp(Ops.WHERE, dtypes.int, arg=None, src=( UOp(Ops.VALID, dtypes.bool, arg=None, src=( UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(50001, 99999), strides=(0, 0), offset=0, mask=((0, 50001), (49999, 99999)), contiguous=False), View(shape=(1024, 50000, 50000), strides=(0, 1, 100000), offset=0, mask=None, contiguous=False))), src=()),)), UOp(Ops.CONST, dtypes.int, arg=1, src=()), x20:=UOp(Ops.CONST, dtypes.int, arg=0, src=()),)),)), - UOp(Ops.ALU, dtypes.int, arg=TernaryOps.WHERE, src=( + UOp(Ops.WHERE, dtypes.int, arg=None, src=( x22:=UOp(Ops.VALID, dtypes.bool, arg=None, src=( UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1024, 50000, 1), strides=(0, 0, 0), offset=0, mask=None, contiguous=False),)), src=()),)), UOp(Ops.CONST, dtypes.int, arg=-1, src=()), x20,)),)),)), - UOp(Ops.ALU, dtypes.bool, arg=TernaryOps.WHERE, src=( + UOp(Ops.WHERE, dtypes.bool, arg=None, src=( x22, UOp(Ops.CONST, dtypes.bool, arg=True, src=()), UOp(Ops.CONST, dtypes.bool, arg=False, src=()),)),)),)),)),)),)),)) @@ -1307,7 +1307,7 @@ class TestLinearizerFailures(unittest.TestCase): UOp(Ops.CAST, dtypes.half, arg=None, src=( UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (5, 6, 7)), src=( UOp(Ops.CAST, dtypes.float, arg=None, src=( - UOp(Ops.ALU, dtypes.half, arg=BinaryOps.MUL, src=( + UOp(Ops.MUL, dtypes.half, arg=None, src=( UOp(Ops.LOAD, dtypes.half, arg=None, src=( UOp(Ops.DEFINE_GLOBAL, dtypes.half.ptr(), arg=1, src=()), UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 256, 1, 64, 4, 58, 4, 58), strides=(0, 200704, 0, 3136, 0, 56, 0, 1), offset=-57, mask=((0, 1), (0, 256), (0, 1), (0, 64), (0, 4), (1, 57), (0, 4), (1, 57)), contiguous=False), View(shape=(256, 1, 64, 56, 56, 64, 3, 3), strides=(3444736, 0, 0, 232, 1, 53824, 13688, 59), offset=0, mask=None, contiguous=False))), src=()),)), @@ -1327,7 +1327,7 @@ class TestLinearizerFailures(unittest.TestCase): UOp(Ops.CAST, dtypes.half, arg=None, src=( UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (5, 6, 7)), src=( UOp(Ops.CAST, dtypes.float, arg=None, src=( - UOp(Ops.ALU, dtypes.half, arg=BinaryOps.MUL, src=( + UOp(Ops.MUL, dtypes.half, arg=None, src=( UOp(Ops.LOAD, dtypes.half, arg=None, src=( UOp(Ops.DEFINE_GLOBAL, dtypes.half.ptr(), arg=1, src=()), UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, W, 1, 64, 4, 58, 4, 58), strides=(0, 200704, 0, 3136, 0, 56, 0, 1), offset=-57, mask=((0, 1), (0, W), (0, 1), (0, 64), (0, 4), (1, 57), (0, 4), (1, 57)), contiguous=False), diff --git a/test/test_linearizer_overflows.py b/test/test_linearizer_overflows.py index 5c152fec11..5ff6e0ca74 100644 --- a/test/test_linearizer_overflows.py +++ b/test/test_linearizer_overflows.py @@ -8,7 +8,7 @@ from tinygrad.engine.search import Opt, OptOps from tinygrad.engine.search import time_linearizer, bufs_from_lin # stuff needed to unpack a kernel -from tinygrad.ops import UOp, Ops, BinaryOps, UnaryOps +from tinygrad.ops import UOp, Ops, BinaryOps from tinygrad.shape.shapetracker import ShapeTracker from tinygrad.shape.view import View @@ -28,13 +28,13 @@ class TestLinearizerOverflow(unittest.TestCase): UOp(Ops.STORE, None, arg=None, src=( UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=0, src=()), UOp(Ops.VIEW, None, arg=ShapeTracker(views=(View(shape=(64, 1, 64, 112, 112, 1, 1, 1), strides=(802816, 0, 12544, 112, 1, 0, 0, 0), offset=0, mask=None, contiguous=True),)), src=()), - UOp(Ops.ALU, dtypes.float, arg=BinaryOps.MAX, src=( - UOp(Ops.ALU, dtypes.float, arg=BinaryOps.ADD, src=( - UOp(Ops.ALU, dtypes.float, arg=BinaryOps.MUL, src=( - UOp(Ops.ALU, dtypes.float, arg=BinaryOps.MUL, src=( - UOp(Ops.ALU, dtypes.float, arg=BinaryOps.ADD, src=( + UOp(Ops.MAX, dtypes.float, arg=None, src=( + UOp(Ops.ADD, dtypes.float, arg=None, src=( + UOp(Ops.MUL, dtypes.float, arg=None, src=( + UOp(Ops.MUL, dtypes.float, arg=None, src=( + UOp(Ops.ADD, dtypes.float, arg=None, src=( UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (7, 6, 5)), src=( - UOp(Ops.ALU, dtypes.float, arg=BinaryOps.MUL, src=( + UOp(Ops.MUL, dtypes.float, arg=None, src=( UOp(Ops.LOAD, dtypes.float, arg=None, src=( UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=1, src=()), UOp(Ops.VIEW, None, arg=ShapeTracker(views=(View(shape=(1, 64, 1, 3, 8, 230, 8, 230), strides=(0, 150528, 0, 50176, 0, 224, 0, 1), offset=-675, mask=((0, 1), (0, 64), (0, 1), (0, 3), (0, 8), (3, 227), (0, 8), (3, 227)), contiguous=False), View(shape=(64, 1, 64, 112, 112, 3, 7, 7), strides=(10156800, 0, 0, 3680, 2, 3385600, 425040, 231), offset=0, mask=None, contiguous=False))), src=()),)), @@ -46,12 +46,12 @@ class TestLinearizerOverflow(unittest.TestCase): UOp(Ops.LOAD, dtypes.float, arg=None, src=( UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=3, src=()), UOp(Ops.VIEW, None, arg=ShapeTracker(views=(View(shape=(64, 1, 64, 112, 112, 1, 1, 1), strides=(0, 0, 1, 0, 0, 0, 0, 0), offset=0, mask=None, contiguous=False),)), src=()),)),)), - UOp(Ops.ALU, dtypes.float, arg=UnaryOps.SQRT, src=( - UOp(Ops.ALU, dtypes.float, arg=BinaryOps.MUL, src=( + UOp(Ops.SQRT, dtypes.float, arg=None, src=( + UOp(Ops.MUL, dtypes.float, arg=None, src=( x23:=ast_const(dtypes.float, 1.0, st_src=( UOp(Ops.VIEW, None, arg=ShapeTracker(views=(View(shape=(64, 1, 64, 112, 112, 1, 1, 1), strides=(0, 0, 0, 0, 0, 0, 0, 0), offset=0, mask=None, contiguous=False),)), src=()),)), - UOp(Ops.ALU, dtypes.float, arg=UnaryOps.RECIP, src=( - UOp(Ops.ALU, dtypes.float, arg=BinaryOps.ADD, src=( + UOp(Ops.RECIP, dtypes.float, arg=None, src=( + UOp(Ops.ADD, dtypes.float, arg=None, src=( x23, ast_const(dtypes.float, 1e-05, st_src=( UOp(Ops.VIEW, None, arg=ShapeTracker(views=(View(shape=(64, 1, 64, 112, 112, 1, 1, 1), strides=(0, 0, 0, 0, 0, 0, 0, 0), offset=0, mask=None, contiguous=False),)), src=()),)),)),)),)),)),)), @@ -69,7 +69,7 @@ class TestLinearizerOverflow(unittest.TestCase): UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=0, src=()), UOp(Ops.VIEW, None, arg=ShapeTracker(views=(View(shape=(512, 1, 64, 32, 32, 1, 1, 1), strides=(65536, 0, 1024, 32, 1, 0, 0, 0), offset=0, mask=None, contiguous=True),)), src=()), UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (7, 6, 5)), src=( - UOp(Ops.ALU, dtypes.float, arg=BinaryOps.MUL, src=( + UOp(Ops.MUL, dtypes.float, arg=None, src=( UOp(Ops.LOAD, dtypes.float, arg=None, src=( UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=1, src=()), UOp(Ops.VIEW, None, arg=ShapeTracker(views=(View(shape=(1, 512, 1, 32, 4, 34, 4, 34), strides=(0, 32768, 0, 1024, 0, 32, 0, 1), offset=-33, mask=((0, 1), (0, 512), (0, 1), (0, 32), (0, 4), (1, 33), (0, 4), (1, 33)), contiguous=False), View(shape=(512, 1, 64, 32, 32, 32, 3, 3), strides=(591872, 0, 0, 136, 1, 18496, 4760, 35), offset=0, mask=None, contiguous=False))), src=()),)), @@ -86,7 +86,7 @@ class TestLinearizerOverflow(unittest.TestCase): UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=0, src=()), UOp(Ops.VIEW, None, arg=ShapeTracker(views=(View(shape=(16, 1, 128, 128, 128, 1, 1, 1), strides=(2097152, 0, 16384, 128, 1, 0, 0, 0), offset=0, mask=None, contiguous=True),)), src=()), UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (7, 6, 5)), src=( - UOp(Ops.ALU, dtypes.float, arg=BinaryOps.MUL, src=( + UOp(Ops.MUL, dtypes.float, arg=None, src=( UOp(Ops.LOAD, dtypes.float, arg=None, src=( UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=1, src=()), UOp(Ops.VIEW, None, arg=ShapeTracker(views=(View(shape=(1, 16, 1, 128, 4, 130, 4, 130), strides=(0, 2097152, 0, 16384, 0, 128, 0, 1), offset=-129, mask=((0, 1), (0, 16), (0, 1), (0, 128), (0, 4), (1, 129), (0, 4), (1, 129)), contiguous=False), View(shape=(16, 1, 128, 128, 128, 128, 3, 3), strides=(34611200, 0, 0, 520, 1, 270400, 68120, 131), offset=0, mask=None, contiguous=False))), src=()),)), @@ -103,7 +103,7 @@ class TestLinearizerOverflow(unittest.TestCase): UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=0, src=()), UOp(Ops.VIEW, None, arg=ShapeTracker(views=(View(shape=(4, 1, 128, 128, 128, 1, 1, 1), strides=(2097152, 0, 16384, 128, 1, 0, 0, 0), offset=0, mask=None, contiguous=True),)), src=()), UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (7, 6, 5)), src=( - UOp(Ops.ALU, dtypes.float, arg=BinaryOps.MUL, src=( + UOp(Ops.MUL, dtypes.float, arg=None, src=( UOp(Ops.LOAD, dtypes.float, arg=None, src=( UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=1, src=()), UOp(Ops.VIEW, None, arg=ShapeTracker(views=(View(shape=(1, 4, 1, 128, 4, 130, 4, 130), strides=(0, 2097152, 0, 16384, 0, 128, 0, 1), offset=-129, mask=((0, 1), (0, 4), (0, 1), (0, 128), (0, 4), (1, 129), (0, 4), (1, 129)), contiguous=False), View(shape=(4, 1, 128, 128, 128, 128, 3, 3), strides=(34611200, 0, 0, 520, 1, 270400, 68120, 131), offset=0, mask=None, contiguous=False))), src=()),)), @@ -120,7 +120,7 @@ class TestLinearizerOverflow(unittest.TestCase): UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=0, src=()), UOp(Ops.VIEW, None, arg=ShapeTracker(views=(View(shape=(2, 1, 128, 128, 128, 1, 1, 1), strides=(2097152, 0, 16384, 128, 1, 0, 0, 0), offset=0, mask=None, contiguous=True),)), src=()), UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (7, 6, 5)), src=( - UOp(Ops.ALU, dtypes.float, arg=BinaryOps.MUL, src=( + UOp(Ops.MUL, dtypes.float, arg=None, src=( UOp(Ops.LOAD, dtypes.float, arg=None, src=( UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=1, src=()), UOp(Ops.VIEW, None, arg=ShapeTracker(views=(View(shape=(1, 2, 1, 128, 4, 130, 4, 130), strides=(0, 2097152, 0, 16384, 0, 128, 0, 1), offset=-129, mask=((0, 1), (0, 2), (0, 1), (0, 128), (0, 4), (1, 129), (0, 4), (1, 129)), contiguous=False), View(shape=(2, 1, 128, 128, 128, 128, 3, 3), strides=(34611200, 0, 0, 520, 1, 270400, 68120, 131), offset=0, mask=None, contiguous=False))), src=()),)), @@ -137,7 +137,7 @@ class TestLinearizerOverflow(unittest.TestCase): UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=0, src=()), UOp(Ops.VIEW, None, arg=ShapeTracker(views=(View(shape=(3, 1, 128, 128, 128, 1, 1, 1), strides=(2097152, 0, 16384, 128, 1, 0, 0, 0), offset=0, mask=None, contiguous=True),)), src=()), UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (7, 6, 5)), src=( - UOp(Ops.ALU, dtypes.float, arg=BinaryOps.MUL, src=( + UOp(Ops.MUL, dtypes.float, arg=None, src=( UOp(Ops.LOAD, dtypes.float, arg=None, src=( UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=1, src=()), UOp(Ops.VIEW, None, arg=ShapeTracker(views=(View(shape=(1, 3, 1, 128, 4, 130, 4, 130), strides=(0, 2097152, 0, 16384, 0, 128, 0, 1), offset=-129, mask=((0, 1), (0, 3), (0, 1), (0, 128), (0, 4), (1, 129), (0, 4), (1, 129)), contiguous=False), View(shape=(3, 1, 128, 128, 128, 128, 3, 3), strides=(34611200, 0, 0, 520, 1, 270400, 68120, 131), offset=0, mask=None, contiguous=False))), src=()),)), @@ -154,7 +154,7 @@ class TestLinearizerOverflow(unittest.TestCase): UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=0, src=()), UOp(Ops.VIEW, None, arg=ShapeTracker(views=(View(shape=(3, 1, 128, 128, 128, 1, 1, 1), strides=(2097152, 0, 16384, 128, 1, 0, 0, 0), offset=0, mask=None, contiguous=True),)), src=()), UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (7, 6, 5)), src=( - UOp(Ops.ALU, dtypes.float, arg=BinaryOps.MUL, src=( + UOp(Ops.MUL, dtypes.float, arg=None, src=( UOp(Ops.LOAD, dtypes.float, arg=None, src=( UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=1, src=()), UOp(Ops.VIEW, None, arg=ShapeTracker(views=(View(shape=(1, 3, 1, 128, 4, 130, 4, 130), strides=(0, 2097152, 0, 16384, 0, 128, 0, 1), offset=-129, mask=((0, 1), (0, 3), (0, 1), (0, 128), (0, 4), (1, 129), (0, 4), (1, 129)), contiguous=False), View(shape=(3, 1, 128, 128, 128, 128, 3, 3), strides=(34611200, 0, 0, 520, 1, 270400, 68120, 131), offset=0, mask=None, contiguous=False))), src=()),)), diff --git a/test/test_schedule.py b/test/test_schedule.py index f4652c198f..9cfc06a00d 100644 --- a/test/test_schedule.py +++ b/test/test_schedule.py @@ -1689,10 +1689,10 @@ class TestIndexing(unittest.TestCase): UOp(Ops.STORE, dtypes.void, arg=None, src=( UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(), arg=0, src=()), UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 1), strides=(0, 0), offset=0, mask=None, contiguous=True),)), src=()), # noqa: E501 - UOp(Ops.REDUCE_AXIS, dtypes.int, arg=(BinaryOps.ADD, (0, 1)), src=( - UOp(Ops.ALU, dtypes.int, arg=BinaryOps.ADD, src=( + UOp(Ops.REDUCE_AXIS, dtypes.int, arg=(Ops.ADD, (0, 1)), src=( + UOp(Ops.ADD, dtypes.int, arg=None, src=( UOp(Ops.VIEW, dtypes.int, arg=ShapeTracker(views=(View(shape=(32, 32), strides=(0, 0), offset=0, mask=None, contiguous=False),)), src=( # noqa E501 - UOp(Ops.REDUCE_AXIS, dtypes.int, arg=(BinaryOps.ADD, (0, 1)), src=( + UOp(Ops.REDUCE_AXIS, dtypes.int, arg=(Ops.ADD, (0, 1)), src=( UOp(Ops.LOAD, dtypes.int, arg=None, src=( x8:=UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(), arg=1, src=()), UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(32, 32), strides=(32, 1), offset=0, mask=None, contiguous=True),)), src=()),)),)),)), # noqa E501 @@ -1776,11 +1776,11 @@ class TestIndexing(unittest.TestCase): UOp(Ops.CONTIGUOUS, dtypes.float, arg=None, src=( x1, UOp(Ops.VIEW, dtypes.float, arg=ShapeTracker(views=(View(shape=(1, 32, 32, 16), strides=(0, 32, 1, 1024), offset=0, mask=None, contiguous=False),)), src=( - UOp(Ops.ALU, dtypes.float, arg=BinaryOps.ADD, src=( - UOp(Ops.ALU, dtypes.float, arg=BinaryOps.ADD, src=( + UOp(Ops.ADD, dtypes.float, arg=None, src=( + UOp(Ops.ADD, dtypes.float, arg=None, src=( UOp(Ops.VIEW, dtypes.float, arg=ShapeTracker(views=(View(shape=(1, 16, 32, 32), strides=(0, 1, 512, 16), offset=0, mask=None, contiguous=False),)), src=( - UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (7, 8)), src=( - UOp(Ops.ALU, dtypes.float, arg=BinaryOps.MUL, src=( + UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(Ops.ADD, (7, 8)), src=( + UOp(Ops.MUL, dtypes.float, arg=None, src=( UOp(Ops.VIEW, dtypes.float, arg=ShapeTracker(views=(View(shape=(1, 32, 32, 1, 1, 4, 4, 4, 4, 1, 1), strides=(0, 512, 16, 0, 0, 0, 0, 4, 1, 0, 0), offset=0, mask=None, contiguous=False),)), src=( x11:=UOp(Ops.LOAD, dtypes.float, arg=None, src=( UOp(Ops.BUFFER, dtypes.float.ptr(), arg=(2, ('METAL', 16384, dtypes.float)), src=()), diff --git a/test/test_search.py b/test/test_search.py index 0e0d1ce3fa..b91d003e54 100644 --- a/test/test_search.py +++ b/test/test_search.py @@ -108,12 +108,12 @@ class TestBEAM(unittest.TestCase): UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=0, src=()), UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 1, 256), strides=(0, 0, 1), offset=0, mask=None, contiguous=True),)), src=()), # noqa: E501 UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.MAX, (1,)), src=( - UOp(Ops.ALU, dtypes.float, arg=BinaryOps.MUL, src=( - UOp(Ops.ALU, dtypes.float, arg=BinaryOps.ADD, src=( - UOp(Ops.ALU, dtypes.float, arg=BinaryOps.ADD, src=( - UOp(Ops.ALU, dtypes.float, arg=BinaryOps.ADD, src=( - UOp(Ops.ALU, dtypes.float, arg=BinaryOps.ADD, src=( - UOp(Ops.ALU, dtypes.float, arg=BinaryOps.ADD, src=( + UOp(Ops.MUL, dtypes.float, arg=None, src=( + UOp(Ops.ADD, dtypes.float, arg=None, src=( + UOp(Ops.ADD, dtypes.float, arg=None, src=( + UOp(Ops.ADD, dtypes.float, arg=None, src=( + UOp(Ops.ADD, dtypes.float, arg=None, src=( + UOp(Ops.ADD, dtypes.float, arg=None, src=( UOp(Ops.LOAD, dtypes.float, arg=None, src=( UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=1, src=()), UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(384768,), strides=(1,), offset=0, mask=((0, 64128),), contiguous=False), View(shape=(1, 501, 256), strides=(0, 1, 501), offset=256512, mask=None, contiguous=False))), src=()),)), # noqa: E501 diff --git a/test/test_uop_graph.py b/test/test_uop_graph.py index 1025424dda..502c0df635 100644 --- a/test/test_uop_graph.py +++ b/test/test_uop_graph.py @@ -2,7 +2,7 @@ from typing import List import unittest, time from tinygrad import dtypes, Device from tinygrad.helpers import DEBUG -from tinygrad.ops import BinaryOps, TernaryOps, UnaryOps, Ops, UOp, KernelInfo +from tinygrad.ops import BinaryOps, Ops, UOp, KernelInfo from tinygrad.ops import UPat, PatternMatcher from tinygrad.renderer import Renderer from tinygrad.codegen.lowerer import rewrite_shapetracker_with_index @@ -24,7 +24,7 @@ class TestGraphRewriteEfficiency(unittest.TestCase): c1 = UOp.const(dtypes.int, 1) c2 = UOp.const(dtypes.int, 2) st = time.perf_counter() - uops = [UOp(Ops.ALU, dtypes.int, (c1, c2), BinaryOps.ADD) for _ in range(10000)] + uops = [UOp(Ops.ADD, dtypes.int, (c1, c2)) for _ in range(10000)] et = time.perf_counter() - st print(f"created {len(uops)} uops in {et*1000:.2f} ms") @@ -35,9 +35,9 @@ class TestGraphRewriteEfficiency(unittest.TestCase): UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(2, 4, 64, 8, 16, 1, 1, 3, 3, 4, 1), strides=(1179648, 9216, 1, 147456, 576, 0, 0, 64, 192, 36864, 0), offset=0, mask=None, contiguous=False),)), src=()), - UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (5, 6, 10)), src=( + UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(Ops.ADD, (5, 6, 10)), src=( UOp(Ops.CAST, dtypes.float, arg=None, src=( - UOp(Ops.ALU, dtypes.half, arg=BinaryOps.MUL, src=( + UOp(Ops.MUL, dtypes.half, arg=None, src=( UOp(Ops.LOAD, dtypes.half, arg=None, src=( UOp(Ops.DEFINE_GLOBAL, dtypes.half.ptr(), arg=1, src=()), UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=( @@ -194,7 +194,7 @@ class TestUOpGraph(unittest.TestCase): def test_add_constant_fold(self): c1 = UOp(Ops.CONST, dtypes.float, arg=1.0) c2 = UOp(Ops.CONST, dtypes.float, arg=2.0) - out = UOp(Ops.ALU, dtypes.float, (c1, c2), BinaryOps.ADD) + out = UOp(Ops.ADD, dtypes.float, (c1, c2)) uops = to_uops_list([out]) self.assertEqual(len(uops), 1) out = uops[-1] @@ -204,9 +204,9 @@ class TestUOpGraph(unittest.TestCase): def test_where_same_fold(self): v = UOp.variable('tmp', 0, 1) c0 = UOp(Ops.CONST, dtypes.int, arg=0) - vc = UOp(Ops.ALU, dtypes.bool, (v, c0), BinaryOps.CMPNE) + vc = UOp(Ops.CMPNE, dtypes.bool, (v, c0)) c1 = UOp(Ops.CONST, dtypes.float, arg=1.0) - out = UOp(Ops.ALU, dtypes.float, (vc, c1, c1), TernaryOps.WHERE) + out = UOp(Ops.WHERE, dtypes.float, (vc, c1, c1)) uops = to_uops_list([out]) self.assertEqual(len(uops), 1) out = uops[-1] @@ -217,7 +217,7 @@ class TestUOpGraph(unittest.TestCase): bf = UOp(Ops.CONST, dtypes.bool, arg=False) c1 = UOp(Ops.CONST, dtypes.float, arg=1.0) c2 = UOp(Ops.CONST, dtypes.float, arg=2.0) - out = UOp(Ops.ALU, dtypes.float, (bf, c1, c2), TernaryOps.WHERE) + out = UOp(Ops.WHERE, dtypes.float, (bf, c1, c2)) uops = to_uops_list([out]) self.assertEqual(len(uops), 1) out = uops[-1] @@ -240,7 +240,7 @@ class TestUOpGraph(unittest.TestCase): ld = UOp(Ops.LOAD, dtypes.float.vec(2), (d0, idx)) vec = UOp(Ops.VECTORIZE, dtypes.float.vec(2), (ld,)) x = UOp(Ops.GEP, dtypes.float, (vec, ), arg=0) - alu = UOp(Ops.ALU, dtypes.float, (x, ), UnaryOps.SQRT) + alu = UOp(Ops.SQRT, dtypes.float, (x, )) out = UOp(Ops.STORE, dtypes.void, (d0, idx, alu)) uops = to_uops_list([out]) self.assertEqual(len([x for x in uops if x.op is Ops.VECTORIZE]), 0) @@ -375,12 +375,12 @@ class TestUOpGraph(unittest.TestCase): v = UOp.variable("tmp", 0, 1) c2 = UOp(Ops.CONST, dtypes.int, arg=2) c4 = UOp(Ops.CONST, dtypes.int, arg=4) - vc = UOp(Ops.ALU, dtypes.int, (v, c2), BinaryOps.ADD) - out = UOp(Ops.ALU, dtypes.int, (vc, c4), BinaryOps.ADD) + vc = UOp(Ops.ADD, dtypes.int, (v, c2)) + out = UOp(Ops.ADD, dtypes.int, (vc, c4)) uops = to_uops_list([out]) self.assertEqual(len(uops), 3) out = uops[-1] - self.assertEqual(out.op, BinaryOps.ADD) + self.assertEqual(out.op, Ops.ADD) self.assertEqual(out.src[1].op, Ops.CONST) self.assertEqual(out.src[1].arg, 6) @@ -436,7 +436,7 @@ class TestUOpGraph(unittest.TestCase): cf = UOp.const(dtypes.float, 0.0) r1 = UOp(Ops.RANGE, dtypes.int, (c0, c2), (1, 0, False)) r2 = UOp(Ops.RANGE, dtypes.int, (c0, c2), (1, 1, False)) - alu = UOp(Ops.ALU, dtypes.int, (r2, r1), BinaryOps.MUL) + alu = UOp(Ops.MUL, dtypes.int, (r2, r1)) store = UOp(Ops.STORE, dtypes.void, (glbl.index(alu), cf)) uops = to_uops_list([store]) ranges = [x for x in uops if x.op is Ops.RANGE] diff --git a/test/test_uops.py b/test/test_uops.py index a7676a61ce..60dca04273 100644 --- a/test/test_uops.py +++ b/test/test_uops.py @@ -33,7 +33,7 @@ def _test_single_value(vals, op, dts): buf_store = uop(uops, Ops.DEFINE_GLOBAL, output_dtype.ptr(), (), 0) buf_loads = [uop(uops, Ops.DEFINE_GLOBAL, dtype.ptr(), (), i+1) for i,dtype in enumerate(dts)] loads = (uop(uops, Ops.LOAD, dtype, [buf_loads[i].index(uop(uops, Ops.CONST, dtypes.int32, (), 0))]) for i, dtype in enumerate(dts)) - alu = uop(uops, Ops.ALU, output_dtype, loads, op) + alu = uop(uops, op, output_dtype, loads) out = uop(uops, Ops.STORE, dtypes.void, (buf_store.index(uop(uops, Ops.CONST, dtypes.int32, (), 0)), alu)) buf = Buffer(Device.DEFAULT, 1, output_dtype).allocate() buf2 = [Buffer(Device.DEFAULT, 1, dtype).allocate().copyin(np.array([a], dtype=_to_np_dtype(dtype)).data) for a,dtype in zip(vals, dts)] @@ -48,7 +48,7 @@ def _test_single_value_const(vals, op, dts): output_dtype = dtypes.bool if op in (BinaryOps.CMPLT, BinaryOps.CMPNE) else dts[-1] buf_store = uop(uops, Ops.DEFINE_GLOBAL, output_dtype.ptr(), (), 0) loads = (uop(uops, Ops.CONST, dtype, [], a) for a,dtype in zip(vals, dts)) - alu = uop(uops, Ops.ALU, output_dtype, loads, op) + alu = uop(uops, op, output_dtype, loads) out = uop(uops, Ops.STORE, dtypes.void, (buf_store.index(uop(uops, Ops.CONST, dtypes.int32, (), 0)), alu)) buf = Buffer(Device.DEFAULT, 1, output_dtype).allocate() prg = _uops_to_prg([out]) @@ -332,8 +332,8 @@ class TestAssembly(unittest.TestCase): c1 = UOp(Ops.CONST, dtypes.int, (), 2) c2 = UOp(Ops.CONST, dtypes.int, (), 3) l1 = UOp(Ops.LOAD, dtypes.int, (g1.index(c1),)) - a1 = UOp(Ops.ALU, dtypes.int, (l1, c1), BinaryOps.MUL) - a2 = UOp(Ops.ALU, dtypes.int, (l1, c2), BinaryOps.MUL) + a1 = UOp(Ops.MUL, dtypes.int, (l1, c1)) + a2 = UOp(Ops.MUL, dtypes.int, (l1, c2)) uops = to_uops_list([a1,a2], opts=Device[Device.DEFAULT].renderer) Device[Device.DEFAULT].renderer.render("test", uops) self.assertEqual(uops[-1].op, BinaryOps.SHL) @@ -344,8 +344,8 @@ class TestAssembly(unittest.TestCase): c1 = UOp(Ops.CONST, dtypes.int, (), 2) c2 = UOp(Ops.CONST, dtypes.int, (), 3) l1 = UOp(Ops.LOAD, dtypes.int, (g1.index(c1),)) - a1 = UOp(Ops.ALU, dtypes.int, (l1, c1), BinaryOps.IDIV) - a2 = UOp(Ops.ALU, dtypes.int, (l1, c2), BinaryOps.IDIV) + a1 = UOp(Ops.IDIV, dtypes.int, (l1, c1)) + a2 = UOp(Ops.IDIV, dtypes.int, (l1, c2)) uops = to_uops_list([a1,a2], opts=Device[Device.DEFAULT].renderer) Device[Device.DEFAULT].renderer.render("test", uops) self.assertEqual(uops[-1].op, BinaryOps.SHR) @@ -357,8 +357,8 @@ class TestUOpMethod(unittest.TestCase): a = UOp(Ops.CONST, dtypes.float, (), 2.0) b = UOp(Ops.CONST, dtypes.float, (), 3.0) - add = UOp(Ops.ALU, dtypes.float, (a, b), BinaryOps.ADD) - mul = UOp(Ops.ALU, dtypes.float, (a, b), BinaryOps.MUL) + add = UOp(Ops.ADD, dtypes.float, (a, b)) + mul = UOp(Ops.MUL, dtypes.float, (a, b)) assert (add < mul) or (mul < add), "add and mul with same src should have an order" def test_uop_variables(self): diff --git a/test/test_uops_stats.py b/test/test_uops_stats.py index 25865639ab..10b12dbc6c 100644 --- a/test/test_uops_stats.py +++ b/test/test_uops_stats.py @@ -4,7 +4,7 @@ from tinygrad.helpers import getenv, GlobalCounters from tinygrad.engine.schedule import create_schedule from tinygrad.engine.realize import lower_schedule_item from tinygrad.codegen.linearize import linearize_uop -from tinygrad.ops import BinaryOps, TernaryOps, flops_mem, Ops, UOp +from tinygrad.ops import flops_mem, Ops, UOp from tinygrad.dtype import dtypes from tinygrad.codegen.kernel import Kernel, Opt, OptOps, KernelOptError @@ -125,8 +125,8 @@ class TestUOpsStats(unittest.TestCase): u1 = UOp(Ops.LOAD, dtypes.int, (globl.index(o1),)) u2 = UOp(Ops.LOAD, dtypes.int, (globl.index(o2),)) u3 = UOp(Ops.CONST, dtypes.int, tuple(), 3) - u4 = UOp(Ops.ALU, dtypes.int, (u1,u2), BinaryOps.MUL) - u5 = UOp(Ops.ALU, dtypes.int, (u4,u3), BinaryOps.ADD) + u4 = UOp(Ops.MUL, dtypes.int, (u1,u2)) + u5 = UOp(Ops.ADD, dtypes.int, (u4,u3)) uops = linearize_uop(u5.sink()) globl = UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(), tuple()) @@ -135,7 +135,7 @@ class TestUOpsStats(unittest.TestCase): u1 = UOp(Ops.LOAD, dtypes.int, (globl.index(o1),)) u2 = UOp(Ops.LOAD, dtypes.int, (globl.index(o2),)) u3 = UOp(Ops.CONST, dtypes.int, tuple(), 3) - u4 = UOp(Ops.ALU, dtypes.int, (u1,u2,u3), TernaryOps.MULACC) + u4 = UOp(Ops.MULACC, dtypes.int, (u1,u2,u3)) uops_fma = linearize_uop(u4.sink()) self.assertEqual(flops_mem(uops), flops_mem(uops_fma)) diff --git a/test/unit/test_pattern_matcher.py b/test/unit/test_pattern_matcher.py index 7925fb0dff..45b2243de3 100644 --- a/test/unit/test_pattern_matcher.py +++ b/test/unit/test_pattern_matcher.py @@ -54,7 +54,7 @@ class TestPatternMatcher(unittest.TestCase): def test_uop(self): matcher = PatternMatcher([(UPat(Ops.CONST, name="x"), lambda x: x)]) c1 = UOp(Ops.CONST, dtypes.float, arg=1.0) - c2 = UOp(Ops.ALU, dtypes.float, (c1, c1), BinaryOps.ADD) + c2 = UOp(Ops.ADD, dtypes.float, (c1, c1)) self.assertEqual(matcher.rewrite(c1), c1) self.assertEqual(matcher.rewrite(c2), None) @@ -63,7 +63,7 @@ class TestPatternMatcher(unittest.TestCase): c1 = UOp(Ops.CONST, dtypes.bool, arg=False) c2 = UOp(Ops.CAST, dtypes.int, (c1,)) c3 = UOp(Ops.CONST, dtypes.float, arg=1.0) - c4 = UOp(Ops.ALU, dtypes.float, (c3, c3), BinaryOps.ADD) + c4 = UOp(Ops.ADD, dtypes.float, (c3, c3)) self.assertEqual(matcher.rewrite(c1), c1) self.assertEqual(matcher.rewrite(c2), c2) self.assertEqual(matcher.rewrite(c4), None) @@ -76,8 +76,8 @@ class TestPatternMatcher(unittest.TestCase): ]) c1 = UOp(Ops.CONST, dtypes.float, arg=0.0) c2 = UOp(Ops.CONST, dtypes.bool, arg=False) - c3 = UOp(Ops.ALU, dtypes.float, (c1, c1), arg=BinaryOps.MAX) - c4 = UOp(Ops.ALU, dtypes.float, (c1, c1), arg=BinaryOps.MUL) + c3 = UOp(Ops.MAX, dtypes.float, (c1, c1)) + c4 = UOp(Ops.MUL, dtypes.float, (c1, c1)) c5 = UOp(Ops.CONST, dtypes.int, arg=-1) self.assertEqual(matcher.rewrite(c1), c1) self.assertEqual(matcher.rewrite(c2), c2) @@ -93,11 +93,11 @@ class TestPatternMatcher(unittest.TestCase): y1 = UOp(Ops.CONST, dtypes.int, arg=1) y2 = UOp(Ops.CONST, dtypes.int, arg=2) y3 = UOp(Ops.CONST, dtypes.int, arg=-1) - c1 = UOp(Ops.ALU, dtypes.int, (y1, y2), BinaryOps.MUL) - c2 = UOp(Ops.ALU, dtypes.int, (y2, y2), BinaryOps.MUL) - c3 = UOp(Ops.ALU, dtypes.int, (y3, y2), BinaryOps.MUL) - c4 = UOp(Ops.ALU, dtypes.int, (y2, y1), BinaryOps.MUL) - c5 = UOp(Ops.ALU, dtypes.int, (y2, y3), BinaryOps.MUL) + c1 = UOp(Ops.MUL, dtypes.int, (y1, y2)) + c2 = UOp(Ops.MUL, dtypes.int, (y2, y2)) + c3 = UOp(Ops.MUL, dtypes.int, (y3, y2)) + c4 = UOp(Ops.MUL, dtypes.int, (y2, y1)) + c5 = UOp(Ops.MUL, dtypes.int, (y2, y3)) self.assertEqual(matcher.rewrite(c1), c1) self.assertEqual(matcher.rewrite(c2), None) self.assertEqual(matcher.rewrite(c3), c3) @@ -135,7 +135,7 @@ class TestPatternMatcher(unittest.TestCase): matcher = PatternMatcher([(UPat(GroupOp.ALU, name="x", src=(UPat(Ops.CONST), UPat(Ops.CONST))), lambda x: x)]) c1 = UOp(Ops.CONST, dtypes.float, arg=1.0) c2 = UOp(Ops.CONST, dtypes.float, arg=2.0) - c3 = UOp(Ops.ALU, dtypes.float, (c1,c2), BinaryOps.ADD) + c3 = UOp(Ops.ADD, dtypes.float, (c1,c2)) self.assertEqual(matcher.rewrite(c3), c3) self.assertEqual(matcher.rewrite(c2), None) # that CONST/ALU -> ALU/CONST rewrite is now instant @@ -152,10 +152,10 @@ class TestPatternMatcher(unittest.TestCase): matcher = PatternMatcher([(UPat(GroupOp.ALU, name="x", src=[UPat(Ops.CONST), UPat(GroupOp.ALU)]), lambda x: x)]) c1 = UOp(Ops.CONST, dtypes.float, arg=1.0) c2 = UOp(Ops.CONST, dtypes.float, arg=2.0) - c3 = UOp(Ops.ALU, dtypes.float, (c1,c2), BinaryOps.ADD) - c4 = UOp(Ops.ALU, dtypes.float, (c3,c2), BinaryOps.ADD) - c5 = UOp(Ops.ALU, dtypes.float, (c2,c3), BinaryOps.ADD) - c6 = UOp(Ops.ALU, dtypes.float, (c3,c4), BinaryOps.ADD) + c3 = UOp(Ops.ADD, dtypes.float, (c1,c2)) + c4 = UOp(Ops.ADD, dtypes.float, (c3,c2)) + c5 = UOp(Ops.ADD, dtypes.float, (c2,c3)) + c6 = UOp(Ops.ADD, dtypes.float, (c3,c4)) self.assertEqual(matcher.rewrite(c3), None) self.assertEqual(matcher.rewrite(c4), c4) self.assertEqual(matcher.rewrite(c5), c5) @@ -165,8 +165,8 @@ class TestPatternMatcher(unittest.TestCase): matcher = PatternMatcher([(UPat(GroupOp.ALU, name="x", src=UPat(Ops.CONST)), lambda x: x)]) c1 = UOp(Ops.CONST, dtypes.float, arg=1.0) c2 = UOp(Ops.CONST, dtypes.float, arg=2.0) - c3 = UOp(Ops.ALU, dtypes.float, (c1,c2), BinaryOps.ADD) - c4 = UOp(Ops.ALU, dtypes.float, (c2,c3), BinaryOps.ADD) + c3 = UOp(Ops.ADD, dtypes.float, (c1,c2)) + c4 = UOp(Ops.ADD, dtypes.float, (c2,c3)) self.assertEqual(matcher.rewrite(c3), c3) self.assertEqual(matcher.rewrite(c4), None) @@ -175,9 +175,9 @@ class TestPatternMatcher(unittest.TestCase): c1 = UOp(Ops.CONST, dtypes.float, arg=1.0) c2 = UOp(Ops.CONST, dtypes.float, arg=2.0) c3 = UOp(Ops.CONST, dtypes.float, arg=3.0) - c4 = UOp(Ops.ALU, dtypes.float, (c1,), UnaryOps.EXP2) - c5 = UOp(Ops.ALU, dtypes.float, (c1,c2), BinaryOps.ADD) - c6 = UOp(Ops.ALU, dtypes.float, (c1,c2,c3), TernaryOps.MULACC) + c4 = UOp(Ops.EXP2, dtypes.float, (c1,)) + c5 = UOp(Ops.ADD, dtypes.float, (c1,c2)) + c6 = UOp(Ops.MULACC, dtypes.float, (c1,c2,c3)) self.assertEqual(matcher.rewrite(c4), None) self.assertEqual(matcher.rewrite(c5), None) self.assertEqual(matcher.rewrite(c6), c6) diff --git a/tinygrad/ops.py b/tinygrad/ops.py index 7a3b1d568d..a9ebb44fdb 100644 --- a/tinygrad/ops.py +++ b/tinygrad/ops.py @@ -230,8 +230,6 @@ def pretty_print(x:Any, rep:Callable, srcfn=lambda x: x.src, cache=None, d=0)->s class UOpMetaClass(type): ucache:WeakValueDictionary[Tuple, UOp] = WeakValueDictionary() def __call__(cls, op:Ops, dtype:DType=dtypes.void, src:Tuple[UOp,...]=tuple(), arg:Any=None): - # TODO: remove this - if op is Ops.ALU: op, arg = arg, None if (ret:=UOpMetaClass.ucache.get(key:=(op, dtype, src, arg), None)) is not None: return ret UOpMetaClass.ucache[key] = ret = super().__call__(op, dtype, src, arg) return ret