mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-22 13:28:06 -05:00
start multireduce lowerer work (var/std) (#5537)
* multireduce no-opts works * passed test_var_multireduce * cleanup * double reduce * extra check for range_group * more checking for range_groups * cleaning up debug prints * cleanup diff * linters * revert kernel changes * these are uops toposort --------- Co-authored-by: timmy <timmy0x@proton.me>
This commit is contained in:
@@ -101,13 +101,39 @@ class TestLinearizer(unittest.TestCase):
|
||||
assert len(mutable_bufs) == len(stores) == 2
|
||||
assert [u.arg[0] for u in mutable_bufs] == [0, 1]
|
||||
|
||||
@unittest.skip("TODO: fix uops toposort")
|
||||
def test_sum_multireduce(self):
|
||||
Tensor.manual_seed(0)
|
||||
x = Tensor.randn(32, dtype=dtypes.float).realize()
|
||||
first_x = LazyOp(BufferOps.LOAD, (), MemBuffer(1, dtypes.float, x.lazydata.st.reshape((1, 32)).expand((32, 32))))
|
||||
first_reduce = LazyOp(ReduceOps.SUM, (first_x,), (1,))
|
||||
second_x = LazyOp(BufferOps.LOAD, (), MemBuffer(1, dtypes.float, x.lazydata.st.reshape((32, 1))))
|
||||
squares = (second_x-first_reduce)
|
||||
squares_sum = LazyOp(ReduceOps.SUM, (squares,), (0,))
|
||||
store = LazyOp(BufferOps.STORE, (squares_sum,), MemBuffer(0, dtypes.float, ShapeTracker.from_shape((1, 1))))
|
||||
wanna_output = (x.numpy()-x.numpy().sum(-1, keepdims=True)).sum(-1)
|
||||
helper_linearizer_ast((store, ), [x], wanna_output=[wanna_output])
|
||||
|
||||
@unittest.skip("TODO: fix uops toposort")
|
||||
def test_double_sum_multireduce(self):
|
||||
Tensor.manual_seed(0)
|
||||
x = Tensor.randn(2, 32, 4, 16, dtype=dtypes.float).realize()
|
||||
first_x = LazyOp(BufferOps.LOAD, (), MemBuffer(1, dtypes.float, x.lazydata.st.reshape((2, 1, 32, 4, 1, 16)).expand((2, 32, 32, 4, 16, 16))))
|
||||
first_reduce = LazyOp(ReduceOps.SUM, (first_x,), (2,5))
|
||||
second_x = LazyOp(BufferOps.LOAD, (), MemBuffer(1, dtypes.float, x.lazydata.st.reshape((2, 32, 1, 4, 16, 1))))
|
||||
squares = (second_x-first_reduce)
|
||||
squares_sum = LazyOp(ReduceOps.SUM, (squares,), (1,4))
|
||||
store = LazyOp(BufferOps.STORE, (squares_sum,), MemBuffer(0, dtypes.float, ShapeTracker.from_shape((2, 1, 1, 4, 1, 1))))
|
||||
wanna_output = (x.numpy()-x.numpy().sum(axis=(1,3), keepdims=True)).sum(axis=(1,3)).reshape((2,1,1,4,1,1))
|
||||
helper_linearizer_ast((store, ), [x], wanna_output=[wanna_output])
|
||||
|
||||
@unittest.skipIf(CI and Device.DEFAULT in {"PTX", "AMD", "NV"}, "ocelot/remu doesn't have multiple wave syncs yet")
|
||||
@unittest.skip("still broken")
|
||||
@unittest.skip("TODO: fix uops toposort")
|
||||
def test_var_multireduce(self):
|
||||
Tensor.manual_seed(0)
|
||||
x = Tensor.randn(3, 27, 32).realize()
|
||||
x = Tensor.randn(3, 27, 32, dtype=dtypes.float).realize()
|
||||
# push reduce (3, 27, 32) -> (3, 27, 1) -> (3, 27, 32) expand to LOAD
|
||||
first_x = LazyOp(BufferOps.LOAD, (), MemBuffer(1, dtypes.float, x.lazydata.st.reshape((3, 27, 32, 1)).expand((3, 27, 32, 32))))
|
||||
first_x = LazyOp(BufferOps.LOAD, (), MemBuffer(1, dtypes.float, x.lazydata.st.reshape((3, 27, 1, 32)).expand((3, 27, 32, 32))))
|
||||
first_reduce = LazyOp(ReduceOps.SUM, (first_x,), (3,))
|
||||
mean = first_reduce * LazyOp(BufferOps.CONST, (), ConstBuffer(0.03125, dtypes.float, ShapeTracker.from_shape(()).reshape((1, 1, 1, 1)).expand((3, 27, 32, 1)))) # noqa: E501
|
||||
# store = LazyOp(BufferOps.STORE, (mean,), MemBuffer(0, dtypes.float, ShapeTracker.from_shape((3, 27, 32, 1))))
|
||||
@@ -115,11 +141,12 @@ class TestLinearizer(unittest.TestCase):
|
||||
second_x = LazyOp(BufferOps.LOAD, (), MemBuffer(1, dtypes.float, x.lazydata.st.reshape((3, 27, 32, 1))))
|
||||
squares = (second_x-mean)*(second_x-mean)
|
||||
squares_sum = LazyOp(ReduceOps.SUM, (squares,), (2,))
|
||||
store = LazyOp(BufferOps.STORE, (squares_sum,), MemBuffer(0, dtypes.float, ShapeTracker.from_shape((3, 27, 1, 1))))
|
||||
wanna_output = x.numpy().var(axis=2, ddof=0)
|
||||
variance = squares_sum * LazyOp(BufferOps.CONST, (), ConstBuffer(0.03125, dtypes.float, ShapeTracker.from_shape(()).reshape((1, 1, 1, 1)).expand((3, 27, 1, 1)))) # noqa: E501
|
||||
store = LazyOp(BufferOps.STORE, (variance,), MemBuffer(0, dtypes.float, ShapeTracker.from_shape((3, 27, 1, 1))))
|
||||
wanna_output = x.numpy().var(axis=2, ddof=0).reshape((3,27,1,1))
|
||||
helper_linearizer_ast((store, ), [x], wanna_output=[wanna_output])
|
||||
# tinygrad ref
|
||||
y_tiny = x.var(axis=2, correction=0)
|
||||
y_tiny = x.var(axis=2, correction=0).reshape(3,27,1,1)
|
||||
np.testing.assert_allclose(y_tiny.numpy(), wanna_output, atol=1e-4, rtol=1e-4)
|
||||
|
||||
# *** buildup to fused indexing
|
||||
@@ -323,11 +350,21 @@ class TestLinearizer(unittest.TestCase):
|
||||
x = Tensor.randn(8,7).softmax().realize()
|
||||
helper_linearizer_ast(ast, [x], opts=opts, wanna_output=[(x.numpy() - x.numpy().sum(axis=1, keepdims=True)).sum(axis=1)])
|
||||
|
||||
@unittest.skip("AST has implicit movement ops")
|
||||
@unittest.skip("TODO: fix uops toposort")
|
||||
def test_multireduce_unroll(self):
|
||||
# unrolled multireduceops will cause an issue where and reduceop following another reduceop will need to bring the "unroll" back:
|
||||
# ex you unroll into four values, the four values sum, then you need to four operations on the sum for the next reduceop
|
||||
ast = LazyOp(op=BufferOps.STORE, src=(LazyOp(op=ReduceOps.SUM, src=(LazyOp(op=BinaryOps.ADD, src=(LazyOp(op=BufferOps.LOAD, src=(), arg=MemBuffer(idx=1, dtype=dtypes.float32, st=ShapeTracker(views=(View(shape=(2, 12), strides=(12, 1), offset=0, mask=None, contiguous=True),)))), LazyOp(op=UnaryOps.NEG, src=(LazyOp(op=ReduceOps.SUM, src=(LazyOp(op=BufferOps.LOAD, src=(), arg=MemBuffer(idx=1, dtype=dtypes.float32, st=ShapeTracker(views=(View(shape=(2, 12), strides=(12, 1), offset=0, mask=None, contiguous=True),),))),), arg=(1,)),), arg=None),)),), arg=(1,)),), arg=MemBuffer(idx=0, dtype=dtypes.float32, st=ShapeTracker(views=(View(shape=(2, 1), strides=(1, 0), offset=0, mask=None, contiguous=True),)))), # noqa: E501
|
||||
Tensor.manual_seed(0)
|
||||
x = Tensor.randn(3, 27, 12, dtype=dtypes.float).realize()
|
||||
first_x = LazyOp(BufferOps.LOAD, (), MemBuffer(1, dtypes.float, x.lazydata.st.reshape((3, 27, 1, 12)).expand((3, 27, 12, 12))))
|
||||
first_reduce = LazyOp(ReduceOps.SUM, (first_x,), (3,))
|
||||
mean = first_reduce * LazyOp(BufferOps.CONST, (), ConstBuffer(1/12, dtypes.float, ShapeTracker.from_shape(()).reshape((1, 1, 1, 1)).expand((3, 27, 12, 1)))) # noqa: E501
|
||||
second_x = LazyOp(BufferOps.LOAD, (), MemBuffer(1, dtypes.float, x.lazydata.st.reshape((3, 27, 12, 1))))
|
||||
squares = (second_x-mean)*(second_x-mean)
|
||||
squares_sum = LazyOp(ReduceOps.SUM, (squares,), (2,))
|
||||
variance = squares_sum * LazyOp(BufferOps.CONST, (), ConstBuffer(1/12, dtypes.float, ShapeTracker.from_shape(()).reshape((1, 1, 1, 1)).expand((3, 27, 1, 1)))) # noqa: E501
|
||||
store = LazyOp(BufferOps.STORE, (variance,), MemBuffer(0, dtypes.float, ShapeTracker.from_shape((3, 27, 1, 1))))
|
||||
wanna_output = x.numpy().var(axis=2, ddof=0).reshape((3,27,1,1))
|
||||
opts = [
|
||||
[Opt(op=OptOps.UNROLL, axis=0, amt=12)],
|
||||
[Opt(op=OptOps.UNROLL, axis=0, amt=6)],
|
||||
@@ -335,8 +372,7 @@ class TestLinearizer(unittest.TestCase):
|
||||
[Opt(op=OptOps.UNROLL, axis=0, amt=3)],
|
||||
[Opt(op=OptOps.UNROLL, axis=0, amt=2)],
|
||||
]
|
||||
x = Tensor.randn(2,12).softmax().realize()
|
||||
helper_linearizer_ast(ast, [x], opts=opts, wanna_output=[(x.numpy() - x.numpy().sum(axis=1, keepdims=True)).sum(axis=1)])
|
||||
helper_linearizer_ast((store,), [x], opts=opts, wanna_output=[wanna_output])
|
||||
|
||||
@unittest.skipIf(CI and Device.DEFAULT in {"AMD"}, "AMD CI doesn't support multiple sync threads yet")
|
||||
@unittest.skip("AST has implicit movement ops")
|
||||
@@ -357,26 +393,61 @@ class TestLinearizer(unittest.TestCase):
|
||||
[((x.numpy() - x.numpy().mean(axis=2, keepdims=True))/x.numpy().std(axis=2, keepdims=True, ddof=0)).sum(axis=2).reshape(-1)])
|
||||
|
||||
@unittest.skipIf(CI and Device.DEFAULT in {"AMD"}, "AMD CI doesn't support multiple sync threads yet")
|
||||
@unittest.skip("AST has implicit movement ops")
|
||||
@unittest.skip("TODO: fix uops toposort")
|
||||
def test_mean_std_multireduce(self):
|
||||
ast = LazyOp(op=BufferOps.STORE, src=(LazyOp(op=UnaryOps.SQRT, src=(LazyOp(op=BinaryOps.MUL, src=(LazyOp(op=ReduceOps.SUM, src=(LazyOp(op=BinaryOps.MUL, src=(LazyOp(op=BinaryOps.ADD, src=(LazyOp(op=BufferOps.LOAD, src=(), arg=MemBuffer(idx=1, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(15, 25, 35), strides=(875, 35, 1), offset=0, mask=None, contiguous=True),)))), LazyOp(op=UnaryOps.NEG, src=(LazyOp(op=BinaryOps.MUL, src=(LazyOp(op=ReduceOps.SUM, src=(LazyOp(op=BufferOps.LOAD, src=(), arg=MemBuffer(idx=1, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(15, 25, 35), strides=(875, 35, 1), offset=0, mask=None, contiguous=True),)))),), arg=(0, 1, 2)), LazyOp(op=BufferOps.CONST, src=(), arg=ConstBuffer(val=7.619047619047618e-05, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(15, 25, 35), strides=(0, 0, 0), offset=0, mask=None, contiguous=False),))))), arg=None),),arg=None)), arg=None), LazyOp(op=BinaryOps.ADD, src=(LazyOp(op=BufferOps.LOAD, src=(), arg=MemBuffer(idx=1, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(15, 25, 35), strides=(875, 35, 1), offset=0, mask=None, contiguous=True),)))), LazyOp(op=UnaryOps.NEG, src=(LazyOp(op=BinaryOps.MUL, src=(LazyOp(op=ReduceOps.SUM, src=(LazyOp(op=BufferOps.LOAD, src=(), arg=MemBuffer(idx=1, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(15, 25, 35), strides=(875, 35, 1), offset=0, mask=None, contiguous=True),)))),), arg=(0, 1, 2)), LazyOp(op=BufferOps.CONST, src=(), arg=ConstBuffer(val=7.619047619047618e-05, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(15, 25, 35), strides=(0, 0, 0), offset=0, mask=None, contiguous=False),))))), arg=None),),arg=None)), arg=None)), arg=None),), arg=(0, 1, 2)), LazyOp(op=BufferOps.CONST, src=(), arg=ConstBuffer(val=7.619628162145687e-05, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(1, 1, 1), strides=(0, 0, 0), offset=0, mask=None, contiguous=True),))))), arg=None),), arg=None),), arg=MemBuffer(idx=0, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(1, 1, 1), strides=(0, 0, 0), offset=0, mask=None, contiguous=True),)))), # noqa: E501
|
||||
x = Tensor.randn(15, 25, 35).realize()
|
||||
helper_linearizer_ast(ast, [x], wanna_output=[x.numpy().std()])
|
||||
Tensor.manual_seed(0)
|
||||
x = Tensor.randn(15, 25, 35, dtype=dtypes.float).realize()
|
||||
first_x = LazyOp(BufferOps.LOAD, (), MemBuffer(1, dtypes.float, x.lazydata.st.reshape((15, 25, 1, 35)).expand((15, 25, 35, 35))))
|
||||
first_reduce = LazyOp(ReduceOps.SUM, (first_x,), (3,))
|
||||
mean = first_reduce * LazyOp(BufferOps.CONST, (), ConstBuffer(1/35, dtypes.float, ShapeTracker.from_shape(()).reshape((1, 1, 1, 1)).expand((15, 25, 35, 1)))) # noqa: E501
|
||||
second_x = LazyOp(BufferOps.LOAD, (), MemBuffer(1, dtypes.float, x.lazydata.st.reshape((15, 25, 35, 1))))
|
||||
squares = (second_x-mean)*(second_x-mean)
|
||||
squares_sum = LazyOp(ReduceOps.SUM, (squares,), (2,))
|
||||
variance = squares_sum * LazyOp(BufferOps.CONST, (), ConstBuffer(1/35, dtypes.float, ShapeTracker.from_shape(()).reshape((1, 1, 1, 1)).expand((15, 25, 1, 1)))) # noqa: E501
|
||||
std = LazyOp(UnaryOps.SQRT, (variance,), None)
|
||||
store = LazyOp(BufferOps.STORE, (std,), MemBuffer(0, dtypes.float, ShapeTracker.from_shape((15, 25, 1, 1))))
|
||||
wanna_output = x.numpy().std(axis=2, ddof=0).reshape((15,25,1,1))
|
||||
helper_linearizer_ast((store,), [x], wanna_output=[wanna_output])
|
||||
|
||||
@unittest.skipIf(CI and Device.DEFAULT in {"AMD"}, "AMD CI doesn't support multiple sync threads yet")
|
||||
@unittest.skip("AST has implicit movement ops")
|
||||
@unittest.skip("TODO: fix uops toposort")
|
||||
def test_mean_std_multireduce_mid_dim(self):
|
||||
ast = LazyOp(op=BufferOps.STORE, src=(LazyOp(op=UnaryOps.SQRT, src=(LazyOp(op=BinaryOps.MUL, src=(LazyOp(op=ReduceOps.SUM, src=(LazyOp(op=BinaryOps.MUL, src=(LazyOp(op=BinaryOps.ADD, src=(LazyOp(op=BufferOps.LOAD, src=(), arg=MemBuffer(idx=1, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(15, 25, 35), strides=(875, 35, 1), offset=0, mask=None, contiguous=True),)))), LazyOp(op=UnaryOps.NEG, src=(LazyOp(op=BinaryOps.MUL, src=(LazyOp(op=ReduceOps.SUM, src=(LazyOp(op=BufferOps.LOAD, src=(), arg=MemBuffer(idx=1, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(15, 25, 35), strides=(875, 35, 1), offset=0, mask=None, contiguous=True),)))),), arg=(1,)), LazyOp(op=BufferOps.CONST, src=(), arg=ConstBuffer(val=0.04, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(15, 25, 35), strides=(0, 0, 0), offset=0, mask=None, contiguous=False),))))), arg=None),),arg=None)), arg=None), LazyOp(op=BinaryOps.ADD, src=(LazyOp(op=BufferOps.LOAD, src=(), arg=MemBuffer(idx=1, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(15, 25, 35), strides=(875, 35, 1), offset=0, mask=None, contiguous=True),)))), LazyOp(op=UnaryOps.NEG, src=(LazyOp(op=BinaryOps.MUL, src=(LazyOp(op=ReduceOps.SUM, src=(LazyOp(op=BufferOps.LOAD, src=(), arg=MemBuffer(idx=1, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(15, 25, 35), strides=(875, 35, 1), offset=0, mask=None, contiguous=True),)))),), arg=(1,)), LazyOp(op=BufferOps.CONST, src=(), arg=ConstBuffer(val=0.04, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(15, 25, 35), strides=(0, 0, 0), offset=0, mask=None, contiguous=False),))))), arg=None),),arg=None)), arg=None)), arg=None),), arg=(1,)), LazyOp(op=BufferOps.CONST, src=(), arg=ConstBuffer(val=0.04, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(1, 1, 1), strides=(0, 0, 0), offset=0, mask=None, contiguous=True),))))), arg=None),), arg=None),), arg=MemBuffer(idx=0, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(15, 1, 35), strides=(35, 35, 1), offset=0, mask=None, contiguous=True),)))), # noqa: E501
|
||||
x = Tensor.randn(15, 25, 35).realize()
|
||||
helper_linearizer_ast(ast, [x], wanna_output=[x.numpy().std(1).reshape(-1)])
|
||||
Tensor.manual_seed(0)
|
||||
x = Tensor.randn(15, 25, 35, dtype=dtypes.float).realize()
|
||||
first_x = LazyOp(BufferOps.LOAD, (), MemBuffer(1, dtypes.float, x.lazydata.st.reshape((15, 1, 25, 35)).expand((15, 25, 25, 35))))
|
||||
first_reduce = LazyOp(ReduceOps.SUM, (first_x,), (2,))
|
||||
mean = first_reduce * LazyOp(BufferOps.CONST, (), ConstBuffer(0.04, dtypes.float, ShapeTracker.from_shape(()).reshape((1, 1, 1, 1)).expand((15, 25, 1, 35)))) # noqa: E501
|
||||
second_x = LazyOp(BufferOps.LOAD, (), MemBuffer(1, dtypes.float, x.lazydata.st.reshape((15, 25, 1, 35))))
|
||||
squares = (second_x-mean)*(second_x-mean)
|
||||
squares_sum = LazyOp(ReduceOps.SUM, (squares,), (1,))
|
||||
variance = squares_sum * LazyOp(BufferOps.CONST, (), ConstBuffer(0.04, dtypes.float, ShapeTracker.from_shape(()).reshape((1, 1, 1, 1)).expand((15, 1, 1, 35)))) # noqa: E501
|
||||
std = LazyOp(UnaryOps.SQRT, (variance,), None)
|
||||
store = LazyOp(BufferOps.STORE, (std,), MemBuffer(0, dtypes.float, ShapeTracker.from_shape((15, 1, 1, 35))))
|
||||
wanna_output = x.numpy().std(axis=1, ddof=0).reshape((15,1,1,35))
|
||||
helper_linearizer_ast((store,), [x], wanna_output=[wanna_output])
|
||||
|
||||
@unittest.skipIf(CI and Device.DEFAULT in {"AMD"}, "AMD CI doesn't support multiple sync threads yet")
|
||||
@unittest.skip("AST has implicit movement ops")
|
||||
@unittest.expectedFailure
|
||||
def test_mean_std_multireduce_multiout(self):
|
||||
std = LazyOp(op=BufferOps.STORE, src=(LazyOp(op=UnaryOps.SQRT, src=(LazyOp(op=BinaryOps.MUL, src=(LazyOp(op=ReduceOps.SUM, src=(LazyOp(op=BinaryOps.MUL, src=(LazyOp(op=BinaryOps.ADD, src=(LazyOp(op=BufferOps.LOAD, src=(), arg=MemBuffer(idx=2, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(15, 25, 35), strides=(875, 35, 1), offset=0, mask=None, contiguous=True),)))), LazyOp(op=UnaryOps.NEG, src=(LazyOp(op=BinaryOps.MUL, src=(LazyOp(op=ReduceOps.SUM, src=(LazyOp(op=BufferOps.LOAD, src=(), arg=MemBuffer(idx=2, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(15, 25, 35), strides=(875, 35, 1), offset=0, mask=None, contiguous=True),)))),), arg=(0, 1, 2)), LazyOp(op=BufferOps.CONST, src=(), arg=ConstBuffer(val=7.619047619047618e-05, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(15, 25, 35), strides=(0, 0, 0), offset=0, mask=None, contiguous=False),))))), arg=None),),arg=None)), arg=None), LazyOp(op=BinaryOps.ADD, src=(LazyOp(op=BufferOps.LOAD, src=(), arg=MemBuffer(idx=2, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(15, 25, 35), strides=(875, 35, 1), offset=0, mask=None, contiguous=True),)))), LazyOp(op=UnaryOps.NEG, src=(LazyOp(op=BinaryOps.MUL, src=(LazyOp(op=ReduceOps.SUM, src=(LazyOp(op=BufferOps.LOAD, src=(), arg=MemBuffer(idx=2, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(15, 25, 35), strides=(875, 35, 1), offset=0, mask=None, contiguous=True),)))),), arg=(0, 1, 2)), LazyOp(op=BufferOps.CONST, src=(), arg=ConstBuffer(val=7.619047619047618e-05, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(15, 25, 35), strides=(0, 0, 0), offset=0, mask=None, contiguous=False),))))), arg=None),),arg=None)), arg=None)), arg=None),), arg=(0, 1, 2)), LazyOp(op=BufferOps.CONST, src=(), arg=ConstBuffer(val=7.619628162145687e-05, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(1, 1, 1), strides=(0, 0, 0), offset=0, mask=None, contiguous=True),))))), arg=None),), arg=None),), arg=MemBuffer(idx=0, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(1, 1, 1), strides=(0, 0, 0), offset=0, mask=None, contiguous=True),)))) # noqa: E501
|
||||
mean = LazyOp(op=BufferOps.STORE, src=(LazyOp(op=BinaryOps.MUL, src=(LazyOp(op=ReduceOps.SUM, src=(LazyOp(op=BufferOps.LOAD, src=(), arg=MemBuffer(idx=2, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(15, 25, 35), strides=(875, 35, 1), offset=0, mask=None, contiguous=True),)))),), arg=(0, 1, 2)), LazyOp(op=BufferOps.CONST, src=(), arg=ConstBuffer(val=7.619047619047618e-05, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(15, 25, 35), strides=(0, 0, 0), offset=0, mask=None, contiguous=False),))))), arg=None),), arg=MemBuffer(idx=1, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(1, 1, 1), strides=(0, 0, 0), offset=0, mask=None, contiguous=True),)))) # noqa: E501
|
||||
x = Tensor.randn(15, 25, 35).realize()
|
||||
helper_linearizer_ast((std,mean), [x], wanna_output=[x.numpy().std(), x.numpy().mean()])
|
||||
Tensor.manual_seed(0)
|
||||
x = Tensor.randn(15, 25, 35, dtype=dtypes.float).realize()
|
||||
first_x = LazyOp(BufferOps.LOAD, (), MemBuffer(2, dtypes.float, x.lazydata.st.reshape((15, 25, 1, 35)).expand((15, 25, 35, 35))))
|
||||
first_reduce = LazyOp(ReduceOps.SUM, (first_x,), (3,))
|
||||
mean = first_reduce * LazyOp(BufferOps.CONST, (), ConstBuffer(1/35, dtypes.float, ShapeTracker.from_shape(()).reshape((1, 1, 1, 1)).expand((15, 25, 35, 1)))) # noqa: E501
|
||||
second_x = LazyOp(BufferOps.LOAD, (), MemBuffer(2, dtypes.float, x.lazydata.st.reshape((15, 25, 35, 1))))
|
||||
squares = (second_x-mean)*(second_x-mean)
|
||||
squares_sum = LazyOp(ReduceOps.SUM, (squares,), (2,))
|
||||
variance = squares_sum * LazyOp(BufferOps.CONST, (), ConstBuffer(1/35, dtypes.float, ShapeTracker.from_shape(()).reshape((1, 1, 1, 1)).expand((15, 25, 1, 1)))) # noqa: E501
|
||||
std = LazyOp(UnaryOps.SQRT, (variance,), None)
|
||||
third_reduce = LazyOp(ReduceOps.SUM, (second_x,), (2,))
|
||||
mean_out = third_reduce * LazyOp(BufferOps.CONST, (), ConstBuffer(1/35, dtypes.float, ShapeTracker.from_shape(()).reshape((1, 1, 1, 1)).expand((15, 25, 1, 1)))) # noqa: E501
|
||||
store_mean = LazyOp(BufferOps.STORE, (mean_out,), MemBuffer(1, dtypes.float, ShapeTracker.from_shape((15,25,1,1))))
|
||||
store_std = LazyOp(BufferOps.STORE, (std,), MemBuffer(0, dtypes.float, ShapeTracker.from_shape((15, 25, 1, 1))))
|
||||
wanna_output = [x.numpy().std(axis=2, ddof=0).reshape(15,25,1,1), x.numpy().mean(axis=2).reshape(15,25,1,1)]
|
||||
lins = helper_linearizer_ast((store_std,store_mean), [x], wanna_output=wanna_output)
|
||||
|
||||
for k in lins:
|
||||
assert len([u for u in k.uops if u.op is UOps.DEFINE_ACC]) == 2, "got more than two accs (didn't reuse the mean reduce)"
|
||||
|
||||
@unittest.skipIf(CI and Device.DEFAULT in {"AMD"}, "AMD CI doesn't support multiple sync threads yet")
|
||||
@unittest.skip("AST has implicit movement ops")
|
||||
|
||||
Reference in New Issue
Block a user