update tests (#7554)

Co-authored-by: John Doe <null@mail.com>
Co-authored-by: chenyu <chenyu@fastmail.com>
This commit is contained in:
Carl Basho
2024-11-05 16:35:15 +00:00
committed by GitHub
parent 207bca6cea
commit 630a7f37cf
12 changed files with 314 additions and 316 deletions

View File

@@ -2,7 +2,7 @@
import unittest
from tinygrad import Device
from tinygrad.ops import UOp, Ops, BinaryOps, UnaryOps
from tinygrad.ops import UOp, Ops, BinaryOps
from tinygrad.engine.search import Opt, OptOps
from tinygrad.dtype import dtypes
from tinygrad.shape.shapetracker import ShapeTracker
@@ -17,12 +17,12 @@ class TestOpenpilotValidhack(unittest.TestCase):
UOp(Ops.STORE, dtypes.void, arg=None, src=(
UOp(Ops.DEFINE_GLOBAL, dtypes.imagef((64, 1024, 4)), arg=0, src=()),
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 64, 128, 1, 1, 8, 4, 1, 1, 1, 1), strides=(0, 4096, 32, 0, 0, 4, 1, 0, 0, 0, 0), offset=0, mask=None, contiguous=True),)), src=()),
UOp(Ops.ALU, dtypes.float, arg=BinaryOps.ADD, src=(
UOp(Ops.ALU, dtypes.float, arg=BinaryOps.MAX, src=(
x5:=UOp(Ops.ALU, dtypes.float, arg=BinaryOps.ADD, src=(
UOp(Ops.ADD, dtypes.float, arg=None, src=(
UOp(Ops.MAX, dtypes.float, arg=None, src=(
x5:=UOp(Ops.ADD, dtypes.float, arg=None, src=(
UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (7, 8, 9, 10)), src=(
UOp(Ops.CAST, dtypes.float, arg=None, src=(
UOp(Ops.ALU, dtypes.float, arg=BinaryOps.MUL, src=(
UOp(Ops.MUL, dtypes.float, arg=None, src=(
UOp(Ops.LOAD, dtypes.float, arg=None, src=(
UOp(Ops.DEFINE_GLOBAL, dtypes.imagef((128, 768, 4)), arg=1, src=()),
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 1, 1, 1, 1, 3, 1, 4, 4, 130, 4, 258), strides=(0, 0, 0, 0, 0, 4, 0, 1, 0, 3072, 0, 12), offset=-3084, mask=((0, 1), (0, 1), (0, 1), (0, 1), (0, 1), (0, 3), (0, 1), (0, 4), (0, 4), (1, 129), (0, 4), (1, 257)), contiguous=False), View(shape=(1, 64, 128, 1, 1, 8, 4, 3, 4, 3, 3), strides=(0, 2064, 2, 0, 0, 0, 0, 2146560, 536640, 135192, 259), offset=0, mask=None, contiguous=False))), src=()),)),
@@ -35,14 +35,14 @@ class TestOpenpilotValidhack(unittest.TestCase):
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 64, 128, 1, 1, 8, 4, 1, 1, 1, 1), strides=(0, 0, 0, 0, 0, 4, 1, 0, 0, 0, 0), offset=0, mask=None, contiguous=False),)), src=()),)),)),
x19:=UOp(Ops.CONST, dtypes.float, arg=0.0, src=(
x20:=UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 64, 128, 1, 1, 8, 4, 1, 1, 1, 1), strides=(0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0), offset=0, mask=None, contiguous=False),)), src=()),)),)),
UOp(Ops.ALU, dtypes.float, arg=BinaryOps.MUL, src=(
UOp(Ops.ALU, dtypes.float, arg=BinaryOps.MAX, src=(
UOp(Ops.ALU, dtypes.float, arg=BinaryOps.ADD, src=(
UOp(Ops.MUL, dtypes.float, arg=None, src=(
UOp(Ops.MAX, dtypes.float, arg=None, src=(
UOp(Ops.ADD, dtypes.float, arg=None, src=(
UOp(Ops.CONST, dtypes.float, arg=1.0, src=(
x20,)),
UOp(Ops.ALU, dtypes.float, arg=BinaryOps.MUL, src=(
UOp(Ops.ALU, dtypes.float, arg=UnaryOps.EXP2, src=(
UOp(Ops.ALU, dtypes.float, arg=BinaryOps.MUL, src=(
UOp(Ops.MUL, dtypes.float, arg=None, src=(
UOp(Ops.EXP2, dtypes.float, arg=None, src=(
UOp(Ops.MUL, dtypes.float, arg=None, src=(
x5,
UOp(Ops.CONST, dtypes.float, arg=1.4426950408889634, src=(
x20,)),)),)),
@@ -67,19 +67,19 @@ class TestOpenpilotValidhack(unittest.TestCase):
UOp(Ops.DEFINE_GLOBAL, dtypes.imagef((10, 128, 4)), arg=0, src=()),
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 10, 512), strides=(0, 512, 1), offset=0, mask=None, contiguous=True),)), src=()),
UOp(Ops.CAST, dtypes.float, arg=None, src=(
UOp(Ops.ALU, dtypes.float, arg=BinaryOps.ADD, src=(
UOp(Ops.ALU, dtypes.float, arg=BinaryOps.ADD, src=(
UOp(Ops.ADD, dtypes.float, arg=None, src=(
UOp(Ops.ADD, dtypes.float, arg=None, src=(
UOp(Ops.LOAD, dtypes.float, arg=None, src=(
UOp(Ops.DEFINE_GLOBAL, dtypes.imagef((1, 128, 4)), arg=1, src=()),
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 10, 512), strides=(0, 0, 1), offset=0, mask=((0, 1), (0, 1), (0, 512)), contiguous=False),)), src=()),)),
UOp(Ops.CAST, dtypes.float, arg=None, src=(
UOp(Ops.ALU, dtypes.float, arg=BinaryOps.ADD, src=(
UOp(Ops.ALU, dtypes.float, arg=BinaryOps.ADD, src=(
UOp(Ops.ALU, dtypes.float, arg=BinaryOps.ADD, src=(
UOp(Ops.ALU, dtypes.float, arg=BinaryOps.ADD, src=(
UOp(Ops.ALU, dtypes.float, arg=BinaryOps.ADD, src=(
UOp(Ops.ALU, dtypes.float, arg=BinaryOps.ADD, src=(
UOp(Ops.ALU, dtypes.float, arg=BinaryOps.ADD, src=(
UOp(Ops.ADD, dtypes.float, arg=None, src=(
UOp(Ops.ADD, dtypes.float, arg=None, src=(
UOp(Ops.ADD, dtypes.float, arg=None, src=(
UOp(Ops.ADD, dtypes.float, arg=None, src=(
UOp(Ops.ADD, dtypes.float, arg=None, src=(
UOp(Ops.ADD, dtypes.float, arg=None, src=(
UOp(Ops.ADD, dtypes.float, arg=None, src=(
UOp(Ops.LOAD, dtypes.float, arg=None, src=(
x18:=UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=2, src=()),
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 10, 512), strides=(0, 0, 1), offset=48128, mask=((0, 1), (1, 2), (0, 512)), contiguous=False),)), src=()),)),

View File

@@ -600,16 +600,16 @@ class TestLinearizer(unittest.TestCase):
UOp(Ops.STORE, dtypes.void, arg=None, src=(
UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(), arg=0, src=()),
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 20, 1), strides=(0, 1, 0), offset=0, mask=None, contiguous=True),)), src=()), # noqa E501
UOp(Ops.ALU, dtypes.int, arg=BinaryOps.ADD, src=(
UOp(Ops.ALU, dtypes.int, arg=BinaryOps.ADD, src=(
UOp(Ops.ADD, dtypes.int, arg=None, src=(
UOp(Ops.ADD, dtypes.int, arg=None, src=(
ast_const(dtypes.int, st=ShapeTracker(views=(View(shape=(1, 20, 1), strides=(0, 0, 0), offset=0, mask=None, contiguous=False),)), val=10),
UOp(Ops.ALU, dtypes.int, arg=BinaryOps.MUL, src=(
UOp(Ops.MUL, dtypes.int, arg=None, src=(
ast_const(dtypes.int, -1, (1, 20, 1)),
UOp(Ops.REDUCE_AXIS, dtypes.int, arg=(BinaryOps.MAX, (0,)), src=(
UOp(Ops.ALU, dtypes.int, arg=BinaryOps.MUL, src=(
UOp(Ops.MUL, dtypes.int, arg=None, src=(
UOp(Ops.CAST, dtypes.int, arg=None, src=(
UOp(Ops.ALU, dtypes.bool, arg=BinaryOps.CMPNE, src=(
UOp(Ops.ALU, dtypes.bool, arg=BinaryOps.CMPNE, src=(
UOp(Ops.CMPNE, dtypes.bool, arg=None, src=(
UOp(Ops.CMPNE, dtypes.bool, arg=None, src=(
UOp(Ops.LOAD, dtypes.float, arg=None, src=(
UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=1, src=()),
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(10, 20, 1), strides=(20, 1, 0), offset=0, mask=None, contiguous=True),)), src=()),)), # noqa E501
@@ -617,7 +617,7 @@ class TestLinearizer(unittest.TestCase):
UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=2, src=()),
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(10, 20, 1), strides=(0, 1, 0), offset=0, mask=None, contiguous=False),)), src=()),)),)), # noqa E501
ast_const(dtypes.bool, True, st=ShapeTracker(views=(View(shape=(10, 20, 1), strides=(0, 0, 0), offset=0, mask=None, contiguous=False),))),)),)), # noqa E501
UOp(Ops.ALU, dtypes.int, arg=BinaryOps.ADD, src=(
UOp(Ops.ADD, dtypes.int, arg=None, src=(
UOp(Ops.REDUCE_AXIS, dtypes.int, arg=(BinaryOps.ADD, (2,)), src=(
ast_const(dtypes.int, -1, st=ShapeTracker(views=(View(shape=(11, 19), strides=(0, 0), offset=0, mask=((0, 11), (9, 19)), contiguous=False), View(shape=(10, 20, 10), strides=(1, 0, 20), offset=0, mask=None, contiguous=False)))),)), # noqa E501
ast_const(dtypes.int, 10, (10, 20, 1)))),)),)),)),)),
@@ -632,16 +632,16 @@ class TestLinearizer(unittest.TestCase):
UOp(Ops.STORE, dtypes.void, arg=None, src=(
UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(), arg=0, src=()),
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 1), strides=(0, 0), offset=0, mask=None, contiguous=True),)), src=()), # noqa: E501
UOp(Ops.ALU, dtypes.int, arg=BinaryOps.ADD, src=(
UOp(Ops.ALU, dtypes.int, arg=BinaryOps.ADD, src=(
UOp(Ops.ADD, dtypes.int, arg=None, src=(
UOp(Ops.ADD, dtypes.int, arg=None, src=(
ast_const(dtypes.int, 200, (1, 1)),
UOp(Ops.ALU, dtypes.int, arg=BinaryOps.MUL, src=(
UOp(Ops.MUL, dtypes.int, arg=None, src=(
ast_const(dtypes.int, -1, (1, 1)),
UOp(Ops.REDUCE_AXIS, dtypes.int, arg=(BinaryOps.MAX, (0,)), src=(
UOp(Ops.ALU, dtypes.int, arg=BinaryOps.MUL, src=(
UOp(Ops.MUL, dtypes.int, arg=None, src=(
UOp(Ops.CAST, dtypes.int, arg=None, src=(
UOp(Ops.ALU, dtypes.bool, arg=BinaryOps.CMPNE, src=(
UOp(Ops.ALU, dtypes.bool, arg=BinaryOps.CMPNE, src=(
UOp(Ops.CMPNE, dtypes.bool, arg=None, src=(
UOp(Ops.CMPNE, dtypes.bool, arg=None, src=(
UOp(Ops.LOAD, dtypes.float, arg=None, src=(
UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=1, src=()),
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(200, 1), strides=(1, 0), offset=0, mask=None, contiguous=True),)), src=()),)), # noqa: E501
@@ -649,7 +649,7 @@ class TestLinearizer(unittest.TestCase):
UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=2, src=()),
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(200, 1), strides=(0, 0), offset=0, mask=None, contiguous=False),)), src=()),)),)), # noqa: E501
ast_const(dtypes.bool, True, (200, 1)),)),)),
UOp(Ops.ALU, dtypes.int, arg=BinaryOps.ADD, src=(
UOp(Ops.ADD, dtypes.int, arg=None, src=(
UOp(Ops.REDUCE_AXIS, dtypes.int, arg=(BinaryOps.ADD, (1,)), src=(
ast_const(dtypes.int, -1, st=ShapeTracker(views=(View(shape=(201, 399), strides=(0, 0), offset=0, mask=((0, 201), (199, 399)), contiguous=False), View(shape=(200, 200), strides=(1, 400), offset=0, mask=None, contiguous=False)))),)), # noqa: E501
ast_const(dtypes.int, 200, (200, 1)),)),)),)),)),)),
@@ -732,16 +732,16 @@ class TestLinearizer(unittest.TestCase):
UOp(Ops.STORE, src=(
UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=0),
UOp(Ops.VIEW, arg=ShapeTracker(views=(View(shape=(N, 1, 1), strides=(1, 0, 0), offset=0, mask=None, contiguous=True),))),
UOp(Ops.ALU, dtypes.float, arg=TernaryOps.WHERE, src=(
UOp(Ops.ALU, dtypes.bool, arg=BinaryOps.CMPLT, src=(
UOp(Ops.WHERE, dtypes.float, arg=None, src=(
UOp(Ops.CMPLT, dtypes.bool, arg=None, src=(
ast_const(dtypes.float, 0.5*N, (N, 1, 1)),
UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (1,)), src=(
UOp(Ops.ALU, dtypes.float, arg=BinaryOps.ADD, src=(
UOp(Ops.ADD, dtypes.float, arg=None, src=(
UOp(Ops.LOAD, dtypes.float, src=(
UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=1),
ld1.to_uop(),)),
UOp(Ops.ALU, dtypes.float, arg=TernaryOps.WHERE, src=(
UOp(Ops.ALU, dtypes.bool, arg=BinaryOps.CMPLT, src=(
UOp(Ops.WHERE, dtypes.float, arg=None, src=(
UOp(Ops.CMPLT, dtypes.bool, arg=None, src=(
ast_const(dtypes.float, 0.75*N, (N, N, 1)),
UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (2,)), src=(
UOp(Ops.LOAD, dtypes.float, src=(
@@ -765,16 +765,16 @@ class TestLinearizer(unittest.TestCase):
UOp(Ops.STORE, src=(
UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=0, src=()),
UOp(Ops.VIEW, arg=ShapeTracker(views=(View(shape=(1, 1, N), strides=(0, 0, 1), offset=0, mask=None, contiguous=True),)), src=()),
UOp(Ops.ALU, dtypes.float, arg=TernaryOps.WHERE, src=(
UOp(Ops.ALU, dtypes.bool, arg=BinaryOps.CMPLT, src=(
UOp(Ops.WHERE, dtypes.float, arg=None, src=(
UOp(Ops.CMPLT, dtypes.bool, arg=None, src=(
ast_const(dtypes.float, 0.5*N, (1, 1, N)),
UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (0,)), src=(
UOp(Ops.ALU, dtypes.float, arg=BinaryOps.ADD, src=(
UOp(Ops.ADD, dtypes.float, arg=None, src=(
UOp(Ops.LOAD, dtypes.float, src=(
UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=1, src=()),
ld1.to_uop(),)),
UOp(Ops.ALU, dtypes.float, arg=TernaryOps.WHERE, src=(
UOp(Ops.ALU, dtypes.bool, arg=BinaryOps.CMPLT, src=(
UOp(Ops.WHERE, dtypes.float, arg=None, src=(
UOp(Ops.CMPLT, dtypes.bool, arg=None, src=(
ast_const(dtypes.float, 0.75*N, (N, 1, N)),
UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (1,)), src=(
UOp(Ops.LOAD, dtypes.float, src=(
@@ -801,16 +801,16 @@ class TestLinearizer(unittest.TestCase):
UOp(Ops.STORE, src=(
UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=0, src=()),
UOp(Ops.VIEW, arg=ShapeTracker(views=(View(shape=(1, 1, 1, 1), strides=(0, 0, 0, 0), offset=0, mask=None, contiguous=True),))),
UOp(Ops.ALU, dtypes.float, arg=TernaryOps.WHERE, src=(
UOp(Ops.ALU, dtypes.bool, arg=BinaryOps.CMPLT, src=(
UOp(Ops.WHERE, dtypes.float, arg=None, src=(
UOp(Ops.CMPLT, dtypes.bool, arg=None, src=(
ast_const(dtypes.float, 0.5*N, (1, 1, 1, 1)),
UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (0, 1)), src=(
UOp(Ops.ALU, dtypes.float, arg=BinaryOps.ADD, src=(
UOp(Ops.ADD, dtypes.float, arg=None, src=(
UOp(Ops.LOAD, dtypes.float, src=(
UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=1),
UOp(Ops.VIEW, arg=ShapeTracker(views=(View(shape=(N, N, 1, 1), strides=(N, 1, 0, 0), offset=0, mask=None, contiguous=True),))),)), # noqa: E501
UOp(Ops.ALU, dtypes.float, arg=TernaryOps.WHERE, src=(
UOp(Ops.ALU, dtypes.bool, arg=BinaryOps.CMPLT, src=(
UOp(Ops.WHERE, dtypes.float, arg=None, src=(
UOp(Ops.CMPLT, dtypes.bool, arg=None, src=(
ast_const(dtypes.float, 0.75*N, (N, N, 1, 1)),
UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (2, 3)), src=(
UOp(Ops.LOAD, dtypes.float, src=(
@@ -1604,7 +1604,7 @@ class TestFloat4(unittest.TestCase):
UOp(Ops.VIEW, arg=ShapeTracker(views=(View(shape=(1, 3, 32000, 1), strides=(0, 32000, 1, 0), offset=0, mask=None, contiguous=True),))), # noqa: E501
UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (3,)), src=(
UOp(Ops.CAST, dtypes.float, src=(
UOp(Ops.ALU, dtypes.half, arg=BinaryOps.MUL, src=(
UOp(Ops.MUL, dtypes.half, arg=None, src=(
UOp(Ops.LOAD, dtypes.half, src=(
UOp(Ops.DEFINE_GLOBAL, dtypes.half.ptr(), arg=1),
UOp(Ops.VIEW, arg=ShapeTracker(views=(View(shape=(1, 3, 32000, 1024), strides=(0, 4096, 0, 1), offset=0, mask=None, contiguous=False),))),)), # noqa: E501
@@ -1631,9 +1631,9 @@ class TestFloat4(unittest.TestCase):
UOp(Ops.STORE, src=(
UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=0),
UOp(Ops.VIEW, arg=ShapeTracker(views=(View(shape=(1, 1, 128, 512, 512, 1, 1, 1), strides=(0, 0, 262144, 512, 1, 0, 0, 0), offset=0, mask=None, contiguous=True),))), # noqa: E501
UOp(Ops.ALU, dtypes.float, arg=BinaryOps.ADD, src=(
UOp(Ops.ADD, dtypes.float, arg=None, src=(
UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (5, 6, 7)), src=(
UOp(Ops.ALU, dtypes.float, arg=BinaryOps.MUL, src=(
UOp(Ops.MUL, dtypes.float, arg=None, src=(
UOp(Ops.LOAD, dtypes.float, src=(
UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=1),
UOp(Ops.VIEW, arg=ShapeTracker(views=(View(shape=(1, 1, 1, 256, 4, 514, 4, 514), strides=(0, 0, 0, 262144, 0, 512, 0, 1), offset=-513, mask=((0, 1), (0, 1), (0, 1), (0, 256), (0, 4), (1, 513), (0, 4), (1, 513)), contiguous=False), View(shape=(1, 1, 128, 512, 512, 256, 3, 3), strides=(0, 0, 0, 2056, 1, 4227136, 1058840, 515), offset=0, mask=None, contiguous=False)))),)), # noqa: E501
@@ -1950,9 +1950,9 @@ class TestKernelOpts(unittest.TestCase):
UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=0, src=()),
UOp(Ops.VIEW, arg=ShapeTracker(views=(View(shape=(1, 256), strides=(0, 1), offset=0, mask=None, contiguous=True),))),
UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (0,)), src=(
UOp(Ops.ALU, dtypes.float, arg=BinaryOps.MUL, src=(
UOp(Ops.MUL, dtypes.float, arg=None, src=(
UOp(Ops.CAST, dtypes.float, src=(
UOp(Ops.ALU, dtypes.bool, arg=BinaryOps.CMPNE, src=(
UOp(Ops.CMPNE, dtypes.bool, arg=None, src=(
UOp(Ops.LOAD, dtypes.int, src=(
UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(), arg=1),
UOp(Ops.VIEW, arg=ShapeTracker(views=(View(shape=(1243, 256), strides=(0, 1), offset=0, mask=None, contiguous=False),))),)), # noqa: E501

View File

@@ -5,7 +5,7 @@
import unittest
from test.helpers import ast_const
from tinygrad import Device, dtypes
from tinygrad.ops import UOp, Ops, BinaryOps, TernaryOps
from tinygrad.ops import UOp, Ops, BinaryOps
from tinygrad.helpers import getenv
from tinygrad.shape.shapetracker import ShapeTracker, View
from tinygrad.engine.search import Opt, OptOps
@@ -18,12 +18,12 @@ class TestLinearizerDumb(unittest.TestCase):
UOp(Ops.STORE, dtypes.void, arg=None, src=(
UOp(Ops.DEFINE_GLOBAL, dtypes.half.ptr(), arg=0, src=()),
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(64, 1, 512, 7, 7, 1, 1, 1), strides=(25088, 0, 49, 7, 1, 0, 0, 0), offset=0, mask=None, contiguous=True),)), src=()),
UOp(Ops.ALU, dtypes.half, arg=BinaryOps.MAX, src=(
UOp(Ops.ALU, dtypes.half, arg=BinaryOps.MUL, src=(
UOp(Ops.MAX, dtypes.half, arg=None, src=(
UOp(Ops.MUL, dtypes.half, arg=None, src=(
UOp(Ops.CAST, dtypes.half, arg=None, src=(
UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (5, 6, 7)), src=(
UOp(Ops.CAST, dtypes.float, arg=None, src=(
UOp(Ops.ALU, dtypes.half, arg=BinaryOps.MUL, src=(
UOp(Ops.MUL, dtypes.half, arg=None, src=(
UOp(Ops.LOAD, dtypes.half, arg=None, src=(
UOp(Ops.DEFINE_GLOBAL, dtypes.half.ptr(), arg=1, src=()),
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 64, 1, 512, 4, 9, 4, 9), strides=(0, 25088, 0, 49, 0, 7, 0, 1), offset=-8, mask=((0, 1), (0, 64), (0, 1), (0, 512), (0, 4), (1, 8), (0, 4), (1, 8)), contiguous=False), View(shape=(64, 1, 512, 7, 7, 512, 3, 3), strides=(663552, 0, 0, 36, 1, 1296, 360, 10), offset=0, mask=None, contiguous=False))), src=()),)),
@@ -51,10 +51,10 @@ class TestLinearizerDumb(unittest.TestCase):
UOp(Ops.STORE, dtypes.void, arg=None, src=(
UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(), arg=0, src=()),
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1000, 1), strides=(1, 0), offset=0, mask=None, contiguous=True),)), src=()),
UOp(Ops.ALU, dtypes.int, arg=BinaryOps.MUL, src=(
UOp(Ops.MUL, dtypes.int, arg=None, src=(
UOp(Ops.CAST, dtypes.int, arg=None, src=(
UOp(Ops.ALU, dtypes.bool, arg=BinaryOps.CMPNE, src=(
UOp(Ops.ALU, dtypes.bool, arg=BinaryOps.CMPNE, src=(
UOp(Ops.CMPNE, dtypes.bool, arg=None, src=(
UOp(Ops.CMPNE, dtypes.bool, arg=None, src=(
UOp(Ops.LOAD, dtypes.float, arg=None, src=(
UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=1, src=()),
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1000, 1), strides=(1, 0), offset=0, mask=None, contiguous=True),)), src=()),)),
@@ -63,7 +63,7 @@ class TestLinearizerDumb(unittest.TestCase):
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1000, 1), strides=(0, 0), offset=0, mask=None, contiguous=False),)), src=()),)),)),
ast_const(dtypes.bool, True, st_src=(
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1000, 1), strides=(0, 0), offset=0, mask=None, contiguous=False),)), src=()),)),)),)),
UOp(Ops.ALU, dtypes.int, arg=BinaryOps.ADD, src=(
UOp(Ops.ADD, dtypes.int, arg=None, src=(
UOp(Ops.REDUCE_AXIS, dtypes.int, arg=(BinaryOps.ADD, (1,)), src=(
ast_const(dtypes.int, -1, st_src=(
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1001, 1999), strides=(0, 0), offset=0, mask=((0, 1001), (999, 1999)), contiguous=False), View(shape=(1000, 1000), strides=(1, 2000), offset=0, mask=None, contiguous=False))), src=()),)),)),
@@ -107,11 +107,11 @@ class TestLinearizerDumb(unittest.TestCase):
UOp(Ops.CAST, dtypes.half, arg=None, src=(
UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (1,)), src=(
UOp(Ops.CAST, dtypes.float, arg=None, src=(
UOp(Ops.ALU, dtypes.half, arg=BinaryOps.MUL, src=(
UOp(Ops.MUL, dtypes.half, arg=None, src=(
UOp(Ops.CAST, dtypes.half, arg=None, src=(
UOp(Ops.ALU, dtypes.bool, arg=BinaryOps.CMPNE, src=(
UOp(Ops.ALU, dtypes.bool, arg=BinaryOps.CMPNE, src=(
UOp(Ops.ALU, dtypes.int, arg=BinaryOps.ADD, src=(
UOp(Ops.CMPNE, dtypes.bool, arg=None, src=(
UOp(Ops.CMPNE, dtypes.bool, arg=None, src=(
UOp(Ops.ADD, dtypes.int, arg=None, src=(
UOp(Ops.REDUCE_AXIS, dtypes.int, arg=(BinaryOps.ADD, (2,)), src=(
ast_const(dtypes.int, 1, st_src=(
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(32001, 63999), strides=(0, 0), offset=0, mask=((0, 32001), (31999, 63999)), contiguous=False), View(shape=(4096, 32000, 32000), strides=(0, 1, 64000), offset=0, mask=None, contiguous=False))), src=()),)),)),
@@ -137,10 +137,10 @@ class TestLinearizerDumb(unittest.TestCase):
UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=0, src=()),
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(3, 1, 1), strides=(1, 0, 0), offset=0, mask=None, contiguous=True),)), src=()),
UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (2,)), src=(
UOp(Ops.ALU, dtypes.float, arg=BinaryOps.MUL, src=(
UOp(Ops.MUL, dtypes.float, arg=None, src=(
UOp(Ops.CAST, dtypes.float, arg=None, src=(
UOp(Ops.ALU, dtypes.bool, arg=BinaryOps.CMPNE, src=(
UOp(Ops.ALU, dtypes.bool, arg=BinaryOps.CMPNE, src=(
UOp(Ops.CMPNE, dtypes.bool, arg=None, src=(
UOp(Ops.CMPNE, dtypes.bool, arg=None, src=(
UOp(Ops.LOAD, dtypes.long, arg=None, src=(
UOp(Ops.DEFINE_GLOBAL, dtypes.long.ptr(), arg=1, src=()),
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(3, 1, 5), strides=(1, 0, 0), offset=0, mask=None, contiguous=False),)), src=()),)),
@@ -169,9 +169,9 @@ class TestLinearizerDumb(unittest.TestCase):
UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=0, src=()),
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 1), strides=(0, 0), offset=0, mask=None, contiguous=True),)), src=()),
UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (0, 1)), src=(
UOp(Ops.ALU, dtypes.float, arg=TernaryOps.WHERE, src=(
UOp(Ops.ALU, dtypes.bool, arg=BinaryOps.CMPNE, src=(
UOp(Ops.ALU, dtypes.bool, arg=BinaryOps.CMPNE, src=(
UOp(Ops.WHERE, dtypes.float, arg=None, src=(
UOp(Ops.CMPNE, dtypes.bool, arg=None, src=(
UOp(Ops.CMPNE, dtypes.bool, arg=None, src=(
UOp(Ops.LOAD, dtypes.long, arg=None, src=(
UOp(Ops.DEFINE_GLOBAL, dtypes.long.ptr(), arg=1, src=()),
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(3, 6), strides=(6, 1), offset=0, mask=None, contiguous=True),)), src=()),)),
@@ -201,7 +201,7 @@ class TestLinearizerDumb(unittest.TestCase):
UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=0, src=()),
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(4, 5, 13, 1, 1, 1, 1, 1, 4, 3, 3), strides=(2340, 468, 36, 0, 0, 0, 0, 0, 9, 3, 1), offset=0, mask=None, contiguous=True),)), src=()),
UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (6,)), src=(
UOp(Ops.ALU, dtypes.float, arg=BinaryOps.MUL, src=(
UOp(Ops.MUL, dtypes.float, arg=None, src=(
UOp(Ops.LOAD, dtypes.float, arg=None, src=(
UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=1, src=()),
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(4, 5, 13, 1, 1, 1, 4, 1, 4, 3, 3), strides=(0, 0, 0, 0, 0, 0, 1, 0, 4, 48, 16), offset=0, mask=None, contiguous=False),)), src=()),)),

View File

@@ -2,7 +2,7 @@
import unittest, random
import numpy as np
from tinygrad.codegen.kernel import Kernel, KernelOptError
from tinygrad.ops import UOp, Ops, BinaryOps, UnaryOps, TernaryOps
from tinygrad.ops import UOp, Ops, BinaryOps
from tinygrad.engine.search import Opt, OptOps
from tinygrad import Device, dtypes, Tensor
from tinygrad.helpers import CI
@@ -44,8 +44,8 @@ class TestLinearizerFailures(unittest.TestCase):
UOp(Ops.STORE, dtypes.void, arg=None, src=(
UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=0, src=()),
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(32, 16, 1), strides=(16, 1, 0), offset=0, mask=None, contiguous=True),)), src=()),
UOp(Ops.ALU, dtypes.float, arg=BinaryOps.ADD, src=(
UOp(Ops.ALU, dtypes.float, arg=BinaryOps.ADD, src=(
UOp(Ops.ADD, dtypes.float, arg=None, src=(
UOp(Ops.ADD, dtypes.float, arg=None, src=(
UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (2,)), src=(
UOp(Ops.LOAD, dtypes.float, arg=None, src=(
UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=1, src=()),
@@ -89,9 +89,9 @@ class TestLinearizerFailures(unittest.TestCase):
UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=0, src=()),
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 1, 1, 1, 1, 1, 1, 1), strides=(0, 0, 0, 0, 0, 0, 0, 0), offset=0, mask=None, contiguous=True),)), src=()),
UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (0, 2, 4, 6)), src=(
UOp(Ops.ALU, dtypes.float, arg=BinaryOps.ADD, src=(
x5:=UOp(Ops.ALU, dtypes.float, arg=BinaryOps.MUL, src=(
UOp(Ops.ALU, dtypes.float, arg=BinaryOps.ADD, src=(
UOp(Ops.ADD, dtypes.float, arg=None, src=(
x5:=UOp(Ops.MUL, dtypes.float, arg=None, src=(
UOp(Ops.ADD, dtypes.float, arg=None, src=(
ast_const(dtypes.float, 0.1464405059814453, st_src=(
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(2, 1, 4, 1, 3, 1, 4, 1), strides=(0, 0, 0, 0, 0, 0, 0, 0), offset=0, mask=None, contiguous=False),)), src=()),)),
ast_const(dtypes.float, 1.0, st_src=(
@@ -109,7 +109,7 @@ class TestLinearizerFailures(unittest.TestCase):
UOp(Ops.STORE, dtypes.void, arg=None, src=(
UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(), arg=0, src=()),
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(10, 1), strides=(1, 0), offset=0, mask=None, contiguous=True),)), src=()),
UOp(Ops.ALU, dtypes.int, arg=BinaryOps.ADD, src=(
UOp(Ops.ADD, dtypes.int, arg=None, src=(
UOp(Ops.REDUCE_AXIS, dtypes.int, arg=(BinaryOps.ADD, (1,)), src=(
ast_const(dtypes.int, -1, st_src=(
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(11, 19), strides=(0, 0), offset=0, mask=((0, 11), (9, 19)), contiguous=False), View(shape=(10, 10), strides=(1, 20), offset=0, mask=None, contiguous=False))), src=()),)),)),
@@ -137,13 +137,13 @@ class TestLinearizerFailures(unittest.TestCase):
UOp(Ops.STORE, dtypes.void, arg=None, src=(
UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=0, src=()),
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 1, 1), strides=(0, 0, 0), offset=0, mask=None, contiguous=True),)), src=()),
UOp(Ops.ALU, dtypes.float, arg=UnaryOps.SQRT, src=(
UOp(Ops.ALU, dtypes.float, arg=UnaryOps.RECIP, src=(
UOp(Ops.ALU, dtypes.float, arg=BinaryOps.ADD, src=(
UOp(Ops.ALU, dtypes.float, arg=BinaryOps.MUL, src=(
UOp(Ops.SQRT, dtypes.float, arg=None, src=(
UOp(Ops.RECIP, dtypes.float, arg=None, src=(
UOp(Ops.ADD, dtypes.float, arg=None, src=(
UOp(Ops.MUL, dtypes.float, arg=None, src=(
UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (2,)), src=(
UOp(Ops.ALU, dtypes.float, arg=BinaryOps.MUL, src=(
x9:=UOp(Ops.ALU, dtypes.float, arg=BinaryOps.ADD, src=(
UOp(Ops.MUL, dtypes.float, arg=None, src=(
x9:=UOp(Ops.ADD, dtypes.float, arg=None, src=(
UOp(Ops.LOAD, dtypes.float, arg=None, src=(
UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=1, src=()),
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 1, 4096), strides=(0, 0, 1), offset=0, mask=None, contiguous=True),)), src=()),)),
@@ -166,7 +166,7 @@ class TestLinearizerFailures(unittest.TestCase):
UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=0, src=()),
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 1, 1, 3, 1, 1, 1, 1, 5, 15, 5, 3, 4), strides=(0, 0, 0, 4500, 0, 0, 0, 0, 900, 60, 12, 4, 1), offset=0, mask=None, contiguous=True),)), src=()),
UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (1,)), src=(
UOp(Ops.ALU, dtypes.float, arg=BinaryOps.MUL, src=(
UOp(Ops.MUL, dtypes.float, arg=None, src=(
UOp(Ops.LOAD, dtypes.float, arg=None, src=(
UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=1, src=()),
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 2, 1, 3, 1, 1, 1, 1, 5, 15, 5, 3, 4), strides=(0, 3, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0), offset=0, mask=None, contiguous=False),)), src=()),)),
@@ -181,9 +181,9 @@ class TestLinearizerFailures(unittest.TestCase):
UOp(Ops.STORE, dtypes.void, arg=None, src=(
UOp(Ops.DEFINE_GLOBAL, dtypes.half.ptr(), arg=0, src=()),
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 1, 1024, 1), strides=(0, 0, 1, 0), offset=0, mask=None, contiguous=True),)), src=()),
UOp(Ops.ALU, dtypes.half, arg=BinaryOps.ADD, src=(
UOp(Ops.ADD, dtypes.half, arg=None, src=(
UOp(Ops.REDUCE_AXIS, dtypes.half, arg=(BinaryOps.ADD, (3,)), src=(
UOp(Ops.ALU, dtypes.half, arg=BinaryOps.MUL, src=(
UOp(Ops.MUL, dtypes.half, arg=None, src=(
UOp(Ops.LOAD, dtypes.half, arg=None, src=(
UOp(Ops.DEFINE_GLOBAL, dtypes.half.ptr(), arg=1, src=()),
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 1, 1024, 50257), strides=(0, 0, 0, 1), offset=0, mask=None, contiguous=False),)), src=()),)),
@@ -200,13 +200,13 @@ class TestLinearizerFailures(unittest.TestCase):
UOp(Ops.STORE, dtypes.void, arg=None, src=(
UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=0, src=()),
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 64, 1, 1), strides=(0, 1, 0, 0), offset=0, mask=None, contiguous=True),)), src=()),
UOp(Ops.ALU, dtypes.float, arg=UnaryOps.RECIP, src=(
UOp(Ops.RECIP, dtypes.float, arg=None, src=(
UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (0, 2, 3)), src=(
UOp(Ops.ALU, dtypes.float, arg=BinaryOps.MUL, src=(
UOp(Ops.ALU, dtypes.float, arg=BinaryOps.MUL, src=(
UOp(Ops.ALU, dtypes.float, arg=BinaryOps.ADD, src=(
UOp(Ops.ALU, dtypes.float, arg=BinaryOps.MAX, src=(
UOp(Ops.ALU, dtypes.float, arg=BinaryOps.ADD, src=(
UOp(Ops.MUL, dtypes.float, arg=None, src=(
UOp(Ops.MUL, dtypes.float, arg=None, src=(
UOp(Ops.ADD, dtypes.float, arg=None, src=(
UOp(Ops.MAX, dtypes.float, arg=None, src=(
UOp(Ops.ADD, dtypes.float, arg=None, src=(
UOp(Ops.LOAD, dtypes.float, arg=None, src=(
UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=1, src=()),
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(512, 64, 6, 6), strides=(2304, 36, 6, 1), offset=0, mask=None, contiguous=True),)), src=()),)),
@@ -220,19 +220,19 @@ class TestLinearizerFailures(unittest.TestCase):
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(512, 64, 6, 6), strides=(0, 1, 0, 0), offset=0, mask=None, contiguous=False),)), src=()),)),)),
ast_const(dtypes.float, 1.0, st_src=(
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(512, 64, 6, 6), strides=(0, 0, 0, 0), offset=0, mask=None, contiguous=False),)), src=()),)),)),
UOp(Ops.ALU, dtypes.float, arg=BinaryOps.MUL, src=(
UOp(Ops.ALU, dtypes.float, arg=BinaryOps.MUL, src=(
UOp(Ops.ALU, dtypes.float, arg=BinaryOps.ADD, src=(
UOp(Ops.MUL, dtypes.float, arg=None, src=(
UOp(Ops.MUL, dtypes.float, arg=None, src=(
UOp(Ops.ADD, dtypes.float, arg=None, src=(
ast_const(dtypes.float, 1.0, st_src=(
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(512, 64, 3, 3, 2, 2), strides=(0, 0, 0, 0, 0, 0), offset=0, mask=None, contiguous=False), View(shape=(512, 64, 3, 2, 3, 2), strides=(2304, 36, 12, 2, 4, 1), offset=0, mask=None, contiguous=False), View(shape=(512, 64, 6, 6), strides=(2304, 36, 6, 1), offset=0, mask=None, contiguous=True))), src=()),)),
UOp(Ops.CAST, dtypes.float, arg=None, src=(
UOp(Ops.ALU, dtypes.bool, arg=BinaryOps.CMPLT, src=(
UOp(Ops.ALU, dtypes.float, arg=BinaryOps.ADD, src=(
UOp(Ops.ALU, dtypes.float, arg=BinaryOps.MUL, src=(
UOp(Ops.ALU, dtypes.float, arg=BinaryOps.MUL, src=(
UOp(Ops.ALU, dtypes.float, arg=BinaryOps.ADD, src=(
UOp(Ops.ALU, dtypes.float, arg=BinaryOps.MAX, src=(
UOp(Ops.ALU, dtypes.float, arg=BinaryOps.ADD, src=(
UOp(Ops.CMPLT, dtypes.bool, arg=None, src=(
UOp(Ops.ADD, dtypes.float, arg=None, src=(
UOp(Ops.MUL, dtypes.float, arg=None, src=(
UOp(Ops.MUL, dtypes.float, arg=None, src=(
UOp(Ops.ADD, dtypes.float, arg=None, src=(
UOp(Ops.MAX, dtypes.float, arg=None, src=(
UOp(Ops.ADD, dtypes.float, arg=None, src=(
UOp(Ops.LOAD, dtypes.float, arg=None, src=(
UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=1, src=()),
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(512, 64, 6, 6), strides=(2304, 36, 6, 1), offset=0, mask=None, contiguous=True), View(shape=(512, 64, 3, 3, 2, 2), strides=(2304, 36, 12, 2, 6, 1), offset=0, mask=None, contiguous=False), View(shape=(512, 64, 3, 2, 3, 2), strides=(2304, 36, 12, 2, 4, 1), offset=0, mask=None, contiguous=False), View(shape=(512, 64, 6, 6), strides=(2304, 36, 6, 1), offset=0, mask=None, contiguous=True))), src=()),)),
@@ -246,11 +246,11 @@ class TestLinearizerFailures(unittest.TestCase):
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(512, 64, 6, 6), strides=(0, 1, 0, 0), offset=0, mask=None, contiguous=False), View(shape=(512, 64, 3, 3, 2, 2), strides=(2304, 36, 12, 2, 6, 1), offset=0, mask=None, contiguous=False), View(shape=(512, 64, 3, 2, 3, 2), strides=(2304, 36, 12, 2, 4, 1), offset=0, mask=None, contiguous=False), View(shape=(512, 64, 6, 6), strides=(2304, 36, 6, 1), offset=0, mask=None, contiguous=True))), src=()),)),)),
ast_const(dtypes.float, 1.0, st_src=(
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(512, 64, 6, 6), strides=(0, 0, 0, 0), offset=0, mask=None, contiguous=False), View(shape=(512, 64, 3, 3, 2, 2), strides=(2304, 36, 12, 2, 6, 1), offset=0, mask=None, contiguous=False), View(shape=(512, 64, 3, 2, 3, 2), strides=(2304, 36, 12, 2, 4, 1), offset=0, mask=None, contiguous=False), View(shape=(512, 64, 6, 6), strides=(2304, 36, 6, 1), offset=0, mask=None, contiguous=True))), src=()),)),)),
UOp(Ops.ALU, dtypes.float, arg=UnaryOps.SQRT, src=(
UOp(Ops.SQRT, dtypes.float, arg=None, src=(
UOp(Ops.CAST, dtypes.float, arg=None, src=(
UOp(Ops.ALU, dtypes.float, arg=UnaryOps.RECIP, src=(
UOp(Ops.ALU, dtypes.float, arg=BinaryOps.ADD, src=(
UOp(Ops.ALU, dtypes.float, arg=BinaryOps.MUL, src=(
UOp(Ops.RECIP, dtypes.float, arg=None, src=(
UOp(Ops.ADD, dtypes.float, arg=None, src=(
UOp(Ops.MUL, dtypes.float, arg=None, src=(
UOp(Ops.LOAD, dtypes.float, arg=None, src=(
UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=4, src=()),
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(64,), strides=(1,), offset=0, mask=None, contiguous=True), View(shape=(512, 64, 6, 6), strides=(0, 1, 0, 0), offset=0, mask=None, contiguous=False), View(shape=(512, 64, 3, 3, 2, 2), strides=(2304, 36, 12, 2, 6, 1), offset=0, mask=None, contiguous=False), View(shape=(512, 64, 3, 2, 3, 2), strides=(2304, 36, 12, 2, 4, 1), offset=0, mask=None, contiguous=False), View(shape=(512, 64, 6, 6), strides=(2304, 36, 6, 1), offset=0, mask=None, contiguous=True))), src=()),)),
@@ -262,7 +262,7 @@ class TestLinearizerFailures(unittest.TestCase):
UOp(Ops.LOAD, dtypes.float, arg=None, src=(
UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=5, src=()),
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(512, 64, 3, 3, 2, 2), strides=(576, 9, 3, 1, 0, 0), offset=0, mask=None, contiguous=False), View(shape=(512, 64, 3, 2, 3, 2), strides=(2304, 36, 12, 2, 4, 1), offset=0, mask=None, contiguous=False), View(shape=(512, 64, 6, 6), strides=(2304, 36, 6, 1), offset=0, mask=None, contiguous=True))), src=()),)),)),)),)),
UOp(Ops.ALU, dtypes.float, arg=UnaryOps.RECIP, src=(
UOp(Ops.RECIP, dtypes.float, arg=None, src=(
UOp(Ops.LOAD, dtypes.float, arg=None, src=(
UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=6, src=()),
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(512, 64, 3, 3, 2, 2), strides=(576, 9, 3, 1, 0, 0), offset=0, mask=None, contiguous=False), View(shape=(512, 64, 3, 2, 3, 2), strides=(2304, 36, 12, 2, 4, 1), offset=0, mask=None, contiguous=False), View(shape=(512, 64, 6, 6), strides=(2304, 36, 6, 1), offset=0, mask=None, contiguous=True))), src=()),)),)),)),
@@ -277,9 +277,9 @@ class TestLinearizerFailures(unittest.TestCase):
UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=0, src=()),
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 1, 1, 1, 1, 4, 1, 6, 1, 3), strides=(0, 0, 0, 0, 0, 18, 0, 3, 0, 1), offset=0, mask=None, contiguous=True),)), src=()),
UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (0, 2, 4, 6)), src=(
UOp(Ops.ALU, dtypes.float, arg=BinaryOps.ADD, src=(
x5:=UOp(Ops.ALU, dtypes.float, arg=BinaryOps.MUL, src=(
UOp(Ops.ALU, dtypes.float, arg=BinaryOps.ADD, src=(
UOp(Ops.ADD, dtypes.float, arg=None, src=(
x5:=UOp(Ops.MUL, dtypes.float, arg=None, src=(
UOp(Ops.ADD, dtypes.float, arg=None, src=(
UOp(Ops.LOAD, dtypes.float, arg=None, src=(
UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=1, src=()),
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(2, 1, 4, 1, 3, 4, 2, 6, 1, 3), strides=(0, 0, 0, 0, 0, 18, 0, 3, 0, 1), offset=0, mask=None, contiguous=False),)), src=()),)),
@@ -299,10 +299,10 @@ class TestLinearizerFailures(unittest.TestCase):
UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=0, src=()),
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 1, 1, 1, 1, 4, 1, 6, 1, 3), strides=(0, 0, 0, 0, 0, 18, 0, 3, 0, 1), offset=0, mask=None, contiguous=True),)), src=()),
UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (0, 2, 4, 8)), src=(
UOp(Ops.ALU, dtypes.float, arg=BinaryOps.ADD, src=(
x5:=UOp(Ops.ALU, dtypes.float, arg=BinaryOps.ADD, src=(
x6:=UOp(Ops.ALU, dtypes.float, arg=BinaryOps.MUL, src=(
UOp(Ops.ALU, dtypes.float, arg=BinaryOps.ADD, src=(
UOp(Ops.ADD, dtypes.float, arg=None, src=(
x5:=UOp(Ops.ADD, dtypes.float, arg=None, src=(
x6:=UOp(Ops.MUL, dtypes.float, arg=None, src=(
UOp(Ops.ADD, dtypes.float, arg=None, src=(
UOp(Ops.LOAD, dtypes.float, arg=None, src=(
UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=1, src=()),
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(2, 1, 4, 1, 3, 4, 2, 6, 1, 3), strides=(0, 0, 0, 0, 0, 18, 0, 3, 0, 1), offset=0, mask=None, contiguous=False),)), src=()),)),
@@ -323,9 +323,9 @@ class TestLinearizerFailures(unittest.TestCase):
UOp(Ops.STORE, dtypes.void, arg=None, src=(
UOp(Ops.DEFINE_GLOBAL, dtypes.half.ptr(), arg=0, src=()),
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(2, 1, 384, 1), strides=(384, 0, 1, 0), offset=0, mask=None, contiguous=True),)), src=()),
UOp(Ops.ALU, dtypes.half, arg=BinaryOps.ADD, src=(
UOp(Ops.ADD, dtypes.half, arg=None, src=(
UOp(Ops.REDUCE_AXIS, dtypes.half, arg=(BinaryOps.ADD, (3,)), src=(
UOp(Ops.ALU, dtypes.half, arg=BinaryOps.MUL, src=(
UOp(Ops.MUL, dtypes.half, arg=None, src=(
UOp(Ops.LOAD, dtypes.half, arg=None, src=(
UOp(Ops.DEFINE_GLOBAL, dtypes.half.ptr(), arg=1, src=()),
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(2, 1, 384, 51864), strides=(51864, 0, 0, 1), offset=0, mask=None, contiguous=False),)), src=()),)),
@@ -344,9 +344,9 @@ class TestLinearizerFailures(unittest.TestCase):
UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=0, src=()),
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 1, 1, 1, 1, 4, 1, 6, 1, 3), strides=(0, 0, 0, 0, 0, 18, 0, 3, 0, 1), offset=0, mask=None, contiguous=True),)), src=()),
UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (0, 2, 4, 6)), src=(
UOp(Ops.ALU, dtypes.float, arg=BinaryOps.ADD, src=(
x5:=UOp(Ops.ALU, dtypes.float, arg=BinaryOps.MUL, src=(
UOp(Ops.ALU, dtypes.float, arg=BinaryOps.ADD, src=(
UOp(Ops.ADD, dtypes.float, arg=None, src=(
x5:=UOp(Ops.MUL, dtypes.float, arg=None, src=(
UOp(Ops.ADD, dtypes.float, arg=None, src=(
UOp(Ops.LOAD, dtypes.float, arg=None, src=(
UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=1, src=()),
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(2, 1, 4, 1, 3, 4, 2, 6, 1, 3), strides=(0, 0, 0, 0, 0, 18, 0, 3, 0, 1), offset=0, mask=None, contiguous=False),)), src=()),)),
@@ -365,12 +365,12 @@ class TestLinearizerFailures(unittest.TestCase):
UOp(Ops.STORE, dtypes.void, arg=None, src=(
UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=0, src=()),
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 1, 112, 14, 14, 1, 1, 1), strides=(0, 0, 196, 14, 1, 0, 0, 0), offset=0, mask=None, contiguous=True),)), src=()),
UOp(Ops.ALU, dtypes.float, arg=BinaryOps.ADD, src=(
UOp(Ops.ALU, dtypes.float, arg=BinaryOps.MUL, src=(
UOp(Ops.ALU, dtypes.float, arg=BinaryOps.MUL, src=(
UOp(Ops.ALU, dtypes.float, arg=BinaryOps.ADD, src=(
UOp(Ops.ADD, dtypes.float, arg=None, src=(
UOp(Ops.MUL, dtypes.float, arg=None, src=(
UOp(Ops.MUL, dtypes.float, arg=None, src=(
UOp(Ops.ADD, dtypes.float, arg=None, src=(
UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (5,)), src=(
UOp(Ops.ALU, dtypes.float, arg=BinaryOps.MUL, src=(
UOp(Ops.MUL, dtypes.float, arg=None, src=(
UOp(Ops.LOAD, dtypes.float, arg=None, src=(
UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=1, src=()),
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 1, 112, 14, 14, 480, 1, 1), strides=(0, 0, 0, 14, 1, 196, 0, 0), offset=0, mask=None, contiguous=False),)), src=()),)),
@@ -383,9 +383,9 @@ class TestLinearizerFailures(unittest.TestCase):
UOp(Ops.LOAD, dtypes.float, arg=None, src=(
UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=4, src=()),
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 1, 112, 14, 14, 1, 1, 1), strides=(0, 0, 1, 0, 0, 0, 0, 0), offset=0, mask=None, contiguous=False),)), src=()),)),)),
UOp(Ops.ALU, dtypes.float, arg=UnaryOps.SQRT, src=(
UOp(Ops.ALU, dtypes.float, arg=UnaryOps.RECIP, src=(
UOp(Ops.ALU, dtypes.float, arg=BinaryOps.ADD, src=(
UOp(Ops.SQRT, dtypes.float, arg=None, src=(
UOp(Ops.RECIP, dtypes.float, arg=None, src=(
UOp(Ops.ADD, dtypes.float, arg=None, src=(
UOp(Ops.LOAD, dtypes.float, arg=None, src=(
UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=5, src=()),
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 1, 112, 14, 14, 1, 1, 1), strides=(0, 0, 1, 0, 0, 0, 0, 0), offset=0, mask=None, contiguous=False),)), src=()),)),
@@ -403,7 +403,7 @@ class TestLinearizerFailures(unittest.TestCase):
UOp(Ops.STORE, dtypes.void, arg=None, src=(
UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=0, src=()),
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 13, 1), strides=(0, 1, 0), offset=0, mask=None, contiguous=True),)), src=()),
UOp(Ops.ALU, dtypes.float, arg=BinaryOps.MUL, src=(
UOp(Ops.MUL, dtypes.float, arg=None, src=(
UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (2,)), src=(
UOp(Ops.LOAD, dtypes.float, arg=None, src=(
UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=1, src=()),
@@ -420,7 +420,7 @@ class TestLinearizerFailures(unittest.TestCase):
UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=0, src=()),
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(2, 1, 40, 1, 28, 28, 1, 1), strides=(31360, 0, 784, 0, 28, 1, 0, 0), offset=0, mask=None, contiguous=True),)), src=()),
UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (3,)), src=(
UOp(Ops.ALU, dtypes.float, arg=BinaryOps.MUL, src=(
UOp(Ops.MUL, dtypes.float, arg=None, src=(
UOp(Ops.LOAD, dtypes.float, arg=None, src=(
UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=1, src=()),
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(2, 1, 40, 240, 28, 28, 1, 1), strides=(0, 0, 1, 40, 0, 0, 0, 0), offset=0, mask=None, contiguous=False),)), src=()),)),
@@ -436,13 +436,13 @@ class TestLinearizerFailures(unittest.TestCase):
UOp(Ops.STORE, dtypes.void, arg=None, src=(
UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=0, src=()),
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(2, 1, 384, 1), strides=(384, 0, 1, 0), offset=0, mask=None, contiguous=True),)), src=()),
UOp(Ops.ALU, dtypes.float, arg=BinaryOps.ADD, src=(
UOp(Ops.ADD, dtypes.float, arg=None, src=(
UOp(Ops.LOAD, dtypes.float, arg=None, src=(
UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=1, src=()),
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(2, 1, 384, 1), strides=(384, 0, 1, 0), offset=0, mask=None, contiguous=True),)), src=()),)),
UOp(Ops.ALU, dtypes.float, arg=BinaryOps.ADD, src=(
UOp(Ops.ADD, dtypes.float, arg=None, src=(
UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (3,)), src=(
UOp(Ops.ALU, dtypes.float, arg=BinaryOps.MUL, src=(
UOp(Ops.MUL, dtypes.float, arg=None, src=(
UOp(Ops.LOAD, dtypes.float, arg=None, src=(
UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=2, src=()),
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(2, 1, 384, 1536), strides=(1536, 0, 0, 1), offset=0, mask=None, contiguous=False),)), src=()),)),
@@ -462,7 +462,7 @@ class TestLinearizerFailures(unittest.TestCase):
UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=0, src=()),
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(2, 1, 4, 1, 9, 7, 3, 3), strides=(2268, 0, 567, 0, 63, 9, 3, 1), offset=0, mask=None, contiguous=True),)), src=()),
UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (3,)), src=(
UOp(Ops.ALU, dtypes.float, arg=BinaryOps.MUL, src=(
UOp(Ops.MUL, dtypes.float, arg=None, src=(
UOp(Ops.LOAD, dtypes.float, arg=None, src=(
UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=1, src=()),
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(2, 1, 4, 4, 9, 7, 3, 3), strides=(0, 0, 36, 9, 0, 0, -3, -1), offset=8, mask=None, contiguous=False),)), src=()),)),
@@ -478,7 +478,7 @@ class TestLinearizerFailures(unittest.TestCase):
UOp(Ops.STORE, dtypes.void, arg=None, src=(
UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=0, src=()),
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(4, 4), strides=(4, 1), offset=0, mask=None, contiguous=True),)), src=()),
UOp(Ops.ALU, dtypes.float, arg=BinaryOps.MUL, src=(
UOp(Ops.MUL, dtypes.float, arg=None, src=(
UOp(Ops.LOAD, dtypes.float, arg=None, src=(
UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=1, src=()),
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(4, 4), strides=(0, 1), offset=0, mask=None, contiguous=False),)), src=()),)),
@@ -504,20 +504,20 @@ class TestLinearizerFailures(unittest.TestCase):
UOp(Ops.STORE, dtypes.void, arg=None, src=(
UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=0, src=()),
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 96, 1, 1), strides=(0, 1, 0, 0), offset=0, mask=None, contiguous=True),)), src=()),
UOp(Ops.ALU, dtypes.float, arg=BinaryOps.MUL, src=(
UOp(Ops.MUL, dtypes.float, arg=None, src=(
x4:=ast_const(dtypes.float, 0.000244140625, st_src=(
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 96, 1, 1), strides=(0, 0, 0, 0), offset=0, mask=None, contiguous=False),)), src=()),)),
UOp(Ops.ALU, dtypes.float, arg=BinaryOps.MUL, src=(
UOp(Ops.ALU, dtypes.float, arg=BinaryOps.MUL, src=(
UOp(Ops.ALU, dtypes.float, arg=BinaryOps.MUL, src=(
UOp(Ops.MUL, dtypes.float, arg=None, src=(
UOp(Ops.MUL, dtypes.float, arg=None, src=(
UOp(Ops.MUL, dtypes.float, arg=None, src=(
UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (0, 2, 3)), src=(
UOp(Ops.ALU, dtypes.float, arg=BinaryOps.MUL, src=(
UOp(Ops.ALU, dtypes.float, arg=BinaryOps.MUL, src=(
UOp(Ops.ALU, dtypes.float, arg=BinaryOps.ADD, src=(
UOp(Ops.ALU, dtypes.float, arg=BinaryOps.ADD, src=(
UOp(Ops.ALU, dtypes.float, arg=BinaryOps.MUL, src=(
UOp(Ops.ALU, dtypes.float, arg=BinaryOps.MUL, src=(
UOp(Ops.ALU, dtypes.float, arg=BinaryOps.ADD, src=(
UOp(Ops.MUL, dtypes.float, arg=None, src=(
UOp(Ops.MUL, dtypes.float, arg=None, src=(
UOp(Ops.ADD, dtypes.float, arg=None, src=(
UOp(Ops.ADD, dtypes.float, arg=None, src=(
UOp(Ops.MUL, dtypes.float, arg=None, src=(
UOp(Ops.MUL, dtypes.float, arg=None, src=(
UOp(Ops.ADD, dtypes.float, arg=None, src=(
UOp(Ops.LOAD, dtypes.float, arg=None, src=(
UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=1, src=()),
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(32, 96, 8, 16), strides=(12288, 128, 16, 1), offset=0, mask=None, contiguous=True),)), src=()),)),
@@ -539,51 +539,51 @@ class TestLinearizerFailures(unittest.TestCase):
UOp(Ops.LOAD, dtypes.float, arg=None, src=(
UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=7, src=()),
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(32, 96, 8, 16), strides=(0, 1, 0, 0), offset=0, mask=None, contiguous=False),)), src=()),)),)),
UOp(Ops.ALU, dtypes.float, arg=BinaryOps.ADD, src=(
UOp(Ops.ADD, dtypes.float, arg=None, src=(
UOp(Ops.LOAD, dtypes.float, arg=None, src=(
UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=8, src=()),
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(2, 32, 48, 8, 16), strides=(0, 8640, 180, 18, 1), offset=19, mask=((1, 2), (0, 32), (0, 48), (0, 8), (0, 16)), contiguous=False), View(shape=(1536, 2, 128), strides=(128, 196608, 1), offset=0, mask=None, contiguous=False), View(shape=(32, 96, 8, 16), strides=(12288, 128, 16, 1), offset=0, mask=None, contiguous=True))), src=()),)),
UOp(Ops.ALU, dtypes.float, arg=BinaryOps.ADD, src=(
UOp(Ops.ADD, dtypes.float, arg=None, src=(
UOp(Ops.LOAD, dtypes.float, arg=None, src=(
UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=9, src=()),
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(2, 32, 48, 8, 16), strides=(0, 8640, 180, 18, 1), offset=19, mask=((1, 2), (0, 32), (0, 48), (0, 8), (0, 16)), contiguous=False), View(shape=(1536, 2, 128), strides=(128, 196608, 1), offset=0, mask=None, contiguous=False), View(shape=(2, 32, 48, 8, 16), strides=(0, 12288, 128, 16, 1), offset=0, mask=((0, 1), (0, 32), (0, 48), (0, 8), (0, 16)), contiguous=False), View(shape=(1536, 2, 128), strides=(128, 196608, 1), offset=0, mask=None, contiguous=False), View(shape=(32, 96, 8, 16), strides=(12288, 128, 16, 1), offset=0, mask=None, contiguous=True))), src=()),)),
UOp(Ops.ALU, dtypes.float, arg=BinaryOps.ADD, src=(
UOp(Ops.ADD, dtypes.float, arg=None, src=(
UOp(Ops.LOAD, dtypes.float, arg=None, src=(
UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=10, src=()),
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(2, 32, 48, 8, 16), strides=(0, 8640, 180, 18, 1), offset=19, mask=((1, 2), (0, 32), (0, 48), (0, 8), (0, 16)), contiguous=False), View(shape=(1536, 2, 128), strides=(128, 196608, 1), offset=0, mask=None, contiguous=False), View(shape=(2, 32, 48, 8, 16), strides=(0, 12288, 128, 16, 1), offset=0, mask=((0, 1), (0, 32), (0, 48), (0, 8), (0, 16)), contiguous=False), View(shape=(1536, 2, 128), strides=(128, 196608, 1), offset=0, mask=None, contiguous=False), View(shape=(2, 32, 48, 8, 16), strides=(0, 12288, 128, 16, 1), offset=0, mask=((0, 1), (0, 32), (0, 48), (0, 8), (0, 16)), contiguous=False), View(shape=(1536, 2, 128), strides=(128, 196608, 1), offset=0, mask=None, contiguous=False), View(shape=(32, 96, 8, 16), strides=(12288, 128, 16, 1), offset=0, mask=None, contiguous=True))), src=()),)),
UOp(Ops.ALU, dtypes.float, arg=BinaryOps.ADD, src=(
UOp(Ops.ADD, dtypes.float, arg=None, src=(
UOp(Ops.LOAD, dtypes.float, arg=None, src=(
UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=11, src=()),
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(2, 32, 48, 8, 16), strides=(0, 8640, 180, 18, 1), offset=19, mask=((1, 2), (0, 32), (0, 48), (0, 8), (0, 16)), contiguous=False), View(shape=(1536, 2, 128), strides=(128, 196608, 1), offset=0, mask=None, contiguous=False), View(shape=(2, 32, 48, 8, 16), strides=(0, 12288, 128, 16, 1), offset=0, mask=((0, 1), (0, 32), (0, 48), (0, 8), (0, 16)), contiguous=False), View(shape=(1536, 2, 128), strides=(128, 196608, 1), offset=0, mask=None, contiguous=False), View(shape=(2, 32, 48, 8, 16), strides=(0, 12288, 128, 16, 1), offset=0, mask=((0, 1), (0, 32), (0, 48), (0, 8), (0, 16)), contiguous=False), View(shape=(1536, 2, 128), strides=(128, 196608, 1), offset=0, mask=None, contiguous=False), View(shape=(2, 32, 48, 8, 16), strides=(0, 12288, 128, 16, 1), offset=0, mask=((0, 1), (0, 32), (0, 48), (0, 8), (0, 16)), contiguous=False), View(shape=(1536, 2, 128), strides=(128, 196608, 1), offset=0, mask=None, contiguous=False), View(shape=(32, 96, 8, 16), strides=(12288, 128, 16, 1), offset=0, mask=None, contiguous=True))), src=()),)),
UOp(Ops.ALU, dtypes.float, arg=BinaryOps.ADD, src=(
UOp(Ops.ADD, dtypes.float, arg=None, src=(
UOp(Ops.LOAD, dtypes.float, arg=None, src=(
UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=12, src=()),
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(2, 32, 48, 8, 16), strides=(0, 8640, 180, 18, 1), offset=19, mask=((1, 2), (0, 32), (0, 48), (0, 8), (0, 16)), contiguous=False), View(shape=(1536, 2, 128), strides=(128, 196608, 1), offset=0, mask=None, contiguous=False), View(shape=(2, 32, 48, 8, 16), strides=(0, 12288, 128, 16, 1), offset=0, mask=((0, 1), (0, 32), (0, 48), (0, 8), (0, 16)), contiguous=False), View(shape=(1536, 2, 128), strides=(128, 196608, 1), offset=0, mask=None, contiguous=False), View(shape=(2, 32, 48, 8, 16), strides=(0, 12288, 128, 16, 1), offset=0, mask=((0, 1), (0, 32), (0, 48), (0, 8), (0, 16)), contiguous=False), View(shape=(1536, 2, 128), strides=(128, 196608, 1), offset=0, mask=None, contiguous=False), View(shape=(2, 32, 48, 8, 16), strides=(0, 12288, 128, 16, 1), offset=0, mask=((0, 1), (0, 32), (0, 48), (0, 8), (0, 16)), contiguous=False), View(shape=(1536, 2, 128), strides=(128, 196608, 1), offset=0, mask=None, contiguous=False), View(shape=(2, 32, 48, 8, 16), strides=(0, 12288, 128, 16, 1), offset=0, mask=((0, 1), (0, 32), (0, 48), (0, 8), (0, 16)), contiguous=False), View(shape=(1536, 2, 128), strides=(128, 196608, 1), offset=0, mask=None, contiguous=False), View(shape=(32, 96, 8, 16), strides=(12288, 128, 16, 1), offset=0, mask=None, contiguous=True))), src=()),)),
UOp(Ops.ALU, dtypes.float, arg=BinaryOps.ADD, src=(
UOp(Ops.ADD, dtypes.float, arg=None, src=(
UOp(Ops.LOAD, dtypes.float, arg=None, src=(
UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=13, src=()),
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(2, 32, 48, 8, 16), strides=(0, 8640, 180, 18, 1), offset=19, mask=((1, 2), (0, 32), (0, 48), (0, 8), (0, 16)), contiguous=False), View(shape=(1536, 2, 128), strides=(128, 196608, 1), offset=0, mask=None, contiguous=False), View(shape=(2, 32, 48, 8, 16), strides=(0, 12288, 128, 16, 1), offset=0, mask=((0, 1), (0, 32), (0, 48), (0, 8), (0, 16)), contiguous=False), View(shape=(1536, 2, 128), strides=(128, 196608, 1), offset=0, mask=None, contiguous=False), View(shape=(2, 32, 48, 8, 16), strides=(0, 12288, 128, 16, 1), offset=0, mask=((0, 1), (0, 32), (0, 48), (0, 8), (0, 16)), contiguous=False), View(shape=(1536, 2, 128), strides=(128, 196608, 1), offset=0, mask=None, contiguous=False), View(shape=(2, 32, 48, 8, 16), strides=(0, 12288, 128, 16, 1), offset=0, mask=((0, 1), (0, 32), (0, 48), (0, 8), (0, 16)), contiguous=False), View(shape=(1536, 2, 128), strides=(128, 196608, 1), offset=0, mask=None, contiguous=False), View(shape=(2, 32, 48, 8, 16), strides=(0, 12288, 128, 16, 1), offset=0, mask=((0, 1), (0, 32), (0, 48), (0, 8), (0, 16)), contiguous=False), View(shape=(1536, 2, 128), strides=(128, 196608, 1), offset=0, mask=None, contiguous=False), View(shape=(2, 32, 48, 8, 16), strides=(0, 12288, 128, 16, 1), offset=0, mask=((0, 1), (0, 32), (0, 48), (0, 8), (0, 16)), contiguous=False), View(shape=(1536, 2, 128), strides=(128, 196608, 1), offset=0, mask=None, contiguous=False), View(shape=(32, 96, 8, 16), strides=(12288, 128, 16, 1), offset=0, mask=None, contiguous=True))), src=()),)),
UOp(Ops.ALU, dtypes.float, arg=BinaryOps.ADD, src=(
UOp(Ops.ADD, dtypes.float, arg=None, src=(
UOp(Ops.LOAD, dtypes.float, arg=None, src=(
UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=14, src=()),
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(2, 32, 48, 8, 16), strides=(0, 8640, 180, 18, 1), offset=19, mask=((1, 2), (0, 32), (0, 48), (0, 8), (0, 16)), contiguous=False), View(shape=(1536, 2, 128), strides=(128, 196608, 1), offset=0, mask=None, contiguous=False), View(shape=(2, 32, 48, 8, 16), strides=(0, 12288, 128, 16, 1), offset=0, mask=((0, 1), (0, 32), (0, 48), (0, 8), (0, 16)), contiguous=False), View(shape=(1536, 2, 128), strides=(128, 196608, 1), offset=0, mask=None, contiguous=False), View(shape=(2, 32, 48, 8, 16), strides=(0, 12288, 128, 16, 1), offset=0, mask=((0, 1), (0, 32), (0, 48), (0, 8), (0, 16)), contiguous=False), View(shape=(1536, 2, 128), strides=(128, 196608, 1), offset=0, mask=None, contiguous=False), View(shape=(2, 32, 48, 8, 16), strides=(0, 12288, 128, 16, 1), offset=0, mask=((0, 1), (0, 32), (0, 48), (0, 8), (0, 16)), contiguous=False), View(shape=(1536, 2, 128), strides=(128, 196608, 1), offset=0, mask=None, contiguous=False), View(shape=(2, 32, 48, 8, 16), strides=(0, 12288, 128, 16, 1), offset=0, mask=((0, 1), (0, 32), (0, 48), (0, 8), (0, 16)), contiguous=False), View(shape=(1536, 2, 128), strides=(128, 196608, 1), offset=0, mask=None, contiguous=False), View(shape=(2, 32, 48, 8, 16), strides=(0, 12288, 128, 16, 1), offset=0, mask=((0, 1), (0, 32), (0, 48), (0, 8), (0, 16)), contiguous=False), View(shape=(1536, 2, 128), strides=(128, 196608, 1), offset=0, mask=None, contiguous=False), View(shape=(2, 32, 48, 8, 16), strides=(0, 12288, 128, 16, 1), offset=0, mask=((0, 1), (0, 32), (0, 48), (0, 8), (0, 16)), contiguous=False), View(shape=(1536, 2, 128), strides=(128, 196608, 1), offset=0, mask=None, contiguous=False), View(shape=(32, 96, 8, 16), strides=(12288, 128, 16, 1), offset=0, mask=None, contiguous=True))), src=()),)),
UOp(Ops.ALU, dtypes.float, arg=BinaryOps.ADD, src=(
UOp(Ops.ADD, dtypes.float, arg=None, src=(
UOp(Ops.LOAD, dtypes.float, arg=None, src=(
UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=15, src=()),
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(2, 32, 48, 8, 16), strides=(0, 8640, 180, 18, 1), offset=19, mask=((1, 2), (0, 32), (0, 48), (0, 8), (0, 16)), contiguous=False), View(shape=(1536, 2, 128), strides=(128, 196608, 1), offset=0, mask=None, contiguous=False), View(shape=(2, 32, 48, 8, 16), strides=(0, 12288, 128, 16, 1), offset=0, mask=((0, 1), (0, 32), (0, 48), (0, 8), (0, 16)), contiguous=False), View(shape=(1536, 2, 128), strides=(128, 196608, 1), offset=0, mask=None, contiguous=False), View(shape=(2, 32, 48, 8, 16), strides=(0, 12288, 128, 16, 1), offset=0, mask=((0, 1), (0, 32), (0, 48), (0, 8), (0, 16)), contiguous=False), View(shape=(1536, 2, 128), strides=(128, 196608, 1), offset=0, mask=None, contiguous=False), View(shape=(2, 32, 48, 8, 16), strides=(0, 12288, 128, 16, 1), offset=0, mask=((0, 1), (0, 32), (0, 48), (0, 8), (0, 16)), contiguous=False), View(shape=(1536, 2, 128), strides=(128, 196608, 1), offset=0, mask=None, contiguous=False), View(shape=(2, 32, 48, 8, 16), strides=(0, 12288, 128, 16, 1), offset=0, mask=((0, 1), (0, 32), (0, 48), (0, 8), (0, 16)), contiguous=False), View(shape=(1536, 2, 128), strides=(128, 196608, 1), offset=0, mask=None, contiguous=False), View(shape=(2, 32, 48, 8, 16), strides=(0, 12288, 128, 16, 1), offset=0, mask=((0, 1), (0, 32), (0, 48), (0, 8), (0, 16)), contiguous=False), View(shape=(1536, 2, 128), strides=(128, 196608, 1), offset=0, mask=None, contiguous=False), View(shape=(2, 32, 48, 8, 16), strides=(0, 12288, 128, 16, 1), offset=0, mask=((0, 1), (0, 32), (0, 48), (0, 8), (0, 16)), contiguous=False), View(shape=(1536, 2, 128), strides=(128, 196608, 1), offset=0, mask=None, contiguous=False), View(shape=(2, 32, 48, 8, 16), strides=(0, 12288, 128, 16, 1), offset=0, mask=((0, 1), (0, 32), (0, 48), (0, 8), (0, 16)), contiguous=False), View(shape=(1536, 2, 128), strides=(128, 196608, 1), offset=0, mask=None, contiguous=False), View(shape=(32, 96, 8, 16), strides=(12288, 128, 16, 1), offset=0, mask=None, contiguous=True))), src=()),)),
UOp(Ops.LOAD, dtypes.float, arg=None, src=(
UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=16, src=()),
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(2, 32, 48, 8, 16), strides=(0, 17280, 180, 18, 1), offset=19, mask=((0, 1), (0, 32), (0, 48), (0, 8), (0, 16)), contiguous=False), View(shape=(1536, 2, 128), strides=(128, 196608, 1), offset=0, mask=None, contiguous=False), View(shape=(2, 32, 48, 8, 16), strides=(0, 12288, 128, 16, 1), offset=0, mask=((0, 1), (0, 32), (0, 48), (0, 8), (0, 16)), contiguous=False), View(shape=(1536, 2, 128), strides=(128, 196608, 1), offset=0, mask=None, contiguous=False), View(shape=(2, 32, 48, 8, 16), strides=(0, 12288, 128, 16, 1), offset=0, mask=((0, 1), (0, 32), (0, 48), (0, 8), (0, 16)), contiguous=False), View(shape=(1536, 2, 128), strides=(128, 196608, 1), offset=0, mask=None, contiguous=False), View(shape=(2, 32, 48, 8, 16), strides=(0, 12288, 128, 16, 1), offset=0, mask=((0, 1), (0, 32), (0, 48), (0, 8), (0, 16)), contiguous=False), View(shape=(1536, 2, 128), strides=(128, 196608, 1), offset=0, mask=None, contiguous=False), View(shape=(2, 32, 48, 8, 16), strides=(0, 12288, 128, 16, 1), offset=0, mask=((0, 1), (0, 32), (0, 48), (0, 8), (0, 16)), contiguous=False), View(shape=(1536, 2, 128), strides=(128, 196608, 1), offset=0, mask=None, contiguous=False), View(shape=(2, 32, 48, 8, 16), strides=(0, 12288, 128, 16, 1), offset=0, mask=((0, 1), (0, 32), (0, 48), (0, 8), (0, 16)), contiguous=False), View(shape=(1536, 2, 128), strides=(128, 196608, 1), offset=0, mask=None, contiguous=False), View(shape=(2, 32, 48, 8, 16), strides=(0, 12288, 128, 16, 1), offset=0, mask=((0, 1), (0, 32), (0, 48), (0, 8), (0, 16)), contiguous=False), View(shape=(1536, 2, 128), strides=(128, 196608, 1), offset=0, mask=None, contiguous=False), View(shape=(2, 32, 48, 8, 16), strides=(0, 12288, 128, 16, 1), offset=0, mask=((0, 1), (0, 32), (0, 48), (0, 8), (0, 16)), contiguous=False), View(shape=(1536, 2, 128), strides=(128, 196608, 1), offset=0, mask=None, contiguous=False), View(shape=(32, 96, 8, 16), strides=(12288, 128, 16, 1), offset=0, mask=None, contiguous=True))), src=()),)),)),)),)),)),)),)),)),)),)),)),
UOp(Ops.ALU, dtypes.float, arg=UnaryOps.RECIP, src=(
UOp(Ops.ALU, dtypes.float, arg=BinaryOps.MUL, src=(
UOp(Ops.RECIP, dtypes.float, arg=None, src=(
UOp(Ops.MUL, dtypes.float, arg=None, src=(
UOp(Ops.LOAD, dtypes.float, arg=None, src=(
UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=17, src=()),
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 96, 1, 1), strides=(0, 1, 0, 0), offset=0, mask=None, contiguous=True),)), src=()),)),
ast_const(dtypes.float, 2.0, st_src=(
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 96, 1, 1), strides=(0, 0, 0, 0), offset=0, mask=None, contiguous=False),)), src=()),)),)),)),)),
x80:=UOp(Ops.ALU, dtypes.float, arg=UnaryOps.RECIP, src=(
UOp(Ops.ALU, dtypes.float, arg=BinaryOps.ADD, src=(
UOp(Ops.ALU, dtypes.float, arg=BinaryOps.MUL, src=(
x80:=UOp(Ops.RECIP, dtypes.float, arg=None, src=(
UOp(Ops.ADD, dtypes.float, arg=None, src=(
UOp(Ops.MUL, dtypes.float, arg=None, src=(
UOp(Ops.LOAD, dtypes.float, arg=None, src=(
UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=18, src=()),
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 96, 1, 1), strides=(0, 1, 0, 0), offset=0, mask=None, contiguous=True),)), src=()),)),
@@ -622,7 +622,7 @@ class TestLinearizerFailures(unittest.TestCase):
UOp(Ops.STORE, dtypes.void, arg=None, src=(
UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(), arg=0, src=()),
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1024, 1), strides=(1, 0), offset=0, mask=None, contiguous=True),)), src=()),
UOp(Ops.ALU, dtypes.int, arg=BinaryOps.ADD, src=(
UOp(Ops.ADD, dtypes.int, arg=None, src=(
UOp(Ops.REDUCE_AXIS, dtypes.int, arg=(BinaryOps.ADD, (1,)), src=(
ast_const(dtypes.int, 1, st_src=(
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1025, 2047), strides=(0, 0), offset=0, mask=((0, 1025), (1023, 2047)), contiguous=False), View(shape=(1024, 1024), strides=(1, 2048), offset=0, mask=None, contiguous=False))), src=()),)),)),
@@ -637,7 +637,7 @@ class TestLinearizerFailures(unittest.TestCase):
UOp(Ops.STORE, dtypes.void, arg=None, src=(
UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(), arg=0, src=()),
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(128, 1), strides=(1, 0), offset=0, mask=None, contiguous=True),)), src=()),
UOp(Ops.ALU, dtypes.int, arg=BinaryOps.ADD, src=(
UOp(Ops.ADD, dtypes.int, arg=None, src=(
UOp(Ops.REDUCE_AXIS, dtypes.int, arg=(BinaryOps.ADD, (1,)), src=(
ast_const(dtypes.int, 1, st_src=(
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(129, 255), strides=(0, 0), offset=0, mask=((0, 129), (127, 255)), contiguous=False), View(shape=(128, 128), strides=(1, 256), offset=0, mask=None, contiguous=False))), src=()),)),)),
@@ -692,17 +692,17 @@ class TestLinearizerFailures(unittest.TestCase):
UOp(Ops.STORE, dtypes.void, arg=None, src=(
UOp(Ops.DEFINE_GLOBAL, dtypes.bfloat16.ptr(), arg=0, src=()),
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1,), strides=(0,), offset=0, mask=None, contiguous=True),)), src=()),
UOp(Ops.ALU, dtypes.bfloat16, arg=TernaryOps.WHERE, src=(
UOp(Ops.ALU, dtypes.bool, arg=BinaryOps.CMPLT, src=(
UOp(Ops.WHERE, dtypes.bfloat16, arg=None, src=(
UOp(Ops.CMPLT, dtypes.bool, arg=None, src=(
x5:=UOp(Ops.CAST, dtypes.bfloat16, arg=None, src=(
UOp(Ops.LOAD, dtypes.int, arg=None, src=(
UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(), arg=1, src=()),
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1,), strides=(0,), offset=0, mask=None, contiguous=True),)), src=()),)),)),
x9:=ast_const(dtypes.bfloat16, 230.0, st_src=(
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1,), strides=(0,), offset=0, mask=None, contiguous=True),)), src=()),)),)),
UOp(Ops.ALU, dtypes.bfloat16, arg=BinaryOps.ADD, src=(
UOp(Ops.ALU, dtypes.bfloat16, arg=BinaryOps.MUL, src=(
UOp(Ops.ALU, dtypes.bfloat16, arg=BinaryOps.MUL, src=(
UOp(Ops.ADD, dtypes.bfloat16, arg=None, src=(
UOp(Ops.MUL, dtypes.bfloat16, arg=None, src=(
UOp(Ops.MUL, dtypes.bfloat16, arg=None, src=(
x5,
ast_const(dtypes.bfloat16, 0.004347826086956522, st_src=(
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1,), strides=(0,), offset=0, mask=None, contiguous=True),)), src=()),)),)),
@@ -710,10 +710,10 @@ class TestLinearizerFailures(unittest.TestCase):
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1,), strides=(0,), offset=0, mask=None, contiguous=True),)), src=()),)),)),
ast_const(dtypes.bfloat16, 1.99375e-07, st_src=(
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1,), strides=(0,), offset=0, mask=None, contiguous=True),)), src=()),)),)),
UOp(Ops.ALU, dtypes.bfloat16, arg=BinaryOps.ADD, src=(
UOp(Ops.ALU, dtypes.bfloat16, arg=BinaryOps.MUL, src=(
UOp(Ops.ALU, dtypes.bfloat16, arg=BinaryOps.MUL, src=(
UOp(Ops.ALU, dtypes.bfloat16, arg=BinaryOps.ADD, src=(
UOp(Ops.ADD, dtypes.bfloat16, arg=None, src=(
UOp(Ops.MUL, dtypes.bfloat16, arg=None, src=(
UOp(Ops.MUL, dtypes.bfloat16, arg=None, src=(
UOp(Ops.ADD, dtypes.bfloat16, arg=None, src=(
x5,
x9,)),
ast_const(dtypes.bfloat16, 0.0012987012987012987, st_src=(
@@ -732,7 +732,7 @@ class TestLinearizerFailures(unittest.TestCase):
UOp(Ops.CAST, dtypes.half, arg=None, src=(
UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (7, 6, 5)), src=(
UOp(Ops.CAST, dtypes.float, arg=None, src=(
UOp(Ops.ALU, dtypes.half, arg=BinaryOps.MUL, src=(
UOp(Ops.MUL, dtypes.half, arg=None, src=(
UOp(Ops.LOAD, dtypes.half, arg=None, src=(
UOp(Ops.DEFINE_GLOBAL, dtypes.half.ptr(), arg=1, src=()),
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 128, 1, 64, 4, 58, 4, 58), strides=(0, 200704, 0, 3136, 0, 56, 0, 1), offset=-57, mask=((0, 1), (0, 128), (0, 1), (0, 64), (0, 4), (1, 57), (0, 4), (1, 57)), contiguous=False), View(shape=(128, 1, 64, 56, 56, 64, 3, 3), strides=(3444736, 0, 0, 232, 1, 53824, 13688, 59), offset=0, mask=None, contiguous=False))), src=()),)),
@@ -750,7 +750,7 @@ class TestLinearizerFailures(unittest.TestCase):
UOp(Ops.CAST, dtypes.half, arg=None, src=(
UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (7, 6, 5)), src=(
UOp(Ops.CAST, dtypes.float, arg=None, src=(
UOp(Ops.ALU, dtypes.half, arg=BinaryOps.MUL, src=(
UOp(Ops.MUL, dtypes.half, arg=None, src=(
UOp(Ops.LOAD, dtypes.half, arg=None, src=(
UOp(Ops.DEFINE_GLOBAL, dtypes.half.ptr(), arg=1, src=()),
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(256, 1, 12, 31, 31, 3, 2, 2), strides=(3072, 0, 0, 32, 1, 1024, 32, 1), offset=0, mask=None, contiguous=False),)), src=()),)),
@@ -767,9 +767,9 @@ class TestLinearizerFailures(unittest.TestCase):
UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=0, src=()),
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 16, 13, 1), strides=(0, 13, 1, 0), offset=0, mask=None, contiguous=True),)), src=()),
UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (3,)), src=(
UOp(Ops.ALU, dtypes.float, arg=UnaryOps.EXP2, src=(
UOp(Ops.ALU, dtypes.float, arg=BinaryOps.MUL, src=(
UOp(Ops.ALU, dtypes.float, arg=BinaryOps.ADD, src=(
UOp(Ops.EXP2, dtypes.float, arg=None, src=(
UOp(Ops.MUL, dtypes.float, arg=None, src=(
UOp(Ops.ADD, dtypes.float, arg=None, src=(
UOp(Ops.LOAD, dtypes.float, arg=None, src=(
UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=1, src=()),
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 16, 13, 13), strides=(0, 169, 13, 1), offset=0, mask=None, contiguous=True),)), src=()),)),
@@ -792,7 +792,7 @@ class TestLinearizerFailures(unittest.TestCase):
UOp(Ops.CAST, dtypes.half, arg=None, src=(
UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (7, 6, 5)), src=(
UOp(Ops.CAST, dtypes.float, arg=None, src=(
UOp(Ops.ALU, dtypes.half, arg=BinaryOps.MUL, src=(
UOp(Ops.MUL, dtypes.half, arg=None, src=(
UOp(Ops.LOAD, dtypes.half, arg=None, src=(
UOp(Ops.DEFINE_GLOBAL, dtypes.half.ptr(), arg=1, src=()),
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 256, 1, 256, 4, 16, 4, 16), strides=(0, 50176, 0, 196, 0, 14, 0, 1), offset=-15, mask=((0, 1), (0, 256), (0, 1), (0, 256), (0, 4), (1, 15), (0, 4), (1, 15)), contiguous=False), View(shape=(256, 1, 256, 14, 14, 256, 3, 3), strides=(1048576, 0, 0, 64, 1, 4096, 1088, 17), offset=0, mask=None, contiguous=False))), src=()),)),
@@ -809,26 +809,26 @@ class TestLinearizerFailures(unittest.TestCase):
UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=0, src=()),
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1,), strides=(0,), offset=0, mask=None, contiguous=True),)), src=()),
UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (0,)), src=(
UOp(Ops.ALU, dtypes.float, arg=BinaryOps.MUL, src=(
UOp(Ops.MUL, dtypes.float, arg=None, src=(
x5:=UOp(Ops.LOAD, dtypes.float, arg=None, src=(
UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=1, src=()),
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(32640,), strides=(1,), offset=0, mask=((0, 26040),), contiguous=False),)), src=()),)),
UOp(Ops.ALU, dtypes.float, arg=TernaryOps.WHERE, src=(
UOp(Ops.ALU, dtypes.bool, arg=BinaryOps.CMPNE, src=(
UOp(Ops.WHERE, dtypes.float, arg=None, src=(
UOp(Ops.CMPNE, dtypes.bool, arg=None, src=(
x5,
x10:=ast_const(dtypes.float, 0.0, st_src=(
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(32640,), strides=(0,), offset=0, mask=None, contiguous=False),)), src=()),)),)),
UOp(Ops.ALU, dtypes.float, arg=TernaryOps.WHERE, src=(
UOp(Ops.ALU, dtypes.bool, arg=BinaryOps.CMPLT, src=(
UOp(Ops.ALU, dtypes.float, arg=BinaryOps.ADD, src=(
UOp(Ops.ALU, dtypes.float, arg=BinaryOps.ADD, src=(
UOp(Ops.ALU, dtypes.float, arg=BinaryOps.MUL, src=(
UOp(Ops.WHERE, dtypes.float, arg=None, src=(
UOp(Ops.CMPLT, dtypes.bool, arg=None, src=(
UOp(Ops.ADD, dtypes.float, arg=None, src=(
UOp(Ops.ADD, dtypes.float, arg=None, src=(
UOp(Ops.MUL, dtypes.float, arg=None, src=(
ast_const(dtypes.float, 0.06788442333021306, st_src=(
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(32640,), strides=(0,), offset=0, mask=((0, 26040),), contiguous=False),)), src=()),)),
x5,)),
ast_const(dtypes.float, -0.03394221166510653, st_src=(
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(32640,), strides=(0,), offset=0, mask=((0, 26040),), contiguous=False),)), src=()),)),)),
UOp(Ops.ALU, dtypes.float, arg=BinaryOps.ADD, src=(
UOp(Ops.ADD, dtypes.float, arg=None, src=(
UOp(Ops.LOAD, dtypes.float, arg=None, src=(
UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=2, src=()),
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(32640,), strides=(1,), offset=-26040, mask=((26040, 32640),), contiguous=False),)), src=()),)),
@@ -849,9 +849,9 @@ class TestLinearizerFailures(unittest.TestCase):
UOp(Ops.STORE, dtypes.void, arg=None, src=(
UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=0, src=()),
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(4, 1, 6, 10, 3, 1, 1, 1), strides=(180, 0, 30, 3, 1, 0, 0, 0), offset=0, mask=None, contiguous=True),)), src=()),
UOp(Ops.ALU, dtypes.float, arg=BinaryOps.MAX, src=(
UOp(Ops.MAX, dtypes.float, arg=None, src=(
UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (6, 7)), src=(
UOp(Ops.ALU, dtypes.float, arg=BinaryOps.MUL, src=(
UOp(Ops.MUL, dtypes.float, arg=None, src=(
UOp(Ops.LOAD, dtypes.float, arg=None, src=(
UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=1, src=()),
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(4, 1, 6, 10, 3, 1, 2, 5), strides=(77, 0, 0, 7, 1, 0, 7, 1), offset=0, mask=None, contiguous=False),)), src=()),)),
@@ -873,7 +873,7 @@ class TestLinearizerFailures(unittest.TestCase):
UOp(Ops.DEFINE_GLOBAL, dtypes.uchar.ptr(), arg=0, src=()),
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(5, 1), strides=(1, 0), offset=0, mask=None, contiguous=True),)), src=()),
UOp(Ops.CAST, dtypes.uchar, arg=None, src=(
UOp(Ops.ALU, dtypes.uint, arg=BinaryOps.ADD, src=(
UOp(Ops.ADD, dtypes.uint, arg=None, src=(
UOp(Ops.REDUCE_AXIS, dtypes.uint, arg=(BinaryOps.ADD, (1,)), src=(
UOp(Ops.CAST, dtypes.uint, arg=None, src=(
ast_const(dtypes.uchar, 1, st_src=(
@@ -892,10 +892,10 @@ class TestLinearizerFailures(unittest.TestCase):
UOp(Ops.STORE, dtypes.void, arg=None, src=(
UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=0, src=()),
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(512, 1, 32, 24, 24, 1, 1, 1), strides=(18432, 0, 576, 24, 1, 0, 0, 0), offset=0, mask=None, contiguous=True),)), src=()),
UOp(Ops.ALU, dtypes.float, arg=BinaryOps.MAX, src=(
UOp(Ops.ALU, dtypes.float, arg=BinaryOps.ADD, src=(
UOp(Ops.MAX, dtypes.float, arg=None, src=(
UOp(Ops.ADD, dtypes.float, arg=None, src=(
UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (6, 7)), src=(
UOp(Ops.ALU, dtypes.float, arg=BinaryOps.MUL, src=(
UOp(Ops.MUL, dtypes.float, arg=None, src=(
UOp(Ops.CAST, dtypes.float, arg=None, src=(
UOp(Ops.LOAD, dtypes.uchar, arg=None, src=(
UOp(Ops.DEFINE_GLOBAL, dtypes.uchar.ptr(), arg=1, src=()),
@@ -920,7 +920,7 @@ class TestLinearizerFailures(unittest.TestCase):
UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=0, src=()),
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 1, 32, 1, 1, 1, 5, 5, 256), strides=(0, 0, 6400, 0, 0, 0, 1280, 256, 1), offset=0, mask=None, contiguous=True),)), src=()),
UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (0, 3, 4)), src=(
UOp(Ops.ALU, dtypes.float, arg=BinaryOps.MUL, src=(
UOp(Ops.MUL, dtypes.float, arg=None, src=(
UOp(Ops.CAST, dtypes.float, arg=None, src=(
UOp(Ops.LOAD, dtypes.uchar, arg=None, src=(
UOp(Ops.DEFINE_GLOBAL, dtypes.uchar.ptr(), arg=1, src=()),
@@ -940,10 +940,10 @@ class TestLinearizerFailures(unittest.TestCase):
UOp(Ops.STORE, dtypes.void, arg=None, src=(
UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=0, src=()),
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(10000, 1, 32, 24, 24, 1, 1, 1), strides=(18432, 0, 576, 24, 1, 0, 0, 0), offset=0, mask=None, contiguous=True),)), src=()),
UOp(Ops.ALU, dtypes.float, arg=BinaryOps.MAX, src=(
UOp(Ops.ALU, dtypes.float, arg=BinaryOps.ADD, src=(
UOp(Ops.MAX, dtypes.float, arg=None, src=(
UOp(Ops.ADD, dtypes.float, arg=None, src=(
UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (6, 7)), src=(
UOp(Ops.ALU, dtypes.float, arg=BinaryOps.MUL, src=(
UOp(Ops.MUL, dtypes.float, arg=None, src=(
UOp(Ops.CAST, dtypes.float, arg=None, src=(
UOp(Ops.LOAD, dtypes.uchar, arg=None, src=(
UOp(Ops.DEFINE_GLOBAL, dtypes.uchar.ptr(), arg=1, src=()),
@@ -967,7 +967,7 @@ class TestLinearizerFailures(unittest.TestCase):
UOp(Ops.STORE, dtypes.void, arg=None, src=(
UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(), arg=0, src=()),
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(60000, 1), strides=(1, 0), offset=0, mask=None, contiguous=True),)), src=()),
UOp(Ops.ALU, dtypes.int, arg=BinaryOps.ADD, src=(
UOp(Ops.ADD, dtypes.int, arg=None, src=(
UOp(Ops.REDUCE_AXIS, dtypes.int, arg=(BinaryOps.ADD, (1,)), src=(
ast_const(dtypes.int, 1, st_src=(
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(60001, 119999), strides=(0, 0), offset=0, mask=((0, 60001), (59999, 119999)), contiguous=False), View(shape=(60000, 60000), strides=(1, 120000), offset=0, mask=None, contiguous=False))), src=()),)),)),
@@ -988,7 +988,7 @@ class TestLinearizerFailures(unittest.TestCase):
UOp(Ops.CAST, dtypes.half, arg=None, src=(
UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (5, 6, 7)), src=(
UOp(Ops.CAST, dtypes.float, arg=None, src=(
UOp(Ops.ALU, dtypes.half, arg=BinaryOps.MUL, src=(
UOp(Ops.MUL, dtypes.half, arg=None, src=(
UOp(Ops.LOAD, dtypes.half, arg=None, src=(
UOp(Ops.DEFINE_GLOBAL, dtypes.half.ptr(), arg=1, src=()),
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 256, 1, 128, 4, 58, 4, 58), strides=(0, 401408, 0, 3136, 0, 56, 0, 1), offset=-57, mask=((0, 1), (0, 256), (0, 1), (0, 128), (0, 4), (1, 57), (0, 4), (1, 57)), contiguous=False), View(shape=(256, 1, 128, 28, 28, 128, 3, 3), strides=(6889472, 0, 0, 464, 2, 53824, 13688, 59), offset=0, mask=None, contiguous=False))), src=()),)),
@@ -1052,18 +1052,18 @@ class TestLinearizerFailures(unittest.TestCase):
UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=0, src=()),
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(2, 3, 1, 1, 1), strides=(3, 1, 0, 0, 0), offset=0, mask=None, contiguous=True),)), src=()),
UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (2, 3)), src=(
UOp(Ops.ALU, dtypes.float, arg=BinaryOps.MUL, src=(
UOp(Ops.MUL, dtypes.float, arg=None, src=(
UOp(Ops.LOAD, dtypes.float, arg=None, src=(
UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=1, src=()),
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(2, 3, 2, 3, 1), strides=(0, 0, 3, 1, 0), offset=0, mask=None, contiguous=False),)), src=()),)),
UOp(Ops.CAST, dtypes.float, arg=None, src=(
UOp(Ops.ALU, dtypes.bool, arg=BinaryOps.MUL, src=(
UOp(Ops.ALU, dtypes.bool, arg=BinaryOps.CMPNE, src=(
UOp(Ops.ALU, dtypes.bool, arg=BinaryOps.CMPNE, src=(
UOp(Ops.MUL, dtypes.bool, arg=None, src=(
UOp(Ops.CMPNE, dtypes.bool, arg=None, src=(
UOp(Ops.CMPNE, dtypes.bool, arg=None, src=(
UOp(Ops.LOAD, dtypes.int, arg=None, src=(
UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(), arg=2, src=()),
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(2, 3, 2, 3, 1), strides=(0, 0, 0, 0, 0), offset=0, mask=None, contiguous=False),)), src=()),)),
UOp(Ops.ALU, dtypes.int, arg=BinaryOps.ADD, src=(
UOp(Ops.ADD, dtypes.int, arg=None, src=(
UOp(Ops.REDUCE_AXIS, dtypes.int, arg=(BinaryOps.ADD, (4,)), src=(
ast_const(dtypes.int, 1, st_src=(
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(3, 3), strides=(0, 0), offset=0, mask=((0, 3), (1, 3)), contiguous=False), View(shape=(2, 3, 2, 3, 3), strides=(0, 0, 1, 0, 4), offset=0, mask=((0, 2), (0, 3), (0, 2), (0, 3), (0, 2)), contiguous=False))), src=()),)),)),
@@ -1071,12 +1071,12 @@ class TestLinearizerFailures(unittest.TestCase):
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(2, 3, 2, 3, 1), strides=(0, 0, 0, 0, 0), offset=0, mask=None, contiguous=False),)), src=()),)),)),)),
x21:=ast_const(dtypes.bool, True, st_src=(
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(2, 3, 2, 3, 1), strides=(0, 0, 0, 0, 0), offset=0, mask=None, contiguous=False),)), src=()),)),)),
UOp(Ops.ALU, dtypes.bool, arg=BinaryOps.CMPNE, src=(
UOp(Ops.ALU, dtypes.bool, arg=BinaryOps.CMPNE, src=(
UOp(Ops.CMPNE, dtypes.bool, arg=None, src=(
UOp(Ops.CMPNE, dtypes.bool, arg=None, src=(
UOp(Ops.LOAD, dtypes.int, arg=None, src=(
UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(), arg=3, src=()),
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(2, 3, 2, 3, 1), strides=(3, 1, 0, 0, 0), offset=0, mask=None, contiguous=False),)), src=()),)),
UOp(Ops.ALU, dtypes.int, arg=BinaryOps.ADD, src=(
UOp(Ops.ADD, dtypes.int, arg=None, src=(
UOp(Ops.REDUCE_AXIS, dtypes.int, arg=(BinaryOps.ADD, (4,)), src=(
ast_const(dtypes.int, 1, st_src=(
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(4, 5), strides=(0, 0), offset=0, mask=((0, 4), (2, 5)), contiguous=False), View(shape=(2, 3, 2, 3, 3), strides=(0, 0, 0, 1, 6), offset=0, mask=None, contiguous=False))), src=()),)),)),
@@ -1091,13 +1091,13 @@ class TestLinearizerFailures(unittest.TestCase):
UOp(Ops.STORE, dtypes.void, arg=None, src=(
UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=0, src=()),
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(512, 1), strides=(1, 0), offset=0, mask=None, contiguous=True),)), src=()),
UOp(Ops.ALU, dtypes.float, arg=BinaryOps.MUL, src=(
UOp(Ops.MUL, dtypes.float, arg=None, src=(
UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (1,)), src=(
UOp(Ops.ALU, dtypes.float, arg=BinaryOps.MUL, src=(
UOp(Ops.MUL, dtypes.float, arg=None, src=(
UOp(Ops.CAST, dtypes.float, arg=None, src=(
UOp(Ops.ALU, dtypes.bool, arg=BinaryOps.MUL, src=(
UOp(Ops.ALU, dtypes.bool, arg=BinaryOps.CMPNE, src=(
UOp(Ops.ALU, dtypes.bool, arg=BinaryOps.CMPNE, src=(
UOp(Ops.MUL, dtypes.bool, arg=None, src=(
UOp(Ops.CMPNE, dtypes.bool, arg=None, src=(
UOp(Ops.CMPNE, dtypes.bool, arg=None, src=(
UOp(Ops.LOAD, dtypes.int, arg=None, src=(
UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(), arg=1, src=()),
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(512, 10), strides=(0, 1), offset=0, mask=None, contiguous=False),)), src=()),)),
@@ -1112,7 +1112,7 @@ class TestLinearizerFailures(unittest.TestCase):
UOp(Ops.LOAD, dtypes.float, arg=None, src=(
UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=4, src=()),
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(512, 10), strides=(0, 0), offset=0, mask=None, contiguous=False),)), src=()),)),)),)),
UOp(Ops.ALU, dtypes.float, arg=UnaryOps.RECIP, src=(
UOp(Ops.RECIP, dtypes.float, arg=None, src=(
UOp(Ops.LOAD, dtypes.float, arg=None, src=(
UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=5, src=()),
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(512, 1), strides=(1, 0), offset=0, mask=None, contiguous=True),)), src=()),)),)),)),)),))
@@ -1125,7 +1125,7 @@ class TestLinearizerFailures(unittest.TestCase):
UOp(Ops.STORE, dtypes.void, arg=None, src=(
UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(), arg=0, src=()),
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(60000, 1), strides=(1, 0), offset=0, mask=None, contiguous=True),)), src=()),
UOp(Ops.ALU, dtypes.int, arg=BinaryOps.ADD, src=(
UOp(Ops.ADD, dtypes.int, arg=None, src=(
UOp(Ops.REDUCE_AXIS, dtypes.int, arg=(BinaryOps.ADD, (1,)), src=(
ast_const(dtypes.int, 1, st_src=(
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(60001, 119999), strides=(0, 0), offset=0, mask=((0, 60001), (59999, 119999)), contiguous=False), View(shape=(60000, 60000), strides=(1, 120000), offset=0, mask=None, contiguous=False))), src=()),)),)),
@@ -1143,7 +1143,7 @@ class TestLinearizerFailures(unittest.TestCase):
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 1, 64, 1, 1, 256, 1, 1, 256), strides=(0, 0, 65536, 0, 0, 256, 0, 0, 1), offset=0, mask=None, contiguous=True),)), src=()),
UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (3, 4)), src=(
UOp(Ops.CAST, dtypes.float, arg=None, src=(
UOp(Ops.ALU, dtypes.half, arg=BinaryOps.MUL, src=(
UOp(Ops.MUL, dtypes.half, arg=None, src=(
UOp(Ops.LOAD, dtypes.half, arg=None, src=(
UOp(Ops.DEFINE_GLOBAL, dtypes.half.ptr(), arg=1, src=()),
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 1, 64, 56, 56, 256, 1, 1, 256), strides=(0, 0, 0, 56, 1, 3136, 0, 0, 802816), offset=0, mask=None, contiguous=False),)), src=()),)),
@@ -1160,7 +1160,7 @@ class TestLinearizerFailures(unittest.TestCase):
UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=0, src=()),
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(10, 6, 1), strides=(6, 1, 0), offset=0, mask=None, contiguous=True),)), src=()),
UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (2,)), src=(
UOp(Ops.ALU, dtypes.float, arg=BinaryOps.MUL, src=(
UOp(Ops.MUL, dtypes.float, arg=None, src=(
UOp(Ops.LOAD, dtypes.float, arg=None, src=(
UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=1, src=()),
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(10, 6, 10), strides=(10, 0, 1), offset=0, mask=None, contiguous=False),)), src=()),)),
@@ -1176,14 +1176,14 @@ class TestLinearizerFailures(unittest.TestCase):
UOp(Ops.STORE, dtypes.void, arg=None, src=(
UOp(Ops.DEFINE_GLOBAL, dtypes.bool.ptr(), arg=0, src=()),
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 1, 20, 1, 20), strides=(0, 0, 20, 0, 1), offset=0, mask=None, contiguous=True),)), src=()),
UOp(Ops.ALU, dtypes.bool, arg=BinaryOps.CMPNE, src=(
UOp(Ops.CMPNE, dtypes.bool, arg=None, src=(
UOp(Ops.REDUCE_AXIS, dtypes.bool, arg=(BinaryOps.ADD, (3,)), src=(
UOp(Ops.ALU, dtypes.bool, arg=BinaryOps.MUL, src=(
UOp(Ops.MUL, dtypes.bool, arg=None, src=(
UOp(Ops.LOAD, dtypes.bool, arg=None, src=(
UOp(Ops.DEFINE_GLOBAL, dtypes.bool.ptr(), arg=1, src=()),
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 1, 20, 20, 20), strides=(0, 0, 0, 20, 1), offset=0, mask=None, contiguous=False),)), src=()),)),
UOp(Ops.ALU, dtypes.bool, arg=BinaryOps.CMPNE, src=(
UOp(Ops.ALU, dtypes.bool, arg=BinaryOps.CMPNE, src=(
UOp(Ops.CMPNE, dtypes.bool, arg=None, src=(
UOp(Ops.CMPNE, dtypes.bool, arg=None, src=(
UOp(Ops.LOAD, dtypes.int, arg=None, src=(
UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(), arg=2, src=()),
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 1, 20, 20, 20), strides=(0, 0, 1, 0, 0), offset=0, mask=None, contiguous=False),)), src=()),)),
@@ -1203,25 +1203,25 @@ class TestLinearizerFailures(unittest.TestCase):
UOp(Ops.STORE, dtypes.void, arg=None, src=(
UOp(Ops.DEFINE_GLOBAL, dtypes.half.ptr(), arg=0, src=()),
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(12, 1024, 1), strides=(1024, 1, 0), offset=0, mask=None, contiguous=True),)), src=()),
UOp(Ops.ALU, dtypes.half, arg=UnaryOps.RECIP, src=(
UOp(Ops.ALU, dtypes.half, arg=BinaryOps.ADD, src=(
UOp(Ops.ALU, dtypes.half, arg=TernaryOps.WHERE, src=(
UOp(Ops.RECIP, dtypes.half, arg=None, src=(
UOp(Ops.ADD, dtypes.half, arg=None, src=(
UOp(Ops.WHERE, dtypes.half, arg=None, src=(
x6:=UOp(Ops.VALID, dtypes.bool, arg=None, src=(
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(12, 1024, 1), strides=(0, 0, 0), offset=0, mask=None, contiguous=False),)), src=()),)),
UOp(Ops.CONST, dtypes.half, arg=1.0, src=()),
x9:=UOp(Ops.CONST, dtypes.half, arg=0.0, src=()),)),
UOp(Ops.ALU, dtypes.half, arg=UnaryOps.EXP2, src=(
UOp(Ops.ALU, dtypes.half, arg=BinaryOps.MUL, src=(
UOp(Ops.ALU, dtypes.half, arg=BinaryOps.MUL, src=(
UOp(Ops.ALU, dtypes.half, arg=TernaryOps.WHERE, src=(
UOp(Ops.EXP2, dtypes.half, arg=None, src=(
UOp(Ops.MUL, dtypes.half, arg=None, src=(
UOp(Ops.MUL, dtypes.half, arg=None, src=(
UOp(Ops.WHERE, dtypes.half, arg=None, src=(
x6,
UOp(Ops.CONST, dtypes.half, arg=2.0, src=()),
x9,)),
UOp(Ops.ALU, dtypes.half, arg=BinaryOps.ADD, src=(
UOp(Ops.ADD, dtypes.half, arg=None, src=(
UOp(Ops.CAST, dtypes.half, arg=None, src=(
UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (2,)), src=(
UOp(Ops.CAST, dtypes.float, arg=None, src=(
UOp(Ops.ALU, dtypes.half, arg=BinaryOps.MUL, src=(
UOp(Ops.MUL, dtypes.half, arg=None, src=(
UOp(Ops.LOAD, dtypes.half, arg=None, src=(
UOp(Ops.DEFINE_GLOBAL, dtypes.half.ptr(), arg=1, src=()),
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(12, 1024, 1024), strides=(524288, 0, 1), offset=0, mask=None, contiguous=False),)), src=()),)),
@@ -1231,7 +1231,7 @@ class TestLinearizerFailures(unittest.TestCase):
UOp(Ops.LOAD, dtypes.half, arg=None, src=(
UOp(Ops.DEFINE_GLOBAL, dtypes.half.ptr(), arg=3, src=()),
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(12, 1024, 1), strides=(0, 1, 0), offset=0, mask=None, contiguous=False),)), src=()),)),)),)),
UOp(Ops.ALU, dtypes.half, arg=TernaryOps.WHERE, src=(
UOp(Ops.WHERE, dtypes.half, arg=None, src=(
x6,
UOp(Ops.CONST, dtypes.half, arg=-1.4426950408889634, src=()),
x9,)),)),)),)),)),)),))
@@ -1250,7 +1250,7 @@ class TestLinearizerFailures(unittest.TestCase):
UOp(Ops.CAST, dtypes.half, arg=None, src=(
UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (5, 6, 7)), src=(
UOp(Ops.CAST, dtypes.float, arg=None, src=(
UOp(Ops.ALU, dtypes.half, arg=BinaryOps.MUL, src=(
UOp(Ops.MUL, dtypes.half, arg=None, src=(
UOp(Ops.LOAD, dtypes.half, arg=None, src=(
UOp(Ops.DEFINE_GLOBAL, dtypes.half.ptr(), arg=1, src=()),
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 256, 1, 3, 8, 230, 8, 230), strides=(0, 150528, 0, 50176, 0, 224, 0, 1), offset=-675, mask=((0, 1), (0, 256), (0, 1), (0, 3), (0, 8), (3, 227), (0, 8), (3, 227)), contiguous=False), View(shape=(256, 1, 64, 112, 112, 3, 7, 7), strides=(10156800, 0, 0, 3680, 2, 3385600, 425040, 231), offset=0, mask=None, contiguous=False))), src=()),)),
@@ -1267,29 +1267,29 @@ class TestLinearizerFailures(unittest.TestCase):
UOp(Ops.DEFINE_GLOBAL, dtypes.uchar.ptr(), arg=0, src=()),
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1024, 1, 1), strides=(1, 0, 0), offset=0, mask=None, contiguous=True),)), src=()),
UOp(Ops.REDUCE_AXIS, dtypes.uchar, arg=(BinaryOps.ADD, (1,)), src=(
UOp(Ops.ALU, dtypes.uchar, arg=BinaryOps.MUL, src=(
UOp(Ops.MUL, dtypes.uchar, arg=None, src=(
UOp(Ops.LOAD, dtypes.uchar, arg=None, src=(
UOp(Ops.DEFINE_GLOBAL, dtypes.uchar.ptr(), arg=1, src=()),
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1024, 50000, 1), strides=(0, 1, 0), offset=0, mask=None, contiguous=False),)), src=()),)),
UOp(Ops.CAST, dtypes.uchar, arg=None, src=(
UOp(Ops.ALU, dtypes.bool, arg=BinaryOps.CMPNE, src=(
UOp(Ops.ALU, dtypes.bool, arg=BinaryOps.CMPNE, src=(
UOp(Ops.CMPNE, dtypes.bool, arg=None, src=(
UOp(Ops.CMPNE, dtypes.bool, arg=None, src=(
UOp(Ops.LOAD, dtypes.int, arg=None, src=(
UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(), arg=2, src=()),
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1024, 50000, 1), strides=(1, 0, 0), offset=0, mask=None, contiguous=False),)), src=()),)),
UOp(Ops.ALU, dtypes.int, arg=BinaryOps.ADD, src=(
UOp(Ops.ADD, dtypes.int, arg=None, src=(
UOp(Ops.REDUCE_AXIS, dtypes.int, arg=(BinaryOps.ADD, (2,)), src=(
UOp(Ops.ALU, dtypes.int, arg=TernaryOps.WHERE, src=(
UOp(Ops.WHERE, dtypes.int, arg=None, src=(
UOp(Ops.VALID, dtypes.bool, arg=None, src=(
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(50001, 99999), strides=(0, 0), offset=0, mask=((0, 50001), (49999, 99999)), contiguous=False), View(shape=(1024, 50000, 50000), strides=(0, 1, 100000), offset=0, mask=None, contiguous=False))), src=()),)),
UOp(Ops.CONST, dtypes.int, arg=1, src=()),
x20:=UOp(Ops.CONST, dtypes.int, arg=0, src=()),)),)),
UOp(Ops.ALU, dtypes.int, arg=TernaryOps.WHERE, src=(
UOp(Ops.WHERE, dtypes.int, arg=None, src=(
x22:=UOp(Ops.VALID, dtypes.bool, arg=None, src=(
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1024, 50000, 1), strides=(0, 0, 0), offset=0, mask=None, contiguous=False),)), src=()),)),
UOp(Ops.CONST, dtypes.int, arg=-1, src=()),
x20,)),)),)),
UOp(Ops.ALU, dtypes.bool, arg=TernaryOps.WHERE, src=(
UOp(Ops.WHERE, dtypes.bool, arg=None, src=(
x22,
UOp(Ops.CONST, dtypes.bool, arg=True, src=()),
UOp(Ops.CONST, dtypes.bool, arg=False, src=()),)),)),)),)),)),)),))
@@ -1307,7 +1307,7 @@ class TestLinearizerFailures(unittest.TestCase):
UOp(Ops.CAST, dtypes.half, arg=None, src=(
UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (5, 6, 7)), src=(
UOp(Ops.CAST, dtypes.float, arg=None, src=(
UOp(Ops.ALU, dtypes.half, arg=BinaryOps.MUL, src=(
UOp(Ops.MUL, dtypes.half, arg=None, src=(
UOp(Ops.LOAD, dtypes.half, arg=None, src=(
UOp(Ops.DEFINE_GLOBAL, dtypes.half.ptr(), arg=1, src=()),
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 256, 1, 64, 4, 58, 4, 58), strides=(0, 200704, 0, 3136, 0, 56, 0, 1), offset=-57, mask=((0, 1), (0, 256), (0, 1), (0, 64), (0, 4), (1, 57), (0, 4), (1, 57)), contiguous=False), View(shape=(256, 1, 64, 56, 56, 64, 3, 3), strides=(3444736, 0, 0, 232, 1, 53824, 13688, 59), offset=0, mask=None, contiguous=False))), src=()),)),
@@ -1327,7 +1327,7 @@ class TestLinearizerFailures(unittest.TestCase):
UOp(Ops.CAST, dtypes.half, arg=None, src=(
UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (5, 6, 7)), src=(
UOp(Ops.CAST, dtypes.float, arg=None, src=(
UOp(Ops.ALU, dtypes.half, arg=BinaryOps.MUL, src=(
UOp(Ops.MUL, dtypes.half, arg=None, src=(
UOp(Ops.LOAD, dtypes.half, arg=None, src=(
UOp(Ops.DEFINE_GLOBAL, dtypes.half.ptr(), arg=1, src=()),
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, W, 1, 64, 4, 58, 4, 58), strides=(0, 200704, 0, 3136, 0, 56, 0, 1), offset=-57, mask=((0, 1), (0, W), (0, 1), (0, 64), (0, 4), (1, 57), (0, 4), (1, 57)), contiguous=False),

View File

@@ -8,7 +8,7 @@ from tinygrad.engine.search import Opt, OptOps
from tinygrad.engine.search import time_linearizer, bufs_from_lin
# stuff needed to unpack a kernel
from tinygrad.ops import UOp, Ops, BinaryOps, UnaryOps
from tinygrad.ops import UOp, Ops, BinaryOps
from tinygrad.shape.shapetracker import ShapeTracker
from tinygrad.shape.view import View
@@ -28,13 +28,13 @@ class TestLinearizerOverflow(unittest.TestCase):
UOp(Ops.STORE, None, arg=None, src=(
UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=0, src=()),
UOp(Ops.VIEW, None, arg=ShapeTracker(views=(View(shape=(64, 1, 64, 112, 112, 1, 1, 1), strides=(802816, 0, 12544, 112, 1, 0, 0, 0), offset=0, mask=None, contiguous=True),)), src=()),
UOp(Ops.ALU, dtypes.float, arg=BinaryOps.MAX, src=(
UOp(Ops.ALU, dtypes.float, arg=BinaryOps.ADD, src=(
UOp(Ops.ALU, dtypes.float, arg=BinaryOps.MUL, src=(
UOp(Ops.ALU, dtypes.float, arg=BinaryOps.MUL, src=(
UOp(Ops.ALU, dtypes.float, arg=BinaryOps.ADD, src=(
UOp(Ops.MAX, dtypes.float, arg=None, src=(
UOp(Ops.ADD, dtypes.float, arg=None, src=(
UOp(Ops.MUL, dtypes.float, arg=None, src=(
UOp(Ops.MUL, dtypes.float, arg=None, src=(
UOp(Ops.ADD, dtypes.float, arg=None, src=(
UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (7, 6, 5)), src=(
UOp(Ops.ALU, dtypes.float, arg=BinaryOps.MUL, src=(
UOp(Ops.MUL, dtypes.float, arg=None, src=(
UOp(Ops.LOAD, dtypes.float, arg=None, src=(
UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=1, src=()),
UOp(Ops.VIEW, None, arg=ShapeTracker(views=(View(shape=(1, 64, 1, 3, 8, 230, 8, 230), strides=(0, 150528, 0, 50176, 0, 224, 0, 1), offset=-675, mask=((0, 1), (0, 64), (0, 1), (0, 3), (0, 8), (3, 227), (0, 8), (3, 227)), contiguous=False), View(shape=(64, 1, 64, 112, 112, 3, 7, 7), strides=(10156800, 0, 0, 3680, 2, 3385600, 425040, 231), offset=0, mask=None, contiguous=False))), src=()),)),
@@ -46,12 +46,12 @@ class TestLinearizerOverflow(unittest.TestCase):
UOp(Ops.LOAD, dtypes.float, arg=None, src=(
UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=3, src=()),
UOp(Ops.VIEW, None, arg=ShapeTracker(views=(View(shape=(64, 1, 64, 112, 112, 1, 1, 1), strides=(0, 0, 1, 0, 0, 0, 0, 0), offset=0, mask=None, contiguous=False),)), src=()),)),)),
UOp(Ops.ALU, dtypes.float, arg=UnaryOps.SQRT, src=(
UOp(Ops.ALU, dtypes.float, arg=BinaryOps.MUL, src=(
UOp(Ops.SQRT, dtypes.float, arg=None, src=(
UOp(Ops.MUL, dtypes.float, arg=None, src=(
x23:=ast_const(dtypes.float, 1.0, st_src=(
UOp(Ops.VIEW, None, arg=ShapeTracker(views=(View(shape=(64, 1, 64, 112, 112, 1, 1, 1), strides=(0, 0, 0, 0, 0, 0, 0, 0), offset=0, mask=None, contiguous=False),)), src=()),)),
UOp(Ops.ALU, dtypes.float, arg=UnaryOps.RECIP, src=(
UOp(Ops.ALU, dtypes.float, arg=BinaryOps.ADD, src=(
UOp(Ops.RECIP, dtypes.float, arg=None, src=(
UOp(Ops.ADD, dtypes.float, arg=None, src=(
x23,
ast_const(dtypes.float, 1e-05, st_src=(
UOp(Ops.VIEW, None, arg=ShapeTracker(views=(View(shape=(64, 1, 64, 112, 112, 1, 1, 1), strides=(0, 0, 0, 0, 0, 0, 0, 0), offset=0, mask=None, contiguous=False),)), src=()),)),)),)),)),)),)),
@@ -69,7 +69,7 @@ class TestLinearizerOverflow(unittest.TestCase):
UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=0, src=()),
UOp(Ops.VIEW, None, arg=ShapeTracker(views=(View(shape=(512, 1, 64, 32, 32, 1, 1, 1), strides=(65536, 0, 1024, 32, 1, 0, 0, 0), offset=0, mask=None, contiguous=True),)), src=()),
UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (7, 6, 5)), src=(
UOp(Ops.ALU, dtypes.float, arg=BinaryOps.MUL, src=(
UOp(Ops.MUL, dtypes.float, arg=None, src=(
UOp(Ops.LOAD, dtypes.float, arg=None, src=(
UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=1, src=()),
UOp(Ops.VIEW, None, arg=ShapeTracker(views=(View(shape=(1, 512, 1, 32, 4, 34, 4, 34), strides=(0, 32768, 0, 1024, 0, 32, 0, 1), offset=-33, mask=((0, 1), (0, 512), (0, 1), (0, 32), (0, 4), (1, 33), (0, 4), (1, 33)), contiguous=False), View(shape=(512, 1, 64, 32, 32, 32, 3, 3), strides=(591872, 0, 0, 136, 1, 18496, 4760, 35), offset=0, mask=None, contiguous=False))), src=()),)),
@@ -86,7 +86,7 @@ class TestLinearizerOverflow(unittest.TestCase):
UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=0, src=()),
UOp(Ops.VIEW, None, arg=ShapeTracker(views=(View(shape=(16, 1, 128, 128, 128, 1, 1, 1), strides=(2097152, 0, 16384, 128, 1, 0, 0, 0), offset=0, mask=None, contiguous=True),)), src=()),
UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (7, 6, 5)), src=(
UOp(Ops.ALU, dtypes.float, arg=BinaryOps.MUL, src=(
UOp(Ops.MUL, dtypes.float, arg=None, src=(
UOp(Ops.LOAD, dtypes.float, arg=None, src=(
UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=1, src=()),
UOp(Ops.VIEW, None, arg=ShapeTracker(views=(View(shape=(1, 16, 1, 128, 4, 130, 4, 130), strides=(0, 2097152, 0, 16384, 0, 128, 0, 1), offset=-129, mask=((0, 1), (0, 16), (0, 1), (0, 128), (0, 4), (1, 129), (0, 4), (1, 129)), contiguous=False), View(shape=(16, 1, 128, 128, 128, 128, 3, 3), strides=(34611200, 0, 0, 520, 1, 270400, 68120, 131), offset=0, mask=None, contiguous=False))), src=()),)),
@@ -103,7 +103,7 @@ class TestLinearizerOverflow(unittest.TestCase):
UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=0, src=()),
UOp(Ops.VIEW, None, arg=ShapeTracker(views=(View(shape=(4, 1, 128, 128, 128, 1, 1, 1), strides=(2097152, 0, 16384, 128, 1, 0, 0, 0), offset=0, mask=None, contiguous=True),)), src=()),
UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (7, 6, 5)), src=(
UOp(Ops.ALU, dtypes.float, arg=BinaryOps.MUL, src=(
UOp(Ops.MUL, dtypes.float, arg=None, src=(
UOp(Ops.LOAD, dtypes.float, arg=None, src=(
UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=1, src=()),
UOp(Ops.VIEW, None, arg=ShapeTracker(views=(View(shape=(1, 4, 1, 128, 4, 130, 4, 130), strides=(0, 2097152, 0, 16384, 0, 128, 0, 1), offset=-129, mask=((0, 1), (0, 4), (0, 1), (0, 128), (0, 4), (1, 129), (0, 4), (1, 129)), contiguous=False), View(shape=(4, 1, 128, 128, 128, 128, 3, 3), strides=(34611200, 0, 0, 520, 1, 270400, 68120, 131), offset=0, mask=None, contiguous=False))), src=()),)),
@@ -120,7 +120,7 @@ class TestLinearizerOverflow(unittest.TestCase):
UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=0, src=()),
UOp(Ops.VIEW, None, arg=ShapeTracker(views=(View(shape=(2, 1, 128, 128, 128, 1, 1, 1), strides=(2097152, 0, 16384, 128, 1, 0, 0, 0), offset=0, mask=None, contiguous=True),)), src=()),
UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (7, 6, 5)), src=(
UOp(Ops.ALU, dtypes.float, arg=BinaryOps.MUL, src=(
UOp(Ops.MUL, dtypes.float, arg=None, src=(
UOp(Ops.LOAD, dtypes.float, arg=None, src=(
UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=1, src=()),
UOp(Ops.VIEW, None, arg=ShapeTracker(views=(View(shape=(1, 2, 1, 128, 4, 130, 4, 130), strides=(0, 2097152, 0, 16384, 0, 128, 0, 1), offset=-129, mask=((0, 1), (0, 2), (0, 1), (0, 128), (0, 4), (1, 129), (0, 4), (1, 129)), contiguous=False), View(shape=(2, 1, 128, 128, 128, 128, 3, 3), strides=(34611200, 0, 0, 520, 1, 270400, 68120, 131), offset=0, mask=None, contiguous=False))), src=()),)),
@@ -137,7 +137,7 @@ class TestLinearizerOverflow(unittest.TestCase):
UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=0, src=()),
UOp(Ops.VIEW, None, arg=ShapeTracker(views=(View(shape=(3, 1, 128, 128, 128, 1, 1, 1), strides=(2097152, 0, 16384, 128, 1, 0, 0, 0), offset=0, mask=None, contiguous=True),)), src=()),
UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (7, 6, 5)), src=(
UOp(Ops.ALU, dtypes.float, arg=BinaryOps.MUL, src=(
UOp(Ops.MUL, dtypes.float, arg=None, src=(
UOp(Ops.LOAD, dtypes.float, arg=None, src=(
UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=1, src=()),
UOp(Ops.VIEW, None, arg=ShapeTracker(views=(View(shape=(1, 3, 1, 128, 4, 130, 4, 130), strides=(0, 2097152, 0, 16384, 0, 128, 0, 1), offset=-129, mask=((0, 1), (0, 3), (0, 1), (0, 128), (0, 4), (1, 129), (0, 4), (1, 129)), contiguous=False), View(shape=(3, 1, 128, 128, 128, 128, 3, 3), strides=(34611200, 0, 0, 520, 1, 270400, 68120, 131), offset=0, mask=None, contiguous=False))), src=()),)),
@@ -154,7 +154,7 @@ class TestLinearizerOverflow(unittest.TestCase):
UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=0, src=()),
UOp(Ops.VIEW, None, arg=ShapeTracker(views=(View(shape=(3, 1, 128, 128, 128, 1, 1, 1), strides=(2097152, 0, 16384, 128, 1, 0, 0, 0), offset=0, mask=None, contiguous=True),)), src=()),
UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (7, 6, 5)), src=(
UOp(Ops.ALU, dtypes.float, arg=BinaryOps.MUL, src=(
UOp(Ops.MUL, dtypes.float, arg=None, src=(
UOp(Ops.LOAD, dtypes.float, arg=None, src=(
UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=1, src=()),
UOp(Ops.VIEW, None, arg=ShapeTracker(views=(View(shape=(1, 3, 1, 128, 4, 130, 4, 130), strides=(0, 2097152, 0, 16384, 0, 128, 0, 1), offset=-129, mask=((0, 1), (0, 3), (0, 1), (0, 128), (0, 4), (1, 129), (0, 4), (1, 129)), contiguous=False), View(shape=(3, 1, 128, 128, 128, 128, 3, 3), strides=(34611200, 0, 0, 520, 1, 270400, 68120, 131), offset=0, mask=None, contiguous=False))), src=()),)),

View File

@@ -1689,10 +1689,10 @@ class TestIndexing(unittest.TestCase):
UOp(Ops.STORE, dtypes.void, arg=None, src=(
UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(), arg=0, src=()),
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 1), strides=(0, 0), offset=0, mask=None, contiguous=True),)), src=()), # noqa: E501
UOp(Ops.REDUCE_AXIS, dtypes.int, arg=(BinaryOps.ADD, (0, 1)), src=(
UOp(Ops.ALU, dtypes.int, arg=BinaryOps.ADD, src=(
UOp(Ops.REDUCE_AXIS, dtypes.int, arg=(Ops.ADD, (0, 1)), src=(
UOp(Ops.ADD, dtypes.int, arg=None, src=(
UOp(Ops.VIEW, dtypes.int, arg=ShapeTracker(views=(View(shape=(32, 32), strides=(0, 0), offset=0, mask=None, contiguous=False),)), src=( # noqa E501
UOp(Ops.REDUCE_AXIS, dtypes.int, arg=(BinaryOps.ADD, (0, 1)), src=(
UOp(Ops.REDUCE_AXIS, dtypes.int, arg=(Ops.ADD, (0, 1)), src=(
UOp(Ops.LOAD, dtypes.int, arg=None, src=(
x8:=UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(), arg=1, src=()),
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(32, 32), strides=(32, 1), offset=0, mask=None, contiguous=True),)), src=()),)),)),)), # noqa E501
@@ -1776,11 +1776,11 @@ class TestIndexing(unittest.TestCase):
UOp(Ops.CONTIGUOUS, dtypes.float, arg=None, src=(
x1,
UOp(Ops.VIEW, dtypes.float, arg=ShapeTracker(views=(View(shape=(1, 32, 32, 16), strides=(0, 32, 1, 1024), offset=0, mask=None, contiguous=False),)), src=(
UOp(Ops.ALU, dtypes.float, arg=BinaryOps.ADD, src=(
UOp(Ops.ALU, dtypes.float, arg=BinaryOps.ADD, src=(
UOp(Ops.ADD, dtypes.float, arg=None, src=(
UOp(Ops.ADD, dtypes.float, arg=None, src=(
UOp(Ops.VIEW, dtypes.float, arg=ShapeTracker(views=(View(shape=(1, 16, 32, 32), strides=(0, 1, 512, 16), offset=0, mask=None, contiguous=False),)), src=(
UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (7, 8)), src=(
UOp(Ops.ALU, dtypes.float, arg=BinaryOps.MUL, src=(
UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(Ops.ADD, (7, 8)), src=(
UOp(Ops.MUL, dtypes.float, arg=None, src=(
UOp(Ops.VIEW, dtypes.float, arg=ShapeTracker(views=(View(shape=(1, 32, 32, 1, 1, 4, 4, 4, 4, 1, 1), strides=(0, 512, 16, 0, 0, 0, 0, 4, 1, 0, 0), offset=0, mask=None, contiguous=False),)), src=(
x11:=UOp(Ops.LOAD, dtypes.float, arg=None, src=(
UOp(Ops.BUFFER, dtypes.float.ptr(), arg=(2, ('METAL', 16384, dtypes.float)), src=()),

View File

@@ -108,12 +108,12 @@ class TestBEAM(unittest.TestCase):
UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=0, src=()),
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 1, 256), strides=(0, 0, 1), offset=0, mask=None, contiguous=True),)), src=()), # noqa: E501
UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.MAX, (1,)), src=(
UOp(Ops.ALU, dtypes.float, arg=BinaryOps.MUL, src=(
UOp(Ops.ALU, dtypes.float, arg=BinaryOps.ADD, src=(
UOp(Ops.ALU, dtypes.float, arg=BinaryOps.ADD, src=(
UOp(Ops.ALU, dtypes.float, arg=BinaryOps.ADD, src=(
UOp(Ops.ALU, dtypes.float, arg=BinaryOps.ADD, src=(
UOp(Ops.ALU, dtypes.float, arg=BinaryOps.ADD, src=(
UOp(Ops.MUL, dtypes.float, arg=None, src=(
UOp(Ops.ADD, dtypes.float, arg=None, src=(
UOp(Ops.ADD, dtypes.float, arg=None, src=(
UOp(Ops.ADD, dtypes.float, arg=None, src=(
UOp(Ops.ADD, dtypes.float, arg=None, src=(
UOp(Ops.ADD, dtypes.float, arg=None, src=(
UOp(Ops.LOAD, dtypes.float, arg=None, src=(
UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=1, src=()),
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(384768,), strides=(1,), offset=0, mask=((0, 64128),), contiguous=False), View(shape=(1, 501, 256), strides=(0, 1, 501), offset=256512, mask=None, contiguous=False))), src=()),)), # noqa: E501

View File

@@ -2,7 +2,7 @@ from typing import List
import unittest, time
from tinygrad import dtypes, Device
from tinygrad.helpers import DEBUG
from tinygrad.ops import BinaryOps, TernaryOps, UnaryOps, Ops, UOp, KernelInfo
from tinygrad.ops import BinaryOps, Ops, UOp, KernelInfo
from tinygrad.ops import UPat, PatternMatcher
from tinygrad.renderer import Renderer
from tinygrad.codegen.lowerer import rewrite_shapetracker_with_index
@@ -24,7 +24,7 @@ class TestGraphRewriteEfficiency(unittest.TestCase):
c1 = UOp.const(dtypes.int, 1)
c2 = UOp.const(dtypes.int, 2)
st = time.perf_counter()
uops = [UOp(Ops.ALU, dtypes.int, (c1, c2), BinaryOps.ADD) for _ in range(10000)]
uops = [UOp(Ops.ADD, dtypes.int, (c1, c2)) for _ in range(10000)]
et = time.perf_counter() - st
print(f"created {len(uops)} uops in {et*1000:.2f} ms")
@@ -35,9 +35,9 @@ class TestGraphRewriteEfficiency(unittest.TestCase):
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(2, 4, 64, 8, 16, 1, 1, 3, 3, 4, 1),
strides=(1179648, 9216, 1, 147456, 576, 0, 0, 64, 192, 36864, 0),
offset=0, mask=None, contiguous=False),)), src=()),
UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (5, 6, 10)), src=(
UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(Ops.ADD, (5, 6, 10)), src=(
UOp(Ops.CAST, dtypes.float, arg=None, src=(
UOp(Ops.ALU, dtypes.half, arg=BinaryOps.MUL, src=(
UOp(Ops.MUL, dtypes.half, arg=None, src=(
UOp(Ops.LOAD, dtypes.half, arg=None, src=(
UOp(Ops.DEFINE_GLOBAL, dtypes.half.ptr(), arg=1, src=()),
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(
@@ -194,7 +194,7 @@ class TestUOpGraph(unittest.TestCase):
def test_add_constant_fold(self):
c1 = UOp(Ops.CONST, dtypes.float, arg=1.0)
c2 = UOp(Ops.CONST, dtypes.float, arg=2.0)
out = UOp(Ops.ALU, dtypes.float, (c1, c2), BinaryOps.ADD)
out = UOp(Ops.ADD, dtypes.float, (c1, c2))
uops = to_uops_list([out])
self.assertEqual(len(uops), 1)
out = uops[-1]
@@ -204,9 +204,9 @@ class TestUOpGraph(unittest.TestCase):
def test_where_same_fold(self):
v = UOp.variable('tmp', 0, 1)
c0 = UOp(Ops.CONST, dtypes.int, arg=0)
vc = UOp(Ops.ALU, dtypes.bool, (v, c0), BinaryOps.CMPNE)
vc = UOp(Ops.CMPNE, dtypes.bool, (v, c0))
c1 = UOp(Ops.CONST, dtypes.float, arg=1.0)
out = UOp(Ops.ALU, dtypes.float, (vc, c1, c1), TernaryOps.WHERE)
out = UOp(Ops.WHERE, dtypes.float, (vc, c1, c1))
uops = to_uops_list([out])
self.assertEqual(len(uops), 1)
out = uops[-1]
@@ -217,7 +217,7 @@ class TestUOpGraph(unittest.TestCase):
bf = UOp(Ops.CONST, dtypes.bool, arg=False)
c1 = UOp(Ops.CONST, dtypes.float, arg=1.0)
c2 = UOp(Ops.CONST, dtypes.float, arg=2.0)
out = UOp(Ops.ALU, dtypes.float, (bf, c1, c2), TernaryOps.WHERE)
out = UOp(Ops.WHERE, dtypes.float, (bf, c1, c2))
uops = to_uops_list([out])
self.assertEqual(len(uops), 1)
out = uops[-1]
@@ -240,7 +240,7 @@ class TestUOpGraph(unittest.TestCase):
ld = UOp(Ops.LOAD, dtypes.float.vec(2), (d0, idx))
vec = UOp(Ops.VECTORIZE, dtypes.float.vec(2), (ld,))
x = UOp(Ops.GEP, dtypes.float, (vec, ), arg=0)
alu = UOp(Ops.ALU, dtypes.float, (x, ), UnaryOps.SQRT)
alu = UOp(Ops.SQRT, dtypes.float, (x, ))
out = UOp(Ops.STORE, dtypes.void, (d0, idx, alu))
uops = to_uops_list([out])
self.assertEqual(len([x for x in uops if x.op is Ops.VECTORIZE]), 0)
@@ -375,12 +375,12 @@ class TestUOpGraph(unittest.TestCase):
v = UOp.variable("tmp", 0, 1)
c2 = UOp(Ops.CONST, dtypes.int, arg=2)
c4 = UOp(Ops.CONST, dtypes.int, arg=4)
vc = UOp(Ops.ALU, dtypes.int, (v, c2), BinaryOps.ADD)
out = UOp(Ops.ALU, dtypes.int, (vc, c4), BinaryOps.ADD)
vc = UOp(Ops.ADD, dtypes.int, (v, c2))
out = UOp(Ops.ADD, dtypes.int, (vc, c4))
uops = to_uops_list([out])
self.assertEqual(len(uops), 3)
out = uops[-1]
self.assertEqual(out.op, BinaryOps.ADD)
self.assertEqual(out.op, Ops.ADD)
self.assertEqual(out.src[1].op, Ops.CONST)
self.assertEqual(out.src[1].arg, 6)
@@ -436,7 +436,7 @@ class TestUOpGraph(unittest.TestCase):
cf = UOp.const(dtypes.float, 0.0)
r1 = UOp(Ops.RANGE, dtypes.int, (c0, c2), (1, 0, False))
r2 = UOp(Ops.RANGE, dtypes.int, (c0, c2), (1, 1, False))
alu = UOp(Ops.ALU, dtypes.int, (r2, r1), BinaryOps.MUL)
alu = UOp(Ops.MUL, dtypes.int, (r2, r1))
store = UOp(Ops.STORE, dtypes.void, (glbl.index(alu), cf))
uops = to_uops_list([store])
ranges = [x for x in uops if x.op is Ops.RANGE]

View File

@@ -33,7 +33,7 @@ def _test_single_value(vals, op, dts):
buf_store = uop(uops, Ops.DEFINE_GLOBAL, output_dtype.ptr(), (), 0)
buf_loads = [uop(uops, Ops.DEFINE_GLOBAL, dtype.ptr(), (), i+1) for i,dtype in enumerate(dts)]
loads = (uop(uops, Ops.LOAD, dtype, [buf_loads[i].index(uop(uops, Ops.CONST, dtypes.int32, (), 0))]) for i, dtype in enumerate(dts))
alu = uop(uops, Ops.ALU, output_dtype, loads, op)
alu = uop(uops, op, output_dtype, loads)
out = uop(uops, Ops.STORE, dtypes.void, (buf_store.index(uop(uops, Ops.CONST, dtypes.int32, (), 0)), alu))
buf = Buffer(Device.DEFAULT, 1, output_dtype).allocate()
buf2 = [Buffer(Device.DEFAULT, 1, dtype).allocate().copyin(np.array([a], dtype=_to_np_dtype(dtype)).data) for a,dtype in zip(vals, dts)]
@@ -48,7 +48,7 @@ def _test_single_value_const(vals, op, dts):
output_dtype = dtypes.bool if op in (BinaryOps.CMPLT, BinaryOps.CMPNE) else dts[-1]
buf_store = uop(uops, Ops.DEFINE_GLOBAL, output_dtype.ptr(), (), 0)
loads = (uop(uops, Ops.CONST, dtype, [], a) for a,dtype in zip(vals, dts))
alu = uop(uops, Ops.ALU, output_dtype, loads, op)
alu = uop(uops, op, output_dtype, loads)
out = uop(uops, Ops.STORE, dtypes.void, (buf_store.index(uop(uops, Ops.CONST, dtypes.int32, (), 0)), alu))
buf = Buffer(Device.DEFAULT, 1, output_dtype).allocate()
prg = _uops_to_prg([out])
@@ -332,8 +332,8 @@ class TestAssembly(unittest.TestCase):
c1 = UOp(Ops.CONST, dtypes.int, (), 2)
c2 = UOp(Ops.CONST, dtypes.int, (), 3)
l1 = UOp(Ops.LOAD, dtypes.int, (g1.index(c1),))
a1 = UOp(Ops.ALU, dtypes.int, (l1, c1), BinaryOps.MUL)
a2 = UOp(Ops.ALU, dtypes.int, (l1, c2), BinaryOps.MUL)
a1 = UOp(Ops.MUL, dtypes.int, (l1, c1))
a2 = UOp(Ops.MUL, dtypes.int, (l1, c2))
uops = to_uops_list([a1,a2], opts=Device[Device.DEFAULT].renderer)
Device[Device.DEFAULT].renderer.render("test", uops)
self.assertEqual(uops[-1].op, BinaryOps.SHL)
@@ -344,8 +344,8 @@ class TestAssembly(unittest.TestCase):
c1 = UOp(Ops.CONST, dtypes.int, (), 2)
c2 = UOp(Ops.CONST, dtypes.int, (), 3)
l1 = UOp(Ops.LOAD, dtypes.int, (g1.index(c1),))
a1 = UOp(Ops.ALU, dtypes.int, (l1, c1), BinaryOps.IDIV)
a2 = UOp(Ops.ALU, dtypes.int, (l1, c2), BinaryOps.IDIV)
a1 = UOp(Ops.IDIV, dtypes.int, (l1, c1))
a2 = UOp(Ops.IDIV, dtypes.int, (l1, c2))
uops = to_uops_list([a1,a2], opts=Device[Device.DEFAULT].renderer)
Device[Device.DEFAULT].renderer.render("test", uops)
self.assertEqual(uops[-1].op, BinaryOps.SHR)
@@ -357,8 +357,8 @@ class TestUOpMethod(unittest.TestCase):
a = UOp(Ops.CONST, dtypes.float, (), 2.0)
b = UOp(Ops.CONST, dtypes.float, (), 3.0)
add = UOp(Ops.ALU, dtypes.float, (a, b), BinaryOps.ADD)
mul = UOp(Ops.ALU, dtypes.float, (a, b), BinaryOps.MUL)
add = UOp(Ops.ADD, dtypes.float, (a, b))
mul = UOp(Ops.MUL, dtypes.float, (a, b))
assert (add < mul) or (mul < add), "add and mul with same src should have an order"
def test_uop_variables(self):

View File

@@ -4,7 +4,7 @@ from tinygrad.helpers import getenv, GlobalCounters
from tinygrad.engine.schedule import create_schedule
from tinygrad.engine.realize import lower_schedule_item
from tinygrad.codegen.linearize import linearize_uop
from tinygrad.ops import BinaryOps, TernaryOps, flops_mem, Ops, UOp
from tinygrad.ops import flops_mem, Ops, UOp
from tinygrad.dtype import dtypes
from tinygrad.codegen.kernel import Kernel, Opt, OptOps, KernelOptError
@@ -125,8 +125,8 @@ class TestUOpsStats(unittest.TestCase):
u1 = UOp(Ops.LOAD, dtypes.int, (globl.index(o1),))
u2 = UOp(Ops.LOAD, dtypes.int, (globl.index(o2),))
u3 = UOp(Ops.CONST, dtypes.int, tuple(), 3)
u4 = UOp(Ops.ALU, dtypes.int, (u1,u2), BinaryOps.MUL)
u5 = UOp(Ops.ALU, dtypes.int, (u4,u3), BinaryOps.ADD)
u4 = UOp(Ops.MUL, dtypes.int, (u1,u2))
u5 = UOp(Ops.ADD, dtypes.int, (u4,u3))
uops = linearize_uop(u5.sink())
globl = UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(), tuple())
@@ -135,7 +135,7 @@ class TestUOpsStats(unittest.TestCase):
u1 = UOp(Ops.LOAD, dtypes.int, (globl.index(o1),))
u2 = UOp(Ops.LOAD, dtypes.int, (globl.index(o2),))
u3 = UOp(Ops.CONST, dtypes.int, tuple(), 3)
u4 = UOp(Ops.ALU, dtypes.int, (u1,u2,u3), TernaryOps.MULACC)
u4 = UOp(Ops.MULACC, dtypes.int, (u1,u2,u3))
uops_fma = linearize_uop(u4.sink())
self.assertEqual(flops_mem(uops), flops_mem(uops_fma))

View File

@@ -54,7 +54,7 @@ class TestPatternMatcher(unittest.TestCase):
def test_uop(self):
matcher = PatternMatcher([(UPat(Ops.CONST, name="x"), lambda x: x)])
c1 = UOp(Ops.CONST, dtypes.float, arg=1.0)
c2 = UOp(Ops.ALU, dtypes.float, (c1, c1), BinaryOps.ADD)
c2 = UOp(Ops.ADD, dtypes.float, (c1, c1))
self.assertEqual(matcher.rewrite(c1), c1)
self.assertEqual(matcher.rewrite(c2), None)
@@ -63,7 +63,7 @@ class TestPatternMatcher(unittest.TestCase):
c1 = UOp(Ops.CONST, dtypes.bool, arg=False)
c2 = UOp(Ops.CAST, dtypes.int, (c1,))
c3 = UOp(Ops.CONST, dtypes.float, arg=1.0)
c4 = UOp(Ops.ALU, dtypes.float, (c3, c3), BinaryOps.ADD)
c4 = UOp(Ops.ADD, dtypes.float, (c3, c3))
self.assertEqual(matcher.rewrite(c1), c1)
self.assertEqual(matcher.rewrite(c2), c2)
self.assertEqual(matcher.rewrite(c4), None)
@@ -76,8 +76,8 @@ class TestPatternMatcher(unittest.TestCase):
])
c1 = UOp(Ops.CONST, dtypes.float, arg=0.0)
c2 = UOp(Ops.CONST, dtypes.bool, arg=False)
c3 = UOp(Ops.ALU, dtypes.float, (c1, c1), arg=BinaryOps.MAX)
c4 = UOp(Ops.ALU, dtypes.float, (c1, c1), arg=BinaryOps.MUL)
c3 = UOp(Ops.MAX, dtypes.float, (c1, c1))
c4 = UOp(Ops.MUL, dtypes.float, (c1, c1))
c5 = UOp(Ops.CONST, dtypes.int, arg=-1)
self.assertEqual(matcher.rewrite(c1), c1)
self.assertEqual(matcher.rewrite(c2), c2)
@@ -93,11 +93,11 @@ class TestPatternMatcher(unittest.TestCase):
y1 = UOp(Ops.CONST, dtypes.int, arg=1)
y2 = UOp(Ops.CONST, dtypes.int, arg=2)
y3 = UOp(Ops.CONST, dtypes.int, arg=-1)
c1 = UOp(Ops.ALU, dtypes.int, (y1, y2), BinaryOps.MUL)
c2 = UOp(Ops.ALU, dtypes.int, (y2, y2), BinaryOps.MUL)
c3 = UOp(Ops.ALU, dtypes.int, (y3, y2), BinaryOps.MUL)
c4 = UOp(Ops.ALU, dtypes.int, (y2, y1), BinaryOps.MUL)
c5 = UOp(Ops.ALU, dtypes.int, (y2, y3), BinaryOps.MUL)
c1 = UOp(Ops.MUL, dtypes.int, (y1, y2))
c2 = UOp(Ops.MUL, dtypes.int, (y2, y2))
c3 = UOp(Ops.MUL, dtypes.int, (y3, y2))
c4 = UOp(Ops.MUL, dtypes.int, (y2, y1))
c5 = UOp(Ops.MUL, dtypes.int, (y2, y3))
self.assertEqual(matcher.rewrite(c1), c1)
self.assertEqual(matcher.rewrite(c2), None)
self.assertEqual(matcher.rewrite(c3), c3)
@@ -135,7 +135,7 @@ class TestPatternMatcher(unittest.TestCase):
matcher = PatternMatcher([(UPat(GroupOp.ALU, name="x", src=(UPat(Ops.CONST), UPat(Ops.CONST))), lambda x: x)])
c1 = UOp(Ops.CONST, dtypes.float, arg=1.0)
c2 = UOp(Ops.CONST, dtypes.float, arg=2.0)
c3 = UOp(Ops.ALU, dtypes.float, (c1,c2), BinaryOps.ADD)
c3 = UOp(Ops.ADD, dtypes.float, (c1,c2))
self.assertEqual(matcher.rewrite(c3), c3)
self.assertEqual(matcher.rewrite(c2), None)
# that CONST/ALU -> ALU/CONST rewrite is now instant
@@ -152,10 +152,10 @@ class TestPatternMatcher(unittest.TestCase):
matcher = PatternMatcher([(UPat(GroupOp.ALU, name="x", src=[UPat(Ops.CONST), UPat(GroupOp.ALU)]), lambda x: x)])
c1 = UOp(Ops.CONST, dtypes.float, arg=1.0)
c2 = UOp(Ops.CONST, dtypes.float, arg=2.0)
c3 = UOp(Ops.ALU, dtypes.float, (c1,c2), BinaryOps.ADD)
c4 = UOp(Ops.ALU, dtypes.float, (c3,c2), BinaryOps.ADD)
c5 = UOp(Ops.ALU, dtypes.float, (c2,c3), BinaryOps.ADD)
c6 = UOp(Ops.ALU, dtypes.float, (c3,c4), BinaryOps.ADD)
c3 = UOp(Ops.ADD, dtypes.float, (c1,c2))
c4 = UOp(Ops.ADD, dtypes.float, (c3,c2))
c5 = UOp(Ops.ADD, dtypes.float, (c2,c3))
c6 = UOp(Ops.ADD, dtypes.float, (c3,c4))
self.assertEqual(matcher.rewrite(c3), None)
self.assertEqual(matcher.rewrite(c4), c4)
self.assertEqual(matcher.rewrite(c5), c5)
@@ -165,8 +165,8 @@ class TestPatternMatcher(unittest.TestCase):
matcher = PatternMatcher([(UPat(GroupOp.ALU, name="x", src=UPat(Ops.CONST)), lambda x: x)])
c1 = UOp(Ops.CONST, dtypes.float, arg=1.0)
c2 = UOp(Ops.CONST, dtypes.float, arg=2.0)
c3 = UOp(Ops.ALU, dtypes.float, (c1,c2), BinaryOps.ADD)
c4 = UOp(Ops.ALU, dtypes.float, (c2,c3), BinaryOps.ADD)
c3 = UOp(Ops.ADD, dtypes.float, (c1,c2))
c4 = UOp(Ops.ADD, dtypes.float, (c2,c3))
self.assertEqual(matcher.rewrite(c3), c3)
self.assertEqual(matcher.rewrite(c4), None)
@@ -175,9 +175,9 @@ class TestPatternMatcher(unittest.TestCase):
c1 = UOp(Ops.CONST, dtypes.float, arg=1.0)
c2 = UOp(Ops.CONST, dtypes.float, arg=2.0)
c3 = UOp(Ops.CONST, dtypes.float, arg=3.0)
c4 = UOp(Ops.ALU, dtypes.float, (c1,), UnaryOps.EXP2)
c5 = UOp(Ops.ALU, dtypes.float, (c1,c2), BinaryOps.ADD)
c6 = UOp(Ops.ALU, dtypes.float, (c1,c2,c3), TernaryOps.MULACC)
c4 = UOp(Ops.EXP2, dtypes.float, (c1,))
c5 = UOp(Ops.ADD, dtypes.float, (c1,c2))
c6 = UOp(Ops.MULACC, dtypes.float, (c1,c2,c3))
self.assertEqual(matcher.rewrite(c4), None)
self.assertEqual(matcher.rewrite(c5), None)
self.assertEqual(matcher.rewrite(c6), c6)