From 61ee02e93dc953a2cb317ddab59fb85b1d5acd8f Mon Sep 17 00:00:00 2001 From: qazal <77887910+Qazalin@users.noreply.github.com> Date: Thu, 18 Jul 2024 04:43:46 +0800 Subject: [PATCH] start multireduce lowerer work (var/std) (#5537) * multireduce no-opts works * passed test_var_multireduce * cleanup * double reduce * extra check for range_group * more checking for range_groups * cleaning up debug prints * cleanup diff * linters * revert kernel changes * these are uops toposort --------- Co-authored-by: timmy --- test/test_linearizer.py | 117 ++++++++++++++++++++++++++++++++-------- 1 file changed, 94 insertions(+), 23 deletions(-) diff --git a/test/test_linearizer.py b/test/test_linearizer.py index 4322c24ead..f3e52543b3 100644 --- a/test/test_linearizer.py +++ b/test/test_linearizer.py @@ -101,13 +101,39 @@ class TestLinearizer(unittest.TestCase): assert len(mutable_bufs) == len(stores) == 2 assert [u.arg[0] for u in mutable_bufs] == [0, 1] + @unittest.skip("TODO: fix uops toposort") + def test_sum_multireduce(self): + Tensor.manual_seed(0) + x = Tensor.randn(32, dtype=dtypes.float).realize() + first_x = LazyOp(BufferOps.LOAD, (), MemBuffer(1, dtypes.float, x.lazydata.st.reshape((1, 32)).expand((32, 32)))) + first_reduce = LazyOp(ReduceOps.SUM, (first_x,), (1,)) + second_x = LazyOp(BufferOps.LOAD, (), MemBuffer(1, dtypes.float, x.lazydata.st.reshape((32, 1)))) + squares = (second_x-first_reduce) + squares_sum = LazyOp(ReduceOps.SUM, (squares,), (0,)) + store = LazyOp(BufferOps.STORE, (squares_sum,), MemBuffer(0, dtypes.float, ShapeTracker.from_shape((1, 1)))) + wanna_output = (x.numpy()-x.numpy().sum(-1, keepdims=True)).sum(-1) + helper_linearizer_ast((store, ), [x], wanna_output=[wanna_output]) + + @unittest.skip("TODO: fix uops toposort") + def test_double_sum_multireduce(self): + Tensor.manual_seed(0) + x = Tensor.randn(2, 32, 4, 16, dtype=dtypes.float).realize() + first_x = LazyOp(BufferOps.LOAD, (), MemBuffer(1, dtypes.float, x.lazydata.st.reshape((2, 1, 32, 4, 1, 16)).expand((2, 32, 32, 4, 16, 16)))) + first_reduce = LazyOp(ReduceOps.SUM, (first_x,), (2,5)) + second_x = LazyOp(BufferOps.LOAD, (), MemBuffer(1, dtypes.float, x.lazydata.st.reshape((2, 32, 1, 4, 16, 1)))) + squares = (second_x-first_reduce) + squares_sum = LazyOp(ReduceOps.SUM, (squares,), (1,4)) + store = LazyOp(BufferOps.STORE, (squares_sum,), MemBuffer(0, dtypes.float, ShapeTracker.from_shape((2, 1, 1, 4, 1, 1)))) + wanna_output = (x.numpy()-x.numpy().sum(axis=(1,3), keepdims=True)).sum(axis=(1,3)).reshape((2,1,1,4,1,1)) + helper_linearizer_ast((store, ), [x], wanna_output=[wanna_output]) + @unittest.skipIf(CI and Device.DEFAULT in {"PTX", "AMD", "NV"}, "ocelot/remu doesn't have multiple wave syncs yet") - @unittest.skip("still broken") + @unittest.skip("TODO: fix uops toposort") def test_var_multireduce(self): Tensor.manual_seed(0) - x = Tensor.randn(3, 27, 32).realize() + x = Tensor.randn(3, 27, 32, dtype=dtypes.float).realize() # push reduce (3, 27, 32) -> (3, 27, 1) -> (3, 27, 32) expand to LOAD - first_x = LazyOp(BufferOps.LOAD, (), MemBuffer(1, dtypes.float, x.lazydata.st.reshape((3, 27, 32, 1)).expand((3, 27, 32, 32)))) + first_x = LazyOp(BufferOps.LOAD, (), MemBuffer(1, dtypes.float, x.lazydata.st.reshape((3, 27, 1, 32)).expand((3, 27, 32, 32)))) first_reduce = LazyOp(ReduceOps.SUM, (first_x,), (3,)) mean = first_reduce * LazyOp(BufferOps.CONST, (), ConstBuffer(0.03125, dtypes.float, ShapeTracker.from_shape(()).reshape((1, 1, 1, 1)).expand((3, 27, 32, 1)))) # noqa: E501 # store = LazyOp(BufferOps.STORE, (mean,), MemBuffer(0, dtypes.float, ShapeTracker.from_shape((3, 27, 32, 1)))) @@ -115,11 +141,12 @@ class TestLinearizer(unittest.TestCase): second_x = LazyOp(BufferOps.LOAD, (), MemBuffer(1, dtypes.float, x.lazydata.st.reshape((3, 27, 32, 1)))) squares = (second_x-mean)*(second_x-mean) squares_sum = LazyOp(ReduceOps.SUM, (squares,), (2,)) - store = LazyOp(BufferOps.STORE, (squares_sum,), MemBuffer(0, dtypes.float, ShapeTracker.from_shape((3, 27, 1, 1)))) - wanna_output = x.numpy().var(axis=2, ddof=0) + variance = squares_sum * LazyOp(BufferOps.CONST, (), ConstBuffer(0.03125, dtypes.float, ShapeTracker.from_shape(()).reshape((1, 1, 1, 1)).expand((3, 27, 1, 1)))) # noqa: E501 + store = LazyOp(BufferOps.STORE, (variance,), MemBuffer(0, dtypes.float, ShapeTracker.from_shape((3, 27, 1, 1)))) + wanna_output = x.numpy().var(axis=2, ddof=0).reshape((3,27,1,1)) helper_linearizer_ast((store, ), [x], wanna_output=[wanna_output]) # tinygrad ref - y_tiny = x.var(axis=2, correction=0) + y_tiny = x.var(axis=2, correction=0).reshape(3,27,1,1) np.testing.assert_allclose(y_tiny.numpy(), wanna_output, atol=1e-4, rtol=1e-4) # *** buildup to fused indexing @@ -323,11 +350,21 @@ class TestLinearizer(unittest.TestCase): x = Tensor.randn(8,7).softmax().realize() helper_linearizer_ast(ast, [x], opts=opts, wanna_output=[(x.numpy() - x.numpy().sum(axis=1, keepdims=True)).sum(axis=1)]) - @unittest.skip("AST has implicit movement ops") + @unittest.skip("TODO: fix uops toposort") def test_multireduce_unroll(self): # unrolled multireduceops will cause an issue where and reduceop following another reduceop will need to bring the "unroll" back: # ex you unroll into four values, the four values sum, then you need to four operations on the sum for the next reduceop - ast = LazyOp(op=BufferOps.STORE, src=(LazyOp(op=ReduceOps.SUM, src=(LazyOp(op=BinaryOps.ADD, src=(LazyOp(op=BufferOps.LOAD, src=(), arg=MemBuffer(idx=1, dtype=dtypes.float32, st=ShapeTracker(views=(View(shape=(2, 12), strides=(12, 1), offset=0, mask=None, contiguous=True),)))), LazyOp(op=UnaryOps.NEG, src=(LazyOp(op=ReduceOps.SUM, src=(LazyOp(op=BufferOps.LOAD, src=(), arg=MemBuffer(idx=1, dtype=dtypes.float32, st=ShapeTracker(views=(View(shape=(2, 12), strides=(12, 1), offset=0, mask=None, contiguous=True),),))),), arg=(1,)),), arg=None),)),), arg=(1,)),), arg=MemBuffer(idx=0, dtype=dtypes.float32, st=ShapeTracker(views=(View(shape=(2, 1), strides=(1, 0), offset=0, mask=None, contiguous=True),)))), # noqa: E501 + Tensor.manual_seed(0) + x = Tensor.randn(3, 27, 12, dtype=dtypes.float).realize() + first_x = LazyOp(BufferOps.LOAD, (), MemBuffer(1, dtypes.float, x.lazydata.st.reshape((3, 27, 1, 12)).expand((3, 27, 12, 12)))) + first_reduce = LazyOp(ReduceOps.SUM, (first_x,), (3,)) + mean = first_reduce * LazyOp(BufferOps.CONST, (), ConstBuffer(1/12, dtypes.float, ShapeTracker.from_shape(()).reshape((1, 1, 1, 1)).expand((3, 27, 12, 1)))) # noqa: E501 + second_x = LazyOp(BufferOps.LOAD, (), MemBuffer(1, dtypes.float, x.lazydata.st.reshape((3, 27, 12, 1)))) + squares = (second_x-mean)*(second_x-mean) + squares_sum = LazyOp(ReduceOps.SUM, (squares,), (2,)) + variance = squares_sum * LazyOp(BufferOps.CONST, (), ConstBuffer(1/12, dtypes.float, ShapeTracker.from_shape(()).reshape((1, 1, 1, 1)).expand((3, 27, 1, 1)))) # noqa: E501 + store = LazyOp(BufferOps.STORE, (variance,), MemBuffer(0, dtypes.float, ShapeTracker.from_shape((3, 27, 1, 1)))) + wanna_output = x.numpy().var(axis=2, ddof=0).reshape((3,27,1,1)) opts = [ [Opt(op=OptOps.UNROLL, axis=0, amt=12)], [Opt(op=OptOps.UNROLL, axis=0, amt=6)], @@ -335,8 +372,7 @@ class TestLinearizer(unittest.TestCase): [Opt(op=OptOps.UNROLL, axis=0, amt=3)], [Opt(op=OptOps.UNROLL, axis=0, amt=2)], ] - x = Tensor.randn(2,12).softmax().realize() - helper_linearizer_ast(ast, [x], opts=opts, wanna_output=[(x.numpy() - x.numpy().sum(axis=1, keepdims=True)).sum(axis=1)]) + helper_linearizer_ast((store,), [x], opts=opts, wanna_output=[wanna_output]) @unittest.skipIf(CI and Device.DEFAULT in {"AMD"}, "AMD CI doesn't support multiple sync threads yet") @unittest.skip("AST has implicit movement ops") @@ -357,26 +393,61 @@ class TestLinearizer(unittest.TestCase): [((x.numpy() - x.numpy().mean(axis=2, keepdims=True))/x.numpy().std(axis=2, keepdims=True, ddof=0)).sum(axis=2).reshape(-1)]) @unittest.skipIf(CI and Device.DEFAULT in {"AMD"}, "AMD CI doesn't support multiple sync threads yet") - @unittest.skip("AST has implicit movement ops") + @unittest.skip("TODO: fix uops toposort") def test_mean_std_multireduce(self): - ast = LazyOp(op=BufferOps.STORE, src=(LazyOp(op=UnaryOps.SQRT, src=(LazyOp(op=BinaryOps.MUL, src=(LazyOp(op=ReduceOps.SUM, src=(LazyOp(op=BinaryOps.MUL, src=(LazyOp(op=BinaryOps.ADD, src=(LazyOp(op=BufferOps.LOAD, src=(), arg=MemBuffer(idx=1, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(15, 25, 35), strides=(875, 35, 1), offset=0, mask=None, contiguous=True),)))), LazyOp(op=UnaryOps.NEG, src=(LazyOp(op=BinaryOps.MUL, src=(LazyOp(op=ReduceOps.SUM, src=(LazyOp(op=BufferOps.LOAD, src=(), arg=MemBuffer(idx=1, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(15, 25, 35), strides=(875, 35, 1), offset=0, mask=None, contiguous=True),)))),), arg=(0, 1, 2)), LazyOp(op=BufferOps.CONST, src=(), arg=ConstBuffer(val=7.619047619047618e-05, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(15, 25, 35), strides=(0, 0, 0), offset=0, mask=None, contiguous=False),))))), arg=None),),arg=None)), arg=None), LazyOp(op=BinaryOps.ADD, src=(LazyOp(op=BufferOps.LOAD, src=(), arg=MemBuffer(idx=1, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(15, 25, 35), strides=(875, 35, 1), offset=0, mask=None, contiguous=True),)))), LazyOp(op=UnaryOps.NEG, src=(LazyOp(op=BinaryOps.MUL, src=(LazyOp(op=ReduceOps.SUM, src=(LazyOp(op=BufferOps.LOAD, src=(), arg=MemBuffer(idx=1, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(15, 25, 35), strides=(875, 35, 1), offset=0, mask=None, contiguous=True),)))),), arg=(0, 1, 2)), LazyOp(op=BufferOps.CONST, src=(), arg=ConstBuffer(val=7.619047619047618e-05, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(15, 25, 35), strides=(0, 0, 0), offset=0, mask=None, contiguous=False),))))), arg=None),),arg=None)), arg=None)), arg=None),), arg=(0, 1, 2)), LazyOp(op=BufferOps.CONST, src=(), arg=ConstBuffer(val=7.619628162145687e-05, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(1, 1, 1), strides=(0, 0, 0), offset=0, mask=None, contiguous=True),))))), arg=None),), arg=None),), arg=MemBuffer(idx=0, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(1, 1, 1), strides=(0, 0, 0), offset=0, mask=None, contiguous=True),)))), # noqa: E501 - x = Tensor.randn(15, 25, 35).realize() - helper_linearizer_ast(ast, [x], wanna_output=[x.numpy().std()]) + Tensor.manual_seed(0) + x = Tensor.randn(15, 25, 35, dtype=dtypes.float).realize() + first_x = LazyOp(BufferOps.LOAD, (), MemBuffer(1, dtypes.float, x.lazydata.st.reshape((15, 25, 1, 35)).expand((15, 25, 35, 35)))) + first_reduce = LazyOp(ReduceOps.SUM, (first_x,), (3,)) + mean = first_reduce * LazyOp(BufferOps.CONST, (), ConstBuffer(1/35, dtypes.float, ShapeTracker.from_shape(()).reshape((1, 1, 1, 1)).expand((15, 25, 35, 1)))) # noqa: E501 + second_x = LazyOp(BufferOps.LOAD, (), MemBuffer(1, dtypes.float, x.lazydata.st.reshape((15, 25, 35, 1)))) + squares = (second_x-mean)*(second_x-mean) + squares_sum = LazyOp(ReduceOps.SUM, (squares,), (2,)) + variance = squares_sum * LazyOp(BufferOps.CONST, (), ConstBuffer(1/35, dtypes.float, ShapeTracker.from_shape(()).reshape((1, 1, 1, 1)).expand((15, 25, 1, 1)))) # noqa: E501 + std = LazyOp(UnaryOps.SQRT, (variance,), None) + store = LazyOp(BufferOps.STORE, (std,), MemBuffer(0, dtypes.float, ShapeTracker.from_shape((15, 25, 1, 1)))) + wanna_output = x.numpy().std(axis=2, ddof=0).reshape((15,25,1,1)) + helper_linearizer_ast((store,), [x], wanna_output=[wanna_output]) @unittest.skipIf(CI and Device.DEFAULT in {"AMD"}, "AMD CI doesn't support multiple sync threads yet") - @unittest.skip("AST has implicit movement ops") + @unittest.skip("TODO: fix uops toposort") def test_mean_std_multireduce_mid_dim(self): - ast = LazyOp(op=BufferOps.STORE, src=(LazyOp(op=UnaryOps.SQRT, src=(LazyOp(op=BinaryOps.MUL, src=(LazyOp(op=ReduceOps.SUM, src=(LazyOp(op=BinaryOps.MUL, src=(LazyOp(op=BinaryOps.ADD, src=(LazyOp(op=BufferOps.LOAD, src=(), arg=MemBuffer(idx=1, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(15, 25, 35), strides=(875, 35, 1), offset=0, mask=None, contiguous=True),)))), LazyOp(op=UnaryOps.NEG, src=(LazyOp(op=BinaryOps.MUL, src=(LazyOp(op=ReduceOps.SUM, src=(LazyOp(op=BufferOps.LOAD, src=(), arg=MemBuffer(idx=1, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(15, 25, 35), strides=(875, 35, 1), offset=0, mask=None, contiguous=True),)))),), arg=(1,)), LazyOp(op=BufferOps.CONST, src=(), arg=ConstBuffer(val=0.04, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(15, 25, 35), strides=(0, 0, 0), offset=0, mask=None, contiguous=False),))))), arg=None),),arg=None)), arg=None), LazyOp(op=BinaryOps.ADD, src=(LazyOp(op=BufferOps.LOAD, src=(), arg=MemBuffer(idx=1, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(15, 25, 35), strides=(875, 35, 1), offset=0, mask=None, contiguous=True),)))), LazyOp(op=UnaryOps.NEG, src=(LazyOp(op=BinaryOps.MUL, src=(LazyOp(op=ReduceOps.SUM, src=(LazyOp(op=BufferOps.LOAD, src=(), arg=MemBuffer(idx=1, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(15, 25, 35), strides=(875, 35, 1), offset=0, mask=None, contiguous=True),)))),), arg=(1,)), LazyOp(op=BufferOps.CONST, src=(), arg=ConstBuffer(val=0.04, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(15, 25, 35), strides=(0, 0, 0), offset=0, mask=None, contiguous=False),))))), arg=None),),arg=None)), arg=None)), arg=None),), arg=(1,)), LazyOp(op=BufferOps.CONST, src=(), arg=ConstBuffer(val=0.04, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(1, 1, 1), strides=(0, 0, 0), offset=0, mask=None, contiguous=True),))))), arg=None),), arg=None),), arg=MemBuffer(idx=0, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(15, 1, 35), strides=(35, 35, 1), offset=0, mask=None, contiguous=True),)))), # noqa: E501 - x = Tensor.randn(15, 25, 35).realize() - helper_linearizer_ast(ast, [x], wanna_output=[x.numpy().std(1).reshape(-1)]) + Tensor.manual_seed(0) + x = Tensor.randn(15, 25, 35, dtype=dtypes.float).realize() + first_x = LazyOp(BufferOps.LOAD, (), MemBuffer(1, dtypes.float, x.lazydata.st.reshape((15, 1, 25, 35)).expand((15, 25, 25, 35)))) + first_reduce = LazyOp(ReduceOps.SUM, (first_x,), (2,)) + mean = first_reduce * LazyOp(BufferOps.CONST, (), ConstBuffer(0.04, dtypes.float, ShapeTracker.from_shape(()).reshape((1, 1, 1, 1)).expand((15, 25, 1, 35)))) # noqa: E501 + second_x = LazyOp(BufferOps.LOAD, (), MemBuffer(1, dtypes.float, x.lazydata.st.reshape((15, 25, 1, 35)))) + squares = (second_x-mean)*(second_x-mean) + squares_sum = LazyOp(ReduceOps.SUM, (squares,), (1,)) + variance = squares_sum * LazyOp(BufferOps.CONST, (), ConstBuffer(0.04, dtypes.float, ShapeTracker.from_shape(()).reshape((1, 1, 1, 1)).expand((15, 1, 1, 35)))) # noqa: E501 + std = LazyOp(UnaryOps.SQRT, (variance,), None) + store = LazyOp(BufferOps.STORE, (std,), MemBuffer(0, dtypes.float, ShapeTracker.from_shape((15, 1, 1, 35)))) + wanna_output = x.numpy().std(axis=1, ddof=0).reshape((15,1,1,35)) + helper_linearizer_ast((store,), [x], wanna_output=[wanna_output]) @unittest.skipIf(CI and Device.DEFAULT in {"AMD"}, "AMD CI doesn't support multiple sync threads yet") - @unittest.skip("AST has implicit movement ops") + @unittest.expectedFailure def test_mean_std_multireduce_multiout(self): - std = LazyOp(op=BufferOps.STORE, src=(LazyOp(op=UnaryOps.SQRT, src=(LazyOp(op=BinaryOps.MUL, src=(LazyOp(op=ReduceOps.SUM, src=(LazyOp(op=BinaryOps.MUL, src=(LazyOp(op=BinaryOps.ADD, src=(LazyOp(op=BufferOps.LOAD, src=(), arg=MemBuffer(idx=2, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(15, 25, 35), strides=(875, 35, 1), offset=0, mask=None, contiguous=True),)))), LazyOp(op=UnaryOps.NEG, src=(LazyOp(op=BinaryOps.MUL, src=(LazyOp(op=ReduceOps.SUM, src=(LazyOp(op=BufferOps.LOAD, src=(), arg=MemBuffer(idx=2, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(15, 25, 35), strides=(875, 35, 1), offset=0, mask=None, contiguous=True),)))),), arg=(0, 1, 2)), LazyOp(op=BufferOps.CONST, src=(), arg=ConstBuffer(val=7.619047619047618e-05, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(15, 25, 35), strides=(0, 0, 0), offset=0, mask=None, contiguous=False),))))), arg=None),),arg=None)), arg=None), LazyOp(op=BinaryOps.ADD, src=(LazyOp(op=BufferOps.LOAD, src=(), arg=MemBuffer(idx=2, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(15, 25, 35), strides=(875, 35, 1), offset=0, mask=None, contiguous=True),)))), LazyOp(op=UnaryOps.NEG, src=(LazyOp(op=BinaryOps.MUL, src=(LazyOp(op=ReduceOps.SUM, src=(LazyOp(op=BufferOps.LOAD, src=(), arg=MemBuffer(idx=2, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(15, 25, 35), strides=(875, 35, 1), offset=0, mask=None, contiguous=True),)))),), arg=(0, 1, 2)), LazyOp(op=BufferOps.CONST, src=(), arg=ConstBuffer(val=7.619047619047618e-05, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(15, 25, 35), strides=(0, 0, 0), offset=0, mask=None, contiguous=False),))))), arg=None),),arg=None)), arg=None)), arg=None),), arg=(0, 1, 2)), LazyOp(op=BufferOps.CONST, src=(), arg=ConstBuffer(val=7.619628162145687e-05, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(1, 1, 1), strides=(0, 0, 0), offset=0, mask=None, contiguous=True),))))), arg=None),), arg=None),), arg=MemBuffer(idx=0, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(1, 1, 1), strides=(0, 0, 0), offset=0, mask=None, contiguous=True),)))) # noqa: E501 - mean = LazyOp(op=BufferOps.STORE, src=(LazyOp(op=BinaryOps.MUL, src=(LazyOp(op=ReduceOps.SUM, src=(LazyOp(op=BufferOps.LOAD, src=(), arg=MemBuffer(idx=2, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(15, 25, 35), strides=(875, 35, 1), offset=0, mask=None, contiguous=True),)))),), arg=(0, 1, 2)), LazyOp(op=BufferOps.CONST, src=(), arg=ConstBuffer(val=7.619047619047618e-05, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(15, 25, 35), strides=(0, 0, 0), offset=0, mask=None, contiguous=False),))))), arg=None),), arg=MemBuffer(idx=1, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(1, 1, 1), strides=(0, 0, 0), offset=0, mask=None, contiguous=True),)))) # noqa: E501 - x = Tensor.randn(15, 25, 35).realize() - helper_linearizer_ast((std,mean), [x], wanna_output=[x.numpy().std(), x.numpy().mean()]) + Tensor.manual_seed(0) + x = Tensor.randn(15, 25, 35, dtype=dtypes.float).realize() + first_x = LazyOp(BufferOps.LOAD, (), MemBuffer(2, dtypes.float, x.lazydata.st.reshape((15, 25, 1, 35)).expand((15, 25, 35, 35)))) + first_reduce = LazyOp(ReduceOps.SUM, (first_x,), (3,)) + mean = first_reduce * LazyOp(BufferOps.CONST, (), ConstBuffer(1/35, dtypes.float, ShapeTracker.from_shape(()).reshape((1, 1, 1, 1)).expand((15, 25, 35, 1)))) # noqa: E501 + second_x = LazyOp(BufferOps.LOAD, (), MemBuffer(2, dtypes.float, x.lazydata.st.reshape((15, 25, 35, 1)))) + squares = (second_x-mean)*(second_x-mean) + squares_sum = LazyOp(ReduceOps.SUM, (squares,), (2,)) + variance = squares_sum * LazyOp(BufferOps.CONST, (), ConstBuffer(1/35, dtypes.float, ShapeTracker.from_shape(()).reshape((1, 1, 1, 1)).expand((15, 25, 1, 1)))) # noqa: E501 + std = LazyOp(UnaryOps.SQRT, (variance,), None) + third_reduce = LazyOp(ReduceOps.SUM, (second_x,), (2,)) + mean_out = third_reduce * LazyOp(BufferOps.CONST, (), ConstBuffer(1/35, dtypes.float, ShapeTracker.from_shape(()).reshape((1, 1, 1, 1)).expand((15, 25, 1, 1)))) # noqa: E501 + store_mean = LazyOp(BufferOps.STORE, (mean_out,), MemBuffer(1, dtypes.float, ShapeTracker.from_shape((15,25,1,1)))) + store_std = LazyOp(BufferOps.STORE, (std,), MemBuffer(0, dtypes.float, ShapeTracker.from_shape((15, 25, 1, 1)))) + wanna_output = [x.numpy().std(axis=2, ddof=0).reshape(15,25,1,1), x.numpy().mean(axis=2).reshape(15,25,1,1)] + lins = helper_linearizer_ast((store_std,store_mean), [x], wanna_output=wanna_output) + + for k in lins: + assert len([u for u in k.uops if u.op is UOps.DEFINE_ACC]) == 2, "got more than two accs (didn't reuse the mean reduce)" @unittest.skipIf(CI and Device.DEFAULT in {"AMD"}, "AMD CI doesn't support multiple sync threads yet") @unittest.skip("AST has implicit movement ops")