diff --git a/test/test_linearizer.py b/test/test_linearizer.py index d1e003089f..f5bf868e64 100644 --- a/test/test_linearizer.py +++ b/test/test_linearizer.py @@ -195,9 +195,10 @@ class TestLinearizer(unittest.TestCase): end = max(i for i,u in enumerate(lin.uops) if u.uop is UOps.ENDRANGE) assert lin.uops[end+1].uop is UOps.ALU - @unittest.skip + @unittest.skipIf(CI and Device.DEFAULT in {"AMD"}, "AMD CI is really slow here") @unittest.skipUnless(Device[Device.DEFAULT].renderer.has_local, "test requires locals") def test_early_end_local(self): + # with multiple reductions, make sure the if block of one reduceop is closed before starting another ast = LazyOp(op=BufferOps.STORE, src=(LazyOp(op=UnaryOps.SQRT, src=(LazyOp(op=BinaryOps.MUL, src=(LazyOp(op=ReduceOps.SUM, src=(LazyOp(op=BinaryOps.MUL, src=(LazyOp(op=BinaryOps.SUB, src=(LazyOp(op=BufferOps.LOAD, src=(), arg=MemBuffer(idx=1, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(3, 27, 32), strides=(864, 32, 1), offset=0, mask=None, contiguous=True),)))), LazyOp(op=BinaryOps.MUL, src=(LazyOp(op=ReduceOps.SUM, src=(LazyOp(op=BufferOps.LOAD, src=(), arg=MemBuffer(idx=1, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(3, 27, 32), strides=(864, 32, 1), offset=0, mask=None, contiguous=True),)))),), arg=(2,)), LazyOp(op=BufferOps.CONST, src=(), arg=ConstBuffer(val=0.03125, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(3, 27, 32), strides=(0, 0, 0), offset=0, mask=None, contiguous=False),))))), arg=None)), arg=None), LazyOp(op=BinaryOps.SUB, src=(LazyOp(op=BufferOps.LOAD, src=(), arg=MemBuffer(idx=1, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(3, 27, 32), strides=(864, 32, 1), offset=0, mask=None, contiguous=True),)))), LazyOp(op=BinaryOps.MUL, src=(LazyOp(op=ReduceOps.SUM, src=(LazyOp(op=BufferOps.LOAD, src=(), arg=MemBuffer(idx=1, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(3, 27, 32), strides=(864, 32, 1), offset=0, mask=None, contiguous=True),)))),), arg=(2,)), LazyOp(op=BufferOps.CONST, src=(), arg=ConstBuffer(val=0.03125, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(3, 27, 32), strides=(0, 0, 0), offset=0, mask=None, contiguous=False),))))), arg=None)), arg=None)), arg=None),), arg=(2,)), LazyOp(op=BufferOps.CONST, src=(), arg=ConstBuffer(val=0.03125, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(3, 27, 1), strides=(27, 1, 0), offset=0, mask=None, contiguous=True),))))), arg=None),), arg=None),), arg=MemBuffer(idx=0, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(3, 27, 1), strides=(27, 1, 0), offset=0, mask=None, contiguous=True),)))), # noqa: E501 k = Linearizer(*ast) k.hand_coded_optimizations() @@ -213,8 +214,9 @@ class TestLinearizer(unittest.TestCase): x = Tensor.randn(3,27,32).realize() helper_linearizer_ast(ast, [x], wanna_output=[x.numpy().std(axis=2, ddof=0).reshape(-1)]) - @unittest.skip + @unittest.skipIf(CI and Device.DEFAULT in {"AMD"}, "AMD CI is really slow here") def test_reduceops_order(self): + # make sure that the kernel put reduceops in the order of their dependencies when passed to the Linearizer in arbitrary order load = MemBuffer(idx=4, dtype=dtypes.float, st=ShapeTracker.from_shape((32,))) ast0 = LazyOp(op=ReduceOps.SUM, src=(LazyOp(op=BufferOps.LOAD, src=(), arg=load),), arg=(0,)) ast1 = LazyOp(op=ReduceOps.SUM, src=(LazyOp(BinaryOps.SUB, src=(LazyOp(op=BufferOps.LOAD, src=(), arg=load), ast0)),), arg=(0,)) @@ -234,9 +236,10 @@ class TestLinearizer(unittest.TestCase): outs = [b:=(a:=x.numpy()).sum(), c:=(a - b).sum(), d:=(c - a).sum(), (a-d + a-b).sum()] helper_linearizer_ast(tuple(asts[i] for i in order), [x], wanna_output=[outs[i] for i in order]) - @unittest.skip + @unittest.skipIf(CI and Device.DEFAULT in {"AMD"}, "AMD CI is really slow here") @unittest.skipUnless(Device[Device.DEFAULT].renderer.has_local, "test requires locals") def test_multireduce_store_locals(self): + # ensure the result of local reducop is stored and loaded back into every thread for future use ast = LazyOp(op=BufferOps.STORE, src=(LazyOp(op=UnaryOps.SQRT, src=(LazyOp(op=BinaryOps.MUL, src=(LazyOp(op=ReduceOps.SUM, src=(LazyOp(op=BinaryOps.MUL, src=(LazyOp(op=BinaryOps.SUB, src=(LazyOp(op=BufferOps.LOAD, src=(), arg=MemBuffer(idx=1, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(3, 27, 32), strides=(864, 32, 1), offset=0, mask=None, contiguous=True),)))), LazyOp(op=BinaryOps.MUL, src=(LazyOp(op=ReduceOps.SUM, src=(LazyOp(op=BufferOps.LOAD, src=(), arg=MemBuffer(idx=1, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(3, 27, 32), strides=(864, 32, 1), offset=0, mask=None, contiguous=True),)))),), arg=(2,)), LazyOp(op=BufferOps.CONST, src=(), arg=ConstBuffer(val=0.03125, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(3, 27, 32), strides=(0, 0, 0), offset=0, mask=None, contiguous=False),))))), arg=None)), arg=None), LazyOp(op=BinaryOps.SUB, src=(LazyOp(op=BufferOps.LOAD, src=(), arg=MemBuffer(idx=1, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(3, 27, 32), strides=(864, 32, 1), offset=0, mask=None, contiguous=True),)))), LazyOp(op=BinaryOps.MUL, src=(LazyOp(op=ReduceOps.SUM, src=(LazyOp(op=BufferOps.LOAD, src=(), arg=MemBuffer(idx=1, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(3, 27, 32), strides=(864, 32, 1), offset=0, mask=None, contiguous=True),)))),), arg=(2,)), LazyOp(op=BufferOps.CONST, src=(), arg=ConstBuffer(val=0.03125, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(3, 27, 32), strides=(0, 0, 0), offset=0, mask=None, contiguous=False),))))), arg=None)), arg=None)), arg=None),), arg=(2,)), LazyOp(op=BufferOps.CONST, src=(), arg=ConstBuffer(val=0.03125, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(3, 27, 1), strides=(27, 1, 0), offset=0, mask=None, contiguous=True),))))), arg=None),), arg=None),), arg=MemBuffer(idx=0, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(3, 27, 1), strides=(27, 1, 0), offset=0, mask=None, contiguous=True),)))), # noqa: E501 k = Linearizer(*ast) k.hand_coded_optimizations() @@ -251,20 +254,37 @@ class TestLinearizer(unittest.TestCase): x = Tensor.randn(3,27,32).realize() helper_linearizer_ast(ast, [x], wanna_output=[x.numpy().std(axis=2, ddof=0).reshape(-1)]) - @unittest.skip def test_multireduce_upcasting(self): - ast = LazyOp(op=BufferOps.STORE, src=(LazyOp(op=ReduceOps.SUM, src=(LazyOp(op=BinaryOps.SUB, src=(LazyOp(op=BufferOps.LOAD, src=(), arg=MemBuffer(idx=1, dtype=dtypes.float32, st=ShapeTracker(views=(View(shape=(2, 7), strides=(7, 1), offset=0, mask=None, contiguous=True),)))), LazyOp(op=ReduceOps.SUM, src=(LazyOp(op=BufferOps.LOAD, src=(), arg=MemBuffer(idx=1, dtype=dtypes.float32, st=ShapeTracker(views=(View(shape=(2, 7), strides=(7, 1), offset=0, mask=None, contiguous=True),),))),), arg=(1,)),)),), arg=(1,)),), arg=MemBuffer(idx=0, dtype=dtypes.float32, st=ShapeTracker(views=(View(shape=(2, 1), strides=(1, 0), offset=0, mask=None, contiguous=True),)))), # noqa: E501 + # when upcasting multiple reductions, ensure ast_parse will create multiple uops even when using the result of past reductions + ast = LazyOp(op=BufferOps.STORE, src=(LazyOp(op=ReduceOps.SUM, src=(LazyOp(op=BinaryOps.SUB, src=(LazyOp(op=BufferOps.LOAD, src=(), arg=MemBuffer(idx=1, dtype=dtypes.float32, st=ShapeTracker(views=(View(shape=(8, 7), strides=(7, 1), offset=0, mask=None, contiguous=True),)))), LazyOp(op=ReduceOps.SUM, src=(LazyOp(op=BufferOps.LOAD, src=(), arg=MemBuffer(idx=1, dtype=dtypes.float32, st=ShapeTracker(views=(View(shape=(8, 7), strides=(7, 1), offset=0, mask=None, contiguous=True),),))),), arg=(1,)),)),), arg=(1,)),), arg=MemBuffer(idx=0, dtype=dtypes.float32, st=ShapeTracker(views=(View(shape=(8, 1), strides=(1, 0), offset=0, mask=None, contiguous=True),)))), # noqa: E501 k = Linearizer(*ast) k.upcast() k.linearize() define_globals = [u for u in k.uops if u.uop is UOps.DEFINE_GLOBAL] self.assertEqual(len([u for u in k.uops if u.uop is UOps.LOAD and define_globals[1] in u.vin]), 7) self.assertEqual(len([u for u in k.uops if u.uop is UOps.ALU and u.arg is BinaryOps.SUB]), 7) - x = Tensor.randn(2,7).softmax().realize() - helper_linearizer_ast(ast, [x], wanna_output=[(x.numpy() - x.numpy().sum(axis=1, keepdims=True)).sum(axis=1)]) + opts = [[Opt(op=OptOps.UPCAST, axis=0, amt=2)], [Opt(op=OptOps.UPCAST, axis=0, amt=4)]] + x = Tensor.randn(8,7).softmax().realize() + helper_linearizer_ast(ast, [x], opts=opts, wanna_output=[(x.numpy() - x.numpy().sum(axis=1, keepdims=True)).sum(axis=1)]) - @unittest.skip + def test_multireduce_unroll(self): + # unrolled multireduceops will cause an issue where and reduceop following another reduceop will need to bring the "unroll" back: + # ex you unroll into four values, the four values sum, then you need to four operations on the sum for the next reduceop + ast = LazyOp(op=BufferOps.STORE, src=(LazyOp(op=ReduceOps.SUM, src=(LazyOp(op=BinaryOps.SUB, src=(LazyOp(op=BufferOps.LOAD, src=(), arg=MemBuffer(idx=1, dtype=dtypes.float32, st=ShapeTracker(views=(View(shape=(2, 12), strides=(12, 1), offset=0, mask=None, contiguous=True),)))), LazyOp(op=ReduceOps.SUM, src=(LazyOp(op=BufferOps.LOAD, src=(), arg=MemBuffer(idx=1, dtype=dtypes.float32, st=ShapeTracker(views=(View(shape=(2, 12), strides=(12, 1), offset=0, mask=None, contiguous=True),),))),), arg=(1,)),)),), arg=(1,)),), arg=MemBuffer(idx=0, dtype=dtypes.float32, st=ShapeTracker(views=(View(shape=(2, 1), strides=(1, 0), offset=0, mask=None, contiguous=True),)))), # noqa: E501 + opts = [ + [Opt(op=OptOps.UNROLL, axis=0, amt=12)], + [Opt(op=OptOps.UNROLL, axis=0, amt=6)], + [Opt(op=OptOps.UNROLL, axis=0, amt=4)], + [Opt(op=OptOps.UNROLL, axis=0, amt=3)], + [Opt(op=OptOps.UNROLL, axis=0, amt=2)], + ] + x = Tensor.randn(2,12).softmax().realize() + helper_linearizer_ast(ast, [x], opts=opts, wanna_output=[(x.numpy() - x.numpy().sum(axis=1, keepdims=True)).sum(axis=1)]) + + @unittest.skipIf(CI and Device.DEFAULT in {"AMD"}, "AMD CI is really slow here") def test_multireduce_loop_scope(self): + # when rendering multiple reducops, any arithmetic on the result of one reduceop will be rendered within the loop of the next + # these ops need to be moved out of the loop so they be accessed in any scope ast = LazyOp(op=BufferOps.STORE, src=(LazyOp(op=ReduceOps.SUM, src=(LazyOp(op=BinaryOps.DIV, src=(LazyOp(op=BinaryOps.SUB, src=(LazyOp(op=BufferOps.LOAD, src=(), arg=MemBuffer(idx=1, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(3, 27, 32), strides=(864, 32, 1), offset=0, mask=None, contiguous=True),)))), LazyOp(op=BinaryOps.MUL, src=(LazyOp(op=ReduceOps.SUM, src=(LazyOp(op=BufferOps.LOAD, src=(), arg=MemBuffer(idx=1, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(3, 27, 32), strides=(864, 32, 1), offset=0, mask=None, contiguous=True),)))),), arg=(2,)), LazyOp(op=BufferOps.CONST, src=(), arg=ConstBuffer(val=0.03125, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(3, 27, 32), strides=(0, 0, 0), offset=0, mask=None, contiguous=False),))))), arg=None))),LazyOp(op=UnaryOps.SQRT, src=(LazyOp(op=BinaryOps.MUL, src=(LazyOp(op=ReduceOps.SUM, src=(LazyOp(op=BinaryOps.MUL, src=(LazyOp(op=BinaryOps.SUB, src=(LazyOp(op=BufferOps.LOAD, src=(), arg=MemBuffer(idx=1, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(3, 27, 32), strides=(864, 32, 1), offset=0, mask=None, contiguous=True),)))), LazyOp(op=BinaryOps.MUL, src=(LazyOp(op=ReduceOps.SUM, src=(LazyOp(op=BufferOps.LOAD, src=(), arg=MemBuffer(idx=1, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(3, 27, 32), strides=(864, 32, 1), offset=0, mask=None, contiguous=True),)))),), arg=(2,)), LazyOp(op=BufferOps.CONST, src=(), arg=ConstBuffer(val=0.03125, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(3, 27, 32), strides=(0, 0, 0), offset=0, mask=None, contiguous=False),))))), arg=None)), arg=None), LazyOp(op=BinaryOps.SUB, src=(LazyOp(op=BufferOps.LOAD, src=(), arg=MemBuffer(idx=1, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(3, 27, 32), strides=(864, 32, 1), offset=0, mask=None, contiguous=True),)))), LazyOp(op=BinaryOps.MUL, src=(LazyOp(op=ReduceOps.SUM, src=(LazyOp(op=BufferOps.LOAD, src=(), arg=MemBuffer(idx=1, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(3, 27, 32), strides=(864, 32, 1), offset=0, mask=None, contiguous=True),)))),), arg=(2,)), LazyOp(op=BufferOps.CONST, src=(), arg=ConstBuffer(val=0.03125, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(3, 27, 32), strides=(0, 0, 0), offset=0, mask=None, contiguous=False),))))), arg=None)), arg=None)), arg=None),), arg=(2,)), LazyOp(op=BufferOps.CONST, src=(), arg=ConstBuffer(val=0.03125, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(3, 27, 1), strides=(27, 1, 0), offset=0, mask=None, contiguous=True),))))), arg=None),), arg=None),),),), arg=(2,)),), arg=MemBuffer(idx=0, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(3, 27, 1), strides=(27, 1, 0), offset=0, mask=None, contiguous=True),),))), # noqa: E501 k = Linearizer(*ast) k.hand_coded_optimizations() @@ -280,25 +300,55 @@ class TestLinearizer(unittest.TestCase): helper_linearizer_ast(ast, [x], wanna_output= \ [((x.numpy() - x.numpy().mean(axis=2, keepdims=True))/x.numpy().std(axis=2, keepdims=True, ddof=0)).sum(axis=2).reshape(-1)]) - @unittest.skip + @unittest.skipIf(CI and Device.DEFAULT in {"AMD"}, "AMD CI is really slow here") def test_mean_std_multireduce(self): ast = LazyOp(op=BufferOps.STORE, src=(LazyOp(op=UnaryOps.SQRT, src=(LazyOp(op=BinaryOps.MUL, src=(LazyOp(op=ReduceOps.SUM, src=(LazyOp(op=BinaryOps.MUL, src=(LazyOp(op=BinaryOps.SUB, src=(LazyOp(op=BufferOps.LOAD, src=(), arg=MemBuffer(idx=1, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(15, 25, 35), strides=(875, 35, 1), offset=0, mask=None, contiguous=True),)))), LazyOp(op=BinaryOps.MUL, src=(LazyOp(op=ReduceOps.SUM, src=(LazyOp(op=BufferOps.LOAD, src=(), arg=MemBuffer(idx=1, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(15, 25, 35), strides=(875, 35, 1), offset=0, mask=None, contiguous=True),)))),), arg=(0, 1, 2)), LazyOp(op=BufferOps.CONST, src=(), arg=ConstBuffer(val=7.619047619047618e-05, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(15, 25, 35), strides=(0, 0, 0), offset=0, mask=None, contiguous=False),))))), arg=None)), arg=None), LazyOp(op=BinaryOps.SUB, src=(LazyOp(op=BufferOps.LOAD, src=(), arg=MemBuffer(idx=1, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(15, 25, 35), strides=(875, 35, 1), offset=0, mask=None, contiguous=True),)))), LazyOp(op=BinaryOps.MUL, src=(LazyOp(op=ReduceOps.SUM, src=(LazyOp(op=BufferOps.LOAD, src=(), arg=MemBuffer(idx=1, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(15, 25, 35), strides=(875, 35, 1), offset=0, mask=None, contiguous=True),)))),), arg=(0, 1, 2)), LazyOp(op=BufferOps.CONST, src=(), arg=ConstBuffer(val=7.619047619047618e-05, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(15, 25, 35), strides=(0, 0, 0), offset=0, mask=None, contiguous=False),))))), arg=None)), arg=None)), arg=None),), arg=(0, 1, 2)), LazyOp(op=BufferOps.CONST, src=(), arg=ConstBuffer(val=7.619628162145687e-05, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(1, 1, 1), strides=(0, 0, 0), offset=0, mask=None, contiguous=True),))))), arg=None),), arg=None),), arg=MemBuffer(idx=0, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(1, 1, 1), strides=(0, 0, 0), offset=0, mask=None, contiguous=True),)))), # noqa: E501 x = Tensor.randn(15, 25, 35).realize() helper_linearizer_ast(ast, [x], wanna_output=[x.numpy().std()]) - @unittest.skip + @unittest.skipIf(CI and Device.DEFAULT in {"AMD"}, "AMD CI is really slow here") def test_mean_std_multireduce_mid_dim(self): ast = LazyOp(op=BufferOps.STORE, src=(LazyOp(op=UnaryOps.SQRT, src=(LazyOp(op=BinaryOps.MUL, src=(LazyOp(op=ReduceOps.SUM, src=(LazyOp(op=BinaryOps.MUL, src=(LazyOp(op=BinaryOps.SUB, src=(LazyOp(op=BufferOps.LOAD, src=(), arg=MemBuffer(idx=1, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(15, 25, 35), strides=(875, 35, 1), offset=0, mask=None, contiguous=True),)))), LazyOp(op=BinaryOps.MUL, src=(LazyOp(op=ReduceOps.SUM, src=(LazyOp(op=BufferOps.LOAD, src=(), arg=MemBuffer(idx=1, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(15, 25, 35), strides=(875, 35, 1), offset=0, mask=None, contiguous=True),)))),), arg=(1,)), LazyOp(op=BufferOps.CONST, src=(), arg=ConstBuffer(val=0.04, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(15, 25, 35), strides=(0, 0, 0), offset=0, mask=None, contiguous=False),))))), arg=None)), arg=None), LazyOp(op=BinaryOps.SUB, src=(LazyOp(op=BufferOps.LOAD, src=(), arg=MemBuffer(idx=1, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(15, 25, 35), strides=(875, 35, 1), offset=0, mask=None, contiguous=True),)))), LazyOp(op=BinaryOps.MUL, src=(LazyOp(op=ReduceOps.SUM, src=(LazyOp(op=BufferOps.LOAD, src=(), arg=MemBuffer(idx=1, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(15, 25, 35), strides=(875, 35, 1), offset=0, mask=None, contiguous=True),)))),), arg=(1,)), LazyOp(op=BufferOps.CONST, src=(), arg=ConstBuffer(val=0.04, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(15, 25, 35), strides=(0, 0, 0), offset=0, mask=None, contiguous=False),))))), arg=None)), arg=None)), arg=None),), arg=(1,)), LazyOp(op=BufferOps.CONST, src=(), arg=ConstBuffer(val=0.04, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(1, 1, 1), strides=(0, 0, 0), offset=0, mask=None, contiguous=True),))))), arg=None),), arg=None),), arg=MemBuffer(idx=0, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(15, 1, 35), strides=(35, 35, 1), offset=0, mask=None, contiguous=True),)))), # noqa: E501 x = Tensor.randn(15, 25, 35).realize() helper_linearizer_ast(ast, [x], wanna_output=[x.numpy().std(1).reshape(-1)]) - @unittest.skip + @unittest.skipIf(CI and Device.DEFAULT in {"AMD"}, "AMD CI is really slow here") def test_mean_std_multireduce_multiout(self): std = LazyOp(op=BufferOps.STORE, src=(LazyOp(op=UnaryOps.SQRT, src=(LazyOp(op=BinaryOps.MUL, src=(LazyOp(op=ReduceOps.SUM, src=(LazyOp(op=BinaryOps.MUL, src=(LazyOp(op=BinaryOps.SUB, src=(LazyOp(op=BufferOps.LOAD, src=(), arg=MemBuffer(idx=2, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(15, 25, 35), strides=(875, 35, 1), offset=0, mask=None, contiguous=True),)))), LazyOp(op=BinaryOps.MUL, src=(LazyOp(op=ReduceOps.SUM, src=(LazyOp(op=BufferOps.LOAD, src=(), arg=MemBuffer(idx=2, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(15, 25, 35), strides=(875, 35, 1), offset=0, mask=None, contiguous=True),)))),), arg=(0, 1, 2)), LazyOp(op=BufferOps.CONST, src=(), arg=ConstBuffer(val=7.619047619047618e-05, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(15, 25, 35), strides=(0, 0, 0), offset=0, mask=None, contiguous=False),))))), arg=None)), arg=None), LazyOp(op=BinaryOps.SUB, src=(LazyOp(op=BufferOps.LOAD, src=(), arg=MemBuffer(idx=2, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(15, 25, 35), strides=(875, 35, 1), offset=0, mask=None, contiguous=True),)))), LazyOp(op=BinaryOps.MUL, src=(LazyOp(op=ReduceOps.SUM, src=(LazyOp(op=BufferOps.LOAD, src=(), arg=MemBuffer(idx=2, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(15, 25, 35), strides=(875, 35, 1), offset=0, mask=None, contiguous=True),)))),), arg=(0, 1, 2)), LazyOp(op=BufferOps.CONST, src=(), arg=ConstBuffer(val=7.619047619047618e-05, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(15, 25, 35), strides=(0, 0, 0), offset=0, mask=None, contiguous=False),))))), arg=None)), arg=None)), arg=None),), arg=(0, 1, 2)), LazyOp(op=BufferOps.CONST, src=(), arg=ConstBuffer(val=7.619628162145687e-05, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(1, 1, 1), strides=(0, 0, 0), offset=0, mask=None, contiguous=True),))))), arg=None),), arg=None),), arg=MemBuffer(idx=0, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(1, 1, 1), strides=(0, 0, 0), offset=0, mask=None, contiguous=True),)))) # noqa: E501 mean = LazyOp(op=BufferOps.STORE, src=(LazyOp(op=BinaryOps.MUL, src=(LazyOp(op=ReduceOps.SUM, src=(LazyOp(op=BufferOps.LOAD, src=(), arg=MemBuffer(idx=2, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(15, 25, 35), strides=(875, 35, 1), offset=0, mask=None, contiguous=True),)))),), arg=(0, 1, 2)), LazyOp(op=BufferOps.CONST, src=(), arg=ConstBuffer(val=7.619047619047618e-05, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(15, 25, 35), strides=(0, 0, 0), offset=0, mask=None, contiguous=False),))))), arg=None),), arg=MemBuffer(idx=1, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(1, 1, 1), strides=(0, 0, 0), offset=0, mask=None, contiguous=True),)))) # noqa: E501 x = Tensor.randn(15, 25, 35).realize() helper_linearizer_ast((std,mean), [x], wanna_output=[x.numpy().std(), x.numpy().mean()]) + @unittest.skipIf(CI and Device.DEFAULT in {"AMD"}, "AMD CI is really slow here") + def test_softmax_multireduce(self): + x = Tensor.rand(4, 32).realize() + x_ast = LazyOp(op=BufferOps.LOAD, src=(), arg=MemBuffer(idx=1, dtype=dtypes.float, st=ShapeTracker.from_shape((4,32)))) + max_x = LazyOp(op=ReduceOps.MAX, src=(x_ast,), arg=(1,)) + centered_x = LazyOp(op=BinaryOps.SUB, src=(x_ast, max_x)) + exp_x = LazyOp(op=UnaryOps.EXP2, src=(centered_x,)) + sum_exp_x = LazyOp(op=ReduceOps.SUM, src=(exp_x,), arg=(1,)) + y = LazyOp(op=BinaryOps.DIV, src=(exp_x, sum_exp_x)) + y_reduced = LazyOp(op=ReduceOps.SUM, src=(y,), arg=(1,)) + ast = LazyOp(op=BufferOps.STORE, src=(y_reduced,), arg=MemBuffer(idx=0, dtype=dtypes.float, st=ShapeTracker.from_shape((4,1)))) + expected = ((np_exp2:=np.exp2(x.numpy() - x.numpy().max(axis=-1, keepdims=True)))/np_exp2.sum(axis=-1, keepdims=True)).sum(axis=-1) + helper_linearizer_ast((ast,), [x], wanna_output=[expected]) + + @unittest.skipIf(CI and Device.DEFAULT in {"AMD"}, "AMD CI is really slow here") + def test_softmax_multireduce_multiout(self): + x = Tensor.rand(4, 32).realize() + x_ast = LazyOp(op=BufferOps.LOAD, src=(), arg=MemBuffer(idx=3, dtype=dtypes.float, st=ShapeTracker.from_shape((4,32)))) + max_x = LazyOp(op=ReduceOps.MAX, src=(x_ast,), arg=(1,)) + exp_x = LazyOp(op=UnaryOps.EXP2, src=(LazyOp(op=BinaryOps.SUB, src=(x_ast, max_x)),)) + sum_exp_x = LazyOp(op=ReduceOps.SUM, src=(exp_x,), arg=(1,)) + ast = LazyOp(op=BufferOps.STORE, src=(LazyOp(op=ReduceOps.SUM, src=(LazyOp(op=BinaryOps.DIV, src=(exp_x, sum_exp_x)),), arg=(1,)),), arg=MemBuffer(idx=0, dtype=dtypes.float, st=ShapeTracker.from_shape((4,1)))) # noqa: E501 + max_x_ast = LazyOp(op=BufferOps.STORE, src=(max_x,), arg=MemBuffer(idx=1, dtype=dtypes.float, st=ShapeTracker.from_shape((4,1)))) + sum_exp_x_ast = LazyOp(op=BufferOps.STORE, src=(sum_exp_x,), arg=MemBuffer(idx=2, dtype=dtypes.float, st=ShapeTracker.from_shape((4,1)))) + expected = [ + ((np_exp2:=np.exp2(x.numpy()-(np_max_x:=x.numpy().max(axis=-1,keepdims=True))))/(sum_exp_x:=np_exp2.sum(axis=-1,keepdims=True))).sum(axis=-1,), + np_max_x.reshape(-1), sum_exp_x.reshape(-1) + ] + helper_linearizer_ast((ast,max_x_ast,sum_exp_x_ast), [x], wanna_output=expected) + def test_load_dedup(self): # for different leaves in the AST, the same loads may occur. @@ -378,9 +428,9 @@ class TestLinearizer(unittest.TestCase): assert accs[1].dtype == stores[1].vin[-1].dtype == dtypes.float assert stores[1].vin[0].uop is UOps.DEFINE_GLOBAL - @unittest.skip("multireduce isn't supported yet") + @unittest.skipIf(CI and Device.DEFAULT in {"AMD"}, "AMD CI is really slow here") def test_upcast_multireduce_nested_local_upcast(self): - x, y, z, w = [Tensor.rand(1,128).realize() for _ in range(4)] + x, y, z, w = [Tensor.rand((1,128) if i % 2 == 0 else (1,128,128)).realize() for i in range(4)] st0 = ShapeTracker(views=(View(shape=(1, 128, 128), strides=(0, 0, 1), offset=0, mask=None, contiguous=False),)) st1 = ShapeTracker(views=(View(shape=(1, 128, 128), strides=(0, 1, 128), offset=0, mask=None, contiguous=False),)) ld0 = LazyOp(BufferOps.LOAD, (), MemBuffer(1, dtypes.float, st0)) @@ -1051,7 +1101,7 @@ class TestKernelOpts(unittest.TestCase): [Opt(OptOps.LOCAL, 0, 2), Opt(OptOps.GROUPTOP, 0, 2), Opt(OptOps.UPCAST, 0, 8), Opt(OptOps.UNROLL, 1, 4)], ]) - @unittest.skip("multireduce isn't supported yet") + @unittest.skipIf(CI and Device.DEFAULT in {"AMD"}, "AMD CI is really slow here") @unittest.skipUnless(Device[Device.DEFAULT].renderer.has_local, "test requires locals") @unittest.skipUnless(Device[Device.DEFAULT].renderer.has_shared, "test requires shared") def test_local_and_grouped_reduce_multireduce(self): @@ -1065,7 +1115,7 @@ class TestKernelOpts(unittest.TestCase): d = Tensor.rand(4, 4, N).realize() r1 = (d.sqrt() + ((c+1).sum(axis=3).exp())) ast = _temp_create_multireduce_ast(r0, r1) - helper_linearizer_ast(ast, [a, b, c, d], [ + helper_linearizer_ast(ast, [b, a, d, c], [ [Opt(OptOps.LOCAL, 0, 2)], [Opt(OptOps.LOCAL, 0, 8)], [Opt(OptOps.LOCAL, 0, 16)], # Checking how it works with locals @@ -1081,7 +1131,7 @@ class TestKernelOpts(unittest.TestCase): [Opt(OptOps.LOCAL, 0, 2), Opt(OptOps.GROUPTOP, 0, 2), Opt(OptOps.UPCAST, 0, 8), Opt(OptOps.UNROLL, 1, 4)], ]) - @unittest.skip("multireduce isn't supported yet") + @unittest.skipIf(CI and Device.DEFAULT in {"AMD"}, "AMD CI is really slow here") @unittest.skipUnless(Device[Device.DEFAULT].renderer.has_local, "test requires locals") @unittest.skipUnless(Device[Device.DEFAULT].renderer.has_shared, "test requires shared") def test_atomic_store_multireduce(self): @@ -1099,18 +1149,18 @@ class TestKernelOpts(unittest.TestCase): dummy = Tensor.rand(4,4,1).realize() r0,r1 = (a-dummy).sum(-1), b.sum(-1) ast = _temp_create_multireduce_ast(r0, r1, replace_idxs={2:r1}, merge=lambda r0,_: r0) - lins += helper_linearizer_ast(ast, [a,b], [[Opt(OptOps.GROUP, 0, 2)]]) + lins += helper_linearizer_ast(ast, [a], [[Opt(OptOps.GROUP, 0, 2)]]) for k in lins: - local_bufs = [u for u in k.uops if u.uop is UOps.DEFINE_LOCAL] seen_bar = False for u in k.uops: if u.uop is UOps.BARRIER: assert not seen_bar, "redudant barrier" seen_bar = True - elif (u.uop is UOps.LOAD or u.uop is UOps.STORE) and any([v in local_bufs for v in u.vin]): seen_bar = False + elif (u.uop is UOps.LOAD or u.uop is UOps.STORE): seen_bar = False - @unittest.skip("multireduce isn't supported yet") + @unittest.skip("TODO: broken") + @unittest.skipIf(CI and Device.DEFAULT in {"AMD"}, "AMD CI is really slow here") @unittest.skipUnless(Device[Device.DEFAULT].renderer.has_local, "test requires locals") @unittest.skipUnless(Device[Device.DEFAULT].renderer.has_shared, "test requires shared") def test_atomic_store_unrolled_multireduce(self): @@ -1124,15 +1174,14 @@ class TestKernelOpts(unittest.TestCase): ]) for k in lins: - local_bufs = [u for u in k.uops if u.uop is UOps.DEFINE_LOCAL] seen_bar = False for u in k.uops: if u.uop is UOps.BARRIER: assert not seen_bar, "redudant barrier" seen_bar = True - elif (u.uop is UOps.LOAD or u.uop is UOps.STORE) and any([v in local_bufs for v in u.vin]): seen_bar = False + elif (u.uop is UOps.LOAD or u.uop is UOps.STORE): seen_bar = False - @unittest.skip("multireduce isn't supported yet") + @unittest.skipIf(CI and Device.DEFAULT in {"AMD"}, "AMD CI is really slow here") @unittest.skipUnless(Device[Device.DEFAULT].renderer.has_local, "test requires locals") @unittest.skipUnless(Device[Device.DEFAULT].renderer.has_shared, "test requires shared") def test_atomic_store_nested_range_multireduce(self): @@ -1150,13 +1199,12 @@ class TestKernelOpts(unittest.TestCase): ]) for k in lins: - local_bufs = [u for u in k.uops if u.uop is UOps.DEFINE_LOCAL] seen_bar = False for u in k.uops: if u.uop is UOps.BARRIER: assert not seen_bar, "redudant barrier" seen_bar = True - elif (u.uop is UOps.LOAD or u.uop is UOps.STORE) and any([v in local_bufs for v in u.vin]): seen_bar = False + elif (u.uop is UOps.LOAD or u.uop is UOps.STORE): seen_bar = False def test_upcasts(self): N = 16 @@ -1208,7 +1256,7 @@ class TestKernelOpts(unittest.TestCase): [Opt(OptOps.LOCAL, 0, 4), Opt(OptOps.LOCAL, 0, 4), Opt(OptOps.GROUPTOP, 0, 8), Opt(OptOps.UNROLL, 0, 4), Opt(OptOps.UPCAST, 0, 8)], ]) - @unittest.skip("multireduce isn't supported yet") + @unittest.skipIf(CI and Device.DEFAULT in {"AMD"}, "AMD CI is really slow here") @unittest.skipUnless(Device[Device.DEFAULT].renderer.has_local, "test requires locals") @unittest.skipUnless(Device[Device.DEFAULT].renderer.has_shared, "test requires shared") def test_matmul_multireduce(self): @@ -1268,7 +1316,7 @@ class TestKernelOpts(unittest.TestCase): Opt(OptOps.UPCAST, 0, 2)], # No globals ]) - @unittest.skip("multireduce isn't supported yet") + @unittest.skipIf(CI and Device.DEFAULT in {"AMD"}, "AMD CI is really slow here") @unittest.skipUnless(Device[Device.DEFAULT].renderer.has_local, "test requires locals") @unittest.skipUnless(Device[Device.DEFAULT].renderer.has_shared, "test requires shared") def test_double_reduce_multireduce(self): @@ -1322,7 +1370,7 @@ class TestKernelOpts(unittest.TestCase): with self.assertRaises(KernelOptError): k.apply_opt(Opt(OptOps.TC, 0, 1)) - @unittest.skip("multireduce isn't supported yet") + @unittest.skip("TODO: update TC tests") @unittest.skipUnless(Device[Device.DEFAULT].renderer.tensor_cores, "test requires tensor cores") def test_invalid_fused_tensor_core(self): Tensor.manual_seed(1552) @@ -1371,8 +1419,7 @@ class TestKernelOpts(unittest.TestCase): # [Opt(OptOps.GROUP, 0, 2)] # doesn't work because group_for_reduce dims become early locals (conflicting with TC) ], apply_tc=True, atol=atol, rtol=rtol) - # NOTE: indexing issue happens here - @unittest.skip("multireduce isn't supported yet") + @unittest.skip("TODO: update TC tests") @unittest.skipUnless(Device[Device.DEFAULT].renderer.tensor_cores, "test requires tensor cores") def test_fused_tensor_core_simple(self): N = 64 @@ -1384,7 +1431,7 @@ class TestKernelOpts(unittest.TestCase): r1 = c.matmul(d, acc_dtype=tc.dtype_out) check_fused_tc_opt(tc, r0, r1, [a, b, c, d]) - @unittest.skip("multireduce isn't supported yet") + @unittest.skip("TODO: update TC tests") @unittest.skipUnless(Device[Device.DEFAULT].renderer.tensor_cores, "test requires tensor cores") def test_fused_tensor_core_permuted(self): N = 64 @@ -1504,7 +1551,7 @@ class TestKernelOpts(unittest.TestCase): [Opt(OptOps.PADTO, 0, 32), Opt(OptOps.UPCAST, 0, 8),], ]) - @unittest.skip("multireduce isn't supported yet") + @unittest.skipIf(CI and Device.DEFAULT in {"AMD"}, "AMD CI is really slow here") def test_padto_sum_multireduce(self): Tensor.manual_seed(0) N = 17 @@ -1520,15 +1567,14 @@ class TestKernelOpts(unittest.TestCase): helper_linearizer_ast(ast((1, ), (17, 1)), [x], opts=opts, wanna_output=[(x.numpy()-x.numpy().sum(axis=1,keepdims=True)).sum(1)]) # pad reduce axis TODO: broken - with self.assertRaises(AssertionError): - expected = (x.numpy()-x.numpy().sum(axis=0,keepdims=True)).sum(0) - helper_linearizer_ast(ast((0, ), (1, 17)), [x], opts=[[Opt(OptOps.PADTO, 1, 32)]], wanna_output=[expected]) + unittest.skipIf(CI and Device.DEFAULT in {"AMD"}, "AMD CI is really slow here") + expected = (x.numpy()-x.numpy().sum(axis=0,keepdims=True)).sum(0) + helper_linearizer_ast(ast((0, ), (1, 17)), [x], opts=[[Opt(OptOps.PADTO, 1, 32)]], wanna_output=[expected]) - with self.assertRaises(AssertionError): - op = LazyOp(op=BufferOps.STORE, src=(LazyOp(op=ReduceOps.SUM, src=(LazyOp(op=BinaryOps.SUB, src=(x_ld,LazyOp(op=ReduceOps.SUM, src=(x_ld,), arg=(0,1)))),), arg=(0,1)),), arg=MemBuffer(idx=0, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(1, 1), strides=(0, 1), offset=0, mask=None, contiguous=True),)))) # noqa: E501 - helper_linearizer_ast((op,), [x], opts=[[Opt(OptOps.PADTO, 0, 32)],], wanna_output=[(x.numpy()-x.numpy().sum(keepdims=True)).sum()]) + op = LazyOp(op=BufferOps.STORE, src=(LazyOp(op=ReduceOps.SUM, src=(LazyOp(op=BinaryOps.SUB, src=(x_ld,LazyOp(op=ReduceOps.SUM, src=(x_ld,), arg=(0,1)))),), arg=(0,1)),), arg=MemBuffer(idx=0, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(1, 1), strides=(0, 1), offset=0, mask=None, contiguous=True),)))) # noqa: E501 + helper_linearizer_ast((op,), [x], opts=[[Opt(OptOps.PADTO, 0, 32)],], wanna_output=[(x.numpy()-x.numpy().sum(keepdims=True)).sum()]) - @unittest.skip("multireduce isn't supported yet") + @unittest.skipIf(CI and Device.DEFAULT in {"AMD"}, "AMD CI is really slow here") def test_padto_max_multireduce(self): Tensor.manual_seed(0) N = 17 @@ -1543,7 +1589,7 @@ class TestKernelOpts(unittest.TestCase): helper_linearizer_ast(ast((0, ), (1, 17)), [x], opts=opts, wanna_output=[(x.numpy()+x.numpy().max(axis=0,keepdims=True)).max(0)]) helper_linearizer_ast(ast((1, ), (17, 1)), [x], opts=opts, wanna_output=[(x.numpy()+x.numpy().max(axis=1,keepdims=True)).max(1)]) - @unittest.skip("multireduce isn't supported yet") + @unittest.skipIf(CI and Device.DEFAULT in {"AMD"}, "AMD CI is really slow here") def test_padto_where_multireduce(self): # we need to make sure the ternary operators nest properly N = 17 @@ -1560,16 +1606,17 @@ class TestKernelOpts(unittest.TestCase): wanna_output = np.where(0.5*17 < (x.numpy()+np.where(0.75*17 < x.numpy().sum(axis=0,keepdims=True), a.numpy(), b.numpy())).sum(axis=0),0.0,1.0) ast = LazyOp(op=BufferOps.STORE, src=(LazyOp(op=TernaryOps.WHERE, src=(LazyOp(op=BinaryOps.CMPLT, src=(LazyOp(op=BufferOps.CONST, src=(), arg=ConstBuffer(val=0.5*17, dtype=dtypes.float, st=ShapeTracker.from_shape((1,1)))),LazyOp(op=ReduceOps.SUM, src=(LazyOp(op=BinaryOps.ADD, src=(LazyOp(op=BufferOps.LOAD, src=(), arg=MemBuffer(idx=1, dtype=dtypes.float, st=ShapeTracker.from_shape((N,N)))),LazyOp(op=TernaryOps.WHERE, src=(LazyOp(op=BinaryOps.CMPLT, src=(LazyOp(op=BufferOps.CONST, src=(), arg=ConstBuffer(val=0.75*17, dtype=dtypes.float, st=ShapeTracker.from_shape((1,1)))),LazyOp(op=ReduceOps.SUM, src=(LazyOp(op=BufferOps.LOAD, src=(), arg=MemBuffer(idx=1, dtype=dtypes.float, st=ShapeTracker.from_shape((N,N)))),), arg=(0,)))),LazyOp(op=BufferOps.LOAD, src=(), arg=MemBuffer(idx=2, dtype=dtypes.float, st=ShapeTracker.from_shape((1,1)))),LazyOp(op=BufferOps.LOAD, src=(), arg=MemBuffer(idx=3, dtype=dtypes.float, st=ShapeTracker.from_shape((1,1)))),)),)),), arg=(0,)),)),LazyOp(op=BufferOps.CONST, src=(), arg=ConstBuffer(val=0.0, dtype=dtypes.float, st=ShapeTracker.from_shape((1,1)))),LazyOp(op=BufferOps.CONST, src=(), arg=ConstBuffer(val=1.0, dtype=dtypes.float, st=ShapeTracker.from_shape((1,1)))),)),), arg=MemBuffer(idx=0, dtype=dtypes.float, st=ShapeTracker.from_shape((1,N)))) # noqa: E501 helper_linearizer_ast((ast,), [x,a,b], opts=opts, wanna_output=[wanna_output]) + # pad reduce axis + unittest.skipIf(CI and Device.DEFAULT in {"AMD"}, "AMD CI is really slow here") helper_linearizer_ast((ast,), [x,a,b], opts=[[Opt(OptOps.PADTO, 1, 32)],], wanna_output=[wanna_output]) wanna_output = np.where(0.5*17 < (x.numpy()+np.where(0.75*17 < x.numpy().sum(keepdims=True), a.numpy(), b.numpy())).sum(keepdims=True),0.0,1.0) ast = LazyOp(op=BufferOps.STORE, src=(LazyOp(op=TernaryOps.WHERE, src=(LazyOp(op=BinaryOps.CMPLT, src=(LazyOp(op=BufferOps.CONST, src=(), arg=ConstBuffer(val=0.5*17, dtype=dtypes.float, st=ShapeTracker.from_shape((1,1)))),LazyOp(op=ReduceOps.SUM, src=(LazyOp(op=BinaryOps.ADD, src=(LazyOp(op=BufferOps.LOAD, src=(), arg=MemBuffer(idx=1, dtype=dtypes.float, st=ShapeTracker.from_shape((N,N)))),LazyOp(op=TernaryOps.WHERE, src=(LazyOp(op=BinaryOps.CMPLT, src=(LazyOp(op=BufferOps.CONST, src=(), arg=ConstBuffer(val=0.75*17, dtype=dtypes.float, st=ShapeTracker.from_shape((1,1)))),LazyOp(op=ReduceOps.SUM, src=(LazyOp(op=BufferOps.LOAD, src=(), arg=MemBuffer(idx=1, dtype=dtypes.float, st=ShapeTracker.from_shape((N,N)))),), arg=(0,1,)))),LazyOp(op=BufferOps.LOAD, src=(), arg=MemBuffer(idx=2, dtype=dtypes.float, st=ShapeTracker.from_shape((1,1)))),LazyOp(op=BufferOps.LOAD, src=(), arg=MemBuffer(idx=3, dtype=dtypes.float, st=ShapeTracker.from_shape((1,1)))),)),)),), arg=(0,1,)),)),LazyOp(op=BufferOps.CONST, src=(), arg=ConstBuffer(val=0.0, dtype=dtypes.float, st=ShapeTracker.from_shape((1,1)))),LazyOp(op=BufferOps.CONST, src=(), arg=ConstBuffer(val=1.0, dtype=dtypes.float, st=ShapeTracker.from_shape((1,1)))),)),), arg=MemBuffer(idx=0, dtype=dtypes.float, st=ShapeTracker.from_shape((1,1)))) # noqa: E501 helper_linearizer_ast((ast,), [x,a,b], opts=[[Opt(OptOps.PADTO, 0, 32)],], wanna_output=[wanna_output.flatten()]) - @unittest.skip("multireduce isn't supported yet") def test_padto_matmul_multireduce(self): - if CI and Device.DEFAULT in ["AMD", "NV"]: self.skipTest("super slow on CUDA and AMD because of the big grid dims") + if CI and Device.DEFAULT in ["AMD", "NV", "CUDA"]: self.skipTest("super slow on CUDA and AMD because of the big grid dims") N = 17 * 17 Tensor.manual_seed(289) a = Tensor.rand(N, N).realize() diff --git a/test/test_linearizer_failures.py b/test/test_linearizer_failures.py index 65f0472c24..6f57b2030a 100644 --- a/test/test_linearizer_failures.py +++ b/test/test_linearizer_failures.py @@ -100,6 +100,11 @@ class TestLinearizerFailures(unittest.TestCase): opts = [Opt(op=OptOps.PADTO, axis=0, amt=32), Opt(op=OptOps.GROUP, axis=0, amt=4)] helper_test_lin(Linearizer(ast), opts, failed_platforms=[]) + def test_failure_12_multireduce(self): + ast = LazyOp(op=BufferOps.STORE, src=(LazyOp(op=ReduceOps.SUM, src=(LazyOp(op=BinaryOps.SUB, src=(LazyOp(op=BinaryOps.ADD, src=(LazyOp(op=BinaryOps.MUL, src=(LazyOp(op=BinaryOps.ADD, src=(LazyOp(op=BufferOps.LOAD, src=(), arg=MemBuffer(idx=1, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(2, 1, 4, 1, 3, 4, 2, 6, 1, 3), strides=(0, 0, 0, 0, 0, 18, 0, 3, 0, 1), offset=0, mask=None, contiguous=False),)))), LazyOp(op=BufferOps.CONST, src=(), arg=ConstBuffer(val=1.0, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(2, 1, 4, 1, 3, 4, 2, 6, 1, 3), strides=(0, 0, 0, 0, 0, 0, 0, 0, 0, 0), offset=0, mask=None, contiguous=False),))))), arg=None), LazyOp(op=BufferOps.LOAD, src=(), arg=MemBuffer(idx=2, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(2, 1, 4, 1, 3, 4, 2, 6, 1, 3), strides=(0, 0, 0, 0, 0, 0, 0, 0, 0, 0), offset=0, mask=None, contiguous=False),))))), arg=None), LazyOp(op=BinaryOps.MUL, src=(LazyOp(op=BinaryOps.ADD, src=(LazyOp(op=BufferOps.LOAD, src=(), arg=MemBuffer(idx=1, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(2, 1, 4, 1, 3, 4, 2, 6, 1, 3), strides=(0, 0, 0, 0, 0, 18, 0, 3, 0, 1), offset=0, mask=None, contiguous=False),)))), LazyOp(op=BufferOps.CONST, src=(), arg=ConstBuffer(val=1.0, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(2, 1, 4, 1, 3, 4, 2, 6, 1, 3), strides=(0, 0, 0, 0, 0, 0, 0, 0, 0, 0), offset=0, mask=None, contiguous=False),))))), arg=None), LazyOp(op=BufferOps.LOAD, src=(), arg=MemBuffer(idx=2, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(2, 1, 4, 1, 3, 4, 2, 6, 1, 3), strides=(0, 0, 0, 0, 0, 0, 0, 0, 0, 0), offset=0, mask=None, contiguous=False),))))), arg=None)), arg=None),LazyOp(op=ReduceOps.SUM, src=(LazyOp(op=BinaryOps.ADD, src=(LazyOp(op=BinaryOps.MUL, src=(LazyOp(op=BinaryOps.ADD, src=(LazyOp(op=BufferOps.LOAD, src=(), arg=MemBuffer(idx=1, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(2, 1, 4, 1, 3, 4, 2, 6, 1, 3), strides=(0, 0, 0, 0, 0, 18, 0, 3, 0, 1), offset=0, mask=None, contiguous=False),)))), LazyOp(op=BufferOps.CONST, src=(), arg=ConstBuffer(val=1.0, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(2, 1, 4, 1, 3, 4, 2, 6, 1, 3), strides=(0, 0, 0, 0, 0, 0, 0, 0, 0, 0), offset=0, mask=None, contiguous=False),))))), arg=None), LazyOp(op=BufferOps.LOAD, src=(), arg=MemBuffer(idx=2, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(2, 1, 4, 1, 3, 4, 2, 6, 1, 3), strides=(0, 0, 0, 0, 0, 0, 0, 0, 0, 0), offset=0, mask=None, contiguous=False),))))), arg=None), LazyOp(op=BinaryOps.MUL, src=(LazyOp(op=BinaryOps.ADD, src=(LazyOp(op=BufferOps.LOAD, src=(), arg=MemBuffer(idx=1, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(2, 1, 4, 1, 3, 4, 2, 6, 1, 3), strides=(0, 0, 0, 0, 0, 18, 0, 3, 0, 1), offset=0, mask=None, contiguous=False),)))), LazyOp(op=BufferOps.CONST, src=(), arg=ConstBuffer(val=1.0, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(2, 1, 4, 1, 3, 4, 2, 6, 1, 3), strides=(0, 0, 0, 0, 0, 0, 0, 0, 0, 0), offset=0, mask=None, contiguous=False),))))), arg=None), LazyOp(op=BufferOps.LOAD, src=(), arg=MemBuffer(idx=2, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(2, 1, 4, 1, 3, 4, 2, 6, 1, 3), strides=(0, 0, 0, 0, 0, 0, 0, 0, 0, 0), offset=0, mask=None, contiguous=False),))))), arg=None)), arg=None),), arg=(0, 2, 4, 8)),)),), arg=(0, 2, 4, 8)),), arg=MemBuffer(idx=0, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(1, 1, 1, 1, 1, 4, 1, 6, 1, 3), strides=(0, 0, 0, 0, 0, 18, 0, 3, 0, 1), offset=0, mask=None, contiguous=True),)),)) + opts = [Opt(op=OptOps.PADTO, axis=0, amt=32), Opt(op=OptOps.GROUP, axis=0, amt=4)] + helper_test_lin(Linearizer(ast), opts, failed_platforms=[]) + # both kernels are correct from a code standpoint, but generate different results due to precision errors (switching to float results in output matches) def test_failure_13(self): ast = LazyOp(op=BufferOps.STORE, src=(LazyOp(op=BinaryOps.ADD, src=(LazyOp(op=ReduceOps.SUM, diff --git a/tinygrad/codegen/kernel.py b/tinygrad/codegen/kernel.py index a12ecf1348..7a3ec285ca 100644 --- a/tinygrad/codegen/kernel.py +++ b/tinygrad/codegen/kernel.py @@ -59,19 +59,17 @@ class Kernel: self.ast = ast self.lazyops = flatten([op.lazyops for op in self.ast]) - # there's only allowed to be one reduceop cached_ordered_lazyops: Dict[LazyOp, List[LazyOp]] = {} def ordered_lazyops(op): if op not in cached_ordered_lazyops: cached_ordered_lazyops[op] = dedup([item for x in op.src for item in ordered_lazyops(x)] + [op]) return cached_ordered_lazyops[op] self.reduceops = dedup([x for out in self.ast for x in ordered_lazyops(out) if x.op in ReduceOps]) - assert len(self.reduceops) < 2, "Only one reduceop allowed" self.outbufs, self.vars = [x.arg for x in self.ast], flatten([x.vars() for x in self.ast]) loadops = [BufferOps.LOAD, BufferOps.CONST] self.bufs: List[Union[MemBuffer, ConstBuffer, LocalBuffer]] = self.outbufs + dedup([x.arg for x in self.lazyops if x.op in loadops]) - # get earlybufs, before the one reduce op + # get earlybufs, before any reduceops self.earlybufs = [x.arg for reduceop in self.reduceops for x in reduceop.lazyops if x.op in BufferOps] self.full_buf_index: int = self.bufs.index(self.earlybufs[0]) if self.earlybufs else 0 diff --git a/tinygrad/codegen/linearizer.py b/tinygrad/codegen/linearizer.py index 24f93b0f6d..cc7ddc3510 100644 --- a/tinygrad/codegen/linearizer.py +++ b/tinygrad/codegen/linearizer.py @@ -221,10 +221,10 @@ class Linearizer(Kernel): global_idxs, local_idxs, upcast_idxs, full_upcast_idxs, reduce_idxs, fake_reduce_idxs, alias_buf_idxs:List[Tuple[int, int, List]]) -> Tuple[List[NumNode|Variable], List[NumNode|Variable]]: # reduce loop - loop_ctx = self.render_loop(reduce_idxs, 2) + loop_ctx = self.render_loop(reduce_idxs, (i:=self.reduceops.index(reduceop))*2+2) # define accumulator - modify idxs if necessary for TC - out_buf = -len(self.reduceops)+self.reduceops.index(reduceop) if self.group_for_reduces else 0 + out_buf = -len(self.reduceops)+i if self.group_for_reduces else 0 accs[reduceop] = self.global_load(out_buf, global_idxs+local_idxs+fake_reduce_idxs+upcast_idxs, acc=reduceop, loop_ctx=loop_ctx) # store local aliases @@ -260,8 +260,13 @@ class Linearizer(Kernel): loaded_buffers.update({b:self.global_load(self.bufs.index(self.local_alias[reduceop][i]) if i in self.local_alias else i, global_idxs+local_idxs+reduce_idxs+full_upcast_idxs) for i,b in enumerate(self.bufs) if b in self.earlybufs}) + def gate_acc(r, idxs): return [ + UOp.alu(TernaryOps.WHERE, valid.render(self.render_ops, self), acc, self.const(0, r.dtype)) if valid.min == 0 and valid.max == 1 else acc + for valid, acc in zip(expand_node(self.sts[self.full_buf_index].expr_idxs(idxs)[1], expand_idxs(idxs)), accs[r])] + local_accs = {r: gate_acc(r,global_idxs+local_idxs+reduce_idxs+full_upcast_idxs) for r in accs} + # run early AST (with reduce) - self.ast_parse(reduceop, accs, self.acc_offsets(self.full_buf_index), loaded_buffers, reduce_acc=accs[reduceop]) + self.ast_parse(reduceop, local_accs, self.acc_offsets(self.full_buf_index), loaded_buffers, reduce_acc=accs[reduceop]) # end the reduce loop self.load_cache.clear() @@ -294,13 +299,13 @@ class Linearizer(Kernel): # NOTE: this structure is the same as the reduce op above # late reduce loop - loop_ctx = self.render_loop(end_local_idxs, 3) + loop_ctx = self.render_loop(end_local_idxs, i*2+3) # define late accumulator accs[reduceop] = self.global_load(0, fake_global_idxs+local_idxs+fake_reduce_idxs+upcast_idxs, acc=reduceop, loop_ctx=loop_ctx) # load localbufs - loaded_buffers[self.bufs[out_buf]] = self.global_load(-1, fake_global_idxs+local_idxs+fake_reduce_idxs+upcast_idxs, barrier=barrier) + loaded_buffers[self.bufs[out_buf]] = self.global_load(out_buf, fake_global_idxs+local_idxs+fake_reduce_idxs+upcast_idxs, barrier=barrier) # there's no AST here (and there's no shape for the reduce LazyOp) self.ast_parse(LazyOp(reduceop.op, (LazyOp(BufferOps.LOAD, (), self.bufs[out_buf]),)),\ @@ -309,8 +314,15 @@ class Linearizer(Kernel): # end the late reduce loop self.load_cache.clear() - # all local indices which were used for group_for_reduce are not valid any more and should be replaced with fake NumNode(0), since they have - # been rewritten with fake end_local_idxs. + if reduceop is not self.reduceops[-1]: + for j in self.upcast_in_mid_reduce_axes: + self.upcasted -= 1 + self.group_for_reduces += 1 + assert self.buf_uops[out_buf] is not None, "Local reduce buf must have been uoped at this point" + fake_local_idxs = local_idxs[:self.local_dims] + [x*0 for x in local_idxs[self.local_dims:]] + stores = self.global_store(out_buf, fake_global_idxs+fake_local_idxs+fake_reduce_idxs+upcast_idxs, accs[reduceop]) + barrier = UOp(UOps.BARRIER, None, tuple(stores)) + accs[reduceop] = self.global_load(out_buf, fake_global_idxs+fake_local_idxs+fake_reduce_idxs+upcast_idxs, barrier=barrier) return local_idxs[:self.local_dims] + [NumNode(0) for _ in range(self.group_for_reduces)], upcast_idxs kernel_cnt: Final[DefaultDict[str, int]] = defaultdict(int) @@ -429,8 +441,12 @@ class Linearizer(Kernel): if len(reduceops) != 0: # TODO: delete render_reduceop and move the logic for group_for_reduces to Block - local_idxs[:], upcast_idxs[:] = self.render_reduceop((r:=reduceops[0]),accs,loaded_buffers,global_idxs,local_idxs,upcast_idxs, full_upcast_idxs, - reduce_idxs,fake_reduce_idxs,alias_buf_idxs[r]) + nlidx, nuidx = self.render_reduceop((r:=reduceops[0]),accs,loaded_buffers,\ + global_idxs,local_idxs,upcast_idxs,full_upcast_idxs,reduce_idxs,fake_reduce_idxs,alias_buf_idxs[r]) + + # all local indices which were used for group_for_reduce are not valid any more and should be replaced with fake NumNode(0), since they have + # been rewritten with fake end_local_idxs. + if r is self.reduceops[-1]: local_idxs[:], upcast_idxs[:] = nlidx, nuidx return [accs[r]] # load latebufs @@ -448,8 +464,7 @@ class Linearizer(Kernel): return [UOp(UOps.BITCAST if x.op is UnaryOps.BITCAST else UOps.CAST, self.get_base_dtype(x.arg), (u,)) for u in self.ast_parse(x.src[0], accs, offs, loaded_buffers)] if x.op in ReduceOps and reduce_acc is None: - assert offs is None, "not available if we aren't doing reduce" - return accs[x] + return [accs[x][i] for i in offs] if offs else accs[x] values = [self.ast_parse(v, accs, offs, loaded_buffers, cache=cache) for v in x.src] ops = {ReduceOps.SUM:BinaryOps.ADD, ReduceOps.MAX:BinaryOps.MAX} diff --git a/tinygrad/codegen/uops.py b/tinygrad/codegen/uops.py index 5e2638111a..b8100a5a8c 100644 --- a/tinygrad/codegen/uops.py +++ b/tinygrad/codegen/uops.py @@ -364,10 +364,10 @@ class UOpGraph: add_parents(sink) @functools.lru_cache(None) - def get_recursive_children(x:UOp, include_self=False) -> Set[UOp]: + def get_recursive_children(x:UOp, include_self=False, stop=lambda x,u: True) -> Set[UOp]: if x.uop is UOps.SINK: return set() - return set.union(set((x,)) if include_self else set(), *([get_recursive_children(u, True) for u in graph[x]] if x.uop is not UOps.PHI else [])) - loops_children = {l:get_recursive_children(l) for l in loops[::-1]} + return set.union(set((x,)) if include_self else set(), *([get_recursive_children(u, True, stop) for u in graph[x] if stop(x,u)])) + loops_children = {l:get_recursive_children(l, stop=lambda x,u: x.uop is not UOps.PHI) for l in loops[::-1]} queue: List = [] def push(u): @@ -384,6 +384,8 @@ class UOpGraph: from test.external.fuzz_uops import fuzz_uops self.fuzz_paths = fuzz_uops(graph, in_degree.copy(), loops_children) + # find all ifs children, add them to the loops children so we can add ends + scope_children = {**loops_children, **{u:get_recursive_children(u, stop=lambda x,u: u.uop is not UOps.BARRIER) for u in ifs[::-1]}} self._uops = [] while queue: p,x = heapq.heappop(queue) @@ -393,10 +395,10 @@ class UOpGraph: self._uops.insert(idx, x) else: self._uops.append(x) - for u, ss in loops_children.items(): + for u, ss in scope_children.items(): if x in ss: ss.remove(x) - if len(ss) == 0: self._uops.append(UOp(UOps.ENDRANGE, None, (u,))) + if len(ss) == 0: self._uops.append(UOp({UOps.RANGE:UOps.ENDRANGE, UOps.IF:UOps.ENDIF}[u.uop], None, (u,))) for u in graph[x]: in_degree[u] -= 1 if in_degree[u] == 0: push(u) @@ -404,9 +406,6 @@ class UOpGraph: assert self._uops[-1].uop is UOps.SINK, f"didn't end with SINK, ended with {self._uops[-1]}" self._uops = self._uops[:-1] - # TODO: ifs should be removed and just the store should be gated - for u in ifs[::-1]: self._uops.append(UOp(UOps.ENDIF, None, (u,))) - if type_verify: self.type_verify() def add(self, uop:UOps, dtype:Optional[DType]=None, vin:Tuple[UOp, ...]=tuple(), arg:Any=None) -> UOp: diff --git a/tinygrad/renderer/assembly.py b/tinygrad/renderer/assembly.py index c7ebac38c8..38df1efbd3 100644 --- a/tinygrad/renderer/assembly.py +++ b/tinygrad/renderer/assembly.py @@ -133,14 +133,14 @@ class PTXRenderer(Renderer): uop,dtype,vin,args = u.uop,u.dtype,u.vin,u.arg if uop is UOps.IF: assert vin[0].dtype is not None - kk(*self.render_bra(f"IF_{r[vin[0]][1:]}", _cast(r[vin[0]], dtypes.bool, vin[0].dtype, u=u, pred=True))) + kk(*self.render_bra(f"IF_{r[vin[0]][1:]}_{cast(List, uops._uops).index(u)}", _cast(r[vin[0]], dtypes.bool, vin[0].dtype, u=u, pred=True))) elif uop is UOps.BARRIER and self.barrier: kk(self.barrier) elif uop is UOps.ENDRANGE: kk(self.asm_for_op[BinaryOps.ADD](r[vin[0]], r[vin[0]], "1", dtypes.int, self.types[dtypes.int]), self.asm_for_op[BinaryOps.CMPLT](pred:=ssa("pred", dtype="pred"), r[vin[0]], r[vin[0].vin[1]], dtypes.int, self.types[dtypes.int])) kk(*self.render_bra(f"LOOP_{r[vin[0]][1:]}", pred)) elif uop is UOps.ENDIF: - kk(f"IF_{r[vin[0].vin[0]][1:]}:") + kk(f"IF_{r[vin[0].vin[0]][1:]}_{cast(List, uops._uops).index(vin[0])}:") elif uop is UOps.STORE: assert vin[0].dtype is not None and vin[2].dtype is not None assert vin[0].dtype == dtypes.int64, "store isn't int64"