mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-07 22:23:55 -05:00
refactor LOAD(DEFINE_GLOBAL, VIEW) in kernels to LOAD(VIEW(DEFINE_GLOBAL)) (#10541)
* changes to core tinygrad * fixups pt1 TC=3 docs/abstractions2.py IMAGE=2 test_quantize_dsp test_schedule * more tests * green now * images stay images
This commit is contained in:
@@ -51,11 +51,11 @@ b = Buffer(DEVICE, 1, dtypes.int32).allocate().copyin(memoryview(bytearray(struc
|
||||
# describe the computation
|
||||
buf_1 = UOp(Ops.DEFINE_GLOBAL, dtypes.int32.ptr(), (), 1)
|
||||
buf_2 = UOp(Ops.DEFINE_GLOBAL, dtypes.int32.ptr(), (), 2)
|
||||
ld_1 = UOp(Ops.LOAD, dtypes.int32, (buf_1, ShapeTracker.from_shape((1,)).to_uop()))
|
||||
ld_2 = UOp(Ops.LOAD, dtypes.int32, (buf_2, ShapeTracker.from_shape((1,)).to_uop()))
|
||||
ld_1 = UOp(Ops.LOAD, dtypes.int32, (buf_1.view(ShapeTracker.from_shape((1,))),))
|
||||
ld_2 = UOp(Ops.LOAD, dtypes.int32, (buf_2.view(ShapeTracker.from_shape((1,))),))
|
||||
alu = ld_1 + ld_2
|
||||
output_buf = UOp(Ops.DEFINE_GLOBAL, dtypes.int32.ptr(), (), 0)
|
||||
st_0 = UOp(Ops.STORE, dtypes.void, (output_buf, ShapeTracker.from_shape((1,)).to_uop(), alu))
|
||||
st_0 = UOp(Ops.STORE, dtypes.void, (output_buf.view(ShapeTracker.from_shape((1,))), alu))
|
||||
s = UOp(Ops.SINK, dtypes.void, (st_0,))
|
||||
|
||||
# convert the computation to a "linearized" format (print the format)
|
||||
|
||||
Binary file not shown.
@@ -90,10 +90,10 @@ class TestLinearizer(unittest.TestCase):
|
||||
def test_multioutput(self):
|
||||
dtype, st = dtypes.int, ShapeTracker.from_shape((8,))
|
||||
g0, g1, g2, g3 = [UOp(Ops.DEFINE_GLOBAL, dtype.ptr(), arg=i) for i in range(4)]
|
||||
a = UOp(Ops.LOAD, dtype, (g2, st.to_uop()))
|
||||
b = UOp(Ops.LOAD, dtype, (g3, st.to_uop()))
|
||||
out0 = UOp(Ops.STORE, dtypes.void, (g0, st.to_uop(), a + b))
|
||||
out1 = UOp(Ops.STORE, dtypes.void, (g1, st.to_uop(), a * b))
|
||||
a = UOp(Ops.LOAD, dtype, src=(g2.view(st),))
|
||||
b = UOp(Ops.LOAD, dtype, src=(g3.view(st),))
|
||||
out0 = UOp(Ops.STORE, dtypes.void, src=(g0.view(st), a + b))
|
||||
out1 = UOp(Ops.STORE, dtypes.void, src=(g1.view(st), a * b))
|
||||
sink = UOp(Ops.SINK, src=(out0, out1))
|
||||
|
||||
a_t = Tensor.full(st.shape, 2).contiguous().realize()
|
||||
@@ -142,12 +142,12 @@ class TestLinearizer(unittest.TestCase):
|
||||
x = Tensor.randn(32, dtype=dtypes.float).realize()
|
||||
st_x = x.lazydata.st
|
||||
g0, g1 = [UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=i) for i in range(2)]
|
||||
first_x = UOp(Ops.LOAD, dtypes.float, (g1, st_x.reshape((1, 32)).expand((32, 32)).to_uop()))
|
||||
first_x = UOp(Ops.LOAD, dtypes.float, (g1.view(st_x.reshape((1, 32)).expand((32, 32))),))
|
||||
first_reduce = UOp(Ops.REDUCE_AXIS, dtypes.float, (first_x,), (Ops.ADD, (1,)))
|
||||
second_x = UOp(Ops.LOAD, dtypes.float, (g1, st_x.reshape((32, 1)).to_uop()))
|
||||
second_x = UOp(Ops.LOAD, dtypes.float, (g1.view(st_x.reshape((32, 1))),))
|
||||
diff = second_x + first_reduce*ast_const(dtypes.float, -1, (32, 1))
|
||||
second_reduce = UOp(Ops.REDUCE_AXIS, dtypes.float, (diff,), (Ops.ADD, (0,)))
|
||||
store = UOp(Ops.STORE, dtypes.void, (g0, ShapeTracker.from_shape((1, 1)).to_uop(), second_reduce))
|
||||
store = UOp(Ops.STORE, dtypes.void, (g0.view(ShapeTracker.from_shape((1, 1))), second_reduce))
|
||||
sink = UOp(Ops.SINK, src=(store,))
|
||||
opts = [
|
||||
[Opt(OptOps.GROUPTOP, 0, 2), Opt(OptOps.GROUPTOP, 1, 2)], # grouping
|
||||
@@ -174,12 +174,12 @@ class TestLinearizer(unittest.TestCase):
|
||||
x = Tensor.randn(27, 32, 5, dtype=dtypes.float).realize()
|
||||
st_x = x.lazydata.st
|
||||
g0, g1 = [UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=i) for i in range(2)]
|
||||
first_x = UOp(Ops.LOAD, dtypes.float, (g1, st_x.reshape((27, 1, 32, 5)).expand((27, 32, 32, 5)).to_uop()))
|
||||
first_x = UOp(Ops.LOAD, dtypes.float, (g1.view(st_x.reshape((27, 1, 32, 5)).expand((27, 32, 32, 5))),))
|
||||
first_reduce = UOp(Ops.REDUCE_AXIS, dtypes.float, (first_x,), (Ops.ADD, (2,)))
|
||||
second_x = UOp(Ops.LOAD, dtypes.float, (g1, st_x.reshape((27, 32, 1, 5)).to_uop()))
|
||||
second_x = UOp(Ops.LOAD, dtypes.float, (g1.view(st_x.reshape((27, 32, 1, 5))),))
|
||||
diff = second_x + first_reduce*ast_const(dtypes.float, -1, (27, 32, 1, 5))
|
||||
second_reduce = UOp(Ops.REDUCE_AXIS, dtypes.float, (diff,), (Ops.ADD, (1,)))
|
||||
store = UOp(Ops.STORE, src=(g0, ShapeTracker.from_shape((27, 1, 1, 5)).to_uop(), second_reduce))
|
||||
store = UOp(Ops.STORE, src=(g0.view(ShapeTracker.from_shape((27, 1, 1, 5))), second_reduce))
|
||||
sink = UOp(Ops.SINK, src=(store,))
|
||||
opts = [
|
||||
# locals
|
||||
@@ -232,15 +232,15 @@ class TestLinearizer(unittest.TestCase):
|
||||
x1 = Tensor.randn(27, 32, 5, dtype=dtypes.float).realize()
|
||||
x2 = Tensor.randn(27, 32, 5, dtype=dtypes.float).realize()
|
||||
g0, g1, g2, g3 = [UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=i) for i in range(4)]
|
||||
first_x = UOp(Ops.LOAD, dtypes.float, (g1, x0.lazydata.st.reshape((27, 1, 1, 32, 5)).expand((27, 32, 32, 32, 5)).to_uop()))
|
||||
first_x = UOp(Ops.LOAD, dtypes.float, (g1.view(x0.lazydata.st.reshape((27, 1, 1, 32, 5)).expand((27, 32, 32, 32, 5))),))
|
||||
first_reduce = UOp(Ops.REDUCE_AXIS, dtypes.float, (first_x,), (Ops.ADD, (3,)))
|
||||
second_x = UOp(Ops.LOAD, dtypes.float, (g2, x1.lazydata.st.reshape((27, 1, 32, 1, 5)).expand((27, 32, 32, 1, 5)).to_uop()))
|
||||
second_x = UOp(Ops.LOAD, dtypes.float, (g2.view(x1.lazydata.st.reshape((27, 1, 32, 1, 5)).expand((27, 32, 32, 1, 5))),))
|
||||
diff = (second_x+first_reduce*ast_const(dtypes.float, -1, (27, 32, 32, 1, 5)))
|
||||
second_reduce = UOp(Ops.REDUCE_AXIS, dtypes.float, (diff,), (Ops.ADD, (2,)))
|
||||
third_x = UOp(Ops.LOAD, dtypes.float, (g3, x2.lazydata.st.reshape((27, 32, 1, 1, 5)).to_uop()))
|
||||
third_x = UOp(Ops.LOAD, dtypes.float, (g3.view(x2.lazydata.st.reshape((27, 32, 1, 1, 5))),))
|
||||
mul = (third_x*second_reduce)
|
||||
third_reduce = UOp(Ops.REDUCE_AXIS, dtypes.float, (mul,), (Ops.ADD, (1,)))
|
||||
store = UOp(Ops.STORE, src=(g0, ShapeTracker.from_shape((27, 1, 1, 1, 5)).to_uop(), third_reduce))
|
||||
store = UOp(Ops.STORE, src=(g0.view(ShapeTracker.from_shape((27, 1, 1, 1, 5))), third_reduce))
|
||||
sink = UOp(Ops.SINK, src=(store,))
|
||||
wanna_output = (x2.numpy()*(x1.numpy()-x0.numpy().sum(axis=1, keepdims=True)).sum(axis=1, keepdims=True)).sum(axis=1).reshape(27,1,1,1,5)
|
||||
lins = helper_linearizer_ast(sink, [x0,x1,x2], wanna_output=[wanna_output])
|
||||
@@ -302,12 +302,12 @@ class TestLinearizer(unittest.TestCase):
|
||||
Tensor.manual_seed(0)
|
||||
x = Tensor.randn(27, 15, 5, dtype=dtypes.float).softmax(1).realize()
|
||||
g0, g1 = [UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=i) for i in range(2)]
|
||||
first_x = UOp(Ops.LOAD, dtypes.float, (g1, x.lazydata.st.reshape((27, 1, 15, 5)).expand((27, 15, 15, 5)).to_uop()))
|
||||
first_x = UOp(Ops.LOAD, dtypes.float, (g1.view(x.lazydata.st.reshape((27, 1, 15, 5)).expand((27, 15, 15, 5))),))
|
||||
first_reduce = UOp(Ops.REDUCE_AXIS, dtypes.float, (first_x,), (Ops.ADD, (2,)))
|
||||
second_x = UOp(Ops.LOAD, dtypes.float, (g1, x.lazydata.st.reshape((27, 15, 1, 5)).to_uop()))
|
||||
second_x = UOp(Ops.LOAD, dtypes.float, (g1.view(x.lazydata.st.reshape((27, 15, 1, 5))),))
|
||||
diff = (second_x+first_reduce*ast_const(dtypes.float, -1, (27, 15, 1, 5)))
|
||||
second_reduce = UOp(Ops.REDUCE_AXIS, dtypes.float, (diff,), (Ops.ADD, (1,)))
|
||||
store = UOp(Ops.STORE, src=(g0, ShapeTracker.from_shape((27, 1, 1, 5)).to_uop(), second_reduce))
|
||||
store = UOp(Ops.STORE, src=(g0.view(ShapeTracker.from_shape((27, 1, 1, 5))), second_reduce))
|
||||
sink = UOp(Ops.SINK, src=(store,))
|
||||
opts = [
|
||||
[Opt(OptOps.GROUPTOP, 0, 3)], # grouping
|
||||
@@ -329,14 +329,14 @@ class TestLinearizer(unittest.TestCase):
|
||||
x = Tensor.randn(4, 32, dtype=dtypes.float).realize()
|
||||
x_p = Tensor.randn(4, 32, dtype=dtypes.float).realize()
|
||||
g0, g1, g2 = [UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=i) for i in range(3)]
|
||||
first_x = UOp(Ops.LOAD, dtypes.float, (g1, x.lazydata.st.reshape((4, 1, 32)).expand((4, 32, 32)).to_uop()))
|
||||
first_x_p = UOp(Ops.LOAD, dtypes.float, (g2, x_p.lazydata.st.reshape((4, 1, 32)).expand((4, 32, 32)).to_uop()))
|
||||
first_x = UOp(Ops.LOAD, dtypes.float, (g1.view(x.lazydata.st.reshape((4, 1, 32)).expand((4, 32, 32))),))
|
||||
first_x_p = UOp(Ops.LOAD, dtypes.float, (g2.view(x_p.lazydata.st.reshape((4, 1, 32)).expand((4, 32, 32))),))
|
||||
first_reduce = UOp(Ops.REDUCE_AXIS, dtypes.float, (first_x,), (Ops.ADD, (2,)))
|
||||
first_reduce_p = UOp(Ops.REDUCE_AXIS, dtypes.float, (first_x_p.alu(Ops.EXP2),), (Ops.ADD, (2,)))
|
||||
second_x = UOp(Ops.LOAD, dtypes.float, (g1, x.lazydata.st.reshape((4, 32, 1)).to_uop()))
|
||||
second_x = UOp(Ops.LOAD, dtypes.float, (g1.view(x.lazydata.st.reshape((4, 32, 1))),))
|
||||
diff = (second_x+(first_reduce + first_reduce_p)*ast_const(dtypes.float, -1, (4, 32, 1)))
|
||||
second_reduce = UOp(Ops.REDUCE_AXIS, dtypes.float, (diff,), (Ops.ADD, (1,)))
|
||||
store = UOp(Ops.STORE, src=(g0, ShapeTracker.from_shape((4, 1, 1)).to_uop(), second_reduce))
|
||||
store = UOp(Ops.STORE, src=(g0.view(ShapeTracker.from_shape((4, 1, 1))), second_reduce))
|
||||
sink = UOp(Ops.SINK, src=(store,))
|
||||
opts = [
|
||||
# [Opt(OptOps.GROUPTOP, 0, 2), Opt(OptOps.GROUPTOP, 1, 2)], # grouping
|
||||
@@ -361,14 +361,14 @@ class TestLinearizer(unittest.TestCase):
|
||||
Tensor.manual_seed(0)
|
||||
x = Tensor.randn(27, 15, 5, dtype=dtypes.float).realize()
|
||||
g0, g1, g2 = [UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=i) for i in range(3)]
|
||||
first_x = UOp(Ops.LOAD, dtypes.float, (g2, x.lazydata.st.reshape((27, 1, 15, 5)).expand((27, 15, 15, 5)).to_uop()))
|
||||
first_x = UOp(Ops.LOAD, dtypes.float, (g2.view(x.lazydata.st.reshape((27, 1, 15, 5)).expand((27, 15, 15, 5))),))
|
||||
first_reduce = UOp(Ops.REDUCE_AXIS, dtypes.float, (first_x,), (Ops.ADD, (2,)))
|
||||
second_x = UOp(Ops.LOAD, dtypes.float, (g2, x.lazydata.st.reshape((27, 15, 1, 5)).to_uop()))
|
||||
second_x = UOp(Ops.LOAD, dtypes.float, (g2.view(x.lazydata.st.reshape((27, 15, 1, 5))),))
|
||||
diff = (second_x+first_reduce*ast_const(dtypes.float, -1, (27, 15, 1, 5)))
|
||||
second_reduce = UOp(Ops.REDUCE_AXIS, dtypes.float, (diff,), (Ops.ADD, (1,)))
|
||||
store0 = UOp(Ops.STORE, src=(g0, ShapeTracker.from_shape((27, 1, 1, 5)).to_uop(), second_reduce))
|
||||
store0 = UOp(Ops.STORE, src=(g0.view(ShapeTracker.from_shape((27, 1, 1, 5))), second_reduce))
|
||||
second_out = second_reduce * ast_const(dtypes.float, 1/15, (27, 1, 1, 5))
|
||||
store1 = UOp(Ops.STORE, src=(g1, ShapeTracker.from_shape((27, 1, 1, 5)).to_uop(), second_out))
|
||||
store1 = UOp(Ops.STORE, src=(g1.view(ShapeTracker.from_shape((27, 1, 1, 5))), second_out))
|
||||
sink = UOp(Ops.SINK, src=(store0, store1))
|
||||
wanna_output = (x.numpy()-x.numpy().sum(axis=1, keepdims=True)).sum(axis=1).reshape(27,1,1,5)
|
||||
|
||||
@@ -383,13 +383,13 @@ class TestLinearizer(unittest.TestCase):
|
||||
Tensor.manual_seed(0)
|
||||
x = Tensor.randn(27, 15, 5, dtype=dtypes.float).realize()
|
||||
g0, g1, g2 = [UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=i) for i in range(3)]
|
||||
first_x = UOp(Ops.LOAD, dtypes.float, (g2, x.lazydata.st.reshape((27, 1, 15, 5)).expand((27, 15, 15, 5)).to_uop()))
|
||||
first_x = UOp(Ops.LOAD, dtypes.float, src=(g2.view(x.lazydata.st.reshape((27, 1, 15, 5)).expand((27, 15, 15, 5))),))
|
||||
first_reduce = UOp(Ops.REDUCE_AXIS, dtypes.float, (first_x,), (Ops.ADD, (2,)))
|
||||
second_x = UOp(Ops.LOAD, dtypes.float, (g2, x.lazydata.st.reshape((27, 15, 1, 5)).to_uop()))
|
||||
second_x = UOp(Ops.LOAD, dtypes.float, src=(g2.view(x.lazydata.st.reshape((27, 15, 1, 5))),))
|
||||
diff = (second_x+first_reduce*ast_const(dtypes.float, -1, (27, 15, 1, 5)))
|
||||
second_reduce = UOp(Ops.REDUCE_AXIS, dtypes.float, (diff,), (Ops.ADD, (1,)))
|
||||
store0 = UOp(Ops.STORE, src=(g0, ShapeTracker.from_shape((27, 1, 1, 5)).to_uop(), second_reduce))
|
||||
store1 = UOp(Ops.STORE, src=(g1, ShapeTracker(views=(View(shape=(27,15,1,5), strides=(5,0,1,1), offset=0, mask=None, contiguous=False),)).to_uop(), first_reduce)) # noqa: E501
|
||||
store0 = UOp(Ops.STORE, src=(g0.view(ShapeTracker.from_shape((27, 1, 1, 5))), second_reduce))
|
||||
store1 = UOp(Ops.STORE, src=(g1.view(ShapeTracker(views=(View(shape=(27,15,1,5), strides=(5,0,1,1), offset=0, mask=None, contiguous=False),))), first_reduce)) # noqa: E501
|
||||
wanna_output0 = (x.numpy()-x.numpy().sum(axis=1, keepdims=True)).sum(axis=1).reshape(27,1,1,5)
|
||||
wanna_output1 = x.numpy().sum(axis=1).reshape(27,1,1,5)
|
||||
|
||||
@@ -402,12 +402,12 @@ class TestLinearizer(unittest.TestCase):
|
||||
Tensor.manual_seed(0)
|
||||
x = Tensor.randn(27, 3, 5, dtype=dtypes.float).realize()
|
||||
g0, g1 = [UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=i) for i in range(2)]
|
||||
first_x = UOp(Ops.LOAD, dtypes.float, (g1, x.lazydata.st.reshape((27, 1, 3, 5)).expand((27, 3, 3, 5)).to_uop()))
|
||||
first_x = UOp(Ops.LOAD, dtypes.float, src=(g1.view(x.lazydata.st.reshape((27, 1, 3, 5)).expand((27, 3, 3, 5))),))
|
||||
first_reduce = UOp(Ops.REDUCE_AXIS, dtypes.float, (first_x,), (Ops.ADD, (2,)))
|
||||
second_x = UOp(Ops.LOAD, dtypes.float, (g1, x.lazydata.st.reshape((27, 3, 1, 5)).to_uop()))
|
||||
second_x = UOp(Ops.LOAD, dtypes.float, src=(g1.view(x.lazydata.st.reshape((27, 3, 1, 5))),))
|
||||
diff = (second_x+first_reduce*ast_const(dtypes.float, -1, (27, 3, 1, 5)))
|
||||
second_reduce = UOp(Ops.REDUCE_AXIS, dtypes.float, (diff,), (Ops.ADD, (1,)))
|
||||
store = UOp(Ops.STORE, src=(g0, ShapeTracker.from_shape((27, 1, 1, 5)).to_uop(), second_reduce))
|
||||
store = UOp(Ops.STORE, src=(g0.view(ShapeTracker.from_shape((27, 1, 1, 5))), second_reduce))
|
||||
sink = UOp(Ops.SINK, src=(store,))
|
||||
opts = [[Opt(OptOps.UNROLL, 0, 3), Opt(OptOps.UNROLL, 0, 3)]]
|
||||
wanna_output = (x.numpy()-x.numpy().sum(axis=1, keepdims=True)).sum(axis=1).reshape(27,1,1,5)
|
||||
@@ -418,12 +418,12 @@ class TestLinearizer(unittest.TestCase):
|
||||
Tensor.manual_seed(0)
|
||||
x = Tensor.randn(27, 3, 5, dtype=dtypes.float).realize()
|
||||
g0, g1 = [UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=i) for i in range(2)]
|
||||
first_x = UOp(Ops.LOAD, dtypes.float, (g1, x.lazydata.st.reshape((27, 1, 3, 5)).expand((27, 3, 3, 5)).to_uop()))
|
||||
first_x = UOp(Ops.LOAD, dtypes.float, src=(g1.view(x.lazydata.st.reshape((27, 1, 3, 5)).expand((27, 3, 3, 5))),))
|
||||
first_reduce = UOp(Ops.REDUCE_AXIS, dtypes.float, (first_x,), (Ops.ADD, (2,)))
|
||||
second_x = UOp(Ops.LOAD, dtypes.float, (g1, x.lazydata.st.reshape((27, 3, 1, 5)).to_uop()))
|
||||
second_x = UOp(Ops.LOAD, dtypes.float, src=(g1.view(x.lazydata.st.reshape((27, 3, 1, 5))),))
|
||||
diff = (second_x+first_reduce*ast_const(dtypes.float, -1, (27, 3, 1, 5)))
|
||||
second_reduce = UOp(Ops.REDUCE_AXIS, dtypes.float, (diff,), (Ops.ADD, (1,)))
|
||||
store = UOp(Ops.STORE, src=(g0, ShapeTracker.from_shape((27, 1, 1, 5)).to_uop(), second_reduce))
|
||||
store = UOp(Ops.STORE, src=(g0.view(ShapeTracker.from_shape((27, 1, 1, 5))), second_reduce))
|
||||
sink = UOp(Ops.SINK, src=(store,))
|
||||
opts = [[Opt(OptOps.UPCAST, 0, 3)]]
|
||||
wanna_output = (x.numpy()-x.numpy().sum(axis=1, keepdims=True)).sum(axis=1).reshape(27,1,1,5)
|
||||
@@ -453,15 +453,15 @@ class TestLinearizer(unittest.TestCase):
|
||||
Tensor.manual_seed(0)
|
||||
x = Tensor.randn(15, 25, 35, dtype=dtypes.float).realize()
|
||||
g0, g1 = [UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=i) for i in range(2)]
|
||||
first_x = UOp(Ops.LOAD, dtypes.float, (g1, x.lazydata.st.reshape((15, 25, 1, 35)).expand((15, 25, 35, 35)).to_uop()))
|
||||
first_x = UOp(Ops.LOAD, dtypes.float, src=(g1.view(x.lazydata.st.reshape((15, 25, 1, 35)).expand((15, 25, 35, 35))),))
|
||||
first_reduce = UOp(Ops.REDUCE_AXIS, dtypes.float, (first_x,), (Ops.ADD, (3,)))
|
||||
neg_mean = first_reduce * ast_const(dtypes.float, -1/35, (15, 25, 35, 1))
|
||||
second_x = UOp(Ops.LOAD, dtypes.float, (g1, x.lazydata.st.reshape((15, 25, 35, 1)).to_uop()))
|
||||
second_x = UOp(Ops.LOAD, dtypes.float, src=(g1.view(x.lazydata.st.reshape((15, 25, 35, 1))),))
|
||||
squares = (second_x+neg_mean)*(second_x+neg_mean)
|
||||
squares_sum = UOp(Ops.REDUCE_AXIS, dtypes.float, (squares,), (Ops.ADD, (2,)))
|
||||
variance = squares_sum * ast_const(dtypes.float, 1/35, (15, 25, 1, 1))
|
||||
std = variance.alu(Ops.SQRT)
|
||||
store = UOp(Ops.STORE, src=(g0, ShapeTracker.from_shape((15, 25, 1, 1)).to_uop(), std))
|
||||
store = UOp(Ops.STORE, src=(g0.view(ShapeTracker.from_shape((15, 25, 1, 1))), std))
|
||||
sink = UOp(Ops.SINK, src=(store,))
|
||||
wanna_output = x.numpy().std(axis=2, ddof=0).reshape((15,25,1,1))
|
||||
helper_linearizer_ast(sink, [x], wanna_output=[wanna_output])
|
||||
@@ -471,15 +471,15 @@ class TestLinearizer(unittest.TestCase):
|
||||
Tensor.manual_seed(0)
|
||||
x = Tensor.randn(15, 25, 35, dtype=dtypes.float).realize()
|
||||
g0, g1 = [UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=i) for i in range(2)]
|
||||
first_x = UOp(Ops.LOAD, dtypes.float, (g1, x.lazydata.st.reshape((15, 1, 25, 35)).expand((15, 25, 25, 35)).to_uop()))
|
||||
first_x = UOp(Ops.LOAD, dtypes.float, src=(g1.view(x.lazydata.st.reshape((15, 1, 25, 35)).expand((15, 25, 25, 35))),))
|
||||
first_reduce = UOp(Ops.REDUCE_AXIS, dtypes.float, (first_x,), (Ops.ADD, (2,)))
|
||||
neg_mean = first_reduce * ast_const(dtypes.float, -0.04, (15, 25, 1, 35))
|
||||
second_x = UOp(Ops.LOAD, dtypes.float, (g1, x.lazydata.st.reshape((15, 25, 1, 35)).to_uop()))
|
||||
second_x = UOp(Ops.LOAD, dtypes.float, src=(g1.view(x.lazydata.st.reshape((15, 25, 1, 35))),))
|
||||
squares = (second_x+neg_mean)*(second_x+neg_mean)
|
||||
squares_sum = UOp(Ops.REDUCE_AXIS, dtypes.float, (squares,), (Ops.ADD, (1,)))
|
||||
variance = squares_sum * ast_const(dtypes.float, 0.04, (15, 1, 1, 35))
|
||||
std = variance.alu(Ops.SQRT)
|
||||
store = UOp(Ops.STORE, src=(g0, ShapeTracker.from_shape((15, 1, 1, 35)).to_uop(), std))
|
||||
store = UOp(Ops.STORE, src=(g0.view(ShapeTracker.from_shape((15, 1, 1, 35))), std))
|
||||
sink = UOp(Ops.SINK, src=(store,))
|
||||
wanna_output = x.numpy().std(axis=1, ddof=0).reshape((15,1,1,35))
|
||||
helper_linearizer_ast(sink, [x], wanna_output=[wanna_output])
|
||||
@@ -514,16 +514,16 @@ class TestLinearizer(unittest.TestCase):
|
||||
x = Tensor.randn(3, 27, 32, dtype=dtypes.float).realize()
|
||||
g0, g1 = [UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=i) for i in range(2)]
|
||||
# push reduce (3, 27, 32) -> (3, 27, 1) -> (3, 27, 32) expand to LOAD
|
||||
first_x = UOp(Ops.LOAD, dtypes.float, (g1, x.lazydata.st.reshape((3, 27, 1, 32)).expand((3, 27, 32, 32)).to_uop()))
|
||||
first_x = UOp(Ops.LOAD, dtypes.float, src=(g1.view(x.lazydata.st.reshape((3, 27, 1, 32)).expand((3, 27, 32, 32))),))
|
||||
first_reduce = UOp(Ops.REDUCE_AXIS, dtypes.float, (first_x,), (Ops.ADD, (3,)))
|
||||
neg_mean = first_reduce * ast_const(dtypes.float, -0.03125, (3, 27, 32, 1))
|
||||
# store = UOp(Ops.STORE, src=(g0, ShapeTracker.from_shape((3, 27, 32, 1)).to_uop(), mean))
|
||||
# verify_lazyop(store)
|
||||
second_x = UOp(Ops.LOAD, dtypes.float, (g1, x.lazydata.st.reshape((3, 27, 32, 1)).to_uop()))
|
||||
second_x = UOp(Ops.LOAD, dtypes.float, src=(g1.view(x.lazydata.st.reshape((3, 27, 32, 1))),))
|
||||
squares = (second_x+neg_mean)*(second_x+neg_mean)
|
||||
squares_sum = UOp(Ops.REDUCE_AXIS, dtypes.float, (squares,), (Ops.ADD, (2,)))
|
||||
variance = squares_sum * ast_const(dtypes.float, 0.03125, (3, 27, 1, 1))
|
||||
store = UOp(Ops.STORE, src=(g0, ShapeTracker.from_shape((3, 27, 1, 1)).to_uop(), variance))
|
||||
store = UOp(Ops.STORE, src=(g0.view(ShapeTracker.from_shape((3, 27, 1, 1))), variance))
|
||||
sink = UOp(Ops.SINK, src=(store,))
|
||||
wanna_output = x.numpy().var(axis=2, ddof=0).reshape((3,27,1,1))
|
||||
helper_linearizer_ast(sink, [x], wanna_output=[wanna_output])
|
||||
@@ -535,15 +535,15 @@ class TestLinearizer(unittest.TestCase):
|
||||
def test_softmax_multireduce(self):
|
||||
x = Tensor.rand(4, 32).realize()
|
||||
g0, g1 = [UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=i) for i in range(2)]
|
||||
first_x = UOp(Ops.LOAD, dtypes.float, (g1, x.lazydata.st.reshape((4, 1, 32,)).expand((4, 32, 32)).to_uop()))
|
||||
first_x = UOp(Ops.LOAD, dtypes.float, src=(g1.view(x.lazydata.st.reshape((4, 1, 32,)).expand((4, 32, 32))),))
|
||||
max_x = UOp(Ops.REDUCE_AXIS, dtypes.float, (first_x,), (Ops.MAX, (2,)))
|
||||
second_x = UOp(Ops.LOAD, dtypes.float, (g1, x.lazydata.st.reshape((4, 32, 1,)).to_uop()))
|
||||
second_x = UOp(Ops.LOAD, dtypes.float, src=(g1.view(x.lazydata.st.reshape((4, 32, 1,))),))
|
||||
centered_x = second_x+max_x*ast_const(dtypes.float, -1, (4, 32, 1))
|
||||
exp_x = centered_x.alu(Ops.EXP2)
|
||||
sum_exp_x = UOp(Ops.REDUCE_AXIS, dtypes.float, (exp_x,), (Ops.ADD, (1,)))
|
||||
# y = exp_x * sum_exp_x.alu(Ops.RECIP) # kernels cannot do a return to full shape
|
||||
recip_sum_exp_x = sum_exp_x.alu(Ops.RECIP)
|
||||
store = UOp(Ops.STORE, src=(g0, ShapeTracker.from_shape((4,1,1)).to_uop(), recip_sum_exp_x))
|
||||
store = UOp(Ops.STORE, src=(g0.view(ShapeTracker.from_shape((4,1,1))), recip_sum_exp_x))
|
||||
sink = UOp(Ops.SINK, src=(store,))
|
||||
expected = 1/np.exp2(x.numpy() - x.numpy().max(axis=-1, keepdims=True)).sum(axis=-1, keepdims=True).reshape(4,1,1)
|
||||
helper_linearizer_ast(sink, [x], wanna_output=[expected])
|
||||
@@ -564,62 +564,82 @@ class TestLinearizer(unittest.TestCase):
|
||||
real_argmax = np.argmax(t.numpy(), axis=0, keepdims=False).reshape(1, 20, 1)
|
||||
ast = UOp(Ops.SINK, dtypes.void, arg=None, src=(
|
||||
UOp(Ops.STORE, dtypes.void, arg=None, src=(
|
||||
UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(), arg=0, src=()),
|
||||
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 20, 1), strides=(0, 1, 0), offset=0, mask=None, contiguous=True),)), src=()), # noqa E501
|
||||
UOp(Ops.VIEW, dtypes.int.ptr(20), arg=ShapeTracker(views=(View(shape=(1, 20, 1), strides=(0, 1, 0), offset=0, mask=None, contiguous=True),)), src=( # noqa: E501
|
||||
UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(-1), arg=0, src=()),)),
|
||||
UOp(Ops.ADD, dtypes.int, arg=None, src=(
|
||||
UOp(Ops.ADD, dtypes.int, arg=None, src=(
|
||||
ast_const(dtypes.int, st=ShapeTracker(views=(View(shape=(1, 20, 1), strides=(0, 0, 0), offset=0, mask=None, contiguous=False),)), val=10),
|
||||
UOp(Ops.CONST, dtypes.int, arg=10, src=(
|
||||
x6:=UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 20, 1), strides=(0, 0, 0), offset=0, mask=None, contiguous=False),)), src=()),)), # noqa: E501
|
||||
UOp(Ops.MUL, dtypes.int, arg=None, src=(
|
||||
ast_const(dtypes.int, -1, (1, 20, 1)),
|
||||
x8:=UOp(Ops.CONST, dtypes.int, arg=-1, src=(
|
||||
x6,)),
|
||||
UOp(Ops.REDUCE_AXIS, dtypes.int, arg=(Ops.MAX, (0,)), src=(
|
||||
UOp(Ops.MUL, dtypes.int, arg=None, src=(
|
||||
UOp(Ops.CAST, dtypes.int, arg=None, src=(
|
||||
UOp(Ops.CMPNE, dtypes.bool, arg=None, src=(
|
||||
UOp(Ops.CMPNE, dtypes.bool, arg=None, src=(
|
||||
UOp(Ops.LOAD, dtypes.float, arg=None, src=(
|
||||
UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=1, src=()),
|
||||
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(10, 20, 1), strides=(20, 1, 0), offset=0, mask=None, contiguous=True),)), src=()),)), # noqa E501
|
||||
UOp(Ops.VIEW, dtypes.float.ptr(200), arg=ShapeTracker(views=(View(shape=(10, 20, 1), strides=(20, 1, 0), offset=0, mask=None, contiguous=True),)), src=( # noqa: E501
|
||||
UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(-1), arg=1, src=()),)),)),
|
||||
UOp(Ops.LOAD, dtypes.float, arg=None, src=(
|
||||
UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=2, src=()),
|
||||
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(10, 20, 1), strides=(0, 1, 0), offset=0, mask=None, contiguous=False),)), src=()),)),)), # noqa E501
|
||||
ast_const(dtypes.bool, True, st=ShapeTracker(views=(View(shape=(10, 20, 1), strides=(0, 0, 0), offset=0, mask=None, contiguous=False),))),)),)), # noqa E501
|
||||
UOp(Ops.VIEW, dtypes.float.ptr(20), arg=ShapeTracker(views=(View(shape=(10, 20, 1), strides=(0, 1, 0), offset=0, mask=None, contiguous=False),)), src=( # noqa: E501
|
||||
UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(-1), arg=2, src=()),)),)),)),
|
||||
UOp(Ops.CONST, dtypes.bool, arg=True, src=(
|
||||
x21:=UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(10, 20, 1), strides=(0, 0, 0), offset=0, mask=None, contiguous=False),)), src=()),)),)),)), # noqa: E501
|
||||
UOp(Ops.ADD, dtypes.int, arg=None, src=(
|
||||
UOp(Ops.REDUCE_AXIS, dtypes.int, arg=(Ops.ADD, (2,)), src=(
|
||||
ast_const(dtypes.int, -1, st=ShapeTracker(views=(View(shape=(11, 19), strides=(0, 0), offset=0, mask=((0, 11), (9, 19)), contiguous=False), View(shape=(10, 20, 10), strides=(1, 0, 20), offset=0, mask=None, contiguous=False)))),)), # noqa E501
|
||||
ast_const(dtypes.int, 10, (10, 20, 1)))),)),)),)),)),
|
||||
ast_const(dtypes.int, -1, (1, 20, 1)),)),)),))
|
||||
UOp(Ops.WHERE, dtypes.int, arg=None, src=(
|
||||
UOp(Ops.VALID, dtypes.bool, arg=None, src=(
|
||||
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(11, 19), strides=(0, 0), offset=0, mask=((0, 11), (9, 19)), contiguous=False), View(shape=(10, 20, 10), strides=(1, 0, 20), offset=0, mask=None, contiguous=False))), src=()),)), # noqa: E501
|
||||
UOp(Ops.CONST, dtypes.int, arg=-1, src=(
|
||||
x28:=UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(10, 20, 10), strides=(0, 0, 0), offset=0, mask=None, contiguous=False),)), src=()),)), # noqa: E501
|
||||
UOp(Ops.CONST, dtypes.int, arg=0, src=(
|
||||
x28,)),)),)),
|
||||
UOp(Ops.CONST, dtypes.int, arg=10, src=(
|
||||
x21,)),)),)),)),)),)),
|
||||
x8,)),)),))
|
||||
helper_linearizer_ast(ast, [t, t_max], wanna_output=[real_argmax])
|
||||
|
||||
def test_argmax_multireduce_flat(self):
|
||||
t = Tensor.randn(10, 20).realize()
|
||||
t_max = t.max().realize()
|
||||
real_argmax = np.argmax(t.numpy())
|
||||
ast = UOp(Ops.SINK, dtypes.void, arg=None, src=(
|
||||
ast = UOp(Ops.SINK, dtypes.void, arg=None, src=(
|
||||
UOp(Ops.STORE, dtypes.void, arg=None, src=(
|
||||
UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(), arg=0, src=()),
|
||||
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 1), strides=(0, 0), offset=0, mask=None, contiguous=True),)), src=()), # noqa: E501
|
||||
UOp(Ops.VIEW, dtypes.int.ptr(1), arg=ShapeTracker(views=(View(shape=(1, 1), strides=(0, 0), offset=0, mask=None, contiguous=True),)), src=(
|
||||
UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(-1), arg=0, src=()),)),
|
||||
UOp(Ops.ADD, dtypes.int, arg=None, src=(
|
||||
UOp(Ops.ADD, dtypes.int, arg=None, src=(
|
||||
ast_const(dtypes.int, 200, (1, 1)),
|
||||
UOp(Ops.CONST, dtypes.int, arg=200, src=(
|
||||
x6:=UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 1), strides=(0, 0), offset=0, mask=None, contiguous=True),)), src=()),)), # noqa: E501
|
||||
UOp(Ops.MUL, dtypes.int, arg=None, src=(
|
||||
ast_const(dtypes.int, -1, (1, 1)),
|
||||
x8:=UOp(Ops.CONST, dtypes.int, arg=-1, src=(
|
||||
x6,)),
|
||||
UOp(Ops.REDUCE_AXIS, dtypes.int, arg=(Ops.MAX, (0,)), src=(
|
||||
UOp(Ops.MUL, dtypes.int, arg=None, src=(
|
||||
UOp(Ops.CAST, dtypes.int, arg=None, src=(
|
||||
UOp(Ops.CMPNE, dtypes.bool, arg=None, src=(
|
||||
UOp(Ops.CMPNE, dtypes.bool, arg=None, src=(
|
||||
UOp(Ops.LOAD, dtypes.float, arg=None, src=(
|
||||
UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=1, src=()),
|
||||
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(200, 1), strides=(1, 0), offset=0, mask=None, contiguous=True),)), src=()),)), # noqa: E501
|
||||
UOp(Ops.VIEW, dtypes.float.ptr(200), arg=ShapeTracker(views=(View(shape=(200, 1), strides=(1, 0), offset=0, mask=None, contiguous=True),)), src=( # noqa: E501
|
||||
UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(-1), arg=1, src=()),)),)),
|
||||
UOp(Ops.LOAD, dtypes.float, arg=None, src=(
|
||||
UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=2, src=()),
|
||||
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(200, 1), strides=(0, 0), offset=0, mask=None, contiguous=False),)), src=()),)),)), # noqa: E501
|
||||
ast_const(dtypes.bool, True, (200, 1)),)),)),
|
||||
UOp(Ops.VIEW, dtypes.float.ptr(1), arg=ShapeTracker(views=(View(shape=(200, 1), strides=(0, 0), offset=0, mask=None, contiguous=False),)), src=( # noqa: E501
|
||||
UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(-1), arg=2, src=()),)),)),)),
|
||||
UOp(Ops.CONST, dtypes.bool, arg=True, src=(
|
||||
x21:=UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(200, 1), strides=(0, 0), offset=0, mask=None, contiguous=False),)), src=()),)),)),)), # noqa: E501
|
||||
UOp(Ops.ADD, dtypes.int, arg=None, src=(
|
||||
UOp(Ops.REDUCE_AXIS, dtypes.int, arg=(Ops.ADD, (1,)), src=(
|
||||
ast_const(dtypes.int, -1, st=ShapeTracker(views=(View(shape=(201, 399), strides=(0, 0), offset=0, mask=((0, 201), (199, 399)), contiguous=False), View(shape=(200, 200), strides=(1, 400), offset=0, mask=None, contiguous=False)))),)), # noqa: E501
|
||||
ast_const(dtypes.int, 200, (200, 1)),)),)),)),)),)),
|
||||
ast_const(dtypes.int, -1, (1, 1)),)),)),))
|
||||
UOp(Ops.WHERE, dtypes.int, arg=None, src=(
|
||||
UOp(Ops.VALID, dtypes.bool, arg=None, src=(
|
||||
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(201, 399), strides=(0, 0), offset=0, mask=((0, 201), (199, 399)), contiguous=False), View(shape=(200, 200), strides=(1, 400), offset=0, mask=None, contiguous=False))), src=()),)), # noqa: E501
|
||||
UOp(Ops.CONST, dtypes.int, arg=-1, src=(
|
||||
x28:=UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(200, 200), strides=(0, 0), offset=0, mask=None, contiguous=False),)), src=()),)), # noqa: E501
|
||||
UOp(Ops.CONST, dtypes.int, arg=0, src=(
|
||||
x28,)),)),)),
|
||||
UOp(Ops.CONST, dtypes.int, arg=200, src=(
|
||||
x21,)),)),)),)),)),)),
|
||||
x8,)),)),))
|
||||
helper_linearizer_ast(ast, [t, t_max], wanna_output=[real_argmax])
|
||||
|
||||
@unittest.skipIf(CI and Device.DEFAULT in {"AMD"}, "AMD CI doesn't support multiple sync threads yet")
|
||||
@@ -636,19 +656,19 @@ class TestLinearizer(unittest.TestCase):
|
||||
]
|
||||
|
||||
g0, g1 = [UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=i) for i in range(2)]
|
||||
x_ld0 = UOp(Ops.LOAD, dtypes.float, (g1, x.lazydata.st.reshape((1, N, N)).expand((N,N,N)).to_uop()))
|
||||
x_ld1 = UOp(Ops.LOAD, dtypes.float, (g1, x.lazydata.st.reshape((N, 1, N)).to_uop()))
|
||||
x_ld0 = UOp(Ops.LOAD, dtypes.float, src=(g1.view(x.lazydata.st.reshape((1, N, N)).expand((N,N,N))),))
|
||||
x_ld1 = UOp(Ops.LOAD, dtypes.float, src=(g1.view(x.lazydata.st.reshape((N, 1, N))),))
|
||||
r0 = UOp(Ops.REDUCE_AXIS, dtypes.float, (x_ld0,), (Ops.ADD, (1,)))
|
||||
r1 = UOp(Ops.REDUCE_AXIS, dtypes.float, (x_ld1+r0*ast_const(dtypes.float, -1, (N, 1, N)),),(Ops.ADD, (0,)))
|
||||
store = UOp(Ops.STORE, src=(g0, ShapeTracker.from_shape((1,1,N)).to_uop(), r1))
|
||||
store = UOp(Ops.STORE, src=(g0.view(ShapeTracker.from_shape((1,1,N))), r1))
|
||||
sink = UOp(Ops.SINK, src=(store,))
|
||||
helper_linearizer_ast(sink, [x], wanna_output=[(x.numpy()-x.numpy().sum(axis=0, keepdims=True)).sum(axis=0).reshape(1,1,N)], opts=opts)
|
||||
|
||||
x_ld0 = UOp(Ops.LOAD, dtypes.float, (g1, x.lazydata.st.reshape((N, 1, N)).expand((N,N,N)).to_uop()))
|
||||
x_ld1 = UOp(Ops.LOAD, dtypes.float, (g1, x.lazydata.st.reshape((N, N, 1)).to_uop()))
|
||||
x_ld0 = UOp(Ops.LOAD, dtypes.float, src=(g1.view(x.lazydata.st.reshape((N, 1, N)).expand((N,N,N))),))
|
||||
x_ld1 = UOp(Ops.LOAD, dtypes.float, src=(g1.view(x.lazydata.st.reshape((N, N, 1))),))
|
||||
r0 = UOp(Ops.REDUCE_AXIS, dtypes.float, (x_ld0,), (Ops.ADD, (2,)))
|
||||
r1 = UOp(Ops.REDUCE_AXIS, dtypes.float, (x_ld1+r0*ast_const(dtypes.float, -1, (N, N, 1)),), (Ops.ADD, (1,)))
|
||||
store = UOp(Ops.STORE, src=(g0, ShapeTracker.from_shape((N,1,1)).to_uop(), r1))
|
||||
store = UOp(Ops.STORE, src=(g0.view(ShapeTracker.from_shape((N,1,1))), r1))
|
||||
sink = UOp(Ops.SINK, src=(store,))
|
||||
helper_linearizer_ast(sink, [x], wanna_output=[(x.numpy()-x.numpy().sum(axis=1, keepdims=True)).sum(axis=1).reshape(N,1,1)], opts=opts)
|
||||
|
||||
@@ -663,19 +683,19 @@ class TestLinearizer(unittest.TestCase):
|
||||
]
|
||||
|
||||
g0, g1 = [UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=i) for i in range(2)]
|
||||
x_ld0 = UOp(Ops.LOAD, dtypes.float, (g1, x.lazydata.st.reshape((1, N, N)).expand((N,N,N)).to_uop()))
|
||||
x_ld1 = UOp(Ops.LOAD, dtypes.float, (g1, x.lazydata.st.reshape((N, 1, N)).to_uop()))
|
||||
x_ld0 = UOp(Ops.LOAD, dtypes.float, src=(g1.view(x.lazydata.st.reshape((1, N, N)).expand((N,N,N))),))
|
||||
x_ld1 = UOp(Ops.LOAD, dtypes.float, src=(g1.view(x.lazydata.st.reshape((N, 1, N))),))
|
||||
r0 = UOp(Ops.REDUCE_AXIS, dtypes.float, (x_ld0,), (Ops.MAX, (1,)))
|
||||
r1 = UOp(Ops.REDUCE_AXIS, dtypes.float, (x_ld1+r0*ast_const(dtypes.float, -1, (N, 1, N)),), (Ops.MAX, (0,)))
|
||||
store = UOp(Ops.STORE, src=(g0, ShapeTracker.from_shape((1,1,N)).to_uop(), r1))
|
||||
store = UOp(Ops.STORE, src=(g0.view(ShapeTracker.from_shape((1,1,N))), r1))
|
||||
sink = UOp(Ops.SINK, src=(store,))
|
||||
helper_linearizer_ast(sink, [x], wanna_output=[(x.numpy()-x.numpy().max(axis=0, keepdims=True)).max(axis=0).reshape(1,1,N)], opts=opts)
|
||||
|
||||
x_ld0 = UOp(Ops.LOAD, dtypes.float, (g1, x.lazydata.st.reshape((N, 1, N)).expand((N,N,N)).to_uop()))
|
||||
x_ld1 = UOp(Ops.LOAD, dtypes.float, (g1, x.lazydata.st.reshape((N, N, 1)).to_uop()))
|
||||
x_ld0 = UOp(Ops.LOAD, dtypes.float, src=(g1.view(x.lazydata.st.reshape((N, 1, N)).expand((N,N,N))),))
|
||||
x_ld1 = UOp(Ops.LOAD, dtypes.float, src=(g1.view(x.lazydata.st.reshape((N, N, 1))),))
|
||||
r0 = UOp(Ops.REDUCE_AXIS, dtypes.float, (x_ld0,), (Ops.MAX, (2,)))
|
||||
r1 = UOp(Ops.REDUCE_AXIS, dtypes.float, (x_ld1+r0*ast_const(dtypes.float, -1, (N, N, 1)),), (Ops.MAX, (1,)))
|
||||
store = UOp(Ops.STORE, src=(g0, ShapeTracker.from_shape((N,1,1)).to_uop(), r1))
|
||||
store = UOp(Ops.STORE, src=(g0.view(ShapeTracker.from_shape((N,1,1))), r1))
|
||||
sink = UOp(Ops.SINK, src=(store,))
|
||||
helper_linearizer_ast(sink, [x], wanna_output=[(x.numpy()-x.numpy().max(axis=1, keepdims=True)).max(axis=1).reshape(N,1,1)], opts=opts)
|
||||
|
||||
@@ -690,36 +710,34 @@ class TestLinearizer(unittest.TestCase):
|
||||
b = Tensor.rand(1, 1).realize()
|
||||
opts = [[Opt(OptOps.PADTO, 0, 32)],[Opt(OptOps.PADTO, 0, 32), Opt(OptOps.UPCAST, 0, 8),],]
|
||||
|
||||
# TODO: these large ASTs are suboptimal but we need this until the scheduler can fuse these
|
||||
wanna_output = np.where(0.5*17 < (x.numpy()+np.where(0.75*17 < x.numpy().sum(axis=1,keepdims=True), a.numpy(), b.numpy())).sum(axis=1),0.0,1.0).reshape((N,1,1)) # noqa: E501
|
||||
ld0 = x.lazydata.st.reshape((N, 1, N)).expand((N,N,N))
|
||||
ld1 = x.lazydata.st.reshape((N, N, 1))
|
||||
ast = UOp(Ops.SINK, src=(
|
||||
UOp(Ops.STORE, src=(
|
||||
UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=0),
|
||||
UOp(Ops.VIEW, arg=ShapeTracker(views=(View(shape=(N, 1, 1), strides=(1, 0, 0), offset=0, mask=None, contiguous=True),))),
|
||||
UOp(Ops.VIEW, dtypes.float.ptr(), arg=ShapeTracker(views=(View(shape=(N, 1, 1), strides=(1, 0, 0), offset=0, mask=None, contiguous=True),)), src=( # noqa: E501
|
||||
UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=0),)),
|
||||
UOp(Ops.WHERE, dtypes.float, arg=None, src=(
|
||||
UOp(Ops.CMPLT, dtypes.bool, arg=None, src=(
|
||||
ast_const(dtypes.float, 0.5*N, (N, 1, 1)),
|
||||
UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(Ops.ADD, (1,)), src=(
|
||||
UOp(Ops.ADD, dtypes.float, arg=None, src=(
|
||||
UOp(Ops.LOAD, dtypes.float, src=(
|
||||
UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=1),
|
||||
ld1.to_uop(),)),
|
||||
UOp(Ops.VIEW, dtypes.float.ptr(), arg=ld1, src=(
|
||||
UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=1),)),)),
|
||||
UOp(Ops.WHERE, dtypes.float, arg=None, src=(
|
||||
UOp(Ops.CMPLT, dtypes.bool, arg=None, src=(
|
||||
ast_const(dtypes.float, 0.75*N, (N, N, 1)),
|
||||
UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(Ops.ADD, (2,)), src=(
|
||||
UOp(Ops.LOAD, dtypes.float, src=(
|
||||
UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=1),
|
||||
ld0.to_uop(),)),)),)),
|
||||
UOp(Ops.VIEW, dtypes.float.ptr(), arg=ld0, src=(
|
||||
UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=1),)),)),)),)),
|
||||
UOp(Ops.LOAD, dtypes.float, src=(
|
||||
UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=2),
|
||||
UOp(Ops.VIEW, arg=ShapeTracker(views=(View(shape=(N, N, 1), strides=(0, 0, 0), offset=0, mask=None, contiguous=False),))),)), # noqa: E501
|
||||
UOp(Ops.VIEW, dtypes.float.ptr(), arg=ShapeTracker(views=(View(shape=(N, N, 1), strides=(0, 0, 0), offset=0, mask=None, contiguous=False),)), src=( # noqa: E501
|
||||
UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=2),)),)),
|
||||
UOp(Ops.LOAD, dtypes.float, src=(
|
||||
UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=3),
|
||||
UOp(Ops.VIEW, arg=ShapeTracker(views=(View(shape=(N, N, 1), strides=(0, 0, 0), offset=0, mask=None, contiguous=False),))),)),)),)),)),)), # noqa: E501
|
||||
|
||||
UOp(Ops.VIEW, dtypes.float.ptr(), arg=ShapeTracker(views=(View(shape=(N, N, 1), strides=(0, 0, 0), offset=0, mask=None, contiguous=False),)), src=( # noqa: E501
|
||||
UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=3),)),)),)),)),)),)),
|
||||
ast_const(dtypes.float, 0.0, (N, 1, 1)),
|
||||
ast_const(dtypes.float, 1.0, (N, 1, 1)),)),)),))
|
||||
helper_linearizer_ast(ast, [x,a,b], opts=opts, wanna_output=[wanna_output])
|
||||
@@ -729,34 +747,32 @@ class TestLinearizer(unittest.TestCase):
|
||||
wanna_output = np.where(0.5*17 < (x.numpy()+np.where(0.75*17 < x.numpy().sum(axis=0,keepdims=True), a.numpy(), b.numpy())).sum(axis=0),0.0,1.0).reshape(1,1,N) # noqa: E501
|
||||
ast = UOp(Ops.SINK, src=(
|
||||
UOp(Ops.STORE, src=(
|
||||
UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=0, src=()),
|
||||
UOp(Ops.VIEW, arg=ShapeTracker(views=(View(shape=(1, 1, N), strides=(0, 0, 1), offset=0, mask=None, contiguous=True),)), src=()),
|
||||
UOp(Ops.VIEW, dtypes.float.ptr(), arg=ShapeTracker(views=(View(shape=(1, 1, N), strides=(0, 0, 1), offset=0, mask=None, contiguous=True),)), src=( # noqa: E501
|
||||
UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=0, src=()),)),
|
||||
UOp(Ops.WHERE, dtypes.float, arg=None, src=(
|
||||
UOp(Ops.CMPLT, dtypes.bool, arg=None, src=(
|
||||
ast_const(dtypes.float, 0.5*N, (1, 1, N)),
|
||||
UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(Ops.ADD, (0,)), src=(
|
||||
UOp(Ops.ADD, dtypes.float, arg=None, src=(
|
||||
UOp(Ops.LOAD, dtypes.float, src=(
|
||||
UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=1, src=()),
|
||||
ld1.to_uop(),)),
|
||||
UOp(Ops.VIEW, dtypes.float.ptr(), arg=ld1, src=(
|
||||
UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=1, src=()),)),)),
|
||||
UOp(Ops.WHERE, dtypes.float, arg=None, src=(
|
||||
UOp(Ops.CMPLT, dtypes.bool, arg=None, src=(
|
||||
ast_const(dtypes.float, 0.75*N, (N, 1, N)),
|
||||
UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(Ops.ADD, (1,)), src=(
|
||||
UOp(Ops.LOAD, dtypes.float, src=(
|
||||
UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=1, src=()),
|
||||
ld0.to_uop(),)),)),)),
|
||||
UOp(Ops.VIEW, dtypes.float.ptr(), arg=ld0, src=(
|
||||
UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=1, src=()),)),)),)),)),
|
||||
UOp(Ops.LOAD, dtypes.float, src=(
|
||||
UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=2, src=()),
|
||||
UOp(Ops.VIEW, arg=ShapeTracker(views=(View(shape=(N, 1, N), strides=(0, 0, 0), offset=0, mask=None, contiguous=False),)), src=()),)), # noqa: E501
|
||||
UOp(Ops.VIEW, dtypes.float.ptr(), arg=ShapeTracker(views=(View(shape=(N, 1, N), strides=(0, 0, 0), offset=0, mask=None, contiguous=False),)), src=( # noqa: E501
|
||||
UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=2, src=()),)),)),
|
||||
UOp(Ops.LOAD, dtypes.float, src=(
|
||||
UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=3, src=()),
|
||||
UOp(Ops.VIEW, arg=ShapeTracker(views=(View(shape=(N, 1, N), strides=(0, 0, 0), offset=0, mask=None, contiguous=False),)), src=()),)),)),)),)),)), # noqa: E501
|
||||
|
||||
UOp(Ops.VIEW, dtypes.float.ptr(), arg=ShapeTracker(views=(View(shape=(N, 1, N), strides=(0, 0, 0), offset=0, mask=None, contiguous=False),)), src=( # noqa: E501
|
||||
UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=3, src=()),)),)),)),)),)),)),
|
||||
ast_const(dtypes.float, 0.0, (1, 1, N)),
|
||||
ast_const(dtypes.float, 1.0, (1, 1, N)),)),)),))
|
||||
helper_linearizer_ast(ast, [x,a,b], opts=opts, wanna_output=[wanna_output])
|
||||
|
||||
# pad reduce axis
|
||||
helper_linearizer_ast(ast, [x,a,b], opts=[[Opt(OptOps.PADTO, 1, 32)],], wanna_output=[wanna_output])
|
||||
|
||||
@@ -765,29 +781,29 @@ class TestLinearizer(unittest.TestCase):
|
||||
wanna_output = np.where(0.5*17 < (x.numpy()+np.where(0.75*17 < x.numpy().sum(keepdims=True), a.numpy(), b.numpy())).sum(keepdims=True),0.0,1.0).reshape((1,1,1,1))# noqa: E501
|
||||
ast = UOp(Ops.SINK, src=(
|
||||
UOp(Ops.STORE, src=(
|
||||
UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=0, src=()),
|
||||
UOp(Ops.VIEW, arg=ShapeTracker(views=(View(shape=(1, 1, 1, 1), strides=(0, 0, 0, 0), offset=0, mask=None, contiguous=True),))),
|
||||
UOp(Ops.VIEW, dtypes.float.ptr(), arg=ShapeTracker(views=(View(shape=(1, 1, 1, 1), strides=(0, 0, 0, 0), offset=0, mask=None, contiguous=True),)), src=( # noqa: E501
|
||||
UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=0, src=()),)),
|
||||
UOp(Ops.WHERE, dtypes.float, arg=None, src=(
|
||||
UOp(Ops.CMPLT, dtypes.bool, arg=None, src=(
|
||||
ast_const(dtypes.float, 0.5*N, (1, 1, 1, 1)),
|
||||
UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(Ops.ADD, (0, 1)), src=(
|
||||
UOp(Ops.ADD, dtypes.float, arg=None, src=(
|
||||
UOp(Ops.LOAD, dtypes.float, src=(
|
||||
UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=1),
|
||||
UOp(Ops.VIEW, arg=ShapeTracker(views=(View(shape=(N, N, 1, 1), strides=(N, 1, 0, 0), offset=0, mask=None, contiguous=True),))),)), # noqa: E501
|
||||
UOp(Ops.VIEW, dtypes.float.ptr(), arg=ShapeTracker(views=(View(shape=(N, N, 1, 1), strides=(N, 1, 0, 0), offset=0, mask=None, contiguous=True),)), src=( # noqa: E501
|
||||
UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=1),)),)),
|
||||
UOp(Ops.WHERE, dtypes.float, arg=None, src=(
|
||||
UOp(Ops.CMPLT, dtypes.bool, arg=None, src=(
|
||||
ast_const(dtypes.float, 0.75*N, (N, N, 1, 1)),
|
||||
UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(Ops.ADD, (2, 3)), src=(
|
||||
UOp(Ops.LOAD, dtypes.float, src=(
|
||||
UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=1),
|
||||
UOp(Ops.VIEW, arg=ShapeTracker(views=(View(shape=(N, N, N, N), strides=(0, 0, N, 1), offset=0, mask=None, contiguous=False),))),)),)),)), # noqa: E501
|
||||
UOp(Ops.VIEW, dtypes.float.ptr(), arg=ShapeTracker(views=(View(shape=(N, N, N, N), strides=(0, 0, N, 1), offset=0, mask=None, contiguous=False),)), src=( # noqa: E501
|
||||
UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=1),)),)),)),)),
|
||||
UOp(Ops.LOAD, dtypes.float, src=(
|
||||
UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=2),
|
||||
UOp(Ops.VIEW, arg=ShapeTracker(views=(View(shape=(N, N, 1, 1), strides=(0, 0, 0, 0), offset=0, mask=None, contiguous=False),))),)), # noqa: E501
|
||||
UOp(Ops.VIEW, dtypes.float.ptr(), arg=ShapeTracker(views=(View(shape=(N, N, 1, 1), strides=(0, 0, 0, 0), offset=0, mask=None, contiguous=False),)), src=( # noqa: E501
|
||||
UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=2),)),)),
|
||||
UOp(Ops.LOAD, dtypes.float, src=(
|
||||
UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=3),
|
||||
UOp(Ops.VIEW, arg=ShapeTracker(views=(View(shape=(N, N, 1, 1), strides=(0, 0, 0, 0), offset=0, mask=None, contiguous=False),))),)),)),)),)),)), # noqa: E501
|
||||
UOp(Ops.VIEW, dtypes.float.ptr(), arg=ShapeTracker(views=(View(shape=(N, N, 1, 1), strides=(0, 0, 0, 0), offset=0, mask=None, contiguous=False),)), src=( # noqa: E501
|
||||
UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=3),)),)),)),)),)),)),
|
||||
ast_const(dtypes.float, 0.0, (1, 1, 1, 1)),
|
||||
ast_const(dtypes.float, 1.0, (1, 1, 1, 1)),)),)),))
|
||||
helper_linearizer_ast(ast, [x,a,b], opts=[[Opt(OptOps.PADTO, 0, 32)],], wanna_output=[wanna_output])
|
||||
@@ -796,9 +812,9 @@ class TestLinearizer(unittest.TestCase):
|
||||
@unittest.skipUnless(Device[Device.DEFAULT].renderer.has_shared, "test requires shared")
|
||||
def test_end_local(self):
|
||||
g0, g1 = [UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(), arg=i) for i in range(2)]
|
||||
load = UOp(Ops.LOAD, dtypes.int, (g1, ShapeTracker.from_shape((32,)).to_uop()))
|
||||
load = UOp(Ops.LOAD, dtypes.int, (g1.view(ShapeTracker.from_shape((32,))),))
|
||||
reduce = UOp(Ops.REDUCE_AXIS, dtypes.int, (load,), (Ops.ADD, (0,)))
|
||||
store = UOp(Ops.STORE, src=(g0, ShapeTracker.from_shape((1,)).to_uop(), reduce))
|
||||
store = UOp(Ops.STORE, src=(g0.view(ShapeTracker.from_shape((1,))), reduce))
|
||||
sink = UOp(Ops.SINK, src=(store,))
|
||||
load_t = Tensor.full(load.st_arg.shape, 1).contiguous().realize()
|
||||
k = helper_linearizer_ast(sink, [load_t], wanna_output=[load_t.numpy().sum()])[1]
|
||||
@@ -903,11 +919,11 @@ class TestLinearizer(unittest.TestCase):
|
||||
g0, g1 = [UOp(Ops.DEFINE_GLOBAL, DT.ptr(), arg=i) for i in range(2)]
|
||||
|
||||
# data1[0] + VAL
|
||||
a = UOp(Ops.LOAD, DT, (g1, ST)) + VAL
|
||||
a = UOp(Ops.LOAD, DT, (g1.view(ST.arg),)) + VAL
|
||||
# (literal const 1) + VAL
|
||||
b = ast_const(DT, 1, ST.arg.shape) + VAL
|
||||
|
||||
store = UOp(Ops.STORE, src=(g0, ST, (a+b)))
|
||||
store = UOp(Ops.STORE, src=(g0.view(ST.arg), (a+b)))
|
||||
sink = UOp(Ops.SINK, src=(store,))
|
||||
lin = Kernel(sink)
|
||||
lin.linearize()
|
||||
@@ -1387,13 +1403,13 @@ class TestLinearizer(unittest.TestCase):
|
||||
@unittest.skipUnless(Device[Device.DEFAULT].renderer.supports_float4, "test requires float4")
|
||||
def test_skip_unmatching_upcasts(self):
|
||||
Tensor.manual_seed(0)
|
||||
ast = UOp(Ops.SINK, src=(
|
||||
UOp(Ops.STORE, src=(
|
||||
UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=0),
|
||||
UOp(Ops.VIEW, arg=ShapeTracker(views=(View(shape=(240, 40, 1, 1), strides=(40, 1, 0, 0), offset=0, mask=None, contiguous=True),))),
|
||||
UOp(Ops.LOAD, dtypes.float, src=(
|
||||
UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=1),
|
||||
UOp(Ops.VIEW, arg=ShapeTracker(views=(View(shape=(240, 40, 1, 1), strides=(1, 240, 0, 0), offset=0, mask=None, contiguous=False),))),)),)),)) # noqa: E501
|
||||
ast = UOp(Ops.SINK, dtypes.void, arg=None, src=(
|
||||
UOp(Ops.STORE, dtypes.void, arg=None, src=(
|
||||
UOp(Ops.VIEW, dtypes.float.ptr(9600), arg=ShapeTracker(views=(View(shape=(240, 40, 1, 1), strides=(40, 1, 0, 0), offset=0, mask=None, contiguous=True),)), src=( # noqa: E501
|
||||
UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(9600), arg=0, src=()),)),
|
||||
UOp(Ops.LOAD, dtypes.float, arg=None, src=(
|
||||
UOp(Ops.VIEW, dtypes.float.ptr(9600), arg=ShapeTracker(views=(View(shape=(240, 40, 1, 1), strides=(1, 240, 0, 0), offset=0, mask=None, contiguous=False),)), src=( # noqa: E501
|
||||
UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(9600), arg=1, src=()),)),)),)),))
|
||||
opt = [
|
||||
Opt(op=OptOps.UPCAST, axis=1, arg=4), Opt(op=OptOps.LOCAL, axis=0, arg=16),
|
||||
Opt(op=OptOps.LOCAL, axis=1, arg=2), Opt(op=OptOps.UPCAST, axis=3, arg=2)
|
||||
@@ -1406,13 +1422,13 @@ class TestLinearizer(unittest.TestCase):
|
||||
@unittest.skipUnless(Device[Device.DEFAULT].renderer.supports_float4, "test requires float4")
|
||||
def test_skip_unmatching_upcasts_with_gep(self):
|
||||
Tensor.manual_seed(0)
|
||||
ast = UOp(Ops.SINK, src=(
|
||||
UOp(Ops.STORE, src=(
|
||||
UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=0, src=()),
|
||||
UOp(Ops.VIEW, arg=ShapeTracker(views=(View(shape=(8, 32, 1, 1), strides=(32, 1, 0, 0), offset=0, mask=None, contiguous=True),))),
|
||||
UOp(Ops.LOAD, dtypes.float, src=(
|
||||
UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=1),
|
||||
UOp(Ops.VIEW, arg=ShapeTracker(views=(View(shape=(8, 32, 1, 1), strides=(1, 8, 0, 0), offset=0, mask=None, contiguous=False),))),)),)),)) # noqa: E501
|
||||
ast = UOp(Ops.SINK, dtypes.void, arg=None, src=(
|
||||
UOp(Ops.STORE, dtypes.void, arg=None, src=(
|
||||
UOp(Ops.VIEW, dtypes.float.ptr(256), arg=ShapeTracker(views=(View(shape=(8, 32, 1, 1), strides=(32, 1, 0, 0), offset=0, mask=None, contiguous=True),)), src=( # noqa: E501
|
||||
UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(256), arg=0, src=()),)),
|
||||
UOp(Ops.LOAD, dtypes.float, arg=None, src=(
|
||||
UOp(Ops.VIEW, dtypes.float.ptr(256), arg=ShapeTracker(views=(View(shape=(8, 32, 1, 1), strides=(1, 8, 0, 0), offset=0, mask=None, contiguous=False),)), src=( # noqa: E501
|
||||
UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(256), arg=1, src=()),)),)),)),))
|
||||
opt = [Opt(op=OptOps.LOCAL, axis=1, arg=4), Opt(op=OptOps.UPCAST, axis=2, arg=2), Opt(op=OptOps.LOCAL, axis=1, arg=8),
|
||||
Opt(op=OptOps.UPCAST, axis=1, arg=0), Opt(op=OptOps.UPCAST, axis=1, arg=4), Opt(op=OptOps.LOCAL, axis=0, arg=8),
|
||||
Opt(op=OptOps.UPCAST, axis=1, arg=0), Opt(op=OptOps.UPCAST, axis=0, arg=2)]
|
||||
@@ -1619,19 +1635,19 @@ class TestFloat4(unittest.TestCase):
|
||||
|
||||
def test_half4_load_unrolled(self):
|
||||
# from llama 7B shard 4 gpus
|
||||
ast = UOp(Ops.SINK, src=(
|
||||
UOp(Ops.STORE, src=(
|
||||
UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=0),
|
||||
UOp(Ops.VIEW, arg=ShapeTracker(views=(View(shape=(1, 3, 32000, 1), strides=(0, 32000, 1, 0), offset=0, mask=None, contiguous=True),))), # noqa: E501
|
||||
ast = UOp(Ops.SINK, dtypes.void, arg=None, src=(
|
||||
UOp(Ops.STORE, dtypes.void, arg=None, src=(
|
||||
UOp(Ops.VIEW, dtypes.float.ptr(96000), arg=ShapeTracker(views=(View(shape=(1, 3, 32000, 1), strides=(0, 32000, 1, 0), offset=0, mask=None, contiguous=True),)), src=( # noqa: E501
|
||||
UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(96000), arg=0, src=()),)),
|
||||
UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(Ops.ADD, (3,)), src=(
|
||||
UOp(Ops.CAST, dtypes.float, src=(
|
||||
UOp(Ops.CAST, dtypes.float, arg=None, src=(
|
||||
UOp(Ops.MUL, dtypes.half, arg=None, src=(
|
||||
UOp(Ops.LOAD, dtypes.half, src=(
|
||||
UOp(Ops.DEFINE_GLOBAL, dtypes.half.ptr(), arg=1),
|
||||
UOp(Ops.VIEW, arg=ShapeTracker(views=(View(shape=(1, 3, 32000, 1024), strides=(0, 4096, 0, 1), offset=0, mask=None, contiguous=False),))),)), # noqa: E501
|
||||
UOp(Ops.LOAD, dtypes.half, src=(
|
||||
UOp(Ops.DEFINE_GLOBAL, dtypes.half.ptr(), arg=2),
|
||||
UOp(Ops.VIEW, arg=ShapeTracker(views=(View(shape=(1, 3, 32000, 1024), strides=(0, 0, 1024, 1), offset=0, mask=None, contiguous=False),))),)),)),)),)),)),)) # noqa: E501
|
||||
UOp(Ops.LOAD, dtypes.half, arg=None, src=(
|
||||
UOp(Ops.VIEW, dtypes.half.ptr(9216), arg=ShapeTracker(views=(View(shape=(1, 3, 32000, 1024), strides=(0, 4096, 0, 1), offset=0, mask=None, contiguous=False),)), src=( # noqa: E501
|
||||
UOp(Ops.DEFINE_GLOBAL, dtypes.half.ptr(9216), arg=1, src=()),)),)),
|
||||
UOp(Ops.LOAD, dtypes.half, arg=None, src=(
|
||||
UOp(Ops.VIEW, dtypes.half.ptr(32768000), arg=ShapeTracker(views=(View(shape=(1, 3, 32000, 1024), strides=(0, 0, 1024, 1), offset=0, mask=None, contiguous=False),)), src=( # noqa: E501
|
||||
UOp(Ops.DEFINE_GLOBAL, dtypes.half.ptr(32768000), arg=2, src=()),)),)),)),)),)),)),))
|
||||
|
||||
# TODO: fix this, expected might change but should be positive
|
||||
for expected, opts in [
|
||||
@@ -1648,22 +1664,22 @@ class TestFloat4(unittest.TestCase):
|
||||
@unittest.skip("this doesn't happen anymore")
|
||||
def test_float4_acc(self):
|
||||
# from float32 stable diffusion red tinybox
|
||||
ast = UOp(Ops.SINK, src=(
|
||||
UOp(Ops.STORE, src=(
|
||||
UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=0),
|
||||
UOp(Ops.VIEW, arg=ShapeTracker(views=(View(shape=(1, 1, 128, 512, 512, 1, 1, 1), strides=(0, 0, 262144, 512, 1, 0, 0, 0), offset=0, mask=None, contiguous=True),))), # noqa: E501
|
||||
ast = UOp(Ops.SINK, dtypes.void, arg=None, src=(
|
||||
UOp(Ops.STORE, dtypes.void, arg=None, src=(
|
||||
UOp(Ops.VIEW, dtypes.float.ptr(33554432), arg=ShapeTracker(views=(View(shape=(1, 1, 128, 512, 512, 1, 1, 1), strides=(0, 0, 262144, 512, 1, 0, 0, 0), offset=0, mask=None, contiguous=True),)), src=( # noqa: E501
|
||||
UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(33554432), arg=0, src=()),)),
|
||||
UOp(Ops.ADD, dtypes.float, arg=None, src=(
|
||||
UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(Ops.ADD, (5, 6, 7)), src=(
|
||||
UOp(Ops.MUL, dtypes.float, arg=None, src=(
|
||||
UOp(Ops.LOAD, dtypes.float, src=(
|
||||
UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=1),
|
||||
UOp(Ops.VIEW, arg=ShapeTracker(views=(View(shape=(1, 1, 1, 256, 4, 514, 4, 514), strides=(0, 0, 0, 262144, 0, 512, 0, 1), offset=-513, mask=((0, 1), (0, 1), (0, 1), (0, 256), (0, 4), (1, 513), (0, 4), (1, 513)), contiguous=False), View(shape=(1, 1, 128, 512, 512, 256, 3, 3), strides=(0, 0, 0, 2056, 1, 4227136, 1058840, 515), offset=0, mask=None, contiguous=False)))),)), # noqa: E501
|
||||
UOp(Ops.LOAD, dtypes.float, src=(
|
||||
UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=2),
|
||||
UOp(Ops.VIEW, arg=ShapeTracker(views=(View(shape=(1, 1, 128, 512, 512, 256, 3, 3), strides=(0, 0, 2304, 0, 0, 9, 3, 1), offset=0, mask=None, contiguous=False),))),)),)),)), # noqa: E501
|
||||
UOp(Ops.LOAD, dtypes.float, src=(
|
||||
UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=3),
|
||||
UOp(Ops.VIEW, arg=ShapeTracker(views=(View(shape=(1, 1, 128, 512, 512, 1, 1, 1), strides=(0, 0, 1, 0, 0, 0, 0, 0), offset=0, mask=None, contiguous=False),))),)),)),)),)) # noqa: E501
|
||||
UOp(Ops.LOAD, dtypes.float, arg=None, src=(
|
||||
UOp(Ops.VIEW, dtypes.float.ptr(67108864), arg=ShapeTracker(views=(View(shape=(1, 1, 1, 256, 4, 514, 4, 514), strides=(0, 0, 0, 262144, 0, 512, 0, 1), offset=-513, mask=((0, 1), (0, 1), (0, 1), (0, 256), (0, 4), (1, 513), (0, 4), (1, 513)), contiguous=False), View(shape=(1, 1, 128, 512, 512, 256, 3, 3), strides=(0, 0, 0, 2056, 1, 4227136, 1058840, 515), offset=0, mask=None, contiguous=False))), src=( # noqa: E501
|
||||
UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(67108864), arg=1, src=()),)),)),
|
||||
UOp(Ops.LOAD, dtypes.float, arg=None, src=(
|
||||
UOp(Ops.VIEW, dtypes.float.ptr(294912), arg=ShapeTracker(views=(View(shape=(1, 1, 128, 512, 512, 256, 3, 3), strides=(0, 0, 2304, 0, 0, 9, 3, 1), offset=0, mask=None, contiguous=False),)), src=( # noqa: E501
|
||||
UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(294912), arg=2, src=()),)),)),)),)),
|
||||
UOp(Ops.LOAD, dtypes.float, arg=None, src=(
|
||||
UOp(Ops.VIEW, dtypes.float.ptr(128), arg=ShapeTracker(views=(View(shape=(1, 1, 128, 512, 512, 1, 1, 1), strides=(0, 0, 1, 0, 0, 0, 0, 0), offset=0, mask=None, contiguous=False),)), src=( # noqa: E501
|
||||
UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(128), arg=3, src=()),)),)),)),)),))
|
||||
|
||||
for expected, opts in [
|
||||
(1, [Opt(op=OptOps.UPCAST, axis=2, arg=4)]),
|
||||
@@ -1678,16 +1694,16 @@ class TestFloat4(unittest.TestCase):
|
||||
@unittest.skip("this doesn't happen anymore")
|
||||
def test_float2_acc(self):
|
||||
# from resnet
|
||||
ast = UOp(Ops.SINK, src=(
|
||||
UOp(Ops.STORE, src=(
|
||||
UOp(Ops.DEFINE_GLOBAL, dtypes.half.ptr(), arg=0),
|
||||
UOp(Ops.VIEW, arg=ShapeTracker(views=(View(shape=(1, 256, 1, 64, 1, 114, 1, 114), strides=(0, 831744, 0, 12996, 0, 114, 0, 1), offset=0, mask=None, contiguous=True),))), # noqa: E501
|
||||
UOp(Ops.CAST, dtypes.half, src=(
|
||||
ast = UOp(Ops.SINK, dtypes.void, arg=None, src=(
|
||||
UOp(Ops.STORE, dtypes.void, arg=None, src=(
|
||||
UOp(Ops.VIEW, dtypes.half.ptr(212926464), arg=ShapeTracker(views=(View(shape=(1, 256, 1, 64, 1, 114, 1, 114), strides=(0, 831744, 0, 12996, 0, 114, 0, 1), offset=0, mask=None, contiguous=True),)), src=( # noqa: E501
|
||||
UOp(Ops.DEFINE_GLOBAL, dtypes.half.ptr(212926464), arg=0, src=()),)),
|
||||
UOp(Ops.CAST, dtypes.half, arg=None, src=(
|
||||
UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(Ops.ADD, (4, 6)), src=(
|
||||
UOp(Ops.CAST, dtypes.float, src=(
|
||||
UOp(Ops.LOAD, dtypes.half, src=(
|
||||
UOp(Ops.DEFINE_GLOBAL, dtypes.half.ptr(), arg=1),
|
||||
UOp(Ops.VIEW, arg=ShapeTracker(views=(View(shape=(256, 64, 3, 56, 2, 3, 56, 2), strides=(1806336, 28224, 3, 504, 0, 1, 9, 0), offset=0, mask=((0, 256), (0, 64), (0, 3), (0, 56), (0, 1), (0, 3), (0, 56), (0, 1)), contiguous=False), View(shape=(256, 64, 3, 115, 3, 115), strides=(7225344, 112896, 37632, 336, 112, 1), offset=0, mask=((0, 256), (0, 64), (0, 3), (0, 112), (0, 3), (0, 112)), contiguous=False), View(shape=(256, 64, 456, 456), strides=(7617600, 119025, 345, 1), offset=0, mask=((0, 256), (0, 64), (0, 345), (0, 345)), contiguous=False), View(shape=(1, 256, 1, 64, 4, 114, 4, 114), strides=(0, 13307904, 0, 207936, 51984, 456, 114, 1), offset=0, mask=None, contiguous=True)))),)),)),)),)),)),)) # noqa: E501
|
||||
UOp(Ops.CAST, dtypes.float, arg=None, src=(
|
||||
UOp(Ops.LOAD, dtypes.half, arg=None, src=(
|
||||
UOp(Ops.VIEW, dtypes.half.ptr(462422016), arg=ShapeTracker(views=(View(shape=(256, 64, 3, 56, 2, 3, 56, 2), strides=(1806336, 28224, 3, 504, 0, 1, 9, 0), offset=0, mask=((0, 256), (0, 64), (0, 3), (0, 56), (0, 1), (0, 3), (0, 56), (0, 1)), contiguous=False), View(shape=(256, 64, 3, 115, 3, 115), strides=(7225344, 112896, 37632, 336, 112, 1), offset=0, mask=((0, 256), (0, 64), (0, 3), (0, 112), (0, 3), (0, 112)), contiguous=False), View(shape=(256, 64, 456, 456), strides=(7617600, 119025, 345, 1), offset=0, mask=((0, 256), (0, 64), (0, 345), (0, 345)), contiguous=False), View(shape=(1, 256, 1, 64, 4, 114, 4, 114), strides=(0, 13307904, 0, 207936, 51984, 456, 114, 1), offset=0, mask=None, contiguous=True))), src=( # noqa: E501
|
||||
UOp(Ops.DEFINE_GLOBAL, dtypes.half.ptr(462422016), arg=1, src=()),)),)),)),)),)),)),))
|
||||
for expected, opts in [
|
||||
(16, [Opt(op=OptOps.LOCAL, axis=1, arg=16), Opt(op=OptOps.UPCAST, axis=1, arg=0), Opt(op=OptOps.UPCAST, axis=2, arg=2), Opt(op=OptOps.LOCAL, axis=2, arg=3), Opt(op=OptOps.UPCAST, axis=3, arg=4)]), # noqa: E501
|
||||
(4, [Opt(op=OptOps.LOCAL, axis=1, arg=16), Opt(op=OptOps.UPCAST, axis=1, arg=0), Opt(op=OptOps.UPCAST, axis=2, arg=2)]),
|
||||
@@ -1779,7 +1795,7 @@ class TestHandCodedOpts(unittest.TestCase):
|
||||
def helper_linearizer_ast(ast:UOp, inputs:list[Tensor], *args, **kwargs):
|
||||
assert isinstance(ast, UOp), "ast must be UOp"
|
||||
inbufs = [x.lazydata.base.buffer for x in inputs]
|
||||
outbufs = [Buffer(inbufs[-1].device if inbufs else Device.DEFAULT, out.st_arg.size, out.src[2].dtype).allocate() \
|
||||
outbufs = [Buffer(inbufs[-1].device if inbufs else Device.DEFAULT, out.st_arg.size, out.src[1].dtype).allocate() \
|
||||
for out in ast.src]
|
||||
return _helper_linearizer_opt_ast(ast, outbufs+inbufs, *args, **kwargs)
|
||||
|
||||
@@ -1800,7 +1816,7 @@ def reset_bufs(bufs:list[Buffer]):
|
||||
def _helper_linearizer_opt_ast(realized_ast:UOp, real_bufs:list[Buffer], opts=[],
|
||||
apply_tc=False, atol=1e-4, rtol=1e-4, color_sizes=[], wanna_output=[]) -> list[Kernel]:
|
||||
lins: list[Kernel] = []
|
||||
outbufs = [real_bufs[x.src[0].arg] for x in realized_ast.src]
|
||||
outbufs = [real_bufs[x.src[0].base.arg] for x in realized_ast.src]
|
||||
device = real_bufs[0].device
|
||||
|
||||
def get_prg(k:Kernel): return CompiledRunner(replace(k.to_program(), device=device))
|
||||
@@ -1967,23 +1983,23 @@ class TestKernelOpts(unittest.TestCase):
|
||||
|
||||
@unittest.skipUnless(Device[Device.DEFAULT].renderer.tensor_cores, "test requires tensor cores")
|
||||
def test_buf_index_not_found_tensor_core(self):
|
||||
ast = UOp(Ops.SINK, src=(
|
||||
UOp(Ops.STORE, src=(
|
||||
UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=0, src=()),
|
||||
UOp(Ops.VIEW, arg=ShapeTracker(views=(View(shape=(1, 256), strides=(0, 1), offset=0, mask=None, contiguous=True),))),
|
||||
ast = UOp(Ops.SINK, dtypes.void, arg=None, src=(
|
||||
UOp(Ops.STORE, dtypes.void, arg=None, src=(
|
||||
UOp(Ops.VIEW, dtypes.float.ptr(256), arg=ShapeTracker(views=(View(shape=(1, 256), strides=(0, 1), offset=0, mask=None, contiguous=True),)), src=( # noqa: E501
|
||||
UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(256), arg=0, src=()),)),
|
||||
UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(Ops.ADD, (0,)), src=(
|
||||
UOp(Ops.MUL, dtypes.float, arg=None, src=(
|
||||
UOp(Ops.CAST, dtypes.float, src=(
|
||||
UOp(Ops.CAST, dtypes.float, arg=None, src=(
|
||||
UOp(Ops.CMPNE, dtypes.bool, arg=None, src=(
|
||||
UOp(Ops.LOAD, dtypes.int, src=(
|
||||
UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(), arg=1),
|
||||
UOp(Ops.VIEW, arg=ShapeTracker(views=(View(shape=(1243, 256), strides=(0, 1), offset=0, mask=None, contiguous=False),))),)), # noqa: E501
|
||||
UOp(Ops.LOAD, dtypes.int, src=(
|
||||
UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(), arg=2),
|
||||
UOp(Ops.VIEW, arg=ShapeTracker(views=(View(shape=(1243, 256), strides=(1, 0), offset=0, mask=None, contiguous=False),))),)),)),)), # noqa: E501
|
||||
UOp(Ops.LOAD, dtypes.float, src=(
|
||||
UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=3),
|
||||
UOp(Ops.VIEW, arg=ShapeTracker(views=(View(shape=(1243, 256), strides=(1, 0), offset=0, mask=None, contiguous=False),))),)),)),)),)),)) # noqa: E501
|
||||
UOp(Ops.LOAD, dtypes.int, arg=None, src=(
|
||||
UOp(Ops.VIEW, dtypes.int.ptr(256), arg=ShapeTracker(views=(View(shape=(1243, 256), strides=(0, 1), offset=0, mask=None, contiguous=False),)), src=( # noqa: E501
|
||||
UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(256), arg=1, src=()),)),)),
|
||||
UOp(Ops.LOAD, dtypes.int, arg=None, src=(
|
||||
UOp(Ops.VIEW, dtypes.int.ptr(1243), arg=ShapeTracker(views=(View(shape=(1243, 256), strides=(1, 0), offset=0, mask=None, contiguous=False),)), src=( # noqa: E501
|
||||
UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(1243), arg=2, src=()),)),)),)),)),
|
||||
UOp(Ops.LOAD, dtypes.float, arg=None, src=(
|
||||
UOp(Ops.VIEW, dtypes.float.ptr(1243), arg=ShapeTracker(views=(View(shape=(1243, 256), strides=(1, 0), offset=0, mask=None, contiguous=False),)), src=( # noqa: E501
|
||||
UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(1243), arg=3, src=()),)),)),)),)),)),))
|
||||
k = Kernel(ast, opts=Device[Device.DEFAULT].renderer)
|
||||
with self.assertRaises(KernelOptError):
|
||||
k.apply_opt(Opt(OptOps.TC, 0, (-1, 1, 1)))
|
||||
@@ -2161,9 +2177,9 @@ class TestKernelOpts(unittest.TestCase):
|
||||
def test_padto_group(self):
|
||||
Tensor.manual_seed(0)
|
||||
g0, g1, g2 = [UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=i) for i in range(3)]
|
||||
ld0 = UOp(Ops.LOAD, dtypes.float, (g1, ShapeTracker(views=(View(shape=(2, 1, 4, 1, 3, 4, 2, 6, 1, 3), strides=(0, 0, 0, 0, 0, 18, 0, 3, 0, 1), offset=0, mask=None, contiguous=False),)).to_uop())) # noqa: E501
|
||||
ld1 = UOp(Ops.LOAD, dtypes.float, (g2, ShapeTracker(views=(View(shape=(2, 1, 4, 1, 3, 4, 2, 6, 1, 3), strides=(0, 0, 0, 0, 0, 0, 0, 0, 0, 0), offset=0, mask=None, contiguous=False),)).to_uop())) # noqa: E501
|
||||
store = UOp(Ops.STORE, src=(g0, ShapeTracker(views=(View(shape=(1, 1, 1, 1, 1, 4, 1, 6, 1, 3), strides=(0, 0, 0, 0, 0, 18, 0, 3, 0, 1), offset=0, mask=None, contiguous=True),)).to_uop(), UOp(Ops.REDUCE_AXIS, dtypes.float, (ld0*ld1,), (Ops.ADD, (0, 2, 4, 6)),))) # noqa: E501
|
||||
ld0 = UOp(Ops.LOAD, dtypes.float, src=(g1.view(ShapeTracker(views=(View(shape=(2, 1, 4, 1, 3, 4, 2, 6, 1, 3), strides=(0, 0, 0, 0, 0, 18, 0, 3, 0, 1), offset=0, mask=None, contiguous=False),))),)) # noqa: E501
|
||||
ld1 = UOp(Ops.LOAD, dtypes.float, src=(g2.view(ShapeTracker(views=(View(shape=(2, 1, 4, 1, 3, 4, 2, 6, 1, 3), strides=(0, 0, 0, 0, 0, 0, 0, 0, 0, 0), offset=0, mask=None, contiguous=False),))),)) # noqa: E501
|
||||
store = UOp(Ops.STORE, src=(g0.view(ShapeTracker(views=(View(shape=(1, 1, 1, 1, 1, 4, 1, 6, 1, 3), strides=(0, 0, 0, 0, 0, 18, 0, 3, 0, 1), offset=0, mask=None, contiguous=True),))), UOp(Ops.REDUCE_AXIS, dtypes.float, (ld0*ld1,), (Ops.ADD, (0, 2, 4, 6)),))) # noqa: E501
|
||||
sink = UOp(Ops.SINK, src=(store,))
|
||||
data1 = Tensor.randn(2, 1, 4, 1, 3, 4, 2, 6, 1, 3).realize()
|
||||
data2 = Tensor.randn(2, 1, 4, 1, 3, 4, 2, 6, 1, 3).realize()
|
||||
|
||||
@@ -16,8 +16,8 @@ class TestLinearizerDumb(unittest.TestCase):
|
||||
def test_unmerged_ifs(self):
|
||||
ast = UOp(Ops.SINK, dtypes.void, arg=None, src=(
|
||||
UOp(Ops.STORE, dtypes.void, arg=None, src=(
|
||||
UOp(Ops.DEFINE_GLOBAL, dtypes.half.ptr(1605632), arg=0, src=()),
|
||||
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(64, 1, 512, 7, 7, 1, 1, 1), strides=(25088, 0, 49, 7, 1, 0, 0, 0), offset=0, mask=None, contiguous=True),)), src=()),
|
||||
UOp(Ops.VIEW, dtypes.half.ptr(1605632), arg=ShapeTracker(views=(View(shape=(64, 1, 512, 7, 7, 1, 1, 1), strides=(25088, 0, 49, 7, 1, 0, 0, 0), offset=0, mask=None, contiguous=True),)), src=(
|
||||
UOp(Ops.DEFINE_GLOBAL, dtypes.half.ptr(1605632), arg=0, src=()),)),
|
||||
UOp(Ops.MAX, dtypes.half, arg=None, src=(
|
||||
UOp(Ops.MUL, dtypes.half, arg=None, src=(
|
||||
UOp(Ops.CAST, dtypes.half, arg=None, src=(
|
||||
@@ -25,11 +25,11 @@ class TestLinearizerDumb(unittest.TestCase):
|
||||
UOp(Ops.CAST, dtypes.float, arg=None, src=(
|
||||
UOp(Ops.MUL, dtypes.half, arg=None, src=(
|
||||
UOp(Ops.LOAD, dtypes.half, arg=None, src=(
|
||||
UOp(Ops.DEFINE_GLOBAL, dtypes.half.ptr(1605632), arg=1, src=()),
|
||||
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 64, 1, 512, 4, 9, 4, 9), strides=(0, 25088, 0, 49, 0, 7, 0, 1), offset=-8, mask=((0, 1), (0, 64), (0, 1), (0, 512), (0, 4), (1, 8), (0, 4), (1, 8)), contiguous=False), View(shape=(64, 1, 512, 7, 7, 512, 3, 3), strides=(663552, 0, 0, 36, 1, 1296, 360, 10), offset=0, mask=None, contiguous=False))), src=()),)),
|
||||
UOp(Ops.VIEW, dtypes.half.ptr(1605632), arg=ShapeTracker(views=(View(shape=(1, 64, 1, 512, 4, 9, 4, 9), strides=(0, 25088, 0, 49, 0, 7, 0, 1), offset=-8, mask=((0, 1), (0, 64), (0, 1), (0, 512), (0, 4), (1, 8), (0, 4), (1, 8)), contiguous=False), View(shape=(64, 1, 512, 7, 7, 512, 3, 3), strides=(663552, 0, 0, 36, 1, 1296, 360, 10), offset=0, mask=None, contiguous=False))), src=(
|
||||
UOp(Ops.DEFINE_GLOBAL, dtypes.half.ptr(1605632), arg=1, src=()),)),)),
|
||||
UOp(Ops.LOAD, dtypes.half, arg=None, src=(
|
||||
UOp(Ops.DEFINE_GLOBAL, dtypes.half.ptr(2359296), arg=2, src=()),
|
||||
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(64, 1, 512, 7, 7, 512, 3, 3), strides=(0, 0, 4608, 0, 0, 9, 3, 1), offset=0, mask=None, contiguous=False),)), src=()),)),)),)),)),)),
|
||||
UOp(Ops.VIEW, dtypes.half.ptr(2359296), arg=ShapeTracker(views=(View(shape=(64, 1, 512, 7, 7, 512, 3, 3), strides=(0, 0, 4608, 0, 0, 9, 3, 1), offset=0, mask=None, contiguous=False),)), src=(
|
||||
UOp(Ops.DEFINE_GLOBAL, dtypes.half.ptr(2359296), arg=2, src=()),)),)),)),)),)),)),
|
||||
UOp(Ops.CONST, dtypes.half, arg=0.9999950000374996, src=(
|
||||
x16:=UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(64, 1, 512, 7, 7, 1, 1, 1), strides=(0, 0, 0, 0, 0, 0, 0, 0), offset=0, mask=None, contiguous=False),)), src=()),)),)),
|
||||
UOp(Ops.CONST, dtypes.half, arg=0.0, src=(
|
||||
@@ -48,31 +48,31 @@ class TestLinearizerDumb(unittest.TestCase):
|
||||
def test_max_simplify_and_cancel(self):
|
||||
ast = UOp(Ops.SINK, dtypes.void, arg=None, src=(
|
||||
UOp(Ops.STORE, dtypes.void, arg=None, src=(
|
||||
UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(1000), arg=0, src=()),
|
||||
x2:=UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1000, 1), strides=(1, 0), offset=0, mask=None, contiguous=True),)), src=()),
|
||||
UOp(Ops.VIEW, dtypes.int.ptr(1000), arg=ShapeTracker(views=(View(shape=(1000, 1), strides=(1, 0), offset=0, mask=None, contiguous=True),)), src=(
|
||||
UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(1000), arg=0, src=()),)),
|
||||
UOp(Ops.MUL, dtypes.int, arg=None, src=(
|
||||
UOp(Ops.CAST, dtypes.int, arg=None, src=(
|
||||
UOp(Ops.CMPNE, dtypes.bool, arg=None, src=(
|
||||
UOp(Ops.CMPNE, dtypes.bool, arg=None, src=(
|
||||
UOp(Ops.LOAD, dtypes.float, arg=None, src=(
|
||||
UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(1000), arg=1, src=()),
|
||||
x2,)),
|
||||
UOp(Ops.VIEW, dtypes.float.ptr(1000), arg=ShapeTracker(views=(View(shape=(1000, 1), strides=(1, 0), offset=0, mask=None, contiguous=True),)), src=(
|
||||
UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(1000), arg=1, src=()),)),)),
|
||||
UOp(Ops.LOAD, dtypes.float, arg=None, src=(
|
||||
UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(1), arg=2, src=()),
|
||||
x11:=UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1000, 1), strides=(0, 0), offset=0, mask=None, contiguous=False),)), src=()),)),)),
|
||||
UOp(Ops.VIEW, dtypes.float.ptr(1), arg=ShapeTracker(views=(View(shape=(1000, 1), strides=(0, 0), offset=0, mask=None, contiguous=False),)), src=(
|
||||
UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(1), arg=2, src=()),)),)),)),
|
||||
UOp(Ops.CONST, dtypes.bool, arg=True, src=(
|
||||
x11,)),)),)),
|
||||
x14:=UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1000, 1), strides=(0, 0), offset=0, mask=None, contiguous=False),)), src=()),)),)),)),
|
||||
UOp(Ops.ADD, dtypes.int, arg=None, src=(
|
||||
UOp(Ops.REDUCE_AXIS, dtypes.int, arg=(Ops.ADD, (1,)), src=(
|
||||
UOp(Ops.WHERE, dtypes.int, arg=None, src=(
|
||||
UOp(Ops.VALID, dtypes.bool, arg=None, src=(
|
||||
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1001, 1999), strides=(0, 0), offset=0, mask=((0, 1001), (999, 1999)), contiguous=False), View(shape=(1000, 1000), strides=(1, 2000), offset=0, mask=None, contiguous=False))), src=()),)),
|
||||
UOp(Ops.CONST, dtypes.int, arg=-1, src=(
|
||||
x19:=UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1000, 1000), strides=(0, 0), offset=0, mask=None, contiguous=False),)), src=()),)),
|
||||
x21:=UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1000, 1000), strides=(0, 0), offset=0, mask=None, contiguous=False),)), src=()),)),
|
||||
UOp(Ops.CONST, dtypes.int, arg=0, src=(
|
||||
x19,)),)),)),
|
||||
x21,)),)),)),
|
||||
UOp(Ops.CONST, dtypes.int, arg=1000, src=(
|
||||
x11,)),)),)),)),))
|
||||
x14,)),)),)),)),))
|
||||
opts = [Opt(op=OptOps.UNROLL, axis=0, arg=4), Opt(op=OptOps.LOCAL, axis=0, arg=8)]
|
||||
k = Kernel(ast, opts=Device[Device.DEFAULT].renderer)
|
||||
k.apply_opts(opts)
|
||||
@@ -84,12 +84,12 @@ class TestLinearizerDumb(unittest.TestCase):
|
||||
def test_expander_new_srcs(self):
|
||||
ast = UOp(Ops.SINK, dtypes.void, arg=None, src=(
|
||||
UOp(Ops.STORE, dtypes.void, arg=None, src=(
|
||||
UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(25), arg=0, src=()),
|
||||
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(25, 1), strides=(1, 0), offset=0, mask=None, contiguous=True),)), src=()),
|
||||
UOp(Ops.VIEW, dtypes.float.ptr(25), arg=ShapeTracker(views=(View(shape=(25, 1), strides=(1, 0), offset=0, mask=None, contiguous=True),)), src=(
|
||||
UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(25), arg=0, src=()),)),
|
||||
UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(Ops.ADD, (1,)), src=(
|
||||
UOp(Ops.LOAD, dtypes.float, arg=None, src=(
|
||||
UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(25), arg=1, src=()),
|
||||
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(26, 49), strides=(0, -1), offset=48, mask=((0, 26), (24, 49)), contiguous=False), View(shape=(25, 25), strides=(1, 50), offset=0, mask=None, contiguous=False))), src=()),)),)),)),))
|
||||
UOp(Ops.VIEW, dtypes.float.ptr(25), arg=ShapeTracker(views=(View(shape=(26, 49), strides=(0, -1), offset=48, mask=((0, 26), (24, 49)), contiguous=False), View(shape=(25, 25), strides=(1, 50), offset=0, mask=None, contiguous=False))), src=(
|
||||
UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(25), arg=1, src=()),)),)),)),)),))
|
||||
opts = [Opt(op=OptOps.GROUP, axis=0, arg=0), Opt(op=OptOps.PADTO, axis=0, arg=32), Opt(op=OptOps.LOCAL, axis=0, arg=4), Opt(op=OptOps.UPCAST, axis=0, arg=0)]
|
||||
k = Kernel(ast, opts=Device[Device.DEFAULT].renderer)
|
||||
k.apply_opts(opts)
|
||||
@@ -105,8 +105,8 @@ class TestLinearizerDumb(unittest.TestCase):
|
||||
def test_llama_embedding(self):
|
||||
ast = UOp(Ops.SINK, dtypes.void, arg=None, src=(
|
||||
UOp(Ops.STORE, dtypes.void, arg=None, src=(
|
||||
UOp(Ops.DEFINE_GLOBAL, dtypes.half.ptr(4096), arg=0, src=()),
|
||||
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(4096, 1, 1), strides=(1, 0, 0), offset=0, mask=None, contiguous=True),)), src=()),
|
||||
UOp(Ops.VIEW, dtypes.half.ptr(4096), arg=ShapeTracker(views=(View(shape=(4096, 1, 1), strides=(1, 0, 0), offset=0, mask=None, contiguous=True),)), src=(
|
||||
UOp(Ops.DEFINE_GLOBAL, dtypes.half.ptr(4096), arg=0, src=()),)),
|
||||
UOp(Ops.CAST, dtypes.half, arg=None, src=(
|
||||
UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(Ops.ADD, (1,)), src=(
|
||||
UOp(Ops.CAST, dtypes.float, arg=None, src=(
|
||||
@@ -126,13 +126,13 @@ class TestLinearizerDumb(unittest.TestCase):
|
||||
UOp(Ops.CONST, dtypes.int, arg=-1, src=(
|
||||
x19:=UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(4096, 32000, 1), strides=(0, 0, 0), offset=0, mask=None, contiguous=False),)), src=()),)),)),
|
||||
UOp(Ops.LOAD, dtypes.int, arg=None, src=(
|
||||
UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(1), arg=1, src=()),
|
||||
x19,)),)),
|
||||
UOp(Ops.VIEW, dtypes.int.ptr(1), arg=ShapeTracker(views=(View(shape=(4096, 32000, 1), strides=(0, 0, 0), offset=0, mask=None, contiguous=False),)), src=(
|
||||
UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(1), arg=1, src=()),)),)),)),
|
||||
UOp(Ops.CONST, dtypes.bool, arg=True, src=(
|
||||
x19,)),)),)),
|
||||
UOp(Ops.LOAD, dtypes.half, arg=None, src=(
|
||||
UOp(Ops.DEFINE_GLOBAL, dtypes.half.ptr(131072000), arg=2, src=()),
|
||||
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(4096, 32000, 1), strides=(1, 4096, 0), offset=0, mask=None, contiguous=False),)), src=()),)),)),)),)),)),)),))
|
||||
UOp(Ops.VIEW, dtypes.half.ptr(131072000), arg=ShapeTracker(views=(View(shape=(4096, 32000, 1), strides=(1, 4096, 0), offset=0, mask=None, contiguous=False),)), src=(
|
||||
UOp(Ops.DEFINE_GLOBAL, dtypes.half.ptr(131072000), arg=2, src=()),)),)),)),)),)),)),)),))
|
||||
k = Kernel(ast, opts=Device[Device.DEFAULT].renderer)
|
||||
prg = k.to_program()
|
||||
print(prg.src)
|
||||
@@ -142,25 +142,25 @@ class TestLinearizerDumb(unittest.TestCase):
|
||||
def test_unaligns_idxs(self):
|
||||
ast = UOp(Ops.SINK, dtypes.void, arg=None, src=(
|
||||
UOp(Ops.STORE, dtypes.void, arg=None, src=(
|
||||
UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(3), arg=0, src=()),
|
||||
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(3, 1, 1), strides=(1, 0, 0), offset=0, mask=None, contiguous=True),)), src=()),
|
||||
UOp(Ops.VIEW, dtypes.float.ptr(3), arg=ShapeTracker(views=(View(shape=(3, 1, 1), strides=(1, 0, 0), offset=0, mask=None, contiguous=True),)), src=(
|
||||
UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(3), arg=0, src=()),)),
|
||||
UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(Ops.ADD, (2,)), src=(
|
||||
UOp(Ops.MUL, dtypes.float, arg=None, src=(
|
||||
UOp(Ops.CAST, dtypes.float, arg=None, src=(
|
||||
UOp(Ops.CMPNE, dtypes.bool, arg=None, src=(
|
||||
UOp(Ops.CMPNE, dtypes.bool, arg=None, src=(
|
||||
UOp(Ops.LOAD, dtypes.long, arg=None, src=(
|
||||
UOp(Ops.DEFINE_GLOBAL, dtypes.long.ptr(3), arg=1, src=()),
|
||||
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(3, 1, 5), strides=(1, 0, 0), offset=0, mask=None, contiguous=False),)), src=()),)),
|
||||
UOp(Ops.VIEW, dtypes.long.ptr(3), arg=ShapeTracker(views=(View(shape=(3, 1, 5), strides=(1, 0, 0), offset=0, mask=None, contiguous=False),)), src=(
|
||||
UOp(Ops.DEFINE_GLOBAL, dtypes.long.ptr(3), arg=1, src=()),)),)),
|
||||
UOp(Ops.CAST, dtypes.long, arg=None, src=(
|
||||
UOp(Ops.LOAD, dtypes.int, arg=None, src=(
|
||||
UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(5), arg=2, src=()),
|
||||
x14:=UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(3, 1, 5), strides=(0, 0, 1), offset=0, mask=None, contiguous=False),)), src=()),)),)),)),
|
||||
UOp(Ops.VIEW, dtypes.int.ptr(5), arg=ShapeTracker(views=(View(shape=(3, 1, 5), strides=(0, 0, 1), offset=0, mask=None, contiguous=False),)), src=(
|
||||
UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(5), arg=2, src=()),)),)),)),)),
|
||||
UOp(Ops.CONST, dtypes.bool, arg=True, src=(
|
||||
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(3, 1, 5), strides=(0, 0, 0), offset=0, mask=None, contiguous=False),)), src=()),)),)),)),
|
||||
UOp(Ops.LOAD, dtypes.float, arg=None, src=(
|
||||
UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(5), arg=3, src=()),
|
||||
x14,)),)),)),)),))
|
||||
UOp(Ops.VIEW, dtypes.float.ptr(5), arg=ShapeTracker(views=(View(shape=(3, 1, 5), strides=(0, 0, 1), offset=0, mask=None, contiguous=False),)), src=(
|
||||
UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(5), arg=3, src=()),)),)),)),)),)),))
|
||||
opts = [Opt(op=OptOps.UNROLL, axis=0, arg=0), Opt(op=OptOps.LOCAL, axis=0, arg=3)]
|
||||
k = Kernel(ast, opts=Device[Device.DEFAULT].renderer)
|
||||
k.apply_opts(opts)
|
||||
@@ -174,15 +174,15 @@ class TestLinearizerDumb(unittest.TestCase):
|
||||
def test_unrolled_float4_align(self):
|
||||
ast = UOp(Ops.SINK, dtypes.void, arg=None, src=(
|
||||
UOp(Ops.STORE, dtypes.void, arg=None, src=(
|
||||
UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(1), arg=0, src=()),
|
||||
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 1), strides=(0, 0), offset=0, mask=None, contiguous=True),)), src=()),
|
||||
UOp(Ops.VIEW, dtypes.float.ptr(1), arg=ShapeTracker(views=(View(shape=(1, 1), strides=(0, 0), offset=0, mask=None, contiguous=True),)), src=(
|
||||
UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(1), arg=0, src=()),)),
|
||||
UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(Ops.ADD, (0, 1)), src=(
|
||||
UOp(Ops.WHERE, dtypes.float, arg=None, src=(
|
||||
UOp(Ops.CMPNE, dtypes.bool, arg=None, src=(
|
||||
UOp(Ops.CMPNE, dtypes.bool, arg=None, src=(
|
||||
UOp(Ops.LOAD, dtypes.long, arg=None, src=(
|
||||
UOp(Ops.DEFINE_GLOBAL, dtypes.long.ptr(18), arg=1, src=()),
|
||||
x9:=UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(3, 6), strides=(6, 1), offset=0, mask=None, contiguous=True),)), src=()),)),
|
||||
UOp(Ops.VIEW, dtypes.long.ptr(18), arg=ShapeTracker(views=(View(shape=(3, 6), strides=(6, 1), offset=0, mask=None, contiguous=True),)), src=(
|
||||
UOp(Ops.DEFINE_GLOBAL, dtypes.long.ptr(18), arg=1, src=()),)),)),
|
||||
UOp(Ops.CONST, dtypes.long, arg=-1, src=(
|
||||
x11:=UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(3, 6), strides=(0, 0), offset=0, mask=None, contiguous=False),)), src=()),)),)),
|
||||
UOp(Ops.CONST, dtypes.bool, arg=True, src=(
|
||||
@@ -190,8 +190,8 @@ class TestLinearizerDumb(unittest.TestCase):
|
||||
UOp(Ops.CONST, dtypes.float, arg=0.0, src=(
|
||||
x11,)),
|
||||
UOp(Ops.LOAD, dtypes.float, arg=None, src=(
|
||||
UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(18), arg=2, src=()),
|
||||
x9,)),)),)),)),))
|
||||
UOp(Ops.VIEW, dtypes.float.ptr(18), arg=ShapeTracker(views=(View(shape=(3, 6), strides=(6, 1), offset=0, mask=None, contiguous=True),)), src=(
|
||||
UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(18), arg=2, src=()),)),)),)),)),)),))
|
||||
opts = [Opt(op=OptOps.UNROLL, axis=0, arg=0)]
|
||||
k = Kernel(ast, opts=Device[Device.DEFAULT].renderer)
|
||||
k.apply_opts(opts)
|
||||
@@ -206,16 +206,16 @@ class TestLinearizerDumb(unittest.TestCase):
|
||||
def test_upcasted_stores_out_of_order(self):
|
||||
ast = UOp(Ops.SINK, dtypes.void, arg=None, src=(
|
||||
UOp(Ops.STORE, dtypes.void, arg=None, src=(
|
||||
UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(9360), arg=0, src=()),
|
||||
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(4, 5, 13, 1, 1, 1, 1, 1, 4, 3, 3), strides=(2340, 468, 36, 0, 0, 0, 0, 0, 9, 3, 1), offset=0, mask=None, contiguous=True),)), src=()),
|
||||
UOp(Ops.VIEW, dtypes.float.ptr(9360), arg=ShapeTracker(views=(View(shape=(4, 5, 13, 1, 1, 1, 1, 1, 4, 3, 3), strides=(2340, 468, 36, 0, 0, 0, 0, 0, 9, 3, 1), offset=0, mask=None, contiguous=True),)), src=(
|
||||
UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(9360), arg=0, src=()),)),
|
||||
UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(Ops.ADD, (6,)), src=(
|
||||
UOp(Ops.MUL, dtypes.float, arg=None, src=(
|
||||
UOp(Ops.LOAD, dtypes.float, arg=None, src=(
|
||||
UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(144), arg=1, src=()),
|
||||
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(4, 5, 13, 1, 1, 1, 4, 1, 4, 3, 3), strides=(0, 0, 0, 0, 0, 0, 1, 0, 4, 48, 16), offset=0, mask=None, contiguous=False),)), src=()),)),
|
||||
UOp(Ops.VIEW, dtypes.float.ptr(144), arg=ShapeTracker(views=(View(shape=(4, 5, 13, 1, 1, 1, 4, 1, 4, 3, 3), strides=(0, 0, 0, 0, 0, 0, 1, 0, 4, 48, 16), offset=0, mask=None, contiguous=False),)), src=(
|
||||
UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(144), arg=1, src=()),)),)),
|
||||
UOp(Ops.LOAD, dtypes.float, arg=None, src=(
|
||||
UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(1040), arg=2, src=()),
|
||||
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(4, 5, 13, 1, 1, 1, 4, 1, 4, 3, 3), strides=(260, 13, 1, 0, 0, 0, 65, 0, 0, 0, 0), offset=0, mask=None, contiguous=False),)), src=()),)),)),)),)),))
|
||||
UOp(Ops.VIEW, dtypes.float.ptr(1040), arg=ShapeTracker(views=(View(shape=(4, 5, 13, 1, 1, 1, 4, 1, 4, 3, 3), strides=(260, 13, 1, 0, 0, 0, 65, 0, 0, 0, 0), offset=0, mask=None, contiguous=False),)), src=(
|
||||
UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(1040), arg=2, src=()),)),)),)),)),)),))
|
||||
opts = [Opt(op=OptOps.UPCAST, axis=3, arg=0), Opt(op=OptOps.UPCAST, axis=2, arg=0)]
|
||||
k = Kernel(ast, opts=Device[Device.DEFAULT].renderer)
|
||||
k.apply_opts(opts)
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -173,8 +173,8 @@ class TestLinearizerOverflowAlt(unittest.TestCase):
|
||||
View(shape=(BS, 1, 64, 112, 112, 3, 7, 7), strides=(10156800, 0, 0, 3680, 2, 3385600, 425040, 231), offset=0, mask=None, contiguous=False))).to_uop()
|
||||
in_st_2 = ShapeTracker(views=(View(shape=(BS, 1, 64, 112, 112, 3, 7, 7), strides=(0, 0, 147, 0, 0, 49, 7, 1), offset=0, mask=None, contiguous=False),)).to_uop()
|
||||
ot_st = ShapeTracker(views=(View(shape=(BS, 1, 64, 112, 112, 1, 1, 1), strides=(802816, 0, 12544, 112, 1, 0, 0, 0), offset=0, mask=None, contiguous=True),)).to_uop()
|
||||
prod = UOp(Ops.LOAD, dtypes.float, (g1, in_st_1)) * UOp(Ops.LOAD, dtypes.float, (g2, in_st_2))
|
||||
store = UOp(Ops.STORE, src=(g0, ot_st, UOp(Ops.REDUCE_AXIS, dtypes.float, (prod,), (Ops.ADD, (7, 6, 5)))))
|
||||
prod = UOp(Ops.LOAD, dtypes.float, (g1.view(in_st_1.arg),)) * UOp(Ops.LOAD, dtypes.float, (g2.view(in_st_2.arg),))
|
||||
store = UOp(Ops.STORE, src=(g0.view(ot_st.arg), UOp(Ops.REDUCE_AXIS, dtypes.float, (prod,), (Ops.ADD, (7, 6, 5)))))
|
||||
ast = UOp(Ops.SINK, src=(store,))
|
||||
opts = [Opt(op=OptOps.LOCAL, axis=3, arg=16), Opt(op=OptOps.LOCAL, axis=2, arg=2), Opt(op=OptOps.UPCAST, axis=0, arg=2)]
|
||||
_test_overflow(ast, opts)
|
||||
@@ -185,8 +185,8 @@ class TestLinearizerOverflowAlt(unittest.TestCase):
|
||||
View(shape=(BS, 1, 64, 112, 112, 3, 7, 7), strides=(10156800, 0, 0, 3680, 2, 3385600, 425040, 231), offset=0, mask=None, contiguous=False))).to_uop()
|
||||
in_st_2 = ShapeTracker(views=(View(shape=(BS, 1, 64, 112, 112, 3, 7, 7), strides=(0, 0, 147, 0, 0, 49, 7, 1), offset=0, mask=None, contiguous=False),)).to_uop()
|
||||
ot_st = ShapeTracker(views=(View(shape=(BS, 1, 64, 112, 112, 1, 1, 1), strides=(802816, 0, 12544, 112, 1, 0, 0, 0), offset=0, mask=None, contiguous=True),)).to_uop()
|
||||
prod = UOp(Ops.LOAD, dtypes.float, (g1, in_st_1)) * UOp(Ops.LOAD, dtypes.float, (g2, in_st_2))
|
||||
store = UOp(Ops.STORE, src=(g0, ot_st, UOp(Ops.REDUCE_AXIS, dtypes.float, (prod,), (Ops.ADD, (7, 6, 5)))))
|
||||
prod = UOp(Ops.LOAD, dtypes.float, (g1.view(in_st_1.arg),)) * UOp(Ops.LOAD, dtypes.float, (g2.view(in_st_2.arg),))
|
||||
store = UOp(Ops.STORE, src=(g0.view(ot_st.arg), UOp(Ops.REDUCE_AXIS, dtypes.float, (prod,), (Ops.ADD, (7, 6, 5)))))
|
||||
ast = UOp(Ops.SINK, src=(store,))
|
||||
opts = [Opt(op=OptOps.LOCAL, axis=3, arg=16), Opt(op=OptOps.UPCAST, axis=1, arg=4), Opt(op=OptOps.LOCAL, axis=2, arg=16), Opt(op=OptOps.UPCAST, axis=4, arg=4), Opt(op=OptOps.UPCAST, axis=1, arg=2), Opt(op=OptOps.UPCAST, axis=5, arg=2)]
|
||||
_test_overflow(ast, opts)
|
||||
|
||||
@@ -243,8 +243,8 @@ class TestDSPCache(unittest.TestCase):
|
||||
# string becuase this breaks Python language server for syntax highlight for some reason
|
||||
ast = eval("""UOp(Ops.SINK, dtypes.void, arg=None, src=(
|
||||
UOp(Ops.STORE, dtypes.void, arg=None, src=(
|
||||
UOp(Ops.DEFINE_GLOBAL, dtypes.uchar.ptr(25088), arg=0, src=()),
|
||||
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 28, 28, 32, 1), strides=(0, 896, 32, 1, 0), offset=0, mask=None, contiguous=True),)), src=()),
|
||||
UOp(Ops.VIEW, dtypes.uchar.ptr(25088), arg=ShapeTracker(views=(View(shape=(1, 28, 28, 32, 1), strides=(0, 896, 32, 1, 0), offset=0, mask=None, contiguous=True),)), src=(
|
||||
UOp(Ops.DEFINE_GLOBAL, dtypes.uchar.ptr(25088), arg=0, src=()),)),
|
||||
UOp(Ops.CAST, dtypes.uchar, arg=None, src=(
|
||||
UOp(Ops.XOR, dtypes.int, arg=None, src=(
|
||||
UOp(Ops.MAX, dtypes.int, arg=None, src=(
|
||||
@@ -261,23 +261,23 @@ class TestDSPCache(unittest.TestCase):
|
||||
UOp(Ops.CAST, dtypes.float, arg=None, src=(
|
||||
UOp(Ops.CAST, dtypes.int, arg=None, src=(
|
||||
UOp(Ops.LOAD, dtypes.uchar, arg=None, src=(
|
||||
UOp(Ops.DEFINE_GLOBAL, dtypes.uchar.ptr(150528), arg=1, src=()),
|
||||
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 28, 28, 32, 192), strides=(0, 5376, 192, 0, 1), offset=0, mask=None, contiguous=False),)), src=()),)),)),)),
|
||||
UOp(Ops.VIEW, dtypes.uchar.ptr(150528), arg=ShapeTracker(views=(View(shape=(1, 28, 28, 32, 192), strides=(0, 5376, 192, 0, 1), offset=0, mask=None, contiguous=False),)), src=(
|
||||
UOp(Ops.DEFINE_GLOBAL, dtypes.uchar.ptr(150528), arg=1, src=()),)),)),)),)),
|
||||
UOp(Ops.CONST, dtypes.float, arg=0.012368360534310341, src=(
|
||||
x22:=UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 28, 28, 32, 192), strides=(0, 0, 0, 0, 0), offset=0, mask=None, contiguous=False),)), src=()),)),)),
|
||||
UOp(Ops.MUL, dtypes.float, arg=None, src=(
|
||||
UOp(Ops.CAST, dtypes.float, arg=None, src=(
|
||||
UOp(Ops.CAST, dtypes.int, arg=None, src=(
|
||||
UOp(Ops.LOAD, dtypes.char, arg=None, src=(
|
||||
UOp(Ops.DEFINE_GLOBAL, dtypes.char.ptr(6144), arg=2, src=()),
|
||||
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(32, 48, 4), strides=(4, 128, 1), offset=0, mask=None, contiguous=False), View(shape=(1, 28, 28, 32, 192), strides=(0, 0, 0, 192, 1), offset=0, mask=None, contiguous=False))), src=()),)),)),)),
|
||||
UOp(Ops.VIEW, dtypes.char.ptr(6144), arg=ShapeTracker(views=(View(shape=(32, 48, 4), strides=(4, 128, 1), offset=0, mask=None, contiguous=False), View(shape=(1, 28, 28, 32, 192), strides=(0, 0, 0, 192, 1), offset=0, mask=None, contiguous=False))), src=(
|
||||
UOp(Ops.DEFINE_GLOBAL, dtypes.char.ptr(6144), arg=2, src=()),)),)),)),)),
|
||||
UOp(Ops.CONST, dtypes.float, arg=0.007441135589033365, src=(
|
||||
x22,)),)),)),)),
|
||||
UOp(Ops.MUL, dtypes.float, arg=None, src=(
|
||||
UOp(Ops.CAST, dtypes.float, arg=None, src=(
|
||||
UOp(Ops.LOAD, dtypes.int, arg=None, src=(
|
||||
UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(32), arg=3, src=()),
|
||||
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 28, 28, 32, 1), strides=(0, 0, 0, 1, 0), offset=0, mask=None, contiguous=False),)), src=()),)),)),
|
||||
UOp(Ops.VIEW, dtypes.int.ptr(32), arg=ShapeTracker(views=(View(shape=(1, 28, 28, 32, 1), strides=(0, 0, 0, 1, 0), offset=0, mask=None, contiguous=False),)), src=(
|
||||
UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(32), arg=3, src=()),)),)),)),
|
||||
UOp(Ops.CONST, dtypes.float, arg=9.203465015161783e-05, src=(
|
||||
x36:=UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 28, 28, 32, 1), strides=(0, 0, 0, 0, 0), offset=0, mask=None, contiguous=False),)), src=()),)),)),)),
|
||||
UOp(Ops.CONST, dtypes.float, arg=33.812857328652136, src=(
|
||||
|
||||
@@ -1994,7 +1994,8 @@ class TestIndexing(unittest.TestCase):
|
||||
|
||||
@track_rewrites(named=True)
|
||||
def swizzle_rewrite(u:UOp) -> UOp: return graph_rewrite(graph_rewrite(u, view_left), view_right)
|
||||
def swizzle_cnt(u:UOp) -> int: return len([x for x in u.toposort() if x.op is Ops.VIEW and len(x.src) != 0 and x.src[0].op is not Ops.BUFFER])
|
||||
def swizzle_cnt(u:UOp) -> int:
|
||||
return len([x for x in u.toposort() if x.op is Ops.VIEW and len(x.src) != 0 and x.src[0].op not in {Ops.BUFFER, Ops.DEFINE_GLOBAL}])
|
||||
|
||||
class TestSwizzle(unittest.TestCase):
|
||||
def test_swizzle_simple(self):
|
||||
@@ -2096,7 +2097,7 @@ class TestSwizzle(unittest.TestCase):
|
||||
run_schedule(check_schedule(t, 3))
|
||||
np.testing.assert_equal(t.numpy(), [[0.5, 0.5], [0.5, 0.5], [0., 0.]])
|
||||
|
||||
def store_val(si:ScheduleItem): return si.ast.src[0].src[2]
|
||||
def store_val(si:ScheduleItem): return si.ast.src[0].src[1]
|
||||
zero_pm = UPat(Ops.CONST, arg=0)
|
||||
class TestView(unittest.TestCase):
|
||||
def test_all_masked_out(self):
|
||||
@@ -2252,7 +2253,7 @@ class TestConst(unittest.TestCase):
|
||||
a = Tensor.ones((4,)).pad((1, 1)).contiguous()
|
||||
sched = a.schedule()
|
||||
print(sched[0].ast)
|
||||
const_ast_pattern = UPat(Ops.SINK, src=(UPat.store(UPat(), UPat(), UPat.where(UPat(Ops.VALID), UPat.cvar("x"), UPat(Ops.CONST, arg=0))),))
|
||||
const_ast_pattern = UPat(Ops.SINK, src=(UPat.store(UPat(), UPat.where(UPat(Ops.VALID), UPat.cvar("x"), UPat(Ops.CONST, arg=0))),))
|
||||
self.assertEqual(len(const_ast_pattern.match(sched[0].ast, {})), 1)
|
||||
run_schedule(sched)
|
||||
self.assertListEqual(a.tolist(), [0, 1, 1, 1, 1, 0])
|
||||
@@ -2261,7 +2262,7 @@ class TestConst(unittest.TestCase):
|
||||
a = Tensor.ones((4,)).contiguous()
|
||||
sched = a.schedule()
|
||||
print(sched[0].ast)
|
||||
const_ast_pattern = UPat(Ops.SINK, src=(UPat.store(UPat(), UPat(), UPat(Ops.CONST)),))
|
||||
const_ast_pattern = UPat(Ops.SINK, src=(UPat.store(UPat(), UPat(Ops.CONST)),))
|
||||
self.assertEqual(len(const_ast_pattern.match(sched[0].ast, {})), 1)
|
||||
run_schedule(sched)
|
||||
self.assertListEqual(a.tolist(), [1, 1, 1, 1])
|
||||
|
||||
@@ -127,8 +127,8 @@ class TestBEAM(unittest.TestCase):
|
||||
# taken from https://github.com/tinygrad/tinygrad/issues/4612
|
||||
ast = UOp(Ops.SINK, dtypes.void, arg=None, src=(
|
||||
UOp(Ops.STORE, dtypes.void, arg=None, src=(
|
||||
UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(256), arg=0, src=()),
|
||||
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 1, 256), strides=(0, 0, 1), offset=0, mask=None, contiguous=True),)), src=()), # noqa: E501
|
||||
UOp(Ops.VIEW, dtypes.float.ptr(256), arg=ShapeTracker(views=(View(shape=(1, 1, 256), strides=(0, 0, 1), offset=0, mask=None, contiguous=True),)), src=( # noqa: E501
|
||||
UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(256), arg=0, src=()),)),
|
||||
UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(Ops.MAX, (1,)), src=(
|
||||
UOp(Ops.MUL, dtypes.float, arg=None, src=(
|
||||
UOp(Ops.ADD, dtypes.float, arg=None, src=(
|
||||
@@ -137,23 +137,23 @@ class TestBEAM(unittest.TestCase):
|
||||
UOp(Ops.ADD, dtypes.float, arg=None, src=(
|
||||
UOp(Ops.ADD, dtypes.float, arg=None, src=(
|
||||
UOp(Ops.LOAD, dtypes.float, arg=None, src=(
|
||||
UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(64128), arg=1, src=()),
|
||||
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(384768,), strides=(1,), offset=0, mask=((0, 64128),), contiguous=False), View(shape=(1, 501, 256), strides=(0, 1, 501), offset=256512, mask=None, contiguous=False))), src=()),)), # noqa: E501
|
||||
UOp(Ops.VIEW, dtypes.float.ptr(64128), arg=ShapeTracker(views=(View(shape=(384768,), strides=(1,), offset=0, mask=((0, 64128),), contiguous=False), View(shape=(1, 501, 256), strides=(0, 1, 501), offset=256512, mask=None, contiguous=False))), src=( # noqa: E501
|
||||
UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(64128), arg=1, src=()),)),)),
|
||||
UOp(Ops.LOAD, dtypes.float, arg=None, src=(
|
||||
UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(64128), arg=2, src=()),
|
||||
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(384768,), strides=(1,), offset=-64128, mask=((64128, 128256),), contiguous=False), View(shape=(1, 501, 256), strides=(0, 1, 501), offset=256512, mask=None, contiguous=False))), src=()),)),)), # noqa: E501
|
||||
UOp(Ops.VIEW, dtypes.float.ptr(64128), arg=ShapeTracker(views=(View(shape=(384768,), strides=(1,), offset=-64128, mask=((64128, 128256),), contiguous=False), View(shape=(1, 501, 256), strides=(0, 1, 501), offset=256512, mask=None, contiguous=False))), src=( # noqa: E501
|
||||
UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(64128), arg=2, src=()),)),)),)),
|
||||
UOp(Ops.LOAD, dtypes.float, arg=None, src=(
|
||||
UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(64128), arg=3, src=()),
|
||||
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(384768,), strides=(1,), offset=-128256, mask=((128256, 192384),), contiguous=False), View(shape=(1, 501, 256), strides=(0, 1, 501), offset=256512, mask=None, contiguous=False))), src=()),)),)), # noqa: E501
|
||||
UOp(Ops.VIEW, dtypes.float.ptr(64128), arg=ShapeTracker(views=(View(shape=(384768,), strides=(1,), offset=-128256, mask=((128256, 192384),), contiguous=False), View(shape=(1, 501, 256), strides=(0, 1, 501), offset=256512, mask=None, contiguous=False))), src=( # noqa: E501
|
||||
UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(64128), arg=3, src=()),)),)),)),
|
||||
UOp(Ops.LOAD, dtypes.float, arg=None, src=(
|
||||
UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(64128), arg=4, src=()),
|
||||
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(384768,), strides=(1,), offset=-192384, mask=((192384, 256512),), contiguous=False), View(shape=(1, 501, 256), strides=(0, 1, 501), offset=256512, mask=None, contiguous=False))), src=()),)),)), # noqa: E501
|
||||
UOp(Ops.VIEW, dtypes.float.ptr(64128), arg=ShapeTracker(views=(View(shape=(384768,), strides=(1,), offset=-192384, mask=((192384, 256512),), contiguous=False), View(shape=(1, 501, 256), strides=(0, 1, 501), offset=256512, mask=None, contiguous=False))), src=( # noqa: E501
|
||||
UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(64128), arg=4, src=()),)),)),)),
|
||||
UOp(Ops.LOAD, dtypes.float, arg=None, src=(
|
||||
UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(64128), arg=5, src=()),
|
||||
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(384768,), strides=(1,), offset=-256512, mask=((256512, 320640),), contiguous=False), View(shape=(1, 501, 256), strides=(0, 1, 501), offset=256512, mask=None, contiguous=False))), src=()),)),)), # noqa: E501
|
||||
UOp(Ops.VIEW, dtypes.float.ptr(64128), arg=ShapeTracker(views=(View(shape=(384768,), strides=(1,), offset=-256512, mask=((256512, 320640),), contiguous=False), View(shape=(1, 501, 256), strides=(0, 1, 501), offset=256512, mask=None, contiguous=False))), src=( # noqa: E501
|
||||
UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(64128), arg=5, src=()),)),)),)),
|
||||
UOp(Ops.LOAD, dtypes.float, arg=None, src=(
|
||||
UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(64128), arg=6, src=()),
|
||||
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(384768,), strides=(1,), offset=-320640, mask=((320640, 384768),), contiguous=False), View(shape=(1, 501, 256), strides=(0, 1, 501), offset=256512, mask=None, contiguous=False))), src=()),)),)), # noqa: E501
|
||||
UOp(Ops.VIEW, dtypes.float.ptr(64128), arg=ShapeTracker(views=(View(shape=(384768,), strides=(1,), offset=-320640, mask=((320640, 384768),), contiguous=False), View(shape=(1, 501, 256), strides=(0, 1, 501), offset=256512, mask=None, contiguous=False))), src=( # noqa: E501
|
||||
UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(64128), arg=6, src=()),)),)),)),
|
||||
UOp(Ops.CONST, dtypes.float, arg=1.4285714285714286, src=(
|
||||
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 501, 256), strides=(0, 0, 0), offset=0, mask=None, contiguous=False),)), src=()),)),)),)),)),)) # noqa: E501
|
||||
lin = Kernel(ast)
|
||||
|
||||
@@ -92,7 +92,7 @@ class TestTensorUOp(unittest.TestCase):
|
||||
out.realize()
|
||||
self.assertEqual(out.tolist(), Tensor.zeros(4, 8).tolist())
|
||||
|
||||
reduce_kernel = UPat(Ops.SINK, src=(UPat(Ops.STORE, src=(UPat(), UPat(), UPat(Ops.REDUCE_AXIS)))))
|
||||
reduce_kernel = UPat(Ops.SINK, src=(UPat(Ops.STORE, src=(UPat(), UPat(Ops.REDUCE_AXIS)))))
|
||||
class TestReduceOp(unittest.TestCase):
|
||||
def test_no_split_reduce_kernel(self):
|
||||
a = Tensor.rand(4, 4).realize()
|
||||
|
||||
@@ -29,46 +29,46 @@ class TestVerifyAST(unittest.TestCase):
|
||||
buf_0 = UOp(Ops.DEFINE_GLOBAL, dtype.ptr(), (), 0)
|
||||
buf_1 = UOp(Ops.DEFINE_GLOBAL, dtype.ptr(), (), 1)
|
||||
buf_2 = UOp(Ops.DEFINE_GLOBAL, dtype.ptr(), (), 2)
|
||||
a = UOp(Ops.LOAD, dtype, (buf_1, ShapeTracker.from_shape((32, 1)).to_uop()))
|
||||
b = UOp(Ops.LOAD, dtype, (buf_2, ShapeTracker.from_shape((32, 1)).to_uop()))
|
||||
store = UOp(Ops.STORE, dtypes.void, (buf_0, ShapeTracker.from_shape((32, 1)).to_uop(), a+b))
|
||||
a = UOp(Ops.LOAD, dtype, (buf_1.view(ShapeTracker.from_shape((32, 1))),))
|
||||
b = UOp(Ops.LOAD, dtype, (buf_2.view(ShapeTracker.from_shape((32, 1))),))
|
||||
store = UOp(Ops.STORE, dtypes.void, (buf_0.view(ShapeTracker.from_shape((32, 1))), a+b))
|
||||
helper_test_verify_ast(store)
|
||||
|
||||
def test_exactly_one_full_shape(self):
|
||||
dtype = dtypes.int
|
||||
bufs = [UOp(Ops.DEFINE_GLOBAL, dtype.ptr(), (), i) for i in range(6)]
|
||||
a = UOp(Ops.LOAD, dtype, (bufs[2], ShapeTracker.from_shape((32, 1)).to_uop()))
|
||||
b = UOp(Ops.LOAD, dtype, (bufs[3], ShapeTracker.from_shape((32, 1)).to_uop()))
|
||||
a = UOp(Ops.LOAD, dtype, (bufs[2].view(ShapeTracker.from_shape((32, 1))),))
|
||||
b = UOp(Ops.LOAD, dtype, (bufs[3].view(ShapeTracker.from_shape((32, 1))),))
|
||||
st0 = UOp.store(bufs[0], ShapeTracker.from_shape((32, 1)).to_uop(), a+b)
|
||||
a = UOp(Ops.LOAD, dtype, (bufs[4], ShapeTracker.from_shape((32, 32)).to_uop()))
|
||||
b = UOp(Ops.LOAD, dtype, (bufs[5], ShapeTracker.from_shape((32, 32)).to_uop()))
|
||||
a = UOp(Ops.LOAD, dtype, (bufs[4].view(ShapeTracker.from_shape((32, 32))),))
|
||||
b = UOp(Ops.LOAD, dtype, (bufs[5].view(ShapeTracker.from_shape((32, 32))),))
|
||||
st1 = UOp.store(bufs[1], ShapeTracker.from_shape((32, 32)).to_uop(), a+b)
|
||||
with self.assertRaises(InvalidASTException): helper_test_verify_ast(st0, st1)
|
||||
|
||||
def test_no_implicit_broadcasting(self):
|
||||
bufs = [UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), (), i) for i in range(2)]
|
||||
a = UOp(Ops.LOAD, dtypes.float, (bufs[1], ShapeTracker.from_shape((4, 32)).to_uop()))
|
||||
a = UOp(Ops.LOAD, dtypes.float, (bufs[1].view(ShapeTracker.from_shape((4, 32))),))
|
||||
b = a + UOp(Ops.REDUCE_AXIS, dtypes.float, (a,), (Ops.MAX, (1,)))
|
||||
st = UOp(Ops.STORE, dtypes.void, (bufs[0], ShapeTracker.from_shape((4, 32)).to_uop(), b))
|
||||
st = UOp(Ops.STORE, dtypes.void, (bufs[0].view(ShapeTracker.from_shape((4, 32))), b))
|
||||
with self.assertRaises(InvalidASTException): helper_test_verify_ast(st)
|
||||
|
||||
def test_shrink_ok(self):
|
||||
bufs = [UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), (), i) for i in range(2)]
|
||||
a = UOp(Ops.LOAD, dtypes.float, (bufs[1], ShapeTracker((View((32, 32), strides=(32, 1), offset=0, mask=None, contiguous=True),)).to_uop()))
|
||||
b = UOp(Ops.LOAD, dtypes.float, (bufs[1], ShapeTracker((View((32, 32), strides=(0, 1), offset=0, mask=None, contiguous=False),)).to_uop()))
|
||||
st = UOp.store(bufs[0], ShapeTracker.from_shape((32, 32)).to_uop(), a+b)
|
||||
a = UOp(Ops.LOAD, dtypes.float, (bufs[1].view(ShapeTracker((View((32, 32), strides=(32, 1), offset=0, mask=None, contiguous=True),))),))
|
||||
b = UOp(Ops.LOAD, dtypes.float, (bufs[1].view(ShapeTracker((View((32, 32), strides=(0, 1), offset=0, mask=None, contiguous=False),))),))
|
||||
st = UOp.store(bufs[0].view(ShapeTracker.from_shape((32, 32))), a+b)
|
||||
helper_test_verify_ast(st)
|
||||
|
||||
def test_reduce_store(self):
|
||||
bufs = [UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), (), i) for i in range(2)]
|
||||
a = UOp(Ops.LOAD, dtypes.float, (bufs[1], ShapeTracker.from_shape((32, 1)).to_uop()))
|
||||
a = UOp(Ops.LOAD, dtypes.float, (bufs[1].view(ShapeTracker.from_shape((32, 1))),))
|
||||
r = UOp(Ops.REDUCE_AXIS, dtypes.float, (a,), (Ops.ADD, (0,)))
|
||||
st = UOp.store(bufs[0], ShapeTracker.from_shape((32, 1)).to_uop(), r)
|
||||
with self.assertRaises(InvalidASTException): helper_test_verify_ast(st)
|
||||
|
||||
def test_reduce_add_store(self):
|
||||
bufs = [UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), (), i) for i in range(2)]
|
||||
a = UOp(Ops.LOAD, dtypes.float, (bufs[1], ShapeTracker.from_shape((32, 1)).to_uop()))
|
||||
a = UOp(Ops.LOAD, dtypes.float, (bufs[1].view(ShapeTracker.from_shape((32, 1))),))
|
||||
r = UOp(Ops.REDUCE_AXIS, dtypes.float, (a,), (Ops.ADD, (0,)))
|
||||
st = UOp.store(bufs[0], ShapeTracker.from_shape((32, 1)).to_uop(), r+a)
|
||||
with self.assertRaises(InvalidASTException): helper_test_verify_ast(st)
|
||||
@@ -83,7 +83,7 @@ class TestVerifyAST(unittest.TestCase):
|
||||
|
||||
def test_assert_swizzle(self):
|
||||
buf = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), (), 0)
|
||||
a = UOp(Ops.LOAD, dtypes.float, (buf, ShapeTracker.from_shape((32, 1)).to_uop()))
|
||||
a = UOp(Ops.LOAD, dtypes.float, (buf.view(ShapeTracker.from_shape((32, 1))),))
|
||||
r = UOp(Ops.REDUCE_AXIS, dtypes.float, (a,), (Ops.ADD, (0,)))
|
||||
st = UOp.store(buf, ShapeTracker.from_shape((32, 1)).to_uop(), r.view(r.st.expand((32, 1)))+a)
|
||||
with self.assertRaisesRegex(InvalidASTException, "UOp verification failed"): helper_test_verify_ast(st)
|
||||
@@ -91,7 +91,7 @@ class TestVerifyAST(unittest.TestCase):
|
||||
def test_const_view_always_valid(self):
|
||||
buf = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), (), 0)
|
||||
a = UOp.const(dtypes.int, 0).replace(src=(UOp(Ops.VIEW, dtypes.void, (), ShapeTracker.from_shape(())),))
|
||||
st = UOp.store(buf, ShapeTracker.from_shape(()).to_uop(), a.cast(dtypes.float))
|
||||
st = UOp.store(buf.view(ShapeTracker.from_shape(())), a.cast(dtypes.float))
|
||||
helper_test_verify_ast(st)
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
@@ -99,7 +99,7 @@ class Kernel:
|
||||
return ret
|
||||
|
||||
@property
|
||||
def membufs(self) -> list[UOp]: return dedup([x.src[0] for x in self.bufs if x.op in {Ops.LOAD, Ops.STORE}])
|
||||
def membufs(self) -> list[UOp]: return dedup([x.src[0].base for x in self.bufs if x.op in {Ops.LOAD, Ops.STORE}])
|
||||
|
||||
def upcasted_axis(self, i:int) -> list[tuple[int, Optional[sint], bool]]:
|
||||
upcasted_shape, upcasted_stride = self.sts[i].shape[self.first_upcast:], self.sts[i].real_strides()[self.first_upcast:]
|
||||
@@ -456,7 +456,7 @@ class Kernel:
|
||||
# NOTE: if CONST got masked after applying opts, we create a new VALID
|
||||
if op.op is Ops.CONST and any(v.mask is not None for v in unwrap(st_uop.st).views): return op.view(st_uop.arg).valid()
|
||||
# otherwise we just replace the VIEW source
|
||||
return ret.replace(src=(st_uop,)) if len(op.src) == 1 else ret.replace(src=(ret.src[0], st_uop, *ret.src[2:]))
|
||||
return ret.replace(src=(ret.src[0].replace(arg=st_uop.arg),)+ret.src[1:])
|
||||
if op.op is Ops.SINK:
|
||||
return ret.replace(arg = KernelInfo(to_function_name(self.name) if name_override is None else name_override,
|
||||
self.local_dims, self.upcasted, self.dont_use_locals))
|
||||
@@ -491,8 +491,8 @@ class Kernel:
|
||||
st = store_st = ShapeTracker.from_shape(local_shape)
|
||||
local_buffer = UOp(Ops.DEFINE_LOCAL, tc.dtype_in.ptr(size=st.real_size(), local=True), (), f"temp{i}")
|
||||
if swizzle: store_st = get_tc_swizzle_st(store_st.shape, *swizzle)
|
||||
local_store = UOp.store(local_buffer, store_st.to_uop(), srcs[i])
|
||||
srcs[i] = UOp(Ops.LOAD, tc.dtype_in, (local_buffer, st.to_uop(), local_store))
|
||||
local_store = UOp.store(local_buffer.view(store_st), srcs[i])
|
||||
srcs[i] = UOp(Ops.LOAD, tc.dtype_in, (local_buffer.view(st), local_store))
|
||||
|
||||
tc_reduce_axes = tuple(tcd + ax for ax, _ in tc.get_reduce_axes())
|
||||
if self.use_tensor_cores == 1: # real WMMA, use CONTRACT/UNROLL to get the vectorization right
|
||||
@@ -518,11 +518,11 @@ class Kernel:
|
||||
st_uop = ShapeTracker.from_shape(local_shape).to_uop()
|
||||
local_size = st_uop.arg.real_size()
|
||||
local_buffer = UOp(Ops.DEFINE_LOCAL, op.dtype.ptr(local_size, local=True), (), f"temp{self.reduceops.index(op)}")
|
||||
local_load = UOp(Ops.LOAD, op.dtype, (local_buffer, st_uop, UOp.store(local_buffer, st_uop, ret)))
|
||||
local_load = UOp(Ops.LOAD, op.dtype, (local_buffer.view(st_uop.arg), UOp.store(local_buffer.view(st_uop.arg), ret)))
|
||||
grouped_reduce = UOp(Ops.REDUCE_AXIS, op.dtype, (local_load,), arg=(op.arg[0], grouped_axes))
|
||||
if op is self.reduceops[-1]: return grouped_reduce
|
||||
st_uop = ShapeTracker.from_shape(tuple([1 if i in grouped_axes else a for i,a in enumerate(local_shape)])).to_uop()
|
||||
return UOp(Ops.LOAD, op.dtype, (local_buffer, st_uop, UOp.store(local_buffer, st_uop, grouped_reduce)))
|
||||
return UOp(Ops.LOAD, op.dtype, (local_buffer.view(st_uop.arg), UOp.store(local_buffer.view(st_uop.arg), grouped_reduce)))
|
||||
|
||||
return ret
|
||||
fixed_ast = fixup_ast(self.ast)
|
||||
@@ -578,7 +578,7 @@ class Kernel:
|
||||
# group non-local bufs by the op type (LOAD or STORE) and the buffer arg. take the max access of that buffer in bytes
|
||||
# TODO: these max and min don't work on symbolic, and results are very wrong.
|
||||
mem_bytes = sum(max(x.src[0].dtype.nbytes() for x in group)
|
||||
for _, group in itertools.groupby([x for x in self.ast.toposort() if x.op in GroupOp.Buffer and x.src[0].op is Ops.DEFINE_GLOBAL],
|
||||
key=lambda x: (x.op, x.src[0].arg)))
|
||||
for _, group in itertools.groupby([x for x in self.ast.toposort() if x.op in GroupOp.Buffer and x.src[0].base.op is Ops.DEFINE_GLOBAL],
|
||||
key=lambda x: (x.op, x.src[0].base.arg)))
|
||||
return ProgramSpec(self.name if not name_override else name_override, src, self.opts.device, self.ast, self.uops, self.applied_opts, mem_bytes,
|
||||
global_size=[1,1,1] if self.opts.has_local else None, local_size=[1,1,1] if self.opts.has_local else None)
|
||||
|
||||
@@ -92,7 +92,7 @@ def get_index(ast:UOp, opts:Renderer) -> IndexContext:
|
||||
first_upcasted = len(full_shape)-ki.upcasted
|
||||
# if there's no reduce, this is first_upcasted. assumes reduces are at the end
|
||||
first_reduce = min([first_upcasted]+flatten(x.axis_arg for x in ast.toposort() if x.op is Ops.REDUCE_AXIS))
|
||||
local_loads = [x for x in ast.toposort() if x.op is Ops.LOAD and x.src[0].op is Ops.DEFINE_LOCAL]
|
||||
local_loads = [x for x in ast.toposort() if x.op is Ops.LOAD and x.src[0].base.op is Ops.DEFINE_LOCAL]
|
||||
# NOTE: sum up the reduced axes looking across all local loads, yields the number of grouped reduces
|
||||
group_for_reduces = sum([any(l.st_arg.shape[i]!=ast.src[0].st_arg.shape[i] for l in local_loads) for i in range(first_reduce,first_upcasted)])
|
||||
global_dims = first_reduce-ki.local_dims
|
||||
@@ -141,19 +141,19 @@ def lower_reduce_axis(ctx: IndexContext, x: UOp):
|
||||
def lower_load_store(ctx: IndexContext, x: UOp, buf: UOp):
|
||||
idx, valid = x.st_arg.to_indexed_uops(ctx.ridxs if x.op is Ops.LOAD and buf.op is Ops.DEFINE_LOCAL else ctx.idxs)
|
||||
if x.op is Ops.LOAD:
|
||||
barrier = (UOp(Ops.BARRIER, dtypes.void, (x.src[2],)),) if buf.op is Ops.DEFINE_LOCAL else ()
|
||||
barrier = (UOp(Ops.BARRIER, dtypes.void, (x.src[1],)),) if buf.op is Ops.DEFINE_LOCAL else ()
|
||||
return UOp(Ops.LOAD, x.dtype, (buf.index(idx, valid),) + barrier)
|
||||
# NOTE: only store the local reduceop in the threads that are actually doing the reduce
|
||||
if cast(PtrDType, buf.dtype).local and x.src[2].op is Ops.REDUCE:
|
||||
reduce_input = x.src[2].src[0]
|
||||
if cast(PtrDType, buf.dtype).local and x.src[1].op is Ops.REDUCE:
|
||||
reduce_input = x.src[1].src[0]
|
||||
store_back = reduce_input.op is Ops.LOAD and cast(PtrDType, reduce_input.src[0].dtype).local
|
||||
else: store_back = False
|
||||
# NOTE: If we're storing the reduced value back into each thread, need to zero-out the reduced axes
|
||||
if store_back: idx, _ = x.st_arg.to_indexed_uops([u.const_like(0) if u in x.src[2].src else u for u in ctx.idxs])
|
||||
if store_back: idx, _ = x.st_arg.to_indexed_uops([u.const_like(0) if u in x.src[1].src else u for u in ctx.idxs])
|
||||
if (not cast(PtrDType, buf.dtype).local) or store_back:
|
||||
for oidx, ridx in zip(ctx.idxs, ctx.ridxs):
|
||||
if oidx is not ridx: valid = valid * oidx.eq(0)
|
||||
return UOp(Ops.STORE, dtypes.void, (buf.index(idx, valid), x.src[2]))
|
||||
return UOp(Ops.STORE, dtypes.void, (buf.index(idx, valid), x.src[1]))
|
||||
|
||||
def lower_const(x:UOp):
|
||||
assert all(v.mask is None for v in unwrap(x.st).views), f"VIEW in CONST/DEFINE_VAR source must be unmasked, got {x.st}"
|
||||
@@ -164,7 +164,7 @@ pm_lowerer = PatternMatcher([
|
||||
(UPat((Ops.CONST, Ops.DEFINE_VAR), src=(UPat(Ops.VIEW),), name="x"), lower_const),
|
||||
(UPat(Ops.VALID, src=(UPat(Ops.VIEW),), name="x"), lambda ctx,x: x.st_arg.to_indexed_uops(ctx.idxs)[1]),
|
||||
# rewrite LOAD/STORE VIEW to LOAD/STORE with indexed
|
||||
(UPat((Ops.LOAD, Ops.STORE), src=(UPat.var("buf"), UPat(Ops.VIEW)), allow_any_len=True, name="x"), lower_load_store),
|
||||
(UPat((Ops.LOAD, Ops.STORE), src=(UPat.var("buf").view(),), allow_any_len=True, name="x"), lower_load_store),
|
||||
(UPat(Ops.INDEX, src=(UPat.var("b"), UPat.var("idx"), UPat.const(dtypes.bool, True))), lambda b, idx: b.index(idx)),
|
||||
(UPat(Ops.IGNORE, name="x"), lambda x: x.src[0]),
|
||||
])
|
||||
@@ -194,10 +194,10 @@ pm_quant = symbolic+PatternMatcher([
|
||||
lambda x,y,c1,c2: (x+y)*c1 if abs(c1.arg-c2.arg) < 1e-9 else None),
|
||||
# mul 0 * c1 is 0
|
||||
(UPat(Ops.VALID, src=(UPat(Ops.VIEW, name="v"),)).where(UPat.cvar("c1"), UPat(Ops.CONST, arg=0)) *
|
||||
UPat(Ops.LOAD, src=(UPat(), UPat(Ops.VIEW, name="v"))).cast(dtypes.int).cast(dtypes.float).named("ld"), lambda ld,v,c1: ld*c1),
|
||||
UPat(Ops.LOAD, src=(UPat().view(name="v"),)).cast(dtypes.int).cast(dtypes.float).named("ld"), lambda ld,v,c1: ld*c1),
|
||||
# mul (with plus) 0 * c1 is 0
|
||||
(UPat(Ops.VALID, src=(UPat(Ops.VIEW, name="v"),)).where(UPat.cvar("c1"), UPat(Ops.CONST, arg=0)) *
|
||||
(UPat(Ops.LOAD, src=(UPat(), UPat(Ops.VIEW, name="v"))).cast(dtypes.int) + \
|
||||
(UPat(Ops.LOAD, src=(UPat().view(name="v"),)).cast(dtypes.int) + \
|
||||
UPat(Ops.VALID, src=(UPat(Ops.VIEW, name="v"),)).where(UPat.cvar(), UPat(Ops.CONST, arg=0))).cast(dtypes.float).named("ld"),
|
||||
lambda ld,v,c1: ld*c1),
|
||||
|
||||
|
||||
@@ -270,7 +270,7 @@ merge_views = PatternMatcher([
|
||||
(UPat(GroupOp.Movement, src=(UPat.var("x"),), name="mop"), lambda mop,x: x.view(mop.st)),
|
||||
# remove NOOP views
|
||||
(UPat.var("x").view(name="view"), lambda x,view: x if x.st is not None and view.st.contiguous and view.shape == x.shape else None),
|
||||
(UPat().view(name="view"),
|
||||
(UPat(GroupOp.All-{Ops.DEFINE_GLOBAL}).view(name="view"),
|
||||
lambda view: view.const_like(0) if (mask:=view.st.views[-1].mask) is not None and any((x[1]-x[0]) == 0 for x in mask) else None),
|
||||
# VIEW on BufferOps replaces the ShapeTracker
|
||||
(UPat(Ops.VIEW, src=(UPat((Ops.LOAD, Ops.STORE, Ops.VALID), name="x"),), name="view"),
|
||||
@@ -362,11 +362,11 @@ view_right = merge_views+PatternMatcher([
|
||||
|
||||
add_buffer_ops = PatternMatcher([
|
||||
# LOAD
|
||||
(UPat(Ops.BUFFER, name="x"), lambda ctx,x: UOp.load(UOp(Ops.DEFINE_GLOBAL, x.dtype.ptr(x.size), (), ctx.index(x)), x.st.to_uop())),
|
||||
(UPat(Ops.BUFFER, name="x"), lambda ctx,x: UOp.load(UOp(Ops.DEFINE_GLOBAL, x.dtype.ptr(x.size), (), ctx.index(x)).view(x.st),)),
|
||||
# STORE (except for meta ops)
|
||||
(UPat(Ops.SINK, src=(UPat(GroupOp.Meta, name="x"),)), lambda x:x),
|
||||
(UPat(Ops.SINK, src=UPat(GroupOp.All-{Ops.STORE}), name="sink"), lambda ctx,sink:
|
||||
UOp.sink(*[UOp.store(UOp(Ops.DEFINE_GLOBAL, (s:=x.base).dtype.ptr(ctx[i].size), (), i), s.st.to_uop(), s) for i,x in enumerate(sink.src)])),
|
||||
UOp.sink(*[UOp.store(UOp(Ops.DEFINE_GLOBAL, (s:=x.base).dtype.ptr(ctx[i].size), (), i).view(s.st), s) for i,x in enumerate(sink.src)])),
|
||||
# passthrough ASSIGN
|
||||
(UPat(Ops.ASSIGN, name="x"), lambda x: x.src[1]),
|
||||
# VALID
|
||||
@@ -387,10 +387,10 @@ fix_kernel_ops = PatternMatcher([
|
||||
# remove CONTIGUOUS/DEVICE from kernel AST
|
||||
(UPat((Ops.CONTIGUOUS, Ops.MSELECT), src=(UPat.var("x"),)), lambda x: x),
|
||||
(UPat(Ops.VIEW, src=(UPat(Ops.DEVICE),), name="view"), lambda view: view.replace(src=())),
|
||||
# no ImageDType after load
|
||||
(UPat(GroupOp.All-{Ops.DEFINE_GLOBAL}, name="x"), lambda x: x.replace(dtype=x.dtype.base) if isinstance(x.dtype, ImageDType) else None),
|
||||
# no ImageDType after index
|
||||
(UPat(GroupOp.All-{Ops.DEFINE_GLOBAL, Ops.VIEW}, name="x"), lambda x: x.replace(dtype=x.dtype.base) if isinstance(x.dtype, ImageDType) else None),
|
||||
# if this kernel also assigns to the loaded buffer, ensure we can index it correctly
|
||||
(UPat(Ops.LOAD, src=(UPat.var("glbl"), UPat.var("view"))), check_load_st),
|
||||
(UPat(Ops.LOAD, src=(UPat.var("glbl").view(name="view"),)), check_load_st),
|
||||
])
|
||||
|
||||
replace_globals = PatternMatcher([
|
||||
|
||||
@@ -92,7 +92,7 @@ def _ensure_buffer_alloc(bufs:list[Buffer]) -> list[Buffer]: return [buf.ensure_
|
||||
def bufs_from_lin(lin:Kernel, allocate:bool=True) -> list[Buffer]:
|
||||
bufsts: defaultdict[int, list[UOp]] = defaultdict(list)
|
||||
for x in lin.bufs:
|
||||
if x.src[0].op is Ops.DEFINE_GLOBAL: bufsts[x.src[0].arg].append(x)
|
||||
if x.src[0].base.op is Ops.DEFINE_GLOBAL: bufsts[x.src[0].base.arg].append(x)
|
||||
# TODO: Nones are staying in here if buffers are optimized out!
|
||||
# TODO: add a test for this
|
||||
rawbufs: list[Optional[Buffer]] = [None]*(max(bufsts)+1)
|
||||
|
||||
@@ -175,7 +175,7 @@ class Profiling(contextlib.ContextDecorator):
|
||||
cache_dir: str = os.path.join(getenv("XDG_CACHE_HOME", os.path.expanduser("~/Library/Caches" if OSX else "~/.cache")), "tinygrad")
|
||||
CACHEDB: str = getenv("CACHEDB", os.path.abspath(os.path.join(cache_dir, "cache.db")))
|
||||
|
||||
VERSION = 19
|
||||
VERSION = 20
|
||||
_db_connection = None
|
||||
def db_connection():
|
||||
global _db_connection
|
||||
|
||||
@@ -130,12 +130,12 @@ spec = PatternMatcher([
|
||||
(UPat(Ops.VALID, dtypes.bool, (UPat(Ops.VIEW),)), lambda: True),
|
||||
(UPat(Ops.CONST, name="x"), lambda x: type(x.arg) is type(dtypes.as_const(x.arg, x.dtype))),
|
||||
|
||||
# early LOAD has a <buf, shapetracker, store?>
|
||||
(UPat(Ops.LOAD, src=(UPat((Ops.DEFINE_GLOBAL, Ops.DEFINE_LOCAL)), UPat(Ops.VIEW))), lambda: True),
|
||||
(UPat(Ops.LOAD, src=(UPat((Ops.DEFINE_GLOBAL, Ops.DEFINE_LOCAL)), UPat(Ops.VIEW), UPat(Ops.STORE))), lambda: True),
|
||||
# early LOAD has a <bufview, store?>
|
||||
(UPat(Ops.LOAD, src=(UPat(Ops.VIEW, src=(UPat((Ops.DEFINE_GLOBAL, Ops.DEFINE_LOCAL)),)),)), lambda: True),
|
||||
(UPat(Ops.LOAD, src=(UPat(Ops.VIEW, src=(UPat((Ops.DEFINE_GLOBAL, Ops.DEFINE_LOCAL)),)), UPat(Ops.STORE))), lambda: True),
|
||||
|
||||
# early STORE has a <buf, shapetracker, val>
|
||||
(UPat(Ops.STORE, src=(UPat((Ops.DEFINE_GLOBAL, Ops.DEFINE_LOCAL)), UPat(Ops.VIEW), UPat())), lambda: True),
|
||||
# early STORE has a <bufview, val>
|
||||
(UPat(Ops.STORE, src=(UPat(Ops.VIEW, src=(UPat((Ops.DEFINE_GLOBAL, Ops.DEFINE_LOCAL)),)), UPat())), lambda: True),
|
||||
|
||||
# **** new style load/store ****
|
||||
|
||||
@@ -204,6 +204,7 @@ ast_spec = PatternMatcher([
|
||||
# shapes must have either 1 or n in each dimension
|
||||
(UPat(Ops.SINK, src=UPat(Ops.STORE), name="sink"), verify_sink_dims),
|
||||
# VIEW can only exist in the edges
|
||||
(UPat(Ops.VIEW, src=(UPat((Ops.DEFINE_GLOBAL, Ops.DEFINE_LOCAL),))), lambda: True),
|
||||
(UPat(Ops.VIEW, name="view"), lambda view: len(view.src) == 0),
|
||||
# all parent UOps must have the same shape
|
||||
(UPat(GroupOp.All-{Ops.SINK}, name="root"), lambda root: all_same([x.shape for x in root.src if x.st is not None])),
|
||||
|
||||
Reference in New Issue
Block a user