mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
migrate test_linearizer.py to UOp AST, pt. 2 (#6228)
This commit is contained in:
@@ -8,7 +8,7 @@ from tinygrad.codegen.kernel import Opt, OptOps, KernelOptError, Kernel
|
||||
from tinygrad.codegen.lowerer import get_grouped_dims
|
||||
from tinygrad.ops import UOp, UOps
|
||||
from tinygrad.device import Device, Buffer
|
||||
from extra.ops import BinaryOps, BufferOps, MemBuffer, ConstBuffer, LazyOp, MetaOps, TernaryOps, ReduceOps, UnaryOps, to_uop
|
||||
from extra.ops import BinaryOps, BufferOps, TernaryOps, ReduceOps, UnaryOps
|
||||
from tinygrad.shape.shapetracker import ShapeTracker
|
||||
from tinygrad.shape.view import View
|
||||
# from tinygrad.shape.symbolic import Variable
|
||||
@@ -283,12 +283,14 @@ class TestLinearizer(unittest.TestCase):
|
||||
# check how it works with one reduce optimized and one unoptimized
|
||||
Tensor.manual_seed(0)
|
||||
x = Tensor.randn(27, 15, 5, dtype=dtypes.float).softmax(1).realize()
|
||||
first_x = LazyOp(BufferOps.LOAD, (), MemBuffer(1, dtypes.float, x.lazydata.st.reshape((27, 1, 15, 5)).expand((27, 15, 15, 5))))
|
||||
first_reduce = LazyOp(ReduceOps.SUM, (first_x,), (2,))
|
||||
second_x = LazyOp(BufferOps.LOAD, (), MemBuffer(1, dtypes.float, x.lazydata.st.reshape((27, 15, 1, 5))))
|
||||
diff = (second_x-first_reduce)
|
||||
second_reduce = LazyOp(ReduceOps.SUM, (diff,), (1,))
|
||||
store = LazyOp(BufferOps.STORE, (second_reduce,), MemBuffer(0, dtypes.float, ShapeTracker.from_shape((27, 1, 1, 5))))
|
||||
g0, g1 = [UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), arg=i) for i in range(2)]
|
||||
first_x = UOp(UOps.LOAD, dtypes.float, (g1, x.lazydata.st.reshape((27, 1, 15, 5)).expand((27, 15, 15, 5)).to_uop()))
|
||||
first_reduce = UOp(UOps.REDUCE_AXIS, dtypes.float, (first_x,), (ReduceOps.SUM, (2,)))
|
||||
second_x = UOp(UOps.LOAD, dtypes.float, (g1, x.lazydata.st.reshape((27, 15, 1, 5)).to_uop()))
|
||||
diff = (second_x+first_reduce*ast_const(dtypes.float, -1, (27, 15, 1, 5)))
|
||||
second_reduce = UOp(UOps.REDUCE_AXIS, dtypes.float, (diff,), (ReduceOps.SUM, (1,)))
|
||||
store = UOp(UOps.STORE, src=(g0, ShapeTracker.from_shape((27, 1, 1, 5)).to_uop(), second_reduce))
|
||||
sink = UOp(UOps.SINK, src=(store,))
|
||||
opts = [
|
||||
[Opt(OptOps.GROUPTOP, 0, 3)], # grouping
|
||||
[Opt(OptOps.GROUPTOP, 1, 3)],
|
||||
@@ -298,7 +300,7 @@ class TestLinearizer(unittest.TestCase):
|
||||
[Opt(OptOps.UNROLL, 1, 3)],
|
||||
]
|
||||
wanna_output = (x.numpy()-x.numpy().sum(axis=1, keepdims=True)).sum(axis=1).reshape(27,1,1,5)
|
||||
lins = helper_linearizer_ast((store, ), [x], wanna_output=[wanna_output], opts=opts)
|
||||
lins = helper_linearizer_ast(sink, [x], wanna_output=[wanna_output], opts=opts)
|
||||
for l in lins:
|
||||
ranges = [u.op for u in l.uops if (u.op is UOps.RANGE and u.arg[1]) or (u.op is UOps.ENDRANGE and u.src[0].arg[1])]
|
||||
for i,u in enumerate(ranges):
|
||||
@@ -371,17 +373,18 @@ class TestLinearizer(unittest.TestCase):
|
||||
# if we change the shape of store1 to be contiguous, it will match store0 but not the value it's storing (FAIL!)
|
||||
Tensor.manual_seed(0)
|
||||
x = Tensor.randn(27, 15, 5, dtype=dtypes.float).realize()
|
||||
first_x = LazyOp(BufferOps.LOAD, (), MemBuffer(2, dtypes.float, x.lazydata.st.reshape((27, 1, 15, 5)).expand((27, 15, 15, 5))))
|
||||
first_reduce = LazyOp(ReduceOps.SUM, (first_x,), (2,))
|
||||
second_x = LazyOp(BufferOps.LOAD, (), MemBuffer(2, dtypes.float, x.lazydata.st.reshape((27, 15, 1, 5))))
|
||||
diff = (second_x-first_reduce)
|
||||
second_reduce = LazyOp(ReduceOps.SUM, (diff,), (1,))
|
||||
store0 = LazyOp(BufferOps.STORE, (second_reduce,), MemBuffer(0, dtypes.float, ShapeTracker.from_shape((27, 1, 1, 5))))
|
||||
store1 = LazyOp(BufferOps.STORE, (first_reduce,), MemBuffer(1, dtypes.float, ShapeTracker(views=(View(shape=(27,15,1,5), strides=(5,0,1,1), offset=0, mask=None, contiguous=False),)))) # noqa: E501
|
||||
g0, g1, g2 = [UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), arg=i) for i in range(3)]
|
||||
first_x = UOp(BufferOps.LOAD, dtypes.float, (g2, x.lazydata.st.reshape((27, 1, 15, 5)).expand((27, 15, 15, 5)).to_uop()))
|
||||
first_reduce = UOp(UOps.REDUCE_AXIS, dtypes.float, (first_x,), (ReduceOps.SUM, (2,)))
|
||||
second_x = UOp(UOps.LOAD, dtypes.float, (g2, x.lazydata.st.reshape((27, 15, 1, 5)).to_uop()))
|
||||
diff = (second_x+first_reduce*ast_const(dtypes.float, -1, (27, 15, 1, 5)))
|
||||
second_reduce = UOp(UOps.REDUCE_AXIS, dtypes.float, (diff,), (ReduceOps.SUM, (1,)))
|
||||
store0 = UOp(UOps.STORE, src=(g0, ShapeTracker.from_shape((27, 1, 1, 5)).to_uop(), second_reduce))
|
||||
store1 = UOp(UOps.STORE, src=(g1, ShapeTracker(views=(View(shape=(27,15,1,5), strides=(5,0,1,1), offset=0, mask=None, contiguous=False),)).to_uop(), first_reduce)) # noqa: E501
|
||||
wanna_output0 = (x.numpy()-x.numpy().sum(axis=1, keepdims=True)).sum(axis=1).reshape(27,1,1,5)
|
||||
wanna_output1 = x.numpy().sum(axis=1).reshape(27,1,1,5)
|
||||
|
||||
ast = LazyOp(MetaOps.KERNEL, (store0,store1))
|
||||
ast = UOp(UOps.SINK, src=(store0, store1))
|
||||
k = Kernel(ast)
|
||||
prg = CompiledRunner(replace(k.to_program(), dname=Device.DEFAULT))
|
||||
inbufs = [x.lazydata.base.buffer]
|
||||
@@ -394,29 +397,33 @@ class TestLinearizer(unittest.TestCase):
|
||||
def test_complete_unroll_multireduce(self):
|
||||
Tensor.manual_seed(0)
|
||||
x = Tensor.randn(27, 3, 5, dtype=dtypes.float).realize()
|
||||
first_x = LazyOp(BufferOps.LOAD, (), MemBuffer(1, dtypes.float, x.lazydata.st.reshape((27, 1, 3, 5)).expand((27, 3, 3, 5))))
|
||||
first_reduce = LazyOp(ReduceOps.SUM, (first_x,), (2,))
|
||||
second_x = LazyOp(BufferOps.LOAD, (), MemBuffer(1, dtypes.float, x.lazydata.st.reshape((27, 3, 1, 5))))
|
||||
diff = (second_x-first_reduce)
|
||||
second_reduce = LazyOp(ReduceOps.SUM, (diff,), (1,))
|
||||
store = LazyOp(BufferOps.STORE, (second_reduce,), MemBuffer(0, dtypes.float, ShapeTracker.from_shape((27, 1, 1, 5))))
|
||||
g0, g1 = [UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), arg=i) for i in range(2)]
|
||||
first_x = UOp(UOps.LOAD, dtypes.float, (g1, x.lazydata.st.reshape((27, 1, 3, 5)).expand((27, 3, 3, 5)).to_uop()))
|
||||
first_reduce = UOp(UOps.REDUCE_AXIS, dtypes.float, (first_x,), (ReduceOps.SUM, (2,)))
|
||||
second_x = UOp(UOps.LOAD, dtypes.float, (g1, x.lazydata.st.reshape((27, 3, 1, 5)).to_uop()))
|
||||
diff = (second_x+first_reduce*ast_const(dtypes.float, -1, (27, 3, 1, 5)))
|
||||
second_reduce = UOp(UOps.REDUCE_AXIS, dtypes.float, (diff,), (ReduceOps.SUM, (1,)))
|
||||
store = UOp(UOps.STORE, src=(g0, ShapeTracker.from_shape((27, 1, 1, 5)).to_uop(), second_reduce))
|
||||
sink = UOp(UOps.SINK, src=(store,))
|
||||
opts = [[Opt(OptOps.UNROLL, 0, 3), Opt(OptOps.UNROLL, 0, 3)]]
|
||||
wanna_output = (x.numpy()-x.numpy().sum(axis=1, keepdims=True)).sum(axis=1).reshape(27,1,1,5)
|
||||
helper_linearizer_ast((store, ), [x], wanna_output=[wanna_output], opts=opts)
|
||||
helper_linearizer_ast(sink, [x], wanna_output=[wanna_output], opts=opts)
|
||||
|
||||
@unittest.skipIf(CI and Device.DEFAULT in {"AMD"}, "AMD CI doesn't support multiple sync threads yet")
|
||||
def test_upcast_multireduce(self):
|
||||
Tensor.manual_seed(0)
|
||||
x = Tensor.randn(27, 3, 5, dtype=dtypes.float).realize()
|
||||
first_x = LazyOp(BufferOps.LOAD, (), MemBuffer(1, dtypes.float, x.lazydata.st.reshape((27, 1, 3, 5)).expand((27, 3, 3, 5))))
|
||||
first_reduce = LazyOp(ReduceOps.SUM, (first_x,), (2,))
|
||||
second_x = LazyOp(BufferOps.LOAD, (), MemBuffer(1, dtypes.float, x.lazydata.st.reshape((27, 3, 1, 5))))
|
||||
diff = (second_x-first_reduce)
|
||||
second_reduce = LazyOp(ReduceOps.SUM, (diff,), (1,))
|
||||
store = LazyOp(BufferOps.STORE, (second_reduce,), MemBuffer(0, dtypes.float, ShapeTracker.from_shape((27, 1, 1, 5))))
|
||||
g0, g1 = [UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), arg=i) for i in range(2)]
|
||||
first_x = UOp(UOps.LOAD, dtypes.float, (g1, x.lazydata.st.reshape((27, 1, 3, 5)).expand((27, 3, 3, 5)).to_uop()))
|
||||
first_reduce = UOp(UOps.REDUCE_AXIS, dtypes.float, (first_x,), (ReduceOps.SUM, (2,)))
|
||||
second_x = UOp(UOps.LOAD, dtypes.float, (g1, x.lazydata.st.reshape((27, 3, 1, 5)).to_uop()))
|
||||
diff = (second_x+first_reduce*ast_const(dtypes.float, -1, (27, 3, 1, 5)))
|
||||
second_reduce = UOp(UOps.REDUCE_AXIS, dtypes.float, (diff,), (ReduceOps.SUM, (1,)))
|
||||
store = UOp(UOps.STORE, src=(g0, ShapeTracker.from_shape((27, 1, 1, 5)).to_uop(), second_reduce))
|
||||
sink = UOp(UOps.SINK, src=(store,))
|
||||
opts = [[Opt(OptOps.UPCAST, 0, 3)]]
|
||||
wanna_output = (x.numpy()-x.numpy().sum(axis=1, keepdims=True)).sum(axis=1).reshape(27,1,1,5)
|
||||
helper_linearizer_ast((store, ), [x], wanna_output=[wanna_output], opts=opts)
|
||||
helper_linearizer_ast(sink, [x], wanna_output=[wanna_output], opts=opts)
|
||||
|
||||
@unittest.skipIf(CI and Device.DEFAULT in {"AMD"}, "AMD CI doesn't support multiple sync threads yet")
|
||||
@unittest.skipUnless(Device[Device.DEFAULT].renderer.has_local, "test requires locals")
|
||||
@@ -425,67 +432,75 @@ class TestLinearizer(unittest.TestCase):
|
||||
# make sure the if block of a grouped reduce can be closed early and the result loaded back in
|
||||
Tensor.manual_seed(0)
|
||||
x = Tensor.randn(27, 12, 5, dtype=dtypes.float).realize()
|
||||
first_x = LazyOp(BufferOps.LOAD, (), MemBuffer(1, dtypes.float, x.lazydata.st.reshape((27, 1, 12, 5)).expand((27, 12, 12, 5))))
|
||||
first_reduce = LazyOp(ReduceOps.SUM, (first_x,), (2,))
|
||||
second_x = LazyOp(BufferOps.LOAD, (), MemBuffer(1, dtypes.float, x.lazydata.st.reshape((27, 12, 1, 5))))
|
||||
diff = (second_x-first_reduce)
|
||||
second_reduce = LazyOp(ReduceOps.SUM, (diff,), (1,))
|
||||
store = LazyOp(BufferOps.STORE, (second_reduce,), MemBuffer(0, dtypes.float, ShapeTracker.from_shape((27, 1, 1, 5))))
|
||||
g0, g1 = [UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), arg=i) for i in range(2)]
|
||||
first_x = UOp(UOps.LOAD, dtypes.float, (g1, x.lazydata.st.reshape((27, 1, 12, 5)).expand((27, 12, 12, 5)).to_uop()))
|
||||
first_reduce = UOp(UOps.REDUCE_AXIS, dtypes.float, (first_x,), (ReduceOps.SUM, (2,)))
|
||||
second_x = UOp(UOps.LOAD, dtypes.float, (g1, x.lazydata.st.reshape((27, 12, 1, 5)).to_uop()))
|
||||
diff = (second_x+first_reduce*ast_const(dtypes.float, -1, (27, 12, 1, 5)))
|
||||
second_reduce = UOp(UOps.REDUCE_AXIS, dtypes.float, (diff,), (ReduceOps.SUM, (1,)))
|
||||
store = UOp(UOps.STORE, src=(g0, ShapeTracker.from_shape((27, 1, 1, 5)).to_uop(), second_reduce))
|
||||
sink = UOp(UOps.SINK, src=(store,))
|
||||
opts = [[Opt(OptOps.GROUPTOP, 0, 3), Opt(OptOps.GROUPTOP, 1, 3)]]
|
||||
wanna_output = (x.numpy()-x.numpy().sum(axis=1, keepdims=True)).sum(axis=1).reshape(27,1,1,5)
|
||||
helper_linearizer_ast((store, ), [x], wanna_output=[wanna_output], opts=opts)
|
||||
helper_linearizer_ast(sink, [x], wanna_output=[wanna_output], opts=opts)
|
||||
|
||||
@unittest.skipIf(CI and Device.DEFAULT in {"AMD"}, "AMD CI doesn't support multiple sync threads yet")
|
||||
def test_mean_std_multireduce(self):
|
||||
Tensor.manual_seed(0)
|
||||
x = Tensor.randn(15, 25, 35, dtype=dtypes.float).realize()
|
||||
first_x = LazyOp(BufferOps.LOAD, (), MemBuffer(1, dtypes.float, x.lazydata.st.reshape((15, 25, 1, 35)).expand((15, 25, 35, 35))))
|
||||
first_reduce = LazyOp(ReduceOps.SUM, (first_x,), (3,))
|
||||
mean = first_reduce * LazyOp(BufferOps.CONST, (), ConstBuffer(1/35, dtypes.float, ShapeTracker.from_shape(()).reshape((1, 1, 1, 1)).expand((15, 25, 35, 1)))) # noqa: E501
|
||||
second_x = LazyOp(BufferOps.LOAD, (), MemBuffer(1, dtypes.float, x.lazydata.st.reshape((15, 25, 35, 1))))
|
||||
squares = (second_x-mean)*(second_x-mean)
|
||||
squares_sum = LazyOp(ReduceOps.SUM, (squares,), (2,))
|
||||
variance = squares_sum * LazyOp(BufferOps.CONST, (), ConstBuffer(1/35, dtypes.float, ShapeTracker.from_shape(()).reshape((1, 1, 1, 1)).expand((15, 25, 1, 1)))) # noqa: E501
|
||||
std = LazyOp(UnaryOps.SQRT, (variance,), None)
|
||||
store = LazyOp(BufferOps.STORE, (std,), MemBuffer(0, dtypes.float, ShapeTracker.from_shape((15, 25, 1, 1))))
|
||||
g0, g1 = [UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), arg=i) for i in range(2)]
|
||||
first_x = UOp(UOps.LOAD, dtypes.float, (g1, x.lazydata.st.reshape((15, 25, 1, 35)).expand((15, 25, 35, 35)).to_uop()))
|
||||
first_reduce = UOp(UOps.REDUCE_AXIS, dtypes.float, (first_x,), (ReduceOps.SUM, (3,)))
|
||||
neg_mean = first_reduce * ast_const(dtypes.float, -1/35, (15, 25, 35, 1))
|
||||
second_x = UOp(UOps.LOAD, dtypes.float, (g1, x.lazydata.st.reshape((15, 25, 35, 1)).to_uop()))
|
||||
squares = (second_x+neg_mean)*(second_x+neg_mean)
|
||||
squares_sum = UOp(UOps.REDUCE_AXIS, dtypes.float, (squares,), (ReduceOps.SUM, (2,)))
|
||||
variance = squares_sum * ast_const(dtypes.float, 1/35, (15, 25, 1, 1))
|
||||
std = variance.alu(UnaryOps.SQRT)
|
||||
store = UOp(UOps.STORE, src=(g0, ShapeTracker.from_shape((15, 25, 1, 1)).to_uop(), std))
|
||||
sink = UOp(UOps.SINK, src=(store,))
|
||||
wanna_output = x.numpy().std(axis=2, ddof=0).reshape((15,25,1,1))
|
||||
helper_linearizer_ast((store,), [x], wanna_output=[wanna_output])
|
||||
helper_linearizer_ast(sink, [x], wanna_output=[wanna_output])
|
||||
|
||||
@unittest.skipIf(CI and Device.DEFAULT in {"AMD"}, "AMD CI doesn't support multiple sync threads yet")
|
||||
def test_mean_std_multireduce_mid_dim(self):
|
||||
Tensor.manual_seed(0)
|
||||
x = Tensor.randn(15, 25, 35, dtype=dtypes.float).realize()
|
||||
first_x = LazyOp(BufferOps.LOAD, (), MemBuffer(1, dtypes.float, x.lazydata.st.reshape((15, 1, 25, 35)).expand((15, 25, 25, 35))))
|
||||
first_reduce = LazyOp(ReduceOps.SUM, (first_x,), (2,))
|
||||
mean = first_reduce * LazyOp(BufferOps.CONST, (), ConstBuffer(0.04, dtypes.float, ShapeTracker.from_shape(()).reshape((1, 1, 1, 1)).expand((15, 25, 1, 35)))) # noqa: E501
|
||||
second_x = LazyOp(BufferOps.LOAD, (), MemBuffer(1, dtypes.float, x.lazydata.st.reshape((15, 25, 1, 35))))
|
||||
squares = (second_x-mean)*(second_x-mean)
|
||||
squares_sum = LazyOp(ReduceOps.SUM, (squares,), (1,))
|
||||
variance = squares_sum * LazyOp(BufferOps.CONST, (), ConstBuffer(0.04, dtypes.float, ShapeTracker.from_shape(()).reshape((1, 1, 1, 1)).expand((15, 1, 1, 35)))) # noqa: E501
|
||||
std = LazyOp(UnaryOps.SQRT, (variance,), None)
|
||||
store = LazyOp(BufferOps.STORE, (std,), MemBuffer(0, dtypes.float, ShapeTracker.from_shape((15, 1, 1, 35))))
|
||||
g0, g1 = [UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), arg=i) for i in range(2)]
|
||||
first_x = UOp(UOps.LOAD, dtypes.float, (g1, x.lazydata.st.reshape((15, 1, 25, 35)).expand((15, 25, 25, 35)).to_uop()))
|
||||
first_reduce = UOp(UOps.REDUCE_AXIS, dtypes.float, (first_x,), (ReduceOps.SUM, (2,)))
|
||||
neg_mean = first_reduce * ast_const(dtypes.float, -0.04, (15, 25, 1, 35))
|
||||
second_x = UOp(UOps.LOAD, dtypes.float, (g1, x.lazydata.st.reshape((15, 25, 1, 35)).to_uop()))
|
||||
squares = (second_x+neg_mean)*(second_x+neg_mean)
|
||||
squares_sum = UOp(UOps.REDUCE_AXIS, dtypes.float, (squares,), (ReduceOps.SUM, (1,)))
|
||||
variance = squares_sum * ast_const(dtypes.float, 0.04, (15, 1, 1, 35))
|
||||
std = variance.alu(UnaryOps.SQRT)
|
||||
store = UOp(UOps.STORE, src=(g0, ShapeTracker.from_shape((15, 1, 1, 35)).to_uop(), std))
|
||||
sink = UOp(UOps.SINK, src=(store,))
|
||||
wanna_output = x.numpy().std(axis=1, ddof=0).reshape((15,1,1,35))
|
||||
helper_linearizer_ast((store,), [x], wanna_output=[wanna_output])
|
||||
helper_linearizer_ast(sink, [x], wanna_output=[wanna_output])
|
||||
|
||||
@unittest.skipIf(CI and Device.DEFAULT in {"AMD"}, "AMD CI doesn't support multiple sync threads yet")
|
||||
@unittest.expectedFailure
|
||||
def test_mean_std_multireduce_multiout(self):
|
||||
# TODO: Same error as in test_multiout_intermediate_multireduce
|
||||
# TODO: Similar error to test_multiout_intermediate_multireduce (implicit expand vs shape mismatch)
|
||||
Tensor.manual_seed(0)
|
||||
x = Tensor.randn(15, 25, 35, dtype=dtypes.float).realize()
|
||||
first_x = LazyOp(BufferOps.LOAD, (), MemBuffer(2, dtypes.float, x.lazydata.st.reshape((15, 25, 1, 35)).expand((15, 25, 35, 35))))
|
||||
first_reduce = LazyOp(ReduceOps.SUM, (first_x,), (3,))
|
||||
mean = first_reduce * LazyOp(BufferOps.CONST, (), ConstBuffer(1/35, dtypes.float, ShapeTracker.from_shape(()).reshape((1, 1, 1, 1)).expand((15, 25, 35, 1)))) # noqa: E501
|
||||
second_x = LazyOp(BufferOps.LOAD, (), MemBuffer(2, dtypes.float, x.lazydata.st.reshape((15, 25, 35, 1))))
|
||||
squares = (second_x-mean)*(second_x-mean)
|
||||
squares_sum = LazyOp(ReduceOps.SUM, (squares,), (2,))
|
||||
variance = squares_sum * LazyOp(BufferOps.CONST, (), ConstBuffer(1/35, dtypes.float, ShapeTracker.from_shape(()).reshape((1, 1, 1, 1)).expand((15, 25, 1, 1)))) # noqa: E501
|
||||
std = LazyOp(UnaryOps.SQRT, (variance,), None)
|
||||
store_mean = LazyOp(BufferOps.STORE, (mean,), MemBuffer(1, dtypes.float, ShapeTracker.from_shape((15,25,1,1))))
|
||||
store_std = LazyOp(BufferOps.STORE, (std,), MemBuffer(0, dtypes.float, ShapeTracker.from_shape((15, 25, 1, 1))))
|
||||
g0, g1, g2 = [UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), arg=i) for i in range(3)]
|
||||
first_x = UOp(UOps.LOAD, dtypes.float, (g2, x.lazydata.st.reshape((15, 25, 1, 35)).expand((15, 25, 35, 35)).to_uop()))
|
||||
first_reduce = UOp(UOps.REDUCE_AXIS, dtypes.float, (first_x,), (ReduceOps.SUM, (3,)))
|
||||
neg_mean = first_reduce * ast_const(dtypes.float, -1/35, (15, 25, 35, 1))
|
||||
second_x = UOp(UOps.LOAD, dtypes.float, (g2, x.lazydata.st.reshape((15, 25, 35, 1)).to_uop()))
|
||||
squares = (second_x+neg_mean)*(second_x+neg_mean)
|
||||
squares_sum = UOp(UOps.REDUCE_AXIS, dtypes.float, (squares,), (ReduceOps.SUM, (2,)))
|
||||
variance = squares_sum * ast_const(dtypes.float, 1/35, (15, 25, 1, 1))
|
||||
std = variance.alu(UnaryOps.SQRT)
|
||||
store_mean = UOp(UOps.STORE, src=(g1, ShapeTracker.from_shape((15, 25, 1, 1)).to_uop(), neg_mean))
|
||||
store_std = UOp(UOps.STORE, src=(g0, ShapeTracker.from_shape((15, 25, 1, 1)).to_uop(), std))
|
||||
sink = UOp(UOps.SINK, src=(store_std, store_mean))
|
||||
wanna_output = [x.numpy().std(axis=2, ddof=0).reshape(15,25,1,1), x.numpy().mean(axis=2).reshape(15,25,1,1)]
|
||||
|
||||
lins = helper_linearizer_ast((store_std,store_mean), [x], wanna_output=wanna_output)
|
||||
lins = helper_linearizer_ast(sink, [x], wanna_output=wanna_output)
|
||||
for k in lins:
|
||||
assert len([u for u in k.uops if u.op is UOps.DEFINE_ACC]) == 2, "got more than two accs (implies the kernel didn't reuse the mean reduce)"
|
||||
|
||||
@@ -493,19 +508,21 @@ class TestLinearizer(unittest.TestCase):
|
||||
def test_var_multireduce(self):
|
||||
Tensor.manual_seed(0)
|
||||
x = Tensor.randn(3, 27, 32, dtype=dtypes.float).realize()
|
||||
g0, g1 = [UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), arg=i) for i in range(2)]
|
||||
# push reduce (3, 27, 32) -> (3, 27, 1) -> (3, 27, 32) expand to LOAD
|
||||
first_x = LazyOp(BufferOps.LOAD, (), MemBuffer(1, dtypes.float, x.lazydata.st.reshape((3, 27, 1, 32)).expand((3, 27, 32, 32))))
|
||||
first_reduce = LazyOp(ReduceOps.SUM, (first_x,), (3,))
|
||||
mean = first_reduce * LazyOp(BufferOps.CONST, (), ConstBuffer(0.03125, dtypes.float, ShapeTracker.from_shape(()).reshape((1, 1, 1, 1)).expand((3, 27, 32, 1)))) # noqa: E501
|
||||
# store = LazyOp(BufferOps.STORE, (mean,), MemBuffer(0, dtypes.float, ShapeTracker.from_shape((3, 27, 32, 1))))
|
||||
first_x = UOp(UOps.LOAD, dtypes.float, (g1, x.lazydata.st.reshape((3, 27, 1, 32)).expand((3, 27, 32, 32)).to_uop()))
|
||||
first_reduce = UOp(UOps.REDUCE_AXIS, dtypes.float, (first_x,), (ReduceOps.SUM, (3,)))
|
||||
neg_mean = first_reduce * ast_const(dtypes.float, -0.03125, (3, 27, 32, 1))
|
||||
# store = UOp(UOps.STORE, src=(g0, ShapeTracker.from_shape((3, 27, 32, 1)).to_uop(), mean))
|
||||
# verify_lazyop(store)
|
||||
second_x = LazyOp(BufferOps.LOAD, (), MemBuffer(1, dtypes.float, x.lazydata.st.reshape((3, 27, 32, 1))))
|
||||
squares = (second_x-mean)*(second_x-mean)
|
||||
squares_sum = LazyOp(ReduceOps.SUM, (squares,), (2,))
|
||||
variance = squares_sum * LazyOp(BufferOps.CONST, (), ConstBuffer(0.03125, dtypes.float, ShapeTracker.from_shape(()).reshape((1, 1, 1, 1)).expand((3, 27, 1, 1)))) # noqa: E501
|
||||
store = LazyOp(BufferOps.STORE, (variance,), MemBuffer(0, dtypes.float, ShapeTracker.from_shape((3, 27, 1, 1))))
|
||||
second_x = UOp(UOps.LOAD, dtypes.float, (g1, x.lazydata.st.reshape((3, 27, 32, 1)).to_uop()))
|
||||
squares = (second_x+neg_mean)*(second_x+neg_mean)
|
||||
squares_sum = UOp(UOps.REDUCE_AXIS, dtypes.float, (squares,), (ReduceOps.SUM, (2,)))
|
||||
variance = squares_sum * ast_const(dtypes.float, 0.03125, (3, 27, 1, 1))
|
||||
store = UOp(UOps.STORE, src=(g0, ShapeTracker.from_shape((3, 27, 1, 1)).to_uop(), variance))
|
||||
sink = UOp(UOps.SINK, src=(store,))
|
||||
wanna_output = x.numpy().var(axis=2, ddof=0).reshape((3,27,1,1))
|
||||
helper_linearizer_ast((store, ), [x], wanna_output=[wanna_output])
|
||||
helper_linearizer_ast(sink, [x], wanna_output=[wanna_output])
|
||||
# tinygrad ref
|
||||
y_tiny = x.var(axis=2, correction=0).reshape(3,27,1,1)
|
||||
np.testing.assert_allclose(y_tiny.numpy(), wanna_output, atol=1e-4, rtol=1e-4)
|
||||
@@ -513,17 +530,19 @@ class TestLinearizer(unittest.TestCase):
|
||||
@unittest.skipIf(CI and Device.DEFAULT in {"AMD"}, "AMD CI doesn't support multiple sync threads yet")
|
||||
def test_softmax_multireduce(self):
|
||||
x = Tensor.rand(4, 32).realize()
|
||||
first_x = LazyOp(BufferOps.LOAD, (), MemBuffer(1, dtypes.float, x.lazydata.st.reshape((4, 1, 32,)).expand((4, 32, 32))))
|
||||
max_x = LazyOp(op=ReduceOps.MAX, src=(first_x,), arg=(2,))
|
||||
second_x = LazyOp(BufferOps.LOAD, (), MemBuffer(1, dtypes.float, x.lazydata.st.reshape((4, 32, 1,))))
|
||||
centered_x = LazyOp(op=BinaryOps.ADD, src=(second_x, LazyOp(op=UnaryOps.NEG, src=(max_x,), arg=None)))
|
||||
exp_x = LazyOp(op=UnaryOps.EXP2, src=(centered_x,))
|
||||
sum_exp_x = LazyOp(op=ReduceOps.SUM, src=(exp_x,), arg=(1,))
|
||||
# y = LazyOp(op=BinaryOps.MUL, src=(exp_x, LazyOp(op=UnaryOps.RECIP, src=(sum_exp_x,)))) # kernels cannot do a return to full shape
|
||||
recip_sum_exp_x = LazyOp(op=UnaryOps.RECIP, src=(sum_exp_x,))
|
||||
store = LazyOp(op=BufferOps.STORE, src=(recip_sum_exp_x,), arg=MemBuffer(idx=0, dtype=dtypes.float, st=ShapeTracker.from_shape((4,1,1))))
|
||||
g0, g1 = [UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), arg=i) for i in range(2)]
|
||||
first_x = UOp(UOps.LOAD, dtypes.float, (g1, x.lazydata.st.reshape((4, 1, 32,)).expand((4, 32, 32)).to_uop()))
|
||||
max_x = UOp(UOps.REDUCE_AXIS, dtypes.float, (first_x,), (ReduceOps.MAX, (2,)))
|
||||
second_x = UOp(UOps.LOAD, dtypes.float, (g1, x.lazydata.st.reshape((4, 32, 1,)).to_uop()))
|
||||
centered_x = second_x+max_x*ast_const(dtypes.float, -1, (4, 32, 1))
|
||||
exp_x = centered_x.alu(UnaryOps.EXP2)
|
||||
sum_exp_x = UOp(UOps.REDUCE_AXIS, dtypes.float, (exp_x,), (ReduceOps.SUM, (1,)))
|
||||
# y = exp_x * sum_exp_x.alu(UnaryOps.RECIP) # kernels cannot do a return to full shape
|
||||
recip_sum_exp_x = sum_exp_x.alu(UnaryOps.RECIP)
|
||||
store = UOp(UOps.STORE, src=(g0, ShapeTracker.from_shape((4,1,1)).to_uop(), recip_sum_exp_x))
|
||||
sink = UOp(UOps.SINK, src=(store,))
|
||||
expected = 1/np.exp2(x.numpy() - x.numpy().max(axis=-1, keepdims=True)).sum(axis=-1, keepdims=True).reshape(4,1,1)
|
||||
helper_linearizer_ast((store,), [x], wanna_output=[expected])
|
||||
helper_linearizer_ast(sink, [x], wanna_output=[expected])
|
||||
|
||||
# *** buildup to fused indexing
|
||||
@unittest.skipIf(CI, "very slow because of recomputing")
|
||||
@@ -537,84 +556,112 @@ class TestLinearizer(unittest.TestCase):
|
||||
View(shape=(16384, 16384), strides=(1, 32768), offset=0, mask=None, contiguous=False)))
|
||||
arange_input_st = arange_input_st.reshape((1, 16384, 1, 16384)).expand((4, 16384, 256, 16384))
|
||||
arange_axis = (3,)
|
||||
arange = LazyOp(ReduceOps.SUM, (LazyOp(BufferOps.CONST, (), ConstBuffer(1, dtypes.int, arange_input_st)), ), arange_axis)
|
||||
arange = UOp(UOps.REDUCE_AXIS, dtypes.int, (UOp(UOps.CONST, dtypes.int, (arange_input_st.to_uop(),), 1),), (ReduceOps.SUM, arange_axis))
|
||||
output_shape = tuple(1 if i in arange_axis else s for i,s in enumerate(arange_input_st.shape))
|
||||
out = arange-LazyOp.const(1, dtypes.int, output_shape)
|
||||
store = LazyOp(BufferOps.STORE, (out, ), MemBuffer(0, dtypes.int, st=ShapeTracker.from_shape(output_shape)))
|
||||
helper_linearizer_ast((store, ), [], wanna_output=[real_arange])
|
||||
out = arange+ast_const(dtypes.int, -1, output_shape)
|
||||
store = UOp(UOps.STORE, src=(UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.int), arg=0), ShapeTracker.from_shape(output_shape).to_uop(), out))
|
||||
sink = UOp(UOps.SINK, src=(store,))
|
||||
helper_linearizer_ast(sink, [], wanna_output=[real_arange])
|
||||
with Context(DEBUG=0, NOOPT=0): np.testing.assert_equal(tiny.numpy(), real_arange)
|
||||
|
||||
@unittest.skipIf(CI and Device.DEFAULT in {"PTX", "AMD", "NV"}, "very slow")
|
||||
def test_indexing_multireduce(self):
|
||||
g0, g1 = [UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), arg=i) for i in range(2)]
|
||||
g2 = UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.int), arg=2)
|
||||
arange_input_st = ShapeTracker(views=(View(shape=(16385, 32767), strides=(0, 0), offset=0, mask=((0, 16385), (16383, 32767)), contiguous=False), \
|
||||
View(shape=(16384, 16384), strides=(1, 32768), offset=0, mask=None, contiguous=False)))
|
||||
# TODO: do this arange broadcast in the scheduler
|
||||
arange_input_st = arange_input_st.reshape((1, 16384, 1, 16384)).expand((4, 16384, 256, 16384))
|
||||
arange_axis = (3,)
|
||||
arange = LazyOp(ReduceOps.SUM, (LazyOp(BufferOps.CONST, (), ConstBuffer(1, dtypes.int, arange_input_st)), ), arange_axis)
|
||||
arange = UOp(UOps.REDUCE_AXIS, dtypes.int, (UOp(UOps.CONST, dtypes.int, (arange_input_st.to_uop(),), 1),), (ReduceOps.SUM, arange_axis))
|
||||
arange_out_shape = tuple(1 if i in arange_axis else s for i,s in enumerate(arange_input_st.shape))
|
||||
arange = arange-LazyOp.const(1, dtypes.int, arange_out_shape)
|
||||
arange = arange+ast_const(dtypes.int, -1, arange_out_shape)
|
||||
# p2: the indexing
|
||||
dataset = Tensor.rand(16384, 256).realize()
|
||||
data1 = MemBuffer(1, dataset.dtype, ShapeTracker.from_shape(dataset.shape).reshape((1, 16384, 256, 1)).expand(arange_out_shape))
|
||||
data1 = (g1, ShapeTracker.from_shape(dataset.shape).reshape((1, 16384, 256, 1)).expand(arange_out_shape).to_uop())
|
||||
idxs = Tensor([0,3,5,6]).realize()
|
||||
data2 = MemBuffer(2, dtypes.int, ShapeTracker.from_shape((4,)+(1,)*(len(arange_out_shape)-1)).expand(arange_out_shape))
|
||||
reduce_input = LazyOp(BufferOps.LOAD, (), data1)*LazyOp(UnaryOps.CAST, (arange.eq(LazyOp(BufferOps.LOAD, (), data2)),), dataset.dtype)
|
||||
out = LazyOp(ReduceOps.SUM, (reduce_input, ), (1,))
|
||||
output_shape = tuple(1 if i in out.arg else s for i,s in enumerate(arange_out_shape))
|
||||
store = LazyOp(BufferOps.STORE, (out, ), MemBuffer(idx=0, dtype=dtypes.float, st=ShapeTracker.from_shape(output_shape)))
|
||||
data2 = (g2, ShapeTracker.from_shape((4,)+(1,)*(len(arange_out_shape)-1)).expand(arange_out_shape).to_uop())
|
||||
arange_eq = arange.alu(BinaryOps.CMPNE, UOp(UOps.LOAD, dtypes.int, data2)).alu(BinaryOps.CMPNE, ast_const(dtypes.bool, True, arange_out_shape))
|
||||
reduce_input = UOp(UOps.LOAD, dataset.dtype, data1)*UOp(UOps.CAST, dataset.dtype.scalar(), src=(arange_eq,))
|
||||
out_axis = (1,)
|
||||
out = UOp(UOps.REDUCE_AXIS, reduce_input.dtype, (reduce_input,), (ReduceOps.SUM, out_axis))
|
||||
output_shape = tuple(1 if i in out_axis else s for i,s in enumerate(arange_out_shape))
|
||||
store = UOp(UOps.STORE, src=(g0, ShapeTracker.from_shape(output_shape).to_uop(), out))
|
||||
sink = UOp(UOps.SINK, src=(store,))
|
||||
real_index = dataset.numpy()[idxs.numpy()].reshape(4, 1, 256, 1)
|
||||
helper_linearizer_ast((store, ), [dataset, idxs], wanna_output=[real_index])
|
||||
helper_linearizer_ast(sink, [dataset, idxs], wanna_output=[real_index])
|
||||
|
||||
# AssertionError: repeated stores in uops
|
||||
def test_argmax_multireduce_axis0(self):
|
||||
t = Tensor.randn(10, 20).realize()
|
||||
t_max = t.max((0,)).realize()
|
||||
real_argmax = np.argmax(t.numpy(), axis=0, keepdims=False).reshape(1, 20, 1)
|
||||
ast = LazyOp(MetaOps.KERNEL, arg=None, src=(
|
||||
LazyOp(BufferOps.STORE, arg=MemBuffer(idx=0, dtype=dtypes.int, st=ShapeTracker(views=(View(shape=(1, 20, 1), strides=(0, 1, 0), offset=0, mask=None, contiguous=True),))), src=( # noqa E501
|
||||
LazyOp(BinaryOps.ADD, arg=None, src=(
|
||||
LazyOp(BinaryOps.ADD, arg=None, src=(
|
||||
LazyOp(BufferOps.CONST, arg=ConstBuffer(val=10, dtype=dtypes.int, st=ShapeTracker(views=(View(shape=(1, 20, 1), strides=(0, 0, 0), offset=0, mask=None, contiguous=False),))), src=()), # noqa E501
|
||||
LazyOp(UnaryOps.NEG, arg=None, src=(
|
||||
LazyOp(ReduceOps.MAX, arg=(0,), src=(
|
||||
LazyOp(BinaryOps.MUL, arg=None, src=(
|
||||
LazyOp(UnaryOps.CAST, arg=dtypes.int, src=(
|
||||
LazyOp(BinaryOps.CMPNE, arg=None, src=(
|
||||
LazyOp(BinaryOps.CMPNE, arg=None, src=(
|
||||
LazyOp(BufferOps.LOAD, arg=MemBuffer(idx=1, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(10, 20, 1), strides=(20, 1, 0), offset=0, mask=None, contiguous=True),))), src=()), # noqa E501
|
||||
LazyOp(BufferOps.LOAD, arg=MemBuffer(idx=2, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(10, 20, 1), strides=(0, 1, 0), offset=0, mask=None, contiguous=False),))), src=()),)), # noqa E501
|
||||
LazyOp(BufferOps.CONST, arg=ConstBuffer(val=True, dtype=dtypes.bool, st=ShapeTracker(views=(View(shape=(10, 20, 1), strides=(0, 0, 0), offset=0, mask=None, contiguous=False),))), src=()),)),)), # noqa E501
|
||||
LazyOp(BinaryOps.ADD, arg=None, src=(
|
||||
LazyOp(ReduceOps.SUM, arg=(2,), src=(
|
||||
LazyOp(BufferOps.CONST, arg=ConstBuffer(val=-1, dtype=dtypes.int, st=ShapeTracker(views=(View(shape=(11, 19), strides=(0, 0), offset=0, mask=((0, 11), (9, 19)), contiguous=False), View(shape=(10, 20, 10), strides=(1, 0, 20), offset=0, mask=None, contiguous=False)))), src=()),)), # noqa E501
|
||||
LazyOp(BufferOps.CONST, arg=ConstBuffer(val=10, dtype=dtypes.int, st=ShapeTracker(views=(View(shape=(10, 20, 1), strides=(0, 0, 0), offset=0, mask=None, contiguous=False),))), src=()),)),)),)),)),)), # noqa E501
|
||||
LazyOp(BufferOps.CONST, arg=ConstBuffer(val=-1, dtype=dtypes.int, st=ShapeTracker(views=(View(shape=(1, 20, 1), strides=(0, 0, 0), offset=0, mask=None, contiguous=False),))), src=()),)),)),)) # noqa E501
|
||||
ast = UOp(UOps.SINK, None, arg=None, src=(
|
||||
UOp(UOps.STORE, None, arg=None, src=(
|
||||
UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.int), arg=0, src=()),
|
||||
UOp(UOps.SHAPETRACKER, None, arg=ShapeTracker(views=(View(shape=(1, 20, 1), strides=(0, 1, 0), offset=0, mask=None, contiguous=True),)), src=()), # noqa E501
|
||||
UOp(UOps.ALU, dtypes.int, arg=BinaryOps.ADD, src=(
|
||||
UOp(UOps.ALU, dtypes.int, arg=BinaryOps.ADD, src=(
|
||||
UOp(UOps.CONST, dtypes.int, arg=10, src=(
|
||||
UOp(UOps.SHAPETRACKER, None, arg=ShapeTracker(views=(View(shape=(1, 20, 1), strides=(0, 0, 0), offset=0, mask=None, contiguous=False),)), src=()),)), # noqa E501
|
||||
UOp(UOps.ALU, dtypes.int, arg=UnaryOps.NEG, src=(
|
||||
UOp(UOps.REDUCE_AXIS, dtypes.int, arg=(ReduceOps.MAX, (0,)), src=(
|
||||
UOp(UOps.ALU, dtypes.int, arg=BinaryOps.MUL, src=(
|
||||
UOp(UOps.CAST, dtypes.int, arg=None, src=(
|
||||
UOp(UOps.ALU, dtypes.bool, arg=BinaryOps.CMPNE, src=(
|
||||
UOp(UOps.ALU, dtypes.bool, arg=BinaryOps.CMPNE, src=(
|
||||
UOp(UOps.LOAD, dtypes.float, arg=None, src=(
|
||||
UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), arg=1, src=()),
|
||||
UOp(UOps.SHAPETRACKER, None, arg=ShapeTracker(views=(View(shape=(10, 20, 1), strides=(20, 1, 0), offset=0, mask=None, contiguous=True),)), src=()),)), # noqa E501
|
||||
UOp(UOps.LOAD, dtypes.float, arg=None, src=(
|
||||
UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), arg=2, src=()),
|
||||
UOp(UOps.SHAPETRACKER, None, arg=ShapeTracker(views=(View(shape=(10, 20, 1), strides=(0, 1, 0), offset=0, mask=None, contiguous=False),)), src=()),)),)), # noqa E501
|
||||
UOp(UOps.CONST, dtypes.bool, arg=True, src=(
|
||||
UOp(UOps.SHAPETRACKER, None, arg=ShapeTracker(views=(View(shape=(10, 20, 1), strides=(0, 0, 0), offset=0, mask=None, contiguous=False),)), src=()),)),)),)), # noqa E501
|
||||
UOp(UOps.ALU, dtypes.int, arg=BinaryOps.ADD, src=(
|
||||
UOp(UOps.REDUCE_AXIS, dtypes.int, arg=(ReduceOps.SUM, (2,)), src=(
|
||||
UOp(UOps.CONST, dtypes.int, arg=-1, src=(
|
||||
UOp(UOps.SHAPETRACKER, None, arg=ShapeTracker(views=(View(shape=(11, 19), strides=(0, 0), offset=0, mask=((0, 11), (9, 19)), contiguous=False), View(shape=(10, 20, 10), strides=(1, 0, 20), offset=0, mask=None, contiguous=False))), src=()),)),)), # noqa E501
|
||||
UOp(UOps.CONST, dtypes.int, arg=10, src=(
|
||||
UOp(UOps.SHAPETRACKER, None, arg=ShapeTracker(views=(View(shape=(10, 20, 1), strides=(0, 0, 0), offset=0, mask=None, contiguous=False),)), src=()),)),)),)),)),)),)), # noqa E501
|
||||
UOp(UOps.CONST, dtypes.int, arg=-1, src=(
|
||||
UOp(UOps.SHAPETRACKER, None, arg=ShapeTracker(views=(View(shape=(1, 20, 1), strides=(0, 0, 0), offset=0, mask=None, contiguous=False),)), src=()),)),)),)),)) # noqa E501
|
||||
helper_linearizer_ast(ast, [t, t_max], wanna_output=[real_argmax])
|
||||
|
||||
def test_argmax_multireduce_flat(self):
|
||||
t = Tensor.randn(10, 20).realize()
|
||||
t_max = t.max().realize()
|
||||
real_argmax = np.argmax(t.numpy())
|
||||
ast = LazyOp(MetaOps.KERNEL, arg=None, src=(
|
||||
LazyOp(BufferOps.STORE, arg=MemBuffer(idx=0, dtype=dtypes.int, st=ShapeTracker(views=(View(shape=(1, 1), strides=(0, 0), offset=0, mask=None, contiguous=True),))), src=( # noqa E501
|
||||
LazyOp(BinaryOps.ADD, arg=None, src=(
|
||||
LazyOp(BinaryOps.ADD, arg=None, src=(
|
||||
LazyOp(BufferOps.CONST, arg=ConstBuffer(val=200, dtype=dtypes.int, st=ShapeTracker(views=(View(shape=(1, 1), strides=(0, 0), offset=0, mask=None, contiguous=True),))), src=()), # noqa E501
|
||||
LazyOp(UnaryOps.NEG, arg=None, src=(
|
||||
LazyOp(ReduceOps.MAX, arg=(0,), src=(
|
||||
LazyOp(BinaryOps.MUL, arg=None, src=(
|
||||
LazyOp(UnaryOps.CAST, arg=dtypes.int, src=(
|
||||
LazyOp(BinaryOps.CMPNE, arg=None, src=(
|
||||
LazyOp(BinaryOps.CMPNE, arg=None, src=(
|
||||
LazyOp(BufferOps.LOAD, arg=MemBuffer(idx=1, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(200, 1), strides=(1, 0), offset=0, mask=None, contiguous=True),))), src=()), # noqa E501
|
||||
LazyOp(BufferOps.LOAD, arg=MemBuffer(idx=2, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(200, 1), strides=(0, 0), offset=0, mask=None, contiguous=False),))), src=()),)), # noqa E501
|
||||
LazyOp(BufferOps.CONST, arg=ConstBuffer(val=True, dtype=dtypes.bool, st=ShapeTracker(views=(View(shape=(200, 1), strides=(0, 0), offset=0, mask=None, contiguous=False),))), src=()),)),)), # noqa E501
|
||||
LazyOp(BinaryOps.ADD, arg=None, src=(
|
||||
LazyOp(ReduceOps.SUM, arg=(1,), src=(
|
||||
LazyOp(BufferOps.CONST, arg=ConstBuffer(val=-1, dtype=dtypes.int, st=ShapeTracker(views=(View(shape=(201, 399), strides=(0, 0), offset=0, mask=((0, 201), (199, 399)), contiguous=False), View(shape=(200, 200), strides=(1, 400), offset=0, mask=None, contiguous=False)))), src=()),)), # noqa E501
|
||||
LazyOp(BufferOps.CONST, arg=ConstBuffer(val=200, dtype=dtypes.int, st=ShapeTracker(views=(View(shape=(200, 1), strides=(0, 0), offset=0, mask=None, contiguous=False),))), src=()),)),)),)),)),)), # noqa E501
|
||||
LazyOp(BufferOps.CONST, arg=ConstBuffer(val=-1, dtype=dtypes.int, st=ShapeTracker(views=(View(shape=(1, 1), strides=(0, 0), offset=0, mask=None, contiguous=True),))), src=()),)),)),)) # noqa E501
|
||||
ast = UOp(UOps.SINK, None, arg=None, src=(
|
||||
UOp(UOps.STORE, None, arg=None, src=(
|
||||
UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.int), arg=0, src=()),
|
||||
UOp(UOps.SHAPETRACKER, None, arg=ShapeTracker(views=(View(shape=(1, 1), strides=(0, 0), offset=0, mask=None, contiguous=True),)), src=()), # noqa: E501
|
||||
UOp(UOps.ALU, dtypes.int, arg=BinaryOps.ADD, src=(
|
||||
UOp(UOps.ALU, dtypes.int, arg=BinaryOps.ADD, src=(
|
||||
UOp(UOps.CONST, dtypes.int, arg=200, src=(
|
||||
UOp(UOps.SHAPETRACKER, None, arg=ShapeTracker(views=(View(shape=(1, 1), strides=(0, 0), offset=0, mask=None, contiguous=True),)), src=()),)), # noqa: E501
|
||||
UOp(UOps.ALU, dtypes.int, arg=UnaryOps.NEG, src=(
|
||||
UOp(UOps.REDUCE_AXIS, dtypes.int, arg=(ReduceOps.MAX, (0,)), src=(
|
||||
UOp(UOps.ALU, dtypes.int, arg=BinaryOps.MUL, src=(
|
||||
UOp(UOps.CAST, dtypes.int, arg=None, src=(
|
||||
UOp(UOps.ALU, dtypes.bool, arg=BinaryOps.CMPNE, src=(
|
||||
UOp(UOps.ALU, dtypes.bool, arg=BinaryOps.CMPNE, src=(
|
||||
UOp(UOps.LOAD, dtypes.float, arg=None, src=(
|
||||
UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), arg=1, src=()),
|
||||
UOp(UOps.SHAPETRACKER, None, arg=ShapeTracker(views=(View(shape=(200, 1), strides=(1, 0), offset=0, mask=None, contiguous=True),)), src=()),)), # noqa: E501
|
||||
UOp(UOps.LOAD, dtypes.float, arg=None, src=(
|
||||
UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), arg=2, src=()),
|
||||
UOp(UOps.SHAPETRACKER, None, arg=ShapeTracker(views=(View(shape=(200, 1), strides=(0, 0), offset=0, mask=None, contiguous=False),)), src=()),)),)), # noqa: E501
|
||||
UOp(UOps.CONST, dtypes.bool, arg=True, src=(
|
||||
UOp(UOps.SHAPETRACKER, None, arg=ShapeTracker(views=(View(shape=(200, 1), strides=(0, 0), offset=0, mask=None, contiguous=False),)), src=()),)),)),)), # noqa: E501
|
||||
UOp(UOps.ALU, dtypes.int, arg=BinaryOps.ADD, src=(
|
||||
UOp(UOps.REDUCE_AXIS, dtypes.int, arg=(ReduceOps.SUM, (1,)), src=(
|
||||
UOp(UOps.CONST, dtypes.int, arg=-1, src=(
|
||||
UOp(UOps.SHAPETRACKER, None, arg=ShapeTracker(views=(View(shape=(201, 399), strides=(0, 0), offset=0, mask=((0, 201), (199, 399)), contiguous=False), View(shape=(200, 200), strides=(1, 400), offset=0, mask=None, contiguous=False))), src=()),)),)), # noqa: E501
|
||||
UOp(UOps.CONST, dtypes.int, arg=200, src=(
|
||||
UOp(UOps.SHAPETRACKER, None, arg=ShapeTracker(views=(View(shape=(200, 1), strides=(0, 0), offset=0, mask=None, contiguous=False),)), src=()),)),)),)),)),)),)), # noqa: E501
|
||||
UOp(UOps.CONST, dtypes.int, arg=-1, src=(
|
||||
UOp(UOps.SHAPETRACKER, None, arg=ShapeTracker(views=(View(shape=(1, 1), strides=(0, 0), offset=0, mask=None, contiguous=True),)), src=()),)),)),)),)) # noqa: E501
|
||||
helper_linearizer_ast(ast, [t, t_max], wanna_output=[real_argmax])
|
||||
|
||||
@unittest.skipIf(CI and Device.DEFAULT in {"AMD"}, "AMD CI doesn't support multiple sync threads yet")
|
||||
@@ -630,19 +677,22 @@ class TestLinearizer(unittest.TestCase):
|
||||
# [Opt(OptOps.PADTO, 1, 32), Opt(OptOps.PADTO, 2, 32)]
|
||||
]
|
||||
|
||||
x_ld0 = LazyOp(BufferOps.LOAD, (), MemBuffer(1, dtypes.float, x.lazydata.st.reshape((1, N, N)).expand((N,N,N))))
|
||||
x_ld1 = LazyOp(BufferOps.LOAD, (), MemBuffer(1, dtypes.float, x.lazydata.st.reshape((N, 1, N))))
|
||||
r0 = LazyOp(ReduceOps.SUM, (x_ld0,), (1,))
|
||||
r1 = LazyOp(ReduceOps.SUM, (LazyOp(BinaryOps.ADD, (x_ld1, LazyOp(op=UnaryOps.NEG, src=(r0,), arg=None)),),), (0,))
|
||||
store = LazyOp(BufferOps.STORE, (r1, ), MemBuffer(0, dtypes.float, ShapeTracker.from_shape((1,1,N))))
|
||||
helper_linearizer_ast((store,), [x], wanna_output=[(x.numpy()-x.numpy().sum(axis=0, keepdims=True)).sum(axis=0).reshape(1,1,N)], opts=opts)
|
||||
g0, g1 = [UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), arg=i) for i in range(2)]
|
||||
x_ld0 = UOp(UOps.LOAD, dtypes.float, (g1, x.lazydata.st.reshape((1, N, N)).expand((N,N,N)).to_uop()))
|
||||
x_ld1 = UOp(UOps.LOAD, dtypes.float, (g1, x.lazydata.st.reshape((N, 1, N)).to_uop()))
|
||||
r0 = UOp(UOps.REDUCE_AXIS, dtypes.float, (x_ld0,), (ReduceOps.SUM, (1,)))
|
||||
r1 = UOp(UOps.REDUCE_AXIS, dtypes.float, (x_ld1+r0*ast_const(dtypes.float, -1, (N, 1, N)),),(ReduceOps.SUM, (0,)))
|
||||
store = UOp(UOps.STORE, src=(g0, ShapeTracker.from_shape((1,1,N)).to_uop(), r1))
|
||||
sink = UOp(UOps.SINK, src=(store,))
|
||||
helper_linearizer_ast(sink, [x], wanna_output=[(x.numpy()-x.numpy().sum(axis=0, keepdims=True)).sum(axis=0).reshape(1,1,N)], opts=opts)
|
||||
|
||||
x_ld0 = LazyOp(BufferOps.LOAD, (), MemBuffer(1, dtypes.float, x.lazydata.st.reshape((N, 1, N)).expand((N,N,N))))
|
||||
x_ld1 = LazyOp(BufferOps.LOAD, (), MemBuffer(1, dtypes.float, x.lazydata.st.reshape((N, N, 1))))
|
||||
r0 = LazyOp(ReduceOps.SUM, (x_ld0,), (2,))
|
||||
r1 = LazyOp(ReduceOps.SUM, (LazyOp(BinaryOps.ADD, (x_ld1, LazyOp(op=UnaryOps.NEG, src=(r0,), arg=None)),),), (1,))
|
||||
store = LazyOp(BufferOps.STORE, (r1, ), MemBuffer(0, dtypes.float, ShapeTracker.from_shape((N,1,1))))
|
||||
helper_linearizer_ast((store,), [x], wanna_output=[(x.numpy()-x.numpy().sum(axis=1, keepdims=True)).sum(axis=1).reshape(N,1,1)], opts=opts)
|
||||
x_ld0 = UOp(UOps.LOAD, dtypes.float, (g1, x.lazydata.st.reshape((N, 1, N)).expand((N,N,N)).to_uop()))
|
||||
x_ld1 = UOp(UOps.LOAD, dtypes.float, (g1, x.lazydata.st.reshape((N, N, 1)).to_uop()))
|
||||
r0 = UOp(UOps.REDUCE_AXIS, dtypes.float, (x_ld0,), (ReduceOps.SUM, (2,)))
|
||||
r1 = UOp(UOps.REDUCE_AXIS, dtypes.float, (x_ld1+r0*ast_const(dtypes.float, -1, (N, N, 1)),), (ReduceOps.SUM, (1,)))
|
||||
store = UOp(UOps.STORE, src=(g0, ShapeTracker.from_shape((N,1,1)).to_uop(), r1))
|
||||
sink = UOp(UOps.SINK, src=(store,))
|
||||
helper_linearizer_ast(sink, [x], wanna_output=[(x.numpy()-x.numpy().sum(axis=1, keepdims=True)).sum(axis=1).reshape(N,1,1)], opts=opts)
|
||||
|
||||
@unittest.skipIf(CI and Device.DEFAULT in {"AMD"}, "AMD CI doesn't support multiple sync threads yet")
|
||||
def test_padto_max_multireduce(self):
|
||||
@@ -654,19 +704,22 @@ class TestLinearizer(unittest.TestCase):
|
||||
[Opt(OptOps.PADTO, 0, 32), Opt(OptOps.UPCAST, 0, 8),]
|
||||
]
|
||||
|
||||
x_ld0 = LazyOp(BufferOps.LOAD, (), MemBuffer(1, dtypes.float, x.lazydata.st.reshape((1, N, N)).expand((N,N,N))))
|
||||
x_ld1 = LazyOp(BufferOps.LOAD, (), MemBuffer(1, dtypes.float, x.lazydata.st.reshape((N, 1, N))))
|
||||
r0 = LazyOp(ReduceOps.MAX, (x_ld0,), (1,))
|
||||
r1 = LazyOp(ReduceOps.MAX, (LazyOp(BinaryOps.ADD, (x_ld1, LazyOp(op=UnaryOps.NEG, src=(r0,), arg=None)),),), (0,))
|
||||
store = LazyOp(BufferOps.STORE, (r1, ), MemBuffer(0, dtypes.float, ShapeTracker.from_shape((1,1,N))))
|
||||
helper_linearizer_ast((store,), [x], wanna_output=[(x.numpy()-x.numpy().max(axis=0, keepdims=True)).max(axis=0).reshape(1,1,N)], opts=opts)
|
||||
g0, g1 = [UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), arg=i) for i in range(2)]
|
||||
x_ld0 = UOp(UOps.LOAD, dtypes.float, (g1, x.lazydata.st.reshape((1, N, N)).expand((N,N,N)).to_uop()))
|
||||
x_ld1 = UOp(UOps.LOAD, dtypes.float, (g1, x.lazydata.st.reshape((N, 1, N)).to_uop()))
|
||||
r0 = UOp(UOps.REDUCE_AXIS, dtypes.float, (x_ld0,), (ReduceOps.MAX, (1,)))
|
||||
r1 = UOp(UOps.REDUCE_AXIS, dtypes.float, (x_ld1+r0*ast_const(dtypes.float, -1, (N, 1, N)),), (ReduceOps.MAX, (0,)))
|
||||
store = UOp(UOps.STORE, src=(g0, ShapeTracker.from_shape((1,1,N)).to_uop(), r1))
|
||||
sink = UOp(UOps.SINK, src=(store,))
|
||||
helper_linearizer_ast(sink, [x], wanna_output=[(x.numpy()-x.numpy().max(axis=0, keepdims=True)).max(axis=0).reshape(1,1,N)], opts=opts)
|
||||
|
||||
x_ld0 = LazyOp(BufferOps.LOAD, (), MemBuffer(1, dtypes.float, x.lazydata.st.reshape((N, 1, N)).expand((N,N,N))))
|
||||
x_ld1 = LazyOp(BufferOps.LOAD, (), MemBuffer(1, dtypes.float, x.lazydata.st.reshape((N, N, 1))))
|
||||
r0 = LazyOp(ReduceOps.MAX, (x_ld0,), (2,))
|
||||
r1 = LazyOp(ReduceOps.MAX, (LazyOp(BinaryOps.ADD, (x_ld1, LazyOp(op=UnaryOps.NEG, src=(r0,), arg=None)),),), (1,))
|
||||
store = LazyOp(BufferOps.STORE, (r1, ), MemBuffer(0, dtypes.float, ShapeTracker.from_shape((N,1,1))))
|
||||
helper_linearizer_ast((store,), [x], wanna_output=[(x.numpy()-x.numpy().max(axis=1, keepdims=True)).max(axis=1).reshape(N,1,1)], opts=opts)
|
||||
x_ld0 = UOp(UOps.LOAD, dtypes.float, (g1, x.lazydata.st.reshape((N, 1, N)).expand((N,N,N)).to_uop()))
|
||||
x_ld1 = UOp(UOps.LOAD, dtypes.float, (g1, x.lazydata.st.reshape((N, N, 1)).to_uop()))
|
||||
r0 = UOp(UOps.REDUCE_AXIS, dtypes.float, (x_ld0,), (ReduceOps.MAX, (2,)))
|
||||
r1 = UOp(UOps.REDUCE_AXIS, dtypes.float, (x_ld1+r0*ast_const(dtypes.float, -1, (N, N, 1)),), (ReduceOps.MAX, (1,)))
|
||||
store = UOp(UOps.STORE, src=(g0, ShapeTracker.from_shape((N,1,1)).to_uop(), r1))
|
||||
sink = UOp(UOps.SINK, src=(store,))
|
||||
helper_linearizer_ast(sink, [x], wanna_output=[(x.numpy()-x.numpy().max(axis=1, keepdims=True)).max(axis=1).reshape(N,1,1)], opts=opts)
|
||||
|
||||
@unittest.skipIf(CI and Device.DEFAULT in {"AMD"}, "AMD CI doesn't support multiple sync threads yet")
|
||||
def test_padto_where_multireduce(self):
|
||||
@@ -683,33 +736,125 @@ class TestLinearizer(unittest.TestCase):
|
||||
wanna_output = np.where(0.5*17 < (x.numpy()+np.where(0.75*17 < x.numpy().sum(axis=1,keepdims=True), a.numpy(), b.numpy())).sum(axis=1),0.0,1.0).reshape((N,1,1)) # noqa: E501
|
||||
ld0 = x.lazydata.st.reshape((N, 1, N)).expand((N,N,N))
|
||||
ld1 = x.lazydata.st.reshape((N, N, 1))
|
||||
ast = LazyOp(op=BufferOps.STORE, src=(LazyOp(op=TernaryOps.WHERE, src=(LazyOp(op=BinaryOps.CMPLT, src=(LazyOp(op=BufferOps.CONST, src=(), arg=ConstBuffer(val=0.5*17, dtype=dtypes.float, st=ShapeTracker.from_shape((1,1,1)).expand((N,1,1)))),LazyOp(op=ReduceOps.SUM, src=(LazyOp(op=BinaryOps.ADD, src=(LazyOp(op=BufferOps.LOAD, src=(), arg=MemBuffer(idx=1, dtype=dtypes.float, st=ld1)),LazyOp(op=TernaryOps.WHERE, src=(LazyOp(op=BinaryOps.CMPLT, src=(LazyOp(op=BufferOps.CONST, src=(), arg=ConstBuffer(val=0.75*17, dtype=dtypes.float, st=ShapeTracker.from_shape((1,1,1)).expand((N,N,1)))),LazyOp(op=ReduceOps.SUM, src=(LazyOp(op=BufferOps.LOAD, src=(), arg=MemBuffer(idx=1, dtype=dtypes.float, st=ld0)),), arg=(2,)))),LazyOp(op=BufferOps.LOAD, src=(), arg=MemBuffer(idx=2, dtype=dtypes.float, st=ShapeTracker.from_shape((1,1,1)).expand((N,N,1)))),LazyOp(op=BufferOps.LOAD, src=(), arg=MemBuffer(idx=3, dtype=dtypes.float, st=ShapeTracker.from_shape((1,1,1)).expand((N,N,1)))),)),)),), arg=(1,)),)),LazyOp(op=BufferOps.CONST, src=(), arg=ConstBuffer(val=0.0, dtype=dtypes.float, st=ShapeTracker.from_shape((1,1,1)).expand((N,1,1)))),LazyOp(op=BufferOps.CONST, src=(), arg=ConstBuffer(val=1.0, dtype=dtypes.float, st=ShapeTracker.from_shape((1,1,1)).expand((N,1,1)))),)),), arg=MemBuffer(idx=0, dtype=dtypes.float, st=ShapeTracker.from_shape((N,1,1)))) # noqa: E501
|
||||
helper_linearizer_ast((ast,), [x,a,b], opts=opts, wanna_output=[wanna_output])
|
||||
ast = UOp(UOps.SINK, src=(
|
||||
UOp(UOps.STORE, src=(
|
||||
UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), arg=0),
|
||||
UOp(UOps.SHAPETRACKER, arg=ShapeTracker(views=(View(shape=(N, 1, 1), strides=(1, 0, 0), offset=0, mask=None, contiguous=True),))),
|
||||
UOp(UOps.ALU, dtypes.float, arg=TernaryOps.WHERE, src=(
|
||||
UOp(UOps.ALU, dtypes.bool, arg=BinaryOps.CMPLT, src=(
|
||||
UOp(UOps.CONST, dtypes.float, arg=0.5*N, src=(
|
||||
UOp(UOps.SHAPETRACKER, arg=ShapeTracker(views=(View(shape=(N, 1, 1), strides=(0, 0, 0), offset=0, mask=None, contiguous=False),))),)),
|
||||
UOp(UOps.REDUCE_AXIS, dtypes.float, arg=(ReduceOps.SUM, (1,)), src=(
|
||||
UOp(UOps.ALU, dtypes.float, arg=BinaryOps.ADD, src=(
|
||||
UOp(UOps.LOAD, dtypes.float, src=(
|
||||
UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), arg=1),
|
||||
ld1.to_uop(),)),
|
||||
UOp(UOps.ALU, dtypes.float, arg=TernaryOps.WHERE, src=(
|
||||
UOp(UOps.ALU, dtypes.bool, arg=BinaryOps.CMPLT, src=(
|
||||
UOp(UOps.CONST, dtypes.float, arg=0.75*N, src=(
|
||||
UOp(UOps.SHAPETRACKER, arg=ShapeTracker(views=(View(shape=(N, N, 1), strides=(0, 0, 0), offset=0, mask=None, contiguous=False),))),)), # noqa: E501
|
||||
UOp(UOps.REDUCE_AXIS, dtypes.float, arg=(ReduceOps.SUM, (2,)), src=(
|
||||
UOp(UOps.LOAD, dtypes.float, src=(
|
||||
UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), arg=1),
|
||||
ld0.to_uop(),)),)),)),
|
||||
UOp(UOps.LOAD, dtypes.float, src=(
|
||||
UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), arg=2),
|
||||
UOp(UOps.SHAPETRACKER, arg=ShapeTracker(views=(View(shape=(N, N, 1), strides=(0, 0, 0), offset=0, mask=None, contiguous=False),))),)), # noqa: E501
|
||||
UOp(UOps.LOAD, dtypes.float, src=(
|
||||
UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), arg=3),
|
||||
UOp(UOps.SHAPETRACKER, arg=ShapeTracker(views=(View(shape=(N, N, 1), strides=(0, 0, 0), offset=0, mask=None, contiguous=False),))),)),)),)),)),)), # noqa: E501
|
||||
UOp(UOps.CONST, dtypes.float, arg=0.0, src=(
|
||||
UOp(UOps.SHAPETRACKER, arg=ShapeTracker(views=(View(shape=(N, 1, 1), strides=(0, 0, 0), offset=0, mask=None, contiguous=False),))),)),
|
||||
UOp(UOps.CONST, dtypes.float, arg=1.0, src=(
|
||||
UOp(UOps.SHAPETRACKER, arg=ShapeTracker(views=(View(shape=(N, 1, 1), strides=(0, 0, 0), offset=0, mask=None, contiguous=False),))),)),)),)),)) # noqa: E501
|
||||
helper_linearizer_ast(ast, [x,a,b], opts=opts, wanna_output=[wanna_output])
|
||||
|
||||
ld0 = x.lazydata.st.reshape((1, N, N)).expand((N,N,N))
|
||||
ld1 = x.lazydata.st.reshape((N, 1, N))
|
||||
wanna_output = np.where(0.5*17 < (x.numpy()+np.where(0.75*17 < x.numpy().sum(axis=0,keepdims=True), a.numpy(), b.numpy())).sum(axis=0),0.0,1.0).reshape(1,1,N) # noqa: E501
|
||||
ast = LazyOp(op=BufferOps.STORE, src=(LazyOp(op=TernaryOps.WHERE, src=(LazyOp(op=BinaryOps.CMPLT, src=(LazyOp(op=BufferOps.CONST, src=(), arg=ConstBuffer(val=0.5*17, dtype=dtypes.float, st=ShapeTracker.from_shape((1,1,1)).expand((1,1,N)))),LazyOp(op=ReduceOps.SUM, src=(LazyOp(op=BinaryOps.ADD, src=(LazyOp(op=BufferOps.LOAD, src=(), arg=MemBuffer(idx=1, dtype=dtypes.float, st=ld1)),LazyOp(op=TernaryOps.WHERE, src=(LazyOp(op=BinaryOps.CMPLT, src=(LazyOp(op=BufferOps.CONST, src=(), arg=ConstBuffer(val=0.75*17, dtype=dtypes.float, st=ShapeTracker.from_shape((1,1,1)).expand((N,1,N)))),LazyOp(op=ReduceOps.SUM, src=(LazyOp(op=BufferOps.LOAD, src=(), arg=MemBuffer(idx=1, dtype=dtypes.float, st=ld0)),), arg=(1,)))),LazyOp(op=BufferOps.LOAD, src=(), arg=MemBuffer(idx=2, dtype=dtypes.float, st=ShapeTracker.from_shape((1,1,1)).expand((N,1,N)))),LazyOp(op=BufferOps.LOAD, src=(), arg=MemBuffer(idx=3, dtype=dtypes.float, st=ShapeTracker.from_shape((1,1,1)).expand((N,1,N)))),)),)),), arg=(0,)),)),LazyOp(op=BufferOps.CONST, src=(), arg=ConstBuffer(val=0.0, dtype=dtypes.float, st=ShapeTracker.from_shape((1,1,1)).expand((1,1,N)))),LazyOp(op=BufferOps.CONST, src=(), arg=ConstBuffer(val=1.0, dtype=dtypes.float, st=ShapeTracker.from_shape((1,1,1)).expand((1,1,N)))),)),), arg=MemBuffer(idx=0, dtype=dtypes.float, st=ShapeTracker.from_shape((1,1,N)))) # noqa: E501
|
||||
helper_linearizer_ast((ast,), [x,a,b], opts=opts, wanna_output=[wanna_output])
|
||||
ast = UOp(UOps.SINK, src=(
|
||||
UOp(UOps.STORE, src=(
|
||||
UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), arg=0, src=()),
|
||||
UOp(UOps.SHAPETRACKER, arg=ShapeTracker(views=(View(shape=(1, 1, N), strides=(0, 0, 1), offset=0, mask=None, contiguous=True),)), src=()),
|
||||
UOp(UOps.ALU, dtypes.float, arg=TernaryOps.WHERE, src=(
|
||||
UOp(UOps.ALU, dtypes.bool, arg=BinaryOps.CMPLT, src=(
|
||||
UOp(UOps.CONST, dtypes.float, arg=0.5*N, src=(
|
||||
UOp(UOps.SHAPETRACKER, arg=ShapeTracker(views=(View(shape=(1, 1, N), strides=(0, 0, 0), offset=0, mask=None, contiguous=False),)), src=()),)), # noqa: E501
|
||||
UOp(UOps.REDUCE_AXIS, dtypes.float, arg=(ReduceOps.SUM, (0,)), src=(
|
||||
UOp(UOps.ALU, dtypes.float, arg=BinaryOps.ADD, src=(
|
||||
UOp(UOps.LOAD, dtypes.float, src=(
|
||||
UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), arg=1, src=()),
|
||||
ld1.to_uop(),)),
|
||||
UOp(UOps.ALU, dtypes.float, arg=TernaryOps.WHERE, src=(
|
||||
UOp(UOps.ALU, dtypes.bool, arg=BinaryOps.CMPLT, src=(
|
||||
UOp(UOps.CONST, dtypes.float, arg=0.75*N, src=(
|
||||
UOp(UOps.SHAPETRACKER, arg=ShapeTracker(views=(View(shape=(N, 1, N), strides=(0, 0, 0), offset=0, mask=None, contiguous=False),)), src=()),)), # noqa: E501
|
||||
UOp(UOps.REDUCE_AXIS, dtypes.float, arg=(ReduceOps.SUM, (1,)), src=(
|
||||
UOp(UOps.LOAD, dtypes.float, src=(
|
||||
UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), arg=1, src=()),
|
||||
ld0.to_uop(),)),)),)),
|
||||
UOp(UOps.LOAD, dtypes.float, src=(
|
||||
UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), arg=2, src=()),
|
||||
UOp(UOps.SHAPETRACKER, arg=ShapeTracker(views=(View(shape=(N, 1, N), strides=(0, 0, 0), offset=0, mask=None, contiguous=False),)), src=()),)), # noqa: E501
|
||||
UOp(UOps.LOAD, dtypes.float, src=(
|
||||
UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), arg=3, src=()),
|
||||
UOp(UOps.SHAPETRACKER, arg=ShapeTracker(views=(View(shape=(N, 1, N), strides=(0, 0, 0), offset=0, mask=None, contiguous=False),)), src=()),)),)),)),)),)), # noqa: E501
|
||||
UOp(UOps.CONST, dtypes.float, arg=0.0, src=(
|
||||
UOp(UOps.SHAPETRACKER, arg=ShapeTracker(views=(View(shape=(1, 1, N), strides=(0, 0, 0), offset=0, mask=None, contiguous=False),)), src=()),)), # noqa: E501
|
||||
UOp(UOps.CONST, dtypes.float, arg=1.0, src=(
|
||||
UOp(UOps.SHAPETRACKER, arg=ShapeTracker(views=(View(shape=(1, 1, N), strides=(0, 0, 0), offset=0, mask=None, contiguous=False),)), src=()),)),)),)),)) # noqa: E501
|
||||
helper_linearizer_ast(ast, [x,a,b], opts=opts, wanna_output=[wanna_output])
|
||||
|
||||
# # pad reduce axis
|
||||
# helper_linearizer_ast((ast,), [x,a,b], opts=[[Opt(OptOps.PADTO, 1, 32)],], wanna_output=[wanna_output])
|
||||
# pad reduce axis
|
||||
helper_linearizer_ast(ast, [x,a,b], opts=[[Opt(OptOps.PADTO, 1, 32)],], wanna_output=[wanna_output])
|
||||
|
||||
# ld0 = x.lazydata.st.reshape((1,1,N,N)).expand((N,N,N,N))
|
||||
# ld1 = x.lazydata.st.reshape((N,N,1,1))
|
||||
# wanna_output = np.where(0.5*17 < (x.numpy()+np.where(0.75*17 < x.numpy().sum(keepdims=True), a.numpy(), b.numpy())).sum(keepdims=True),0.0,1.0).reshape((1,1,1,1))# noqa: E501
|
||||
# ast = LazyOp(op=BufferOps.STORE, src=(LazyOp(op=TernaryOps.WHERE, src=(LazyOp(op=BinaryOps.CMPLT, src=(LazyOp(op=BufferOps.CONST, src=(), arg=ConstBuffer(val=0.5*17, dtype=dtypes.float, st=ShapeTracker.from_shape((1,1,1,1)))),LazyOp(op=ReduceOps.SUM, src=(LazyOp(op=BinaryOps.ADD, src=(LazyOp(op=BufferOps.LOAD, src=(), arg=MemBuffer(idx=1, dtype=dtypes.float, st=ld1)),LazyOp(op=TernaryOps.WHERE, src=(LazyOp(op=BinaryOps.CMPLT, src=(LazyOp(op=BufferOps.CONST, src=(), arg=ConstBuffer(val=0.75*17, dtype=dtypes.float, st=ShapeTracker.from_shape((1,1,1,1)).expand((N,N,1,1)))),LazyOp(op=ReduceOps.SUM, src=(LazyOp(op=BufferOps.LOAD, src=(), arg=MemBuffer(idx=1, dtype=dtypes.float, st=ld0)),), arg=(2,3,)))),LazyOp(op=BufferOps.LOAD, src=(), arg=MemBuffer(idx=2, dtype=dtypes.float, st=ShapeTracker.from_shape((1,1,1,1)).expand((N,N,1,1)))),LazyOp(op=BufferOps.LOAD, src=(), arg=MemBuffer(idx=3, dtype=dtypes.float, st=ShapeTracker.from_shape((1,1,1,1)).expand((N,N,1,1)))),)),)),), arg=(0,1,)),)),LazyOp(op=BufferOps.CONST, src=(), arg=ConstBuffer(val=0.0, dtype=dtypes.float, st=ShapeTracker.from_shape((1,1,1,1)))),LazyOp(op=BufferOps.CONST, src=(), arg=ConstBuffer(val=1.0, dtype=dtypes.float, st=ShapeTracker.from_shape((1,1,1,1)))),)),), arg=MemBuffer(idx=0, dtype=dtypes.float, st=ShapeTracker.from_shape((1,1,1,1)))) # noqa: E501
|
||||
# helper_linearizer_ast((ast,), [x,a,b], opts=[[Opt(OptOps.PADTO, 0, 32)],], wanna_output=[wanna_output])
|
||||
ld0 = x.lazydata.st.reshape((1,1,N,N)).expand((N,N,N,N))
|
||||
ld1 = x.lazydata.st.reshape((N,N,1,1))
|
||||
wanna_output = np.where(0.5*17 < (x.numpy()+np.where(0.75*17 < x.numpy().sum(keepdims=True), a.numpy(), b.numpy())).sum(keepdims=True),0.0,1.0).reshape((1,1,1,1))# noqa: E501
|
||||
ast = UOp(UOps.SINK, src=(
|
||||
UOp(UOps.STORE, src=(
|
||||
UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), arg=0, src=()),
|
||||
UOp(UOps.SHAPETRACKER, arg=ShapeTracker(views=(View(shape=(1, 1, 1, 1), strides=(0, 0, 0, 0), offset=0, mask=None, contiguous=True),))),
|
||||
UOp(UOps.ALU, dtypes.float, arg=TernaryOps.WHERE, src=(
|
||||
UOp(UOps.ALU, dtypes.bool, arg=BinaryOps.CMPLT, src=(
|
||||
UOp(UOps.CONST, dtypes.float, arg=0.5*N, src=(
|
||||
UOp(UOps.SHAPETRACKER, arg=ShapeTracker(views=(View(shape=(1, 1, 1, 1), strides=(0, 0, 0, 0), offset=0, mask=None, contiguous=True),))),)), # noqa: E501
|
||||
UOp(UOps.REDUCE_AXIS, dtypes.float, arg=(ReduceOps.SUM, (0, 1)), src=(
|
||||
UOp(UOps.ALU, dtypes.float, arg=BinaryOps.ADD, src=(
|
||||
UOp(UOps.LOAD, dtypes.float, src=(
|
||||
UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), arg=1),
|
||||
UOp(UOps.SHAPETRACKER, arg=ShapeTracker(views=(View(shape=(N, N, 1, 1), strides=(N, 1, 0, 0), offset=0, mask=None, contiguous=True),))),)), # noqa: E501
|
||||
UOp(UOps.ALU, dtypes.float, arg=TernaryOps.WHERE, src=(
|
||||
UOp(UOps.ALU, dtypes.bool, arg=BinaryOps.CMPLT, src=(
|
||||
UOp(UOps.CONST, dtypes.float, arg=0.75*N, src=(
|
||||
UOp(UOps.SHAPETRACKER, arg=ShapeTracker(views=(View(shape=(N, N, 1, 1), strides=(0, 0, 0, 0), offset=0, mask=None, contiguous=False),))),)), # noqa: E501
|
||||
UOp(UOps.REDUCE_AXIS, dtypes.float, arg=(ReduceOps.SUM, (2, 3)), src=(
|
||||
UOp(UOps.LOAD, dtypes.float, src=(
|
||||
UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), arg=1),
|
||||
UOp(UOps.SHAPETRACKER, arg=ShapeTracker(views=(View(shape=(N, N, N, N), strides=(0, 0, N, 1), offset=0, mask=None, contiguous=False),))),)),)),)), # noqa: E501
|
||||
UOp(UOps.LOAD, dtypes.float, src=(
|
||||
UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), arg=2),
|
||||
UOp(UOps.SHAPETRACKER, arg=ShapeTracker(views=(View(shape=(N, N, 1, 1), strides=(0, 0, 0, 0), offset=0, mask=None, contiguous=False),))),)), # noqa: E501
|
||||
UOp(UOps.LOAD, dtypes.float, src=(
|
||||
UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), arg=3),
|
||||
UOp(UOps.SHAPETRACKER, arg=ShapeTracker(views=(View(shape=(N, N, 1, 1), strides=(0, 0, 0, 0), offset=0, mask=None, contiguous=False),))),)),)),)),)),)), # noqa: E501
|
||||
UOp(UOps.CONST, dtypes.float, arg=0.0, src=(
|
||||
UOp(UOps.SHAPETRACKER, arg=ShapeTracker(views=(View(shape=(1, 1, 1, 1), strides=(0, 0, 0, 0), offset=0, mask=None, contiguous=True),))),)), # noqa: E501
|
||||
UOp(UOps.CONST, dtypes.float, arg=1.0, src=(
|
||||
UOp(UOps.SHAPETRACKER, arg=ShapeTracker(views=(View(shape=(1, 1, 1, 1), strides=(0, 0, 0, 0), offset=0, mask=None, contiguous=True),))),)),)),)),)) # noqa: E501
|
||||
|
||||
helper_linearizer_ast(ast, [x,a,b], opts=[[Opt(OptOps.PADTO, 0, 32)],], wanna_output=[wanna_output])
|
||||
|
||||
@unittest.skipUnless(Device[Device.DEFAULT].renderer.has_local, "test requires locals")
|
||||
@unittest.skipUnless(Device[Device.DEFAULT].renderer.has_shared, "test requires shared")
|
||||
def test_end_local(self):
|
||||
load = MemBuffer(idx=1, dtype=dtypes.int, st=ShapeTracker.from_shape((32,)))
|
||||
store = MemBuffer(idx=0, dtype=dtypes.int, st=ShapeTracker.from_shape((1,)))
|
||||
ast = LazyOp(op=BufferOps.STORE, src=(LazyOp(op=ReduceOps.SUM, src=(LazyOp(op=BufferOps.LOAD, arg=load),), arg=(0,)),), arg=store),
|
||||
|
||||
load_t = Tensor.full(load.st.shape, 1).contiguous().realize()
|
||||
k = helper_linearizer_ast(ast, [load_t], wanna_output=[load_t.numpy().sum()])[1]
|
||||
g0, g1 = [UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.int), arg=i) for i in range(2)]
|
||||
load = UOp(UOps.LOAD, dtypes.int, (g1, ShapeTracker.from_shape((32,)).to_uop()))
|
||||
reduce = UOp(UOps.REDUCE_AXIS, dtypes.int, (load,), (ReduceOps.SUM, (0,)))
|
||||
store = UOp(UOps.STORE, src=(g0, ShapeTracker.from_shape((1,)).to_uop(), reduce))
|
||||
sink = UOp(UOps.SINK, src=(store,))
|
||||
load_t = Tensor.full(load.st_arg.shape, 1).contiguous().realize()
|
||||
k = helper_linearizer_ast(sink, [load_t], wanna_output=[load_t.numpy().sum()])[1]
|
||||
self.assertEqual(k.uops[-1].op, UOps.ENDIF)
|
||||
self.assertLess(k.uops.index([x for x in k.uops if x.op is UOps.STORE][-1]), k.uops.index(k.uops[-1]))
|
||||
|
||||
@@ -803,16 +948,18 @@ class TestLinearizer(unittest.TestCase):
|
||||
@unittest.skipIf(getenv("PTX"), "broken on ptx for some reason")
|
||||
def test_load_cache_const_bufs(self):
|
||||
# make sure const buffers are differentiated from local and mem buffers
|
||||
ST, DT = ShapeTracker(views=(View(shape=((1,)), strides=(0, 0), offset=0, mask=None, contiguous=False),)), dtypes.int
|
||||
VAL = LazyOp(op=BufferOps.CONST, src=(), arg=ConstBuffer(val=2, dtype=DT, st=ST))
|
||||
ST, DT = ShapeTracker(views=(View(shape=((1,)), strides=(0, 0), offset=0, mask=None, contiguous=False),)).to_uop(), dtypes.int
|
||||
VAL = UOp(UOps.CONST, DT, (ST,), 2)
|
||||
g0, g1 = [UOp(UOps.DEFINE_GLOBAL, PtrDType(DT), arg=i) for i in range(2)]
|
||||
|
||||
# data1[0] + VAL
|
||||
a = LazyOp(op=BinaryOps.ADD, src=(LazyOp(op=BufferOps.LOAD, src=(), arg=MemBuffer(idx=1, dtype=DT, st=ST)), VAL))
|
||||
a = UOp(UOps.LOAD, DT, (g1, ST)) + VAL
|
||||
# (literal const 1) + VAL
|
||||
b = LazyOp(op=BinaryOps.ADD, src=(LazyOp(op=BufferOps.CONST, src=(), arg=ConstBuffer(val=1, dtype=DT, st=ST)), VAL))
|
||||
b = UOp(UOps.CONST, DT, (ST,), 1) + VAL
|
||||
|
||||
ast = LazyOp(op=BufferOps.STORE, src=(LazyOp(op=BinaryOps.ADD, src=(a,b)),), arg=MemBuffer(idx=0, dtype=DT, st=ST))
|
||||
lin = Kernel(ast)
|
||||
store = UOp(UOps.STORE, src=(g0, ST, (a+b)))
|
||||
sink = UOp(UOps.SINK, src=(store,))
|
||||
lin = Kernel(sink)
|
||||
lin.linearize()
|
||||
|
||||
assert len(lin.uops) <= 7, "too many uops"
|
||||
@@ -1235,7 +1382,13 @@ class TestLinearizer(unittest.TestCase):
|
||||
@unittest.skipUnless(Device[Device.DEFAULT].renderer.supports_float4, "test requires float4")
|
||||
def test_skip_unmatching_upcasts(self):
|
||||
Tensor.manual_seed(0)
|
||||
ast = LazyOp(op=BufferOps.STORE, src=(LazyOp(op=BufferOps.LOAD, src=(), arg=MemBuffer(idx=1, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(240, 40, 1, 1), strides=(1, 240, 0, 0), offset=0, mask=None, contiguous=False),)))),), arg=MemBuffer(idx=0, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(240, 40, 1, 1), strides=(40, 1, 0, 0), offset=0, mask=None, contiguous=True),)))), # noqa: E501
|
||||
ast = UOp(UOps.SINK, src=(
|
||||
UOp(UOps.STORE, src=(
|
||||
UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), arg=0),
|
||||
UOp(UOps.SHAPETRACKER, arg=ShapeTracker(views=(View(shape=(240, 40, 1, 1), strides=(40, 1, 0, 0), offset=0, mask=None, contiguous=True),))),
|
||||
UOp(UOps.LOAD, dtypes.float, src=(
|
||||
UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), arg=1),
|
||||
UOp(UOps.SHAPETRACKER, arg=ShapeTracker(views=(View(shape=(240, 40, 1, 1), strides=(1, 240, 0, 0), offset=0, mask=None, contiguous=False),))),)),)),)) # noqa: E501
|
||||
opt = [
|
||||
Opt(op=OptOps.UPCAST, axis=1, amt=4), Opt(op=OptOps.LOCAL, axis=0, amt=16),
|
||||
Opt(op=OptOps.LOCAL, axis=1, amt=2), Opt(op=OptOps.UPCAST, axis=3, amt=2)
|
||||
@@ -1248,7 +1401,13 @@ class TestLinearizer(unittest.TestCase):
|
||||
@unittest.skipUnless(Device[Device.DEFAULT].renderer.supports_float4, "test requires float4")
|
||||
def test_skip_unmatching_upcasts_with_gep(self):
|
||||
Tensor.manual_seed(0)
|
||||
ast = LazyOp(op=BufferOps.STORE, src=(LazyOp(op=BufferOps.LOAD, src=(), arg=MemBuffer(idx=1, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(8, 32, 1, 1), strides=(1, 8, 0, 0), offset=0, mask=None, contiguous=False),)))),), arg=MemBuffer(idx=0, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(8, 32, 1, 1), strides=(32, 1, 0, 0), offset=0, mask=None, contiguous=True),)))), # noqa: E501
|
||||
ast = UOp(UOps.SINK, src=(
|
||||
UOp(UOps.STORE, src=(
|
||||
UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), arg=0, src=()),
|
||||
UOp(UOps.SHAPETRACKER, arg=ShapeTracker(views=(View(shape=(8, 32, 1, 1), strides=(32, 1, 0, 0), offset=0, mask=None, contiguous=True),))),
|
||||
UOp(UOps.LOAD, dtypes.float, src=(
|
||||
UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), arg=1),
|
||||
UOp(UOps.SHAPETRACKER, arg=ShapeTracker(views=(View(shape=(8, 32, 1, 1), strides=(1, 8, 0, 0), offset=0, mask=None, contiguous=False),))),)),)),)) # noqa: E501
|
||||
opt = [Opt(op=OptOps.LOCAL, axis=1, amt=4), Opt(op=OptOps.UPCAST, axis=2, amt=2), Opt(op=OptOps.LOCAL, axis=1, amt=8),
|
||||
Opt(op=OptOps.UPCAST, axis=1, amt=0), Opt(op=OptOps.UPCAST, axis=1, amt=4), Opt(op=OptOps.LOCAL, axis=0, amt=8),
|
||||
Opt(op=OptOps.UPCAST, axis=1, amt=0), Opt(op=OptOps.UPCAST, axis=0, amt=2)]
|
||||
@@ -1405,7 +1564,19 @@ class TestFloat4(unittest.TestCase):
|
||||
|
||||
def test_half4_load_unrolled(self):
|
||||
# from llama 7B shard 4 gpus
|
||||
ast = LazyOp(op=BufferOps.STORE, src=(LazyOp(op=ReduceOps.SUM, src=(LazyOp(op=UnaryOps.CAST, src=(LazyOp(op=BinaryOps.MUL, src=(LazyOp(op=BufferOps.LOAD, src=(), arg=MemBuffer(idx=1, dtype=dtypes.half, st=ShapeTracker(views=(View(shape=(1, 3, 32000, 1024), strides=(0, 4096, 0, 1), offset=0, mask=None, contiguous=False),)))), LazyOp(op=BufferOps.LOAD, src=(), arg=MemBuffer(idx=2, dtype=dtypes.half, st=ShapeTracker(views=(View(shape=(1, 3, 32000, 1024), strides=(0, 0, 1024, 1), offset=0, mask=None, contiguous=False),))))), arg=None),), arg=dtypes.float),), arg=(3,)),), arg=MemBuffer(idx=0, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(1, 3, 32000, 1), strides=(0, 32000, 1, 0), offset=0, mask=None, contiguous=True),)))) # noqa: E501
|
||||
ast = UOp(UOps.SINK, src=(
|
||||
UOp(UOps.STORE, src=(
|
||||
UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), arg=0),
|
||||
UOp(UOps.SHAPETRACKER, arg=ShapeTracker(views=(View(shape=(1, 3, 32000, 1), strides=(0, 32000, 1, 0), offset=0, mask=None, contiguous=True),))), # noqa: E501
|
||||
UOp(UOps.REDUCE_AXIS, dtypes.float, arg=(ReduceOps.SUM, (3,)), src=(
|
||||
UOp(UOps.CAST, dtypes.float, src=(
|
||||
UOp(UOps.ALU, dtypes.half, arg=BinaryOps.MUL, src=(
|
||||
UOp(UOps.LOAD, dtypes.half, src=(
|
||||
UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.half), arg=1),
|
||||
UOp(UOps.SHAPETRACKER, arg=ShapeTracker(views=(View(shape=(1, 3, 32000, 1024), strides=(0, 4096, 0, 1), offset=0, mask=None, contiguous=False),))),)), # noqa: E501
|
||||
UOp(UOps.LOAD, dtypes.half, src=(
|
||||
UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.half), arg=2),
|
||||
UOp(UOps.SHAPETRACKER, arg=ShapeTracker(views=(View(shape=(1, 3, 32000, 1024), strides=(0, 0, 1024, 1), offset=0, mask=None, contiguous=False),))),)),)),)),)),)),)) # noqa: E501
|
||||
|
||||
# TODO: fix this, expected might change but should be positive
|
||||
for expected, opts in [
|
||||
@@ -1421,7 +1592,23 @@ class TestFloat4(unittest.TestCase):
|
||||
|
||||
def test_float4_acc(self):
|
||||
# from float32 stable diffusion red tinybox
|
||||
ast = LazyOp(op=BufferOps.STORE, src=(LazyOp(op=BinaryOps.ADD, src=(LazyOp(op=ReduceOps.SUM, src=(LazyOp(op=BinaryOps.MUL, src=(LazyOp(op=BufferOps.LOAD, src=(), arg=MemBuffer(idx=1, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(1, 1, 1, 256, 4, 514, 4, 514), strides=(0, 0, 0, 262144, 0, 512, 0, 1), offset=-513, mask=((0, 1), (0, 1), (0, 1), (0, 256), (0, 4), (1, 513), (0, 4), (1, 513)), contiguous=False), View(shape=(1, 1, 128, 512, 512, 256, 3, 3), strides=(0, 0, 0, 2056, 1, 4227136, 1058840, 515), offset=0, mask=None, contiguous=False))))), LazyOp(op=BufferOps.LOAD, src=(), arg=MemBuffer(idx=2, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(1, 1, 128, 512, 512, 256, 3, 3), strides=(0, 0, 2304, 0, 0, 9, 3, 1), offset=0, mask=None, contiguous=False),))))), arg=None),), arg=(5, 6, 7)), LazyOp(op=BufferOps.LOAD, src=(), arg=MemBuffer(idx=3, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(1, 1, 128, 512, 512, 1, 1, 1), strides=(0, 0, 1, 0, 0, 0, 0, 0), offset=0, mask=None, contiguous=False),))))), arg=None),), arg=MemBuffer(idx=0, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(1, 1, 128, 512, 512, 1, 1, 1), strides=(0, 0, 262144, 512, 1, 0, 0, 0), offset=0, mask=None, contiguous=True),)))) # noqa: E501
|
||||
ast = UOp(UOps.SINK, src=(
|
||||
UOp(UOps.STORE, src=(
|
||||
UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), arg=0),
|
||||
UOp(UOps.SHAPETRACKER, arg=ShapeTracker(views=(View(shape=(1, 1, 128, 512, 512, 1, 1, 1), strides=(0, 0, 262144, 512, 1, 0, 0, 0), offset=0, mask=None, contiguous=True),))), # noqa: E501
|
||||
UOp(UOps.ALU, dtypes.float, arg=BinaryOps.ADD, src=(
|
||||
UOp(UOps.REDUCE_AXIS, dtypes.float, arg=(ReduceOps.SUM, (5, 6, 7)), src=(
|
||||
UOp(UOps.ALU, dtypes.float, arg=BinaryOps.MUL, src=(
|
||||
UOp(UOps.LOAD, dtypes.float, src=(
|
||||
UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), arg=1),
|
||||
UOp(UOps.SHAPETRACKER, arg=ShapeTracker(views=(View(shape=(1, 1, 1, 256, 4, 514, 4, 514), strides=(0, 0, 0, 262144, 0, 512, 0, 1), offset=-513, mask=((0, 1), (0, 1), (0, 1), (0, 256), (0, 4), (1, 513), (0, 4), (1, 513)), contiguous=False), View(shape=(1, 1, 128, 512, 512, 256, 3, 3), strides=(0, 0, 0, 2056, 1, 4227136, 1058840, 515), offset=0, mask=None, contiguous=False)))),)), # noqa: E501
|
||||
UOp(UOps.LOAD, dtypes.float, src=(
|
||||
UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), arg=2),
|
||||
UOp(UOps.SHAPETRACKER, arg=ShapeTracker(views=(View(shape=(1, 1, 128, 512, 512, 256, 3, 3), strides=(0, 0, 2304, 0, 0, 9, 3, 1), offset=0, mask=None, contiguous=False),))),)),)),)), # noqa: E501
|
||||
UOp(UOps.LOAD, dtypes.float, src=(
|
||||
UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), arg=3),
|
||||
UOp(UOps.SHAPETRACKER, arg=ShapeTracker(views=(View(shape=(1, 1, 128, 512, 512, 1, 1, 1), strides=(0, 0, 1, 0, 0, 0, 0, 0), offset=0, mask=None, contiguous=False),))),)),)),)),)) # noqa: E501
|
||||
|
||||
for expected, opts in [
|
||||
(1, [Opt(op=OptOps.UPCAST, axis=2, amt=4)]),
|
||||
(4, [Opt(op=OptOps.UPCAST, axis=2, amt=4), Opt(op=OptOps.UPCAST, axis=0, amt=4)]),
|
||||
@@ -1435,7 +1622,16 @@ class TestFloat4(unittest.TestCase):
|
||||
@unittest.skipUnless(Device[Device.DEFAULT].renderer.has_local, "test requires locals")
|
||||
def test_float2_acc(self):
|
||||
# from resnet
|
||||
ast = LazyOp(op=BufferOps.STORE, src=(LazyOp(op=UnaryOps.CAST, src=(LazyOp(op=ReduceOps.SUM, src=(LazyOp(op=UnaryOps.CAST, src=(LazyOp(op=BufferOps.LOAD, src=(), arg=MemBuffer(idx=1, dtype=dtypes.half, st=ShapeTracker(views=(View(shape=(256, 64, 3, 56, 2, 3, 56, 2), strides=(1806336, 28224, 3, 504, 0, 1, 9, 0), offset=0, mask=((0, 256), (0, 64), (0, 3), (0, 56), (0, 1), (0, 3), (0, 56), (0, 1)), contiguous=False), View(shape=(256, 64, 3, 115, 3, 115), strides=(7225344, 112896, 37632, 336, 112, 1), offset=0, mask=((0, 256), (0, 64), (0, 3), (0, 112), (0, 3), (0, 112)), contiguous=False), View(shape=(256, 64, 456, 456), strides=(7617600, 119025, 345, 1), offset=0, mask=((0, 256), (0, 64), (0, 345), (0, 345)), contiguous=False), View(shape=(1, 256, 1, 64, 4, 114, 4, 114), strides=(0, 13307904, 0, 207936, 51984, 456, 114, 1), offset=0, mask=None, contiguous=True))))),), arg=dtypes.float),), arg=(4, 6)),), arg=dtypes.half),), arg=MemBuffer(idx=0, dtype=dtypes.half, st=ShapeTracker(views=(View(shape=(1, 256, 1, 64, 1, 114, 1, 114), strides=(0, 831744, 0, 12996, 0, 114, 0, 1), offset=0, mask=None, contiguous=True),)))) # noqa: E501
|
||||
ast = UOp(UOps.SINK, src=(
|
||||
UOp(UOps.STORE, src=(
|
||||
UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.half), arg=0),
|
||||
UOp(UOps.SHAPETRACKER, arg=ShapeTracker(views=(View(shape=(1, 256, 1, 64, 1, 114, 1, 114), strides=(0, 831744, 0, 12996, 0, 114, 0, 1), offset=0, mask=None, contiguous=True),))), # noqa: E501
|
||||
UOp(UOps.CAST, dtypes.half, src=(
|
||||
UOp(UOps.REDUCE_AXIS, dtypes.float, arg=(ReduceOps.SUM, (4, 6)), src=(
|
||||
UOp(UOps.CAST, dtypes.float, src=(
|
||||
UOp(UOps.LOAD, dtypes.half, src=(
|
||||
UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.half), arg=1),
|
||||
UOp(UOps.SHAPETRACKER, arg=ShapeTracker(views=(View(shape=(256, 64, 3, 56, 2, 3, 56, 2), strides=(1806336, 28224, 3, 504, 0, 1, 9, 0), offset=0, mask=((0, 256), (0, 64), (0, 3), (0, 56), (0, 1), (0, 3), (0, 56), (0, 1)), contiguous=False), View(shape=(256, 64, 3, 115, 3, 115), strides=(7225344, 112896, 37632, 336, 112, 1), offset=0, mask=((0, 256), (0, 64), (0, 3), (0, 112), (0, 3), (0, 112)), contiguous=False), View(shape=(256, 64, 456, 456), strides=(7617600, 119025, 345, 1), offset=0, mask=((0, 256), (0, 64), (0, 345), (0, 345)), contiguous=False), View(shape=(1, 256, 1, 64, 4, 114, 4, 114), strides=(0, 13307904, 0, 207936, 51984, 456, 114, 1), offset=0, mask=None, contiguous=True)))),)),)),)),)),)),)) # noqa: E501
|
||||
for expected, opts in [
|
||||
(16, [Opt(op=OptOps.LOCAL, axis=1, amt=16), Opt(op=OptOps.UPCAST, axis=1, amt=0), Opt(op=OptOps.UPCAST, axis=2, amt=2), Opt(op=OptOps.LOCAL, axis=2, amt=3), Opt(op=OptOps.UPCAST, axis=3, amt=4)]), # noqa: E501
|
||||
(4, [Opt(op=OptOps.LOCAL, axis=1, amt=16), Opt(op=OptOps.UPCAST, axis=1, amt=0), Opt(op=OptOps.UPCAST, axis=2, amt=2)]),
|
||||
@@ -1522,10 +1718,9 @@ class TestHandCodedOpts(unittest.TestCase):
|
||||
assert k.local_dims == 1
|
||||
assert k.upcasted == 1
|
||||
|
||||
def helper_linearizer_ast(ast:Union[Tuple[LazyOp, ...], LazyOp, UOp], inputs:List[Tensor], *args, **kwargs):
|
||||
if not isinstance(ast, LazyOp) and not isinstance(ast, UOp): ast = LazyOp(MetaOps.KERNEL, ast)
|
||||
def helper_linearizer_ast(ast:UOp, inputs:List[Tensor], *args, **kwargs):
|
||||
assert isinstance(ast, UOp), "ast must be UOp"
|
||||
inbufs = [x.lazydata.base.buffer for x in inputs]
|
||||
ast = to_uop(ast) if isinstance(ast, LazyOp) else ast
|
||||
outbufs = [Buffer(inbufs[-1].device if inbufs else Device.DEFAULT, out.st_arg.size, cast(DType,out.src[2].dtype)).allocate() \
|
||||
for out in ast.src]
|
||||
return _helper_linearizer_opt_ast(ast, outbufs+inbufs, *args, **kwargs)
|
||||
@@ -1706,7 +1901,23 @@ class TestKernelOpts(unittest.TestCase):
|
||||
|
||||
@unittest.skipUnless(Device[Device.DEFAULT].renderer.tensor_cores, "test requires tensor cores")
|
||||
def test_buf_index_not_found_tensor_core(self):
|
||||
ast = LazyOp(op=BufferOps.STORE, src=(LazyOp(op=ReduceOps.SUM, src=(LazyOp(op=BinaryOps.MUL, src=(LazyOp(op=UnaryOps.CAST, src=(LazyOp(op=BinaryOps.CMPNE, src=(LazyOp(op=BufferOps.LOAD, src=(), arg=MemBuffer(idx=1, dtype=dtypes.int, st=ShapeTracker(views=(View(shape=(1243, 256), strides=(0, 1), offset=0, mask=None, contiguous=False),)))), LazyOp(op=BufferOps.LOAD, src=(), arg=MemBuffer(idx=2, dtype=dtypes.int, st=ShapeTracker(views=(View(shape=(1243, 256), strides=(1, 0), offset=0, mask=None, contiguous=False),))))), arg=None),), arg=dtypes.float), LazyOp(op=BufferOps.LOAD, src=(), arg=MemBuffer(idx=3, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(1243, 256), strides=(1, 0), offset=0, mask=None, contiguous=False),))))), arg=None),), arg=(0,)),), arg=MemBuffer(idx=0, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(1, 256), strides=(0, 1), offset=0, mask=None, contiguous=True),)))) # noqa: E501
|
||||
ast = UOp(UOps.SINK, src=(
|
||||
UOp(UOps.STORE, src=(
|
||||
UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), arg=0, src=()),
|
||||
UOp(UOps.SHAPETRACKER, arg=ShapeTracker(views=(View(shape=(1, 256), strides=(0, 1), offset=0, mask=None, contiguous=True),))),
|
||||
UOp(UOps.REDUCE_AXIS, dtypes.float, arg=(ReduceOps.SUM, (0,)), src=(
|
||||
UOp(UOps.ALU, dtypes.float, arg=BinaryOps.MUL, src=(
|
||||
UOp(UOps.CAST, dtypes.float, src=(
|
||||
UOp(UOps.ALU, dtypes.bool, arg=BinaryOps.CMPNE, src=(
|
||||
UOp(UOps.LOAD, dtypes.int, src=(
|
||||
UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.int), arg=1),
|
||||
UOp(UOps.SHAPETRACKER, arg=ShapeTracker(views=(View(shape=(1243, 256), strides=(0, 1), offset=0, mask=None, contiguous=False),))),)), # noqa: E501
|
||||
UOp(UOps.LOAD, dtypes.int, src=(
|
||||
UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.int), arg=2),
|
||||
UOp(UOps.SHAPETRACKER, arg=ShapeTracker(views=(View(shape=(1243, 256), strides=(1, 0), offset=0, mask=None, contiguous=False),))),)),)),)), # noqa: E501
|
||||
UOp(UOps.LOAD, dtypes.float, src=(
|
||||
UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), arg=3),
|
||||
UOp(UOps.SHAPETRACKER, arg=ShapeTracker(views=(View(shape=(1243, 256), strides=(1, 0), offset=0, mask=None, contiguous=False),))),)),)),)),)),)) # noqa: E501
|
||||
k = Kernel(ast, opts=Device[Device.DEFAULT].renderer)
|
||||
with self.assertRaises(KernelOptError):
|
||||
k.apply_opt(Opt(OptOps.TC, 0, 1))
|
||||
@@ -1880,12 +2091,14 @@ class TestKernelOpts(unittest.TestCase):
|
||||
@unittest.skipUnless(Device[Device.DEFAULT].renderer.has_shared, "test requires shared")
|
||||
def test_padto_group(self):
|
||||
Tensor.manual_seed(0)
|
||||
ld0 = LazyOp(op=BufferOps.LOAD, src=(), arg=MemBuffer(idx=1, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(2, 1, 4, 1, 3, 4, 2, 6, 1, 3), strides=(0, 0, 0, 0, 0, 18, 0, 3, 0, 1), offset=0, mask=None, contiguous=False),)))) # noqa: E501
|
||||
ld1 = LazyOp(op=BufferOps.LOAD, src=(), arg=MemBuffer(idx=2, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(2, 1, 4, 1, 3, 4, 2, 6, 1, 3), strides=(0, 0, 0, 0, 0, 0, 0, 0, 0, 0), offset=0, mask=None, contiguous=False),)))) # noqa: E501
|
||||
ast = LazyOp(op=BufferOps.STORE, src=(LazyOp(op=ReduceOps.SUM, src=(LazyOp(op=BinaryOps.MUL, src=(ld0, ld1)),), arg=(0, 2, 4, 6)),), arg=MemBuffer(idx=0, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(1, 1, 1, 1, 1, 4, 1, 6, 1, 3), strides=(0, 0, 0, 0, 0, 18, 0, 3, 0, 1), offset=0, mask=None, contiguous=True),)))) # noqa: E501
|
||||
g0, g1, g2 = [UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), arg=i) for i in range(3)]
|
||||
ld0 = UOp(UOps.LOAD, dtypes.float, (g1, ShapeTracker(views=(View(shape=(2, 1, 4, 1, 3, 4, 2, 6, 1, 3), strides=(0, 0, 0, 0, 0, 18, 0, 3, 0, 1), offset=0, mask=None, contiguous=False),)).to_uop())) # noqa: E501
|
||||
ld1 = UOp(UOps.LOAD, dtypes.float, (g2, ShapeTracker(views=(View(shape=(2, 1, 4, 1, 3, 4, 2, 6, 1, 3), strides=(0, 0, 0, 0, 0, 0, 0, 0, 0, 0), offset=0, mask=None, contiguous=False),)).to_uop())) # noqa: E501
|
||||
store = UOp(UOps.STORE, src=(g0, ShapeTracker(views=(View(shape=(1, 1, 1, 1, 1, 4, 1, 6, 1, 3), strides=(0, 0, 0, 0, 0, 18, 0, 3, 0, 1), offset=0, mask=None, contiguous=True),)).to_uop(), UOp(UOps.REDUCE_AXIS, dtypes.float, (ld0*ld1,), (ReduceOps.SUM, (0, 2, 4, 6)),))) # noqa: E501
|
||||
sink = UOp(UOps.SINK, src=(store,))
|
||||
data1 = Tensor.randn(2, 1, 4, 1, 3, 4, 2, 6, 1, 3).realize()
|
||||
data2 = Tensor.randn(2, 1, 4, 1, 3, 4, 2, 6, 1, 3).realize()
|
||||
helper_linearizer_ast((ast, ), [data1, data2], opts=[
|
||||
helper_linearizer_ast(sink, [data1, data2], opts=[
|
||||
[Opt(OptOps.PADTO, 0, 32), Opt(OptOps.GROUP, 0, 4)],
|
||||
[Opt(OptOps.PADTO, 0, 32), Opt(OptOps.UPCAST, 0, 8)],
|
||||
[Opt(OptOps.PADTO, 0, 32), Opt(OptOps.UPCAST, 0, 8), Opt(OptOps.GROUP, 0, 4)]
|
||||
|
||||
Reference in New Issue
Block a user