diff --git a/docs/developer/layout.md b/docs/developer/layout.md index 2f9a53a4b1..ab7701fbde 100644 --- a/docs/developer/layout.md +++ b/docs/developer/layout.md @@ -22,12 +22,6 @@ Group UOps into kernels. Transforms the ast into an optimized ast. This is where BEAM search and heuristics live. -::: tinygrad.codegen.opt.get_optimized_ast - options: - members: false - show_labels: false - show_source: false - --- ## tinygrad/codegen diff --git a/test/opt/test_gen_float4.py b/test/opt/test_gen_float4.py new file mode 100644 index 0000000000..a23d633ad7 --- /dev/null +++ b/test/opt/test_gen_float4.py @@ -0,0 +1,229 @@ +import unittest +from tinygrad import Device, Tensor, dtypes +from tinygrad.uop.ops import UOp, Ops +from tinygrad.codegen.opt import Opt, OptOps +from tinygrad.shape.shapetracker import ShapeTracker, View +from tinygrad.engine.realize import get_program +from tinygrad.helpers import AMX + +@unittest.skipUnless(Device[Device.DEFAULT].renderer.supports_float4, "need backends that support float4") +class TestFloat4(unittest.TestCase): + @staticmethod + def count_float4(uops: list[UOp], n=4): + return (len([uop for uop in uops if uop.op is Ops.LOAD and uop.dtype == dtypes.float.vec(n)]), + len([uop for uop in uops if uop.op is Ops.STORE and uop.src[1].dtype == dtypes.float.vec(n)])) + @staticmethod + def count_half4(uops: list[UOp]): + return (len([uop for uop in uops if uop.op is Ops.LOAD and uop.dtype == dtypes.half.vec(4)]), + len([uop for uop in uops if uop.op is Ops.STORE and uop.src[1].dtype == dtypes.half.vec(4)])) + + def test_float4_basic(self): + a = Tensor.empty(2, 8).realize() + b = Tensor.empty(2, 8).realize() + c = a + b + + s = c.schedule()[0] + realized_ast = s.ast + opts_to_apply = [Opt(op=OptOps.UPCAST, axis=0, arg=4)] + program = get_program(realized_ast, Device[Device.DEFAULT].renderer, opts=opts_to_apply) + + assert TestFloat4.count_float4(program.uops) == (2, 1) + + @unittest.skipIf(Device.DEFAULT in {"CPU", "LLVM"} and AMX, "CPU with AMX upcasts float up to size 16") + def test_float4_multidim(self): + a = Tensor.empty(2, 8).realize() + b = Tensor.empty(2, 8).realize() + c = a + b + + s = c.schedule()[0] + uops = get_program(s.ast, opts=[Opt(op=OptOps.UPCAST, axis=0, arg=4), Opt(op=OptOps.UPCAST, axis=0, arg=2)]).uops + assert TestFloat4.count_float4(uops) == (4, 2) + + @unittest.skipUnless(Device.DEFAULT in {"CPU", "LLVM"} and AMX, "Only CPU with AMX upcasts float up to size 16") + def test_float4_multidim_amx(self): + def kernel_for_shape(size, shift): + a = Tensor.empty(2, size).realize() + b = Tensor.empty(2, size).realize() + c = a + b + + s = c.schedule()[0] + return get_program(s.ast, opts=[Opt(op=OptOps.UPCAST, axis=0, arg=4), Opt(op=OptOps.UPCAST, axis=0, arg=shift)]).uops + + sizes = [12, 8, 16] + shifts = [3, 2, 4] + expected_upcast_size = [4, 8, 16] + expected_output = [(6,3), (2,1), (2,1)] + + for i in range(len(sizes)): + assert TestFloat4.count_float4(kernel_for_shape(sizes[i], shifts[i]), expected_upcast_size[i]) == expected_output[i] + + def test_float4_unaligned_load(self): + a = Tensor.empty(9).realize().shrink(((1, 9),)) + b = Tensor.empty(9).realize().shrink(((1, 9),)) + c = a + b + + s = c.schedule()[0] + realized_ast = s.ast + opts_to_apply = [Opt(op=OptOps.UPCAST, axis=0, arg=4)] + program = get_program(realized_ast, Device[Device.DEFAULT].renderer, opts=opts_to_apply) + + assert TestFloat4.count_float4(program.uops) == (0, 1) + + @unittest.skipIf(Device.DEFAULT in {"CPU", "LLVM"} and AMX, "CPU with AMX upcasts float up to size 16") + def test_float4_multidim_unaligned_load(self): + a = Tensor.empty(2, 9).realize().shrink(((0, 2), (1, 9),)) + b = Tensor.empty(2, 9).realize().shrink(((0, 2), (1, 9),)) + c = a + b + + s = c.schedule()[0] + uops = get_program(s.ast, opts=[Opt(op=OptOps.UPCAST, axis=1, arg=4), Opt(op=OptOps.UPCAST, axis=1, arg=2)]).uops + + assert TestFloat4.count_float4(uops) == (0, 2) + + @unittest.skipUnless(Device.DEFAULT in {"CPU", "LLVM"} and AMX, "Only CPU with AMX upcasts float up to size 16") + def test_float4_multidim_unaligned_load_amx(self): + def kernel_for_shape(size, shift): + a = Tensor.empty(2, size).realize().shrink(((0, 2), (1, size),)) + b = Tensor.empty(2, size).realize().shrink(((0, 2), (1, size),)) + c = a + b + + s = c.schedule()[0] + return get_program(s.ast, opts=[Opt(op=OptOps.UPCAST, axis=1, arg=4), Opt(op=OptOps.UPCAST, axis=1, arg=shift)]).uops + + sizes = [13, 9, 17] + shifts = [3, 2, 4] + expected_upcast_size = [4, 8, 16] + expected_output = [(0,3), (0,1), (0,1)] + + for i in range(len(sizes)): + assert TestFloat4.count_float4(kernel_for_shape(sizes[i], shifts[i]), expected_upcast_size[i]) == expected_output[i] + + def test_float4_sometimes_unaligned(self): + a = Tensor.empty(1, 1, 8).realize() + b = Tensor.empty(1, 1, 5).realize().shrink(((0, 1), (0, 1), (1, 5))) + c = a.conv2d(b) + # only the first and last conv dot products are aligned in a, and b is never aligned, so no + # float4 should be emitted (the reduce axis of size 4 is the float4 axis here) + + s = c.schedule()[0] + uops = get_program(s.ast, opts=[Opt(op=OptOps.UNROLL, axis=0, arg=4)]).uops + + assert TestFloat4.count_float4(uops) == (0, 0) + + def test_float4_multidim_sometimes_unaligned(self): + a = Tensor.empty(1, 1, 7).realize() + b = Tensor.empty(1, 1, 5).realize().shrink(((0, 1), (0, 1), (1, 5))) + c = a.conv2d(b) + # the first conv dot product is aligned in a. If we upcast the output and reduce + # dimension, then we could do float4 for only that one set of loads, but we currently + # don't. + # UPDATE: now we do this fusion + + s = c.schedule()[0] + uops = get_program(s.ast, opts=[Opt(op=OptOps.UPCAST, axis=0, arg=0), Opt(op=OptOps.UNROLL, axis=0, arg=0)]).uops + + assert TestFloat4.count_float4(uops) in {(0,1), (1,1)} + + def test_float4_expand(self): + a = Tensor.empty(9).realize().shrink(((1, 9),)) + b = Tensor.empty(2).realize().reshape((2, 1)).expand((2,4)).reshape((8,)) + c = a + b + + # we will upcast the top axis of sz 4. they should not be coalesced into float4, + # since the top axis is not contiguous. + + s = c.schedule()[0] + uops = get_program(s.ast, opts=[Opt(op=OptOps.UPCAST, axis=0, arg=4)]).uops + + assert TestFloat4.count_float4(uops) == (0, 1) + + def test_float4_heterogeneous(self): + a = Tensor.empty(8).realize() + b = Tensor.empty(9).realize().shrink(((1, 9),)) + c = a + b + + # should float4 b but not a + + s = c.schedule()[0] + uops = get_program(s.ast, opts=[Opt(op=OptOps.UPCAST, axis=0, arg=4)]).uops + + assert TestFloat4.count_float4(uops) == (1, 1) + + def test_half4_load_unrolled(self): + # from llama 7B shard 4 gpus + ast = UOp(Ops.SINK, dtypes.void, arg=None, src=( + UOp(Ops.STORE, dtypes.void, arg=None, src=( + UOp(Ops.VIEW, dtypes.float.ptr(96000), arg=ShapeTracker(views=(View(shape=(1, 3, 32000, 1), strides=(0, 32000, 1, 0), offset=0, mask=None, contiguous=True),)), src=( # noqa: E501 + UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(96000), arg=0, src=()),)), + UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(Ops.ADD, (3,)), src=( + UOp(Ops.CAST, dtypes.float, arg=None, src=( + UOp(Ops.MUL, dtypes.half, arg=None, src=( + UOp(Ops.LOAD, dtypes.half, arg=None, src=( + UOp(Ops.VIEW, dtypes.half.ptr(9216), arg=ShapeTracker(views=(View(shape=(1, 3, 32000, 1024), strides=(0, 4096, 0, 1), offset=0, mask=None, contiguous=False),)), src=( # noqa: E501 + UOp(Ops.DEFINE_GLOBAL, dtypes.half.ptr(9216), arg=1, src=()),)),)), + UOp(Ops.LOAD, dtypes.half, arg=None, src=( + UOp(Ops.VIEW, dtypes.half.ptr(32768000), arg=ShapeTracker(views=(View(shape=(1, 3, 32000, 1024), strides=(0, 0, 1024, 1), offset=0, mask=None, contiguous=False),)), src=( # noqa: E501 + UOp(Ops.DEFINE_GLOBAL, dtypes.half.ptr(32768000), arg=2, src=()),)),)),)),)),)),)),)) + + # TODO: fix this, expected might change but should be positive + for expected, opts in [ + ((7, 0), [Opt(op=OptOps.UPCAST, axis=1, arg=4), Opt(op=OptOps.UPCAST, axis=0, arg=3), Opt(op=OptOps.UNROLL, axis=0, arg=4)]), + ((5, 0), [Opt(op=OptOps.UPCAST, axis=1, arg=4), Opt(op=OptOps.UNROLL, axis=0, arg=4)]), + ((2, 0), [Opt(op=OptOps.UNROLL, axis=0, arg=4)]), + ]: + program = get_program(ast, Device[Device.DEFAULT].renderer, opts=opts) + + count = TestFloat4.count_half4(program.uops) + assert count == expected, f"{count=}, {expected=}" + + @unittest.skip("this doesn't happen anymore") + def test_float4_acc(self): + # from float32 stable diffusion red tinybox + ast = UOp(Ops.SINK, dtypes.void, arg=None, src=( + UOp(Ops.STORE, dtypes.void, arg=None, src=( + UOp(Ops.VIEW, dtypes.float.ptr(33554432), arg=ShapeTracker(views=(View(shape=(1, 1, 128, 512, 512, 1, 1, 1), strides=(0, 0, 262144, 512, 1, 0, 0, 0), offset=0, mask=None, contiguous=True),)), src=( # noqa: E501 + UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(33554432), arg=0, src=()),)), + UOp(Ops.ADD, dtypes.float, arg=None, src=( + UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(Ops.ADD, (5, 6, 7)), src=( + UOp(Ops.MUL, dtypes.float, arg=None, src=( + UOp(Ops.LOAD, dtypes.float, arg=None, src=( + UOp(Ops.VIEW, dtypes.float.ptr(67108864), arg=ShapeTracker(views=(View(shape=(1, 1, 1, 256, 4, 514, 4, 514), strides=(0, 0, 0, 262144, 0, 512, 0, 1), offset=-513, mask=((0, 1), (0, 1), (0, 1), (0, 256), (0, 4), (1, 513), (0, 4), (1, 513)), contiguous=False), View(shape=(1, 1, 128, 512, 512, 256, 3, 3), strides=(0, 0, 0, 2056, 1, 4227136, 1058840, 515), offset=0, mask=None, contiguous=False))), src=( # noqa: E501 + UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(67108864), arg=1, src=()),)),)), + UOp(Ops.LOAD, dtypes.float, arg=None, src=( + UOp(Ops.VIEW, dtypes.float.ptr(294912), arg=ShapeTracker(views=(View(shape=(1, 1, 128, 512, 512, 256, 3, 3), strides=(0, 0, 2304, 0, 0, 9, 3, 1), offset=0, mask=None, contiguous=False),)), src=( # noqa: E501 + UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(294912), arg=2, src=()),)),)),)),)), + UOp(Ops.LOAD, dtypes.float, arg=None, src=( + UOp(Ops.VIEW, dtypes.float.ptr(128), arg=ShapeTracker(views=(View(shape=(1, 1, 128, 512, 512, 1, 1, 1), strides=(0, 0, 1, 0, 0, 0, 0, 0), offset=0, mask=None, contiguous=False),)), src=( # noqa: E501 + UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(128), arg=3, src=()),)),)),)),)),)) + + for expected, opts in [ + (1, [Opt(op=OptOps.UPCAST, axis=2, arg=4)]), + (4, [Opt(op=OptOps.UPCAST, axis=2, arg=4), Opt(op=OptOps.UPCAST, axis=0, arg=4)]), + ]: + program = get_program(ast, Device[Device.DEFAULT].renderer, opts=opts) + count = len([uop for uop in program.uops if uop.op is Ops.DEFINE_REG and uop.dtype == dtypes.float.vec(4)]) + assert count == expected, f"{count=}, {expected=}" + + @unittest.skip("this doesn't happen anymore") + def test_float2_acc(self): + # from resnet + ast = UOp(Ops.SINK, dtypes.void, arg=None, src=( + UOp(Ops.STORE, dtypes.void, arg=None, src=( + UOp(Ops.VIEW, dtypes.half.ptr(212926464), arg=ShapeTracker(views=(View(shape=(1, 256, 1, 64, 1, 114, 1, 114), strides=(0, 831744, 0, 12996, 0, 114, 0, 1), offset=0, mask=None, contiguous=True),)), src=( # noqa: E501 + UOp(Ops.DEFINE_GLOBAL, dtypes.half.ptr(212926464), arg=0, src=()),)), + UOp(Ops.CAST, dtypes.half, arg=None, src=( + UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(Ops.ADD, (4, 6)), src=( + UOp(Ops.CAST, dtypes.float, arg=None, src=( + UOp(Ops.LOAD, dtypes.half, arg=None, src=( + UOp(Ops.VIEW, dtypes.half.ptr(462422016), arg=ShapeTracker(views=(View(shape=(256, 64, 3, 56, 2, 3, 56, 2), strides=(1806336, 28224, 3, 504, 0, 1, 9, 0), offset=0, mask=((0, 256), (0, 64), (0, 3), (0, 56), (0, 1), (0, 3), (0, 56), (0, 1)), contiguous=False), View(shape=(256, 64, 3, 115, 3, 115), strides=(7225344, 112896, 37632, 336, 112, 1), offset=0, mask=((0, 256), (0, 64), (0, 3), (0, 112), (0, 3), (0, 112)), contiguous=False), View(shape=(256, 64, 456, 456), strides=(7617600, 119025, 345, 1), offset=0, mask=((0, 256), (0, 64), (0, 345), (0, 345)), contiguous=False), View(shape=(1, 256, 1, 64, 4, 114, 4, 114), strides=(0, 13307904, 0, 207936, 51984, 456, 114, 1), offset=0, mask=None, contiguous=True))), src=( # noqa: E501 + UOp(Ops.DEFINE_GLOBAL, dtypes.half.ptr(462422016), arg=1, src=()),)),)),)),)),)),)),)) + for expected, opts in [ + (16, [Opt(op=OptOps.LOCAL, axis=1, arg=16), Opt(op=OptOps.UPCAST, axis=1, arg=0), Opt(op=OptOps.UPCAST, axis=2, arg=2), Opt(op=OptOps.LOCAL, axis=2, arg=3), Opt(op=OptOps.UPCAST, axis=3, arg=4)]), # noqa: E501 + (4, [Opt(op=OptOps.LOCAL, axis=1, arg=16), Opt(op=OptOps.UPCAST, axis=1, arg=0), Opt(op=OptOps.UPCAST, axis=2, arg=2)]), + ]: + program = get_program(ast, Device[Device.DEFAULT].renderer, opts=opts) + count = len([uop for uop in program.uops if uop.op is Ops.DEFINE_REG and uop.dtype == dtypes.float.vec(2)]) + assert count == expected, f"{count=}, {expected=}" + +if __name__ == '__main__': + unittest.main() diff --git a/test/opt/test_kernel_opts.py b/test/opt/test_kernel_opts.py new file mode 100644 index 0000000000..301204726f --- /dev/null +++ b/test/opt/test_kernel_opts.py @@ -0,0 +1,326 @@ +import unittest +from tinygrad import Device, Tensor, dtypes +from tinygrad.helpers import CI +from tinygrad.codegen.opt import Opt, OptOps, KernelOptError + +# TODO: write a clean version of this +from test.test_linearizer import helper_linearizer_opt + +class TestKernelOpts(unittest.TestCase): + @unittest.skipUnless(Device[Device.DEFAULT].renderer.has_local, "test requires locals") + @unittest.skipUnless(Device[Device.DEFAULT].renderer.has_shared, "test requires shared") + def test_local_and_grouped_reduce(self): + N = 128 + Tensor.manual_seed(1882) + a = Tensor.rand(4, 4, N, N) + b = Tensor.rand(4, 4, N) + r = (b.sqrt() + ((a+1).sum(axis=3).exp())) + helper_linearizer_opt(r, [ + [Opt(OptOps.LOCAL, 0, 2)], + [Opt(OptOps.LOCAL, 0, 8)], + [Opt(OptOps.LOCAL, 0, 16)], # Checking how it works with locals + [Opt(OptOps.GROUPTOP, 0, 2)], + [Opt(OptOps.GROUPTOP, 0, 32)], + [Opt(OptOps.GROUPTOP, 0, 64)], # Checking how it works with grouped reduce + [Opt(OptOps.LOCAL, 0, 2), Opt(OptOps.GROUPTOP, 0, 2)], + [Opt(OptOps.LOCAL, 0, 16), Opt(OptOps.GROUPTOP, 0, 16)], + [Opt(OptOps.LOCAL, 0, 32), Opt(OptOps.GROUPTOP, 0, 2)], + # Checking how it works with locals + grouped reduce + [Opt(OptOps.LOCAL, 0, 2), Opt(OptOps.GROUPTOP, 0, 64)], + # Checking how it works with locals + grouped reduce + upcasts + [Opt(OptOps.LOCAL, 0, 2), Opt(OptOps.GROUPTOP, 0, 2), Opt(OptOps.UPCAST, 0, 8), Opt(OptOps.UNROLL, 1, 4)], + # many local + many group + [Opt(OptOps.GROUP, 0, 2)] * 4, + [Opt(OptOps.LOCAL, 0, 2)] * 4, + [Opt(OptOps.LOCAL, 0, 2), Opt(OptOps.GROUP, 0, 2)] * 4, + ]) + + def test_upcasts(self): + N = 16 + Tensor.manual_seed(1772) + a = Tensor.rand(N, N) + b = Tensor.rand(N, N) + r = (a+b).sqrt() * ((a+1).exp()) + helper_linearizer_opt(r, [ + [Opt(OptOps.UPCAST, 0, 2)], + [Opt(OptOps.UPCAST, 0, 4)], + [Opt(OptOps.UPCAST, 0, 8)], # Checking how it works with upcasts + ]) + + def test_full_upcast(self): + Tensor.manual_seed(1772) + a = Tensor.rand(4) + b = Tensor.rand(4) + r = (a+b).sqrt() * ((a+1).exp()) + helper_linearizer_opt(r, [ + [Opt(OptOps.UPCAST, 0, 4)], # Checking how it works with upcasts + ]) + + @unittest.skipUnless(Device[Device.DEFAULT].renderer.has_local, "test requires locals") + @unittest.skipUnless(Device[Device.DEFAULT].renderer.has_shared, "test requires shared") + def test_matmul(self): + N = 128 + Tensor.manual_seed(1552) + a = Tensor.rand(N, N) + b = Tensor.rand(N, N) + r = a@b + helper_linearizer_opt(r, [ + [Opt(OptOps.UPCAST, 0, 2)], + [Opt(OptOps.UPCAST, 0, 4), Opt(OptOps.UPCAST, 1, 4)], # Checking how it works with upcasts + [Opt(OptOps.LOCAL, 0, 2)], + [Opt(OptOps.LOCAL, 1, 32)], + [Opt(OptOps.LOCAL, 0, 4), Opt(OptOps.LOCAL, 1, 4)], + [Opt(OptOps.LOCAL, 0, 4), Opt(OptOps.LOCAL, 1, 32)], + [Opt(OptOps.LOCAL, 0, 16), Opt(OptOps.LOCAL, 1, 8)], # Checking how it works with locals + [Opt(OptOps.GROUPTOP, 0, 2)], + [Opt(OptOps.GROUPTOP, 0, 32)], + [Opt(OptOps.GROUPTOP, 0, 32), Opt(OptOps.UNROLL, 0, 4)], # Checking how it works with grouped_reduce + [Opt(OptOps.LOCAL, 0, 2), Opt(OptOps.LOCAL, 1, 2), Opt(OptOps.GROUPTOP, 0, 32)], + [Opt(OptOps.LOCAL, 0, 8), Opt(OptOps.GROUPTOP, 0, 32)], + [Opt(OptOps.LOCAL, 0, 4), Opt(OptOps.LOCAL, 0, 8), Opt(OptOps.GROUPTOP, 0, 4)], # Checking how it works with local+grouped_reduce + # Checking all together + [Opt(OptOps.LOCAL, 0, 4), Opt(OptOps.LOCAL, 0, 4), Opt(OptOps.GROUPTOP, 0, 8), Opt(OptOps.UNROLL, 0, 4), Opt(OptOps.UPCAST, 0, 4), + Opt(OptOps.UPCAST, 1, 2)], + # Full global upcast + local + [Opt(OptOps.LOCAL, 0, 4), Opt(OptOps.LOCAL, 0, 4), Opt(OptOps.GROUPTOP, 0, 8), Opt(OptOps.UNROLL, 0, 4), Opt(OptOps.UPCAST, 0, 8)], + ]) + + @unittest.skipUnless(Device[Device.DEFAULT].renderer.has_local, "test requires locals") + @unittest.skipUnless(Device[Device.DEFAULT].renderer.has_shared, "test requires shared") + def test_double_reduce(self): + N = 128 + Tensor.manual_seed(1552) + a = Tensor.rand(8, N, 8, N) + r = a.sum(axis=(1,3)) + helper_linearizer_opt(r, [ + # openCL / GPU=1 is 256 max threads + [Opt(OptOps.GROUPTOP, 0, 2)], [Opt(OptOps.GROUPTOP, 0, 32)], + [Opt(OptOps.GROUPTOP, 1, 2)], [Opt(OptOps.GROUPTOP, 1, 32)], # Checking how it works with 1 grouped_reduce. + [Opt(OptOps.GROUPTOP, 0, 2), Opt(OptOps.GROUPTOP, 1, 2)], + [Opt(OptOps.GROUPTOP, 0, 16), Opt(OptOps.GROUPTOP, 1, 2)], + [Opt(OptOps.GROUPTOP, 0, 4), Opt(OptOps.GROUPTOP, 1, 64)], # Checking how it works with 2 grouped_reduces. + [Opt(OptOps.GROUPTOP, 0, 16), Opt(OptOps.GROUPTOP, 1, 2), Opt(OptOps.UNROLL, 0, 4)], + [Opt(OptOps.GROUPTOP, 0, 2), Opt(OptOps.GROUPTOP, 1, 32), Opt(OptOps.UNROLL, 2, 4)], # Checking how it works with 2 grouped_reduces + upcasts. + [Opt(OptOps.LOCAL, 0, 4), Opt(OptOps.LOCAL, 1, 4), Opt(OptOps.GROUPTOP, 0, 4), Opt(OptOps.GROUPTOP, 1, 4)], + # Checking how it works with 2 grouped_reduces + upcasts + locals. + [Opt(OptOps.LOCAL, 0, 4), Opt(OptOps.LOCAL, 1, 4), Opt(OptOps.GROUPTOP, 0, 2), Opt(OptOps.GROUPTOP, 1, 32), Opt(OptOps.UNROLL, 1, 4)], + [Opt(OptOps.LOCAL, 0, 2), Opt(OptOps.LOCAL, 1, 2), Opt(OptOps.GROUPTOP, 0, 8), Opt(OptOps.GROUPTOP, 1, 4), Opt(OptOps.UPCAST, 0, 2)], + [Opt(OptOps.LOCAL, 0, 2), Opt(OptOps.LOCAL, 1, 2), Opt(OptOps.GROUPTOP, 0, 8), Opt(OptOps.GROUPTOP, 1, 4), Opt(OptOps.UPCAST, 0, 2), + Opt(OptOps.UNROLL, 0, 4), Opt(OptOps.UNROLL, 1, 4)], # Checking how it works with 2 grouped_reduces + upcasts + locals. + [Opt(OptOps.LOCAL, 0, 4), Opt(OptOps.LOCAL, 1, 4), Opt(OptOps.GROUPTOP, 0, 4), Opt(OptOps.GROUPTOP, 1, 4), Opt(OptOps.UPCAST, 0, 2), + Opt(OptOps.UPCAST, 0, 2)], # No globals + ]) + + @unittest.skipUnless(Device[Device.DEFAULT].renderer.tensor_cores, "test requires tensor cores") + @unittest.skipUnless(any(tc.dtype_in == tc.dtype_out == dtypes.half for tc in Device[Device.DEFAULT].renderer.tensor_cores), + "test requires tensor cores with accumulation in half") # testing with half suffices. + def test_tensor_core_opts(self): + N = 128 + Tensor.manual_seed(1552) + a, b = Tensor.rand(N, N, dtype=dtypes.half), Tensor.rand(N, N, dtype=dtypes.half) + r = a.matmul(b, dtype=dtypes.half) + atol, rtol = 0.25, 0.01 + helper_linearizer_opt(r, [ + [], + [Opt(OptOps.UPCAST, 0, 4)], + [Opt(OptOps.UPCAST, 1, 4)], + [Opt(OptOps.UPCAST, 0, 4), Opt(OptOps.UPCAST, 1, 4)], # check upcasts + [Opt(OptOps.UNROLL, 0, 2)], # check unroll + [Opt(OptOps.UPCAST, 0, 4), Opt(OptOps.UNROLL, 0, 2)], # check combo of unroll and local + [Opt(OptOps.UPCAST, 0, 4), Opt(OptOps.UPCAST, 1, 4), Opt(OptOps.UNROLL, 0, 2)], + [Opt(OptOps.UPCAST, 0, 4), Opt(OptOps.UPCAST, 1, 4), Opt(OptOps.UNROLL, 0, 4)], + [Opt(OptOps.UPCAST, 1, 4), Opt(OptOps.UPCAST, 0, 4)], # check permutations + [Opt(OptOps.UNROLL, 0, 2), Opt(OptOps.UPCAST, 0, 4)], + [Opt(OptOps.UPCAST, 0, 4), Opt(OptOps.UNROLL, 0, 2), Opt(OptOps.UPCAST, 1, 4)], + [Opt(OptOps.UNROLL, 0, 2), Opt(OptOps.UPCAST, 1, 4), Opt(OptOps.UPCAST, 0, 4), Opt(OptOps.UNROLL, 0, 4)], + ], apply_tc=True, atol=atol, rtol=rtol) + + @unittest.skipUnless(Device[Device.DEFAULT].renderer.tensor_cores, "test requires tensor cores") + @unittest.skipUnless(any(tc.dtype_in == tc.dtype_out == dtypes.half for tc in Device[Device.DEFAULT].renderer.tensor_cores), + "test requires tensor cores with accumulation in half") # testing with half suffices. + @unittest.skipUnless(Device[Device.DEFAULT].renderer.has_local, "test requires locals") + def test_tensor_core_opts_locals(self): + N = 128 + Tensor.manual_seed(1552) + a, b = Tensor.rand(N, N, dtype=dtypes.half), Tensor.rand(N, N, dtype=dtypes.half) + r = a.matmul(b, dtype=dtypes.half) + atol, rtol = 0.25, 0.01 + helper_linearizer_opt(r, [ + [Opt(OptOps.UNROLL, 0, 0)], # check full unroll of reduce with locals + [Opt(OptOps.LOCAL, 0, 4)], # check local + [Opt(OptOps.UPCAST, 0, 4), Opt(OptOps.UPCAST, 1, 4), Opt(OptOps.UNROLL, 0, 4), Opt(OptOps.LOCAL, 0, 2)], + [Opt(OptOps.LOCAL, 0, 2), Opt(OptOps.UPCAST, 1, 4), Opt(OptOps.UNROLL, 0, 2), Opt(OptOps.UPCAST, 0, 4)], + ], apply_tc=True, atol=atol, rtol=rtol) + + @unittest.skipUnless(Device[Device.DEFAULT].renderer.tensor_cores, "test requires tensor cores") + @unittest.skipUnless(Device[Device.DEFAULT].renderer.has_local, "test requires locals") + @unittest.skipUnless(Device[Device.DEFAULT].renderer.has_shared, "test requires shared memory") + @unittest.skipUnless(any(tc.dtype_in == tc.dtype_out == dtypes.half for tc in Device[Device.DEFAULT].renderer.tensor_cores), + "test requires tensor cores with accumulation in half") # testing with half suffices. + # NOTE: the METAL test is broken, likely due to a compiler bug. passes on CI with -O0 and with default opt level locally on M3 + @unittest.skipIf(Device.DEFAULT == "METAL", "broken for METAL") + @unittest.skip("feature was removed") + def test_tensor_core_opts_group(self): + N = 128 + Tensor.manual_seed(1552) + a, b = Tensor.rand(N, N, dtype=dtypes.half), Tensor.rand(N, N, dtype=dtypes.half) + r = a.matmul(b, dtype=dtypes.half) + atol, rtol = 0.25, 0.01 + helper_linearizer_opt(r, [ + [Opt(OptOps.GROUP, 0, 2)], + [Opt(OptOps.GROUPTOP, 0, 4)], + [Opt(OptOps.UPCAST, 0, 4), Opt(OptOps.GROUP, 0, 2)], + [Opt(OptOps.LOCAL, 0, 4), Opt(OptOps.GROUP, 0, 2)], + [Opt(OptOps.UNROLL, 0, 4), Opt(OptOps.GROUP, 0, 2)], + [Opt(OptOps.UPCAST, 0, 2), Opt(OptOps.LOCAL, 0, 2), Opt(OptOps.GROUP, 0, 2)], + [Opt(OptOps.LOCAL, 0, 2), Opt(OptOps.GROUPTOP, 0, 8), Opt(OptOps.UNROLL, 0, 2), Opt(OptOps.UPCAST, 1, 2)], + ], apply_tc=True, atol=atol, rtol=rtol) + + def test_padto_matmul(self): + if (CI and Device.DEFAULT in ["AMD", "NV", "CUDA"]): + self.skipTest("super slow on CUDA and AMD because of the big grid dims") + N = 17 * 17 + Tensor.manual_seed(289) + a = Tensor.rand(N, N) + b = Tensor.rand(N, N) + helper_linearizer_opt(a@b, [ + [Opt(OptOps.PADTO, 0, 32)], + [Opt(OptOps.PADTO, 1, 32)], + [Opt(OptOps.PADTO, 2, 32)], + [Opt(OptOps.PADTO, 0, 32), Opt(OptOps.PADTO, 1, 32)], + [Opt(OptOps.PADTO, 0, 32), Opt(OptOps.PADTO, 1, 32), Opt(OptOps.PADTO, 2, 32)], + # can optimize further post PADTO + [Opt(OptOps.PADTO, 0, 32), Opt(OptOps.PADTO, 1, 32), Opt(OptOps.UPCAST, 0, 2), Opt(OptOps.UPCAST, 1, 2),], + ]) + + def test_padto_upcasted_not_ok(self): + N = 4 + a = Tensor.rand(N, N) + b = Tensor.rand(N, N) + helper_linearizer_opt(a@b, [ + [Opt(OptOps.UPCAST, 0, 0)], + [Opt(OptOps.UPCAST, 1, 0)], + [Opt(OptOps.UNROLL, 0, 0)], + [Opt(OptOps.PADTO, 0, 8)], + [Opt(OptOps.PADTO, 1, 8)], + [Opt(OptOps.PADTO, 2, 8)], + ]) + with self.assertRaises(KernelOptError): + helper_linearizer_opt(a@b, [[Opt(OptOps.UPCAST, 0, 0), Opt(OptOps.PADTO, 1, 8)]]) + with self.assertRaises(KernelOptError): + helper_linearizer_opt(a@b, [[Opt(OptOps.UPCAST, 1, 0), Opt(OptOps.PADTO, 1, 8)]]) + with self.assertRaises(KernelOptError): + helper_linearizer_opt(a@b, [[Opt(OptOps.UNROLL, 0, 0), Opt(OptOps.PADTO, 2, 8)]]) + + def test_padto_sum_ok(self): + N = 18 * 18 + # NOTE: this setup prevents 17 * 17 contiguous merged into one dimension + a = Tensor.rand(N, N).realize().shrink(((0, 17), (0, 17))) * 100 + b = (Tensor.rand(N, N) < 0.5).realize().shrink(((0, 17), (0, 17))) + + helper_linearizer_opt(a.sum(0), [ + [Opt(OptOps.PADTO, 0, 32)], + [Opt(OptOps.PADTO, 0, 32), Opt(OptOps.UPCAST, 0, 8),], + ]) + helper_linearizer_opt(a.sum(1), [ + [Opt(OptOps.PADTO, 0, 32)], + [Opt(OptOps.PADTO, 0, 32), Opt(OptOps.UPCAST, 0, 8),], + ]) + + # can pad sum reduce axis if there's no unsafe ops prior to sum + for axis in (0, 1): + helper_linearizer_opt(a.sum(), [[Opt(OptOps.PADTO, axis, 32)],]) + helper_linearizer_opt(a.sum(0), [[Opt(OptOps.PADTO, axis, 32)],]) + helper_linearizer_opt(b.sum(), [[Opt(OptOps.PADTO, axis, 32)],]) + helper_linearizer_opt(b.sum(0), [[Opt(OptOps.PADTO, axis, 32)],]) + helper_linearizer_opt(b.sum(dtype=dtypes.bool), [[Opt(OptOps.PADTO, axis, 32)],]) + # TODO: why? + if Device.DEFAULT != "WEBGPU": + helper_linearizer_opt(b.sum(0, dtype=dtypes.bool), [[Opt(OptOps.PADTO, axis, 32)],]) + helper_linearizer_opt(b.sum(1, dtype=dtypes.bool), [[Opt(OptOps.PADTO, axis, 32)],]) + + # having unsafe ops after sum is fine + helper_linearizer_opt(a.sum().exp(), [[Opt(OptOps.PADTO, 0, 32)],]) + helper_linearizer_opt(a.sum(0).exp(), [[Opt(OptOps.PADTO, 1, 32)],]) + + def test_padto_sum_not_ok(self): + N = 18 * 18 + # NOTE: this setup prevents 17 * 17 contiguous merged into one dimension + a = Tensor.rand(N, N).shrink(((0, 17), (0, 17))).exp() + # exp is not safe to pad + with self.assertRaises(KernelOptError): + helper_linearizer_opt(a.exp().sum(), [[Opt(OptOps.PADTO, 0, 32)],]) + with self.assertRaises(KernelOptError): + helper_linearizer_opt(a.exp().sum(0), [[Opt(OptOps.PADTO, 1, 32)],]) + + b = a < 1 + # lt is not safe to pad + with self.assertRaises(KernelOptError): + helper_linearizer_opt(b.sum(), [[Opt(OptOps.PADTO, 0, 32)],]) + with self.assertRaises(KernelOptError): + helper_linearizer_opt(b.sum(0), [[Opt(OptOps.PADTO, 1, 32)],]) + + def test_padto_max(self): + N = 18 * 18 + # NOTE: this setup prevents 17 * 17 contiguous merged into one axis + a = -Tensor.rand(N, N).shrink(((0, 17), (0, 17))) * 100 + + helper_linearizer_opt(a.max(0), [ + [Opt(OptOps.PADTO, 0, 32)], + [Opt(OptOps.PADTO, 0, 32), Opt(OptOps.UPCAST, 0, 8),], + ]) + helper_linearizer_opt(a.max(1), [ + [Opt(OptOps.PADTO, 0, 32)], + [Opt(OptOps.PADTO, 0, 32), Opt(OptOps.UPCAST, 0, 8),], + ]) + + # cannot pad max kernel on reduce + with self.assertRaises(KernelOptError): + helper_linearizer_opt(a.max(), [[Opt(OptOps.PADTO, 0, 32)],]) + with self.assertRaises(KernelOptError): + helper_linearizer_opt(a.max(0), [[Opt(OptOps.PADTO, 1, 32)],]) + + def test_padto_where(self): + Tensor.manual_seed(0) + N = 17 * 17 + a = (Tensor.randn(N, N).realize().max(axis=0, keepdim=True) > 1).where(1, 0) + helper_linearizer_opt(a.max(0), [ + [Opt(OptOps.PADTO, 0, 32)], + [Opt(OptOps.PADTO, 0, 32), Opt(OptOps.UPCAST, 0, 8),], + ]) + + def test_padto_where_multioutput(self): + Tensor.manual_seed(0) + N = 17 * 17 + r = Tensor.randn(N, N).realize().max(axis=0, keepdim=True) > 1 + a0 = r.where(1, 0) + a1 = r.where(2, 0) + helper_linearizer_opt([a0.max(0), a1.max(0)], [ + [Opt(OptOps.PADTO, 0, 32)], + [Opt(OptOps.PADTO, 0, 32), Opt(OptOps.UPCAST, 0, 8),], + ]) + + @unittest.skipUnless(Device[Device.DEFAULT].renderer.has_local, "test requires locals") + @unittest.skipUnless(Device[Device.DEFAULT].renderer.has_shared, "test requires shared") + def test_color_shapes_with_local(self): + N = 32 + Tensor.manual_seed(1552) + a = Tensor.rand(N, N) + b = Tensor.rand(N, N) + r = a@b + opts_shapes = [ + ([Opt(OptOps.LOCAL, 0, 2)], [("blue",16),("blue",32),("cyan",2),("red",32)]), + ([Opt(OptOps.LOCAL, 0, 2),Opt(OptOps.GROUP, 0, 2)], [("blue",16),("blue",32),("cyan",2),("green",2),("red",16)]), + # check to ensure local_dims are stable for full UNROLL of the first reduce + ([Opt(OptOps.LOCAL, 0, 2),Opt(OptOps.UNROLL, 0, 0)], [("blue",16),("blue",32),("cyan",2),("magenta",32)]), + ([Opt(OptOps.UNROLL, 0, 0),Opt(OptOps.LOCAL, 0, 2)], [("blue",16),("blue",32),("cyan",2),("magenta",32)]), + # check behavior for full UNROLL on an existing GROUP + ([Opt(OptOps.LOCAL, 0, 2),Opt(OptOps.GROUP, 0, 0),Opt(OptOps.UNROLL, 0, 2)], [("blue",16),("blue",32),("cyan",2),("green",16),("magenta",2)]), + ([Opt(OptOps.LOCAL, 0, 2),Opt(OptOps.GROUP, 0, 0),Opt(OptOps.UNROLL, 0, 0)], [("blue",16),("blue",32),("cyan",2),("magenta",32)]), + ([Opt(OptOps.GROUP, 0, 0),Opt(OptOps.LOCAL, 0, 2),Opt(OptOps.UNROLL, 0, 0)], [("blue",16),("blue",32),("cyan",2),("magenta",32)]), + ([Opt(OptOps.GROUP, 0, 2),Opt(OptOps.UNROLL, 0, 0)], [("blue",32),("blue",32),("red",16),("magenta",2)]), + ] + helper_linearizer_opt(r, [x[0] for x in opts_shapes], color_sizes=[x[1] for x in opts_shapes]) + +if __name__ == '__main__': + unittest.main() diff --git a/test/test_linearizer.py b/test/test_linearizer.py index c2f8b67786..65cfeabeb6 100644 --- a/test/test_linearizer.py +++ b/test/test_linearizer.py @@ -723,230 +723,6 @@ class TestLinearizer(unittest.TestCase): out = [u for u in get_program(k.ast, k.opts, k.applied_opts).uops if u.op is Ops.STORE][0] assert out.src[1].op is Ops.VECTORIZE and out.src[1].dtype.count != 1 -@unittest.skipUnless(Device[Device.DEFAULT].renderer.supports_float4, "need backends that support float4") -class TestFloat4(unittest.TestCase): - @staticmethod - def count_float4(uops: list[UOp], n=4): - return (len([uop for uop in uops if uop.op is Ops.LOAD and uop.dtype == dtypes.float.vec(n)]), - len([uop for uop in uops if uop.op is Ops.STORE and uop.src[1].dtype == dtypes.float.vec(n)])) - @staticmethod - def count_half4(uops: list[UOp]): - return (len([uop for uop in uops if uop.op is Ops.LOAD and uop.dtype == dtypes.half.vec(4)]), - len([uop for uop in uops if uop.op is Ops.STORE and uop.src[1].dtype == dtypes.half.vec(4)])) - - def test_float4_basic(self): - a = Tensor.empty(2, 8).realize() - b = Tensor.empty(2, 8).realize() - c = a + b - - s = c.schedule()[0] - realized_ast = s.ast - opts_to_apply = [Opt(op=OptOps.UPCAST, axis=0, arg=4)] - realized_ast = realized_ast.replace(arg=KernelInfo(opts_to_apply=tuple(opts_to_apply))) - program = get_program(realized_ast, Device[Device.DEFAULT].renderer) - - assert TestFloat4.count_float4(program.uops) == (2, 1) - - @unittest.skipIf(Device.DEFAULT in {"CPU", "LLVM"} and AMX, "CPU with AMX upcasts float up to size 16") - def test_float4_multidim(self): - a = Tensor.empty(2, 8).realize() - b = Tensor.empty(2, 8).realize() - c = a + b - - s = c.schedule()[0] - uops = get_program(s.ast, opts=[Opt(op=OptOps.UPCAST, axis=0, arg=4), Opt(op=OptOps.UPCAST, axis=0, arg=2)]).uops - assert TestFloat4.count_float4(uops) == (4, 2) - - @unittest.skipUnless(Device.DEFAULT in {"CPU", "LLVM"} and AMX, "Only CPU with AMX upcasts float up to size 16") - def test_float4_multidim_amx(self): - def kernel_for_shape(size, shift): - a = Tensor.empty(2, size).realize() - b = Tensor.empty(2, size).realize() - c = a + b - - s = c.schedule()[0] - return get_program(s.ast, opts=[Opt(op=OptOps.UPCAST, axis=0, arg=4), Opt(op=OptOps.UPCAST, axis=0, arg=shift)]).uops - - sizes = [12, 8, 16] - shifts = [3, 2, 4] - expected_upcast_size = [4, 8, 16] - expected_output = [(6,3), (2,1), (2,1)] - - for i in range(len(sizes)): - assert TestFloat4.count_float4(kernel_for_shape(sizes[i], shifts[i]), expected_upcast_size[i]) == expected_output[i] - - def test_float4_unaligned_load(self): - a = Tensor.empty(9).realize().shrink(((1, 9),)) - b = Tensor.empty(9).realize().shrink(((1, 9),)) - c = a + b - - s = c.schedule()[0] - realized_ast = s.ast - opts_to_apply = [Opt(op=OptOps.UPCAST, axis=0, arg=4)] - realized_ast = realized_ast.replace(arg=KernelInfo(opts_to_apply=tuple(opts_to_apply))) - program = get_program(realized_ast, Device[Device.DEFAULT].renderer) - - assert TestFloat4.count_float4(program.uops) == (0, 1) - - @unittest.skipIf(Device.DEFAULT in {"CPU", "LLVM"} and AMX, "CPU with AMX upcasts float up to size 16") - def test_float4_multidim_unaligned_load(self): - a = Tensor.empty(2, 9).realize().shrink(((0, 2), (1, 9),)) - b = Tensor.empty(2, 9).realize().shrink(((0, 2), (1, 9),)) - c = a + b - - s = c.schedule()[0] - uops = get_program(s.ast, opts=[Opt(op=OptOps.UPCAST, axis=1, arg=4), Opt(op=OptOps.UPCAST, axis=1, arg=2)]).uops - - assert TestFloat4.count_float4(uops) == (0, 2) - - @unittest.skipUnless(Device.DEFAULT in {"CPU", "LLVM"} and AMX, "Only CPU with AMX upcasts float up to size 16") - def test_float4_multidim_unaligned_load_amx(self): - def kernel_for_shape(size, shift): - a = Tensor.empty(2, size).realize().shrink(((0, 2), (1, size),)) - b = Tensor.empty(2, size).realize().shrink(((0, 2), (1, size),)) - c = a + b - - s = c.schedule()[0] - return get_program(s.ast, opts=[Opt(op=OptOps.UPCAST, axis=1, arg=4), Opt(op=OptOps.UPCAST, axis=1, arg=shift)]).uops - - sizes = [13, 9, 17] - shifts = [3, 2, 4] - expected_upcast_size = [4, 8, 16] - expected_output = [(0,3), (0,1), (0,1)] - - for i in range(len(sizes)): - assert TestFloat4.count_float4(kernel_for_shape(sizes[i], shifts[i]), expected_upcast_size[i]) == expected_output[i] - - def test_float4_sometimes_unaligned(self): - a = Tensor.empty(1, 1, 8).realize() - b = Tensor.empty(1, 1, 5).realize().shrink(((0, 1), (0, 1), (1, 5))) - c = a.conv2d(b) - # only the first and last conv dot products are aligned in a, and b is never aligned, so no - # float4 should be emitted (the reduce axis of size 4 is the float4 axis here) - - s = c.schedule()[0] - uops = get_program(s.ast, opts=[Opt(op=OptOps.UNROLL, axis=0, arg=4)]).uops - - assert TestFloat4.count_float4(uops) == (0, 0) - - def test_float4_multidim_sometimes_unaligned(self): - a = Tensor.empty(1, 1, 7).realize() - b = Tensor.empty(1, 1, 5).realize().shrink(((0, 1), (0, 1), (1, 5))) - c = a.conv2d(b) - # the first conv dot product is aligned in a. If we upcast the output and reduce - # dimension, then we could do float4 for only that one set of loads, but we currently - # don't. - # UPDATE: now we do this fusion - - s = c.schedule()[0] - uops = get_program(s.ast, opts=[Opt(op=OptOps.UPCAST, axis=0, arg=0), Opt(op=OptOps.UNROLL, axis=0, arg=0)]).uops - - assert TestFloat4.count_float4(uops) in {(0,1), (1,1)} - - def test_float4_expand(self): - a = Tensor.empty(9).realize().shrink(((1, 9),)) - b = Tensor.empty(2).realize().reshape((2, 1)).expand((2,4)).reshape((8,)) - c = a + b - - # we will upcast the top axis of sz 4. they should not be coalesced into float4, - # since the top axis is not contiguous. - - s = c.schedule()[0] - uops = get_program(s.ast, opts=[Opt(op=OptOps.UPCAST, axis=0, arg=4)]).uops - - assert TestFloat4.count_float4(uops) == (0, 1) - - def test_float4_heterogeneous(self): - a = Tensor.empty(8).realize() - b = Tensor.empty(9).realize().shrink(((1, 9),)) - c = a + b - - # should float4 b but not a - - s = c.schedule()[0] - uops = get_program(s.ast, opts=[Opt(op=OptOps.UPCAST, axis=0, arg=4)]).uops - - assert TestFloat4.count_float4(uops) == (1, 1) - - def test_half4_load_unrolled(self): - # from llama 7B shard 4 gpus - ast = UOp(Ops.SINK, dtypes.void, arg=None, src=( - UOp(Ops.STORE, dtypes.void, arg=None, src=( - UOp(Ops.VIEW, dtypes.float.ptr(96000), arg=ShapeTracker(views=(View(shape=(1, 3, 32000, 1), strides=(0, 32000, 1, 0), offset=0, mask=None, contiguous=True),)), src=( # noqa: E501 - UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(96000), arg=0, src=()),)), - UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(Ops.ADD, (3,)), src=( - UOp(Ops.CAST, dtypes.float, arg=None, src=( - UOp(Ops.MUL, dtypes.half, arg=None, src=( - UOp(Ops.LOAD, dtypes.half, arg=None, src=( - UOp(Ops.VIEW, dtypes.half.ptr(9216), arg=ShapeTracker(views=(View(shape=(1, 3, 32000, 1024), strides=(0, 4096, 0, 1), offset=0, mask=None, contiguous=False),)), src=( # noqa: E501 - UOp(Ops.DEFINE_GLOBAL, dtypes.half.ptr(9216), arg=1, src=()),)),)), - UOp(Ops.LOAD, dtypes.half, arg=None, src=( - UOp(Ops.VIEW, dtypes.half.ptr(32768000), arg=ShapeTracker(views=(View(shape=(1, 3, 32000, 1024), strides=(0, 0, 1024, 1), offset=0, mask=None, contiguous=False),)), src=( # noqa: E501 - UOp(Ops.DEFINE_GLOBAL, dtypes.half.ptr(32768000), arg=2, src=()),)),)),)),)),)),)),)) - - # TODO: fix this, expected might change but should be positive - for expected, opts in [ - ((7, 0), [Opt(op=OptOps.UPCAST, axis=1, arg=4), Opt(op=OptOps.UPCAST, axis=0, arg=3), Opt(op=OptOps.UNROLL, axis=0, arg=4)]), - ((5, 0), [Opt(op=OptOps.UPCAST, axis=1, arg=4), Opt(op=OptOps.UNROLL, axis=0, arg=4)]), - ((2, 0), [Opt(op=OptOps.UNROLL, axis=0, arg=4)]), - ]: - ast = ast.replace(arg=KernelInfo(opts_to_apply=tuple(opts))) - program = get_program(ast, Device[Device.DEFAULT].renderer) - - count = TestFloat4.count_half4(program.uops) - assert count == expected, f"{count=}, {expected=}" - - @unittest.skip("this doesn't happen anymore") - def test_float4_acc(self): - # from float32 stable diffusion red tinybox - ast = UOp(Ops.SINK, dtypes.void, arg=None, src=( - UOp(Ops.STORE, dtypes.void, arg=None, src=( - UOp(Ops.VIEW, dtypes.float.ptr(33554432), arg=ShapeTracker(views=(View(shape=(1, 1, 128, 512, 512, 1, 1, 1), strides=(0, 0, 262144, 512, 1, 0, 0, 0), offset=0, mask=None, contiguous=True),)), src=( # noqa: E501 - UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(33554432), arg=0, src=()),)), - UOp(Ops.ADD, dtypes.float, arg=None, src=( - UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(Ops.ADD, (5, 6, 7)), src=( - UOp(Ops.MUL, dtypes.float, arg=None, src=( - UOp(Ops.LOAD, dtypes.float, arg=None, src=( - UOp(Ops.VIEW, dtypes.float.ptr(67108864), arg=ShapeTracker(views=(View(shape=(1, 1, 1, 256, 4, 514, 4, 514), strides=(0, 0, 0, 262144, 0, 512, 0, 1), offset=-513, mask=((0, 1), (0, 1), (0, 1), (0, 256), (0, 4), (1, 513), (0, 4), (1, 513)), contiguous=False), View(shape=(1, 1, 128, 512, 512, 256, 3, 3), strides=(0, 0, 0, 2056, 1, 4227136, 1058840, 515), offset=0, mask=None, contiguous=False))), src=( # noqa: E501 - UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(67108864), arg=1, src=()),)),)), - UOp(Ops.LOAD, dtypes.float, arg=None, src=( - UOp(Ops.VIEW, dtypes.float.ptr(294912), arg=ShapeTracker(views=(View(shape=(1, 1, 128, 512, 512, 256, 3, 3), strides=(0, 0, 2304, 0, 0, 9, 3, 1), offset=0, mask=None, contiguous=False),)), src=( # noqa: E501 - UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(294912), arg=2, src=()),)),)),)),)), - UOp(Ops.LOAD, dtypes.float, arg=None, src=( - UOp(Ops.VIEW, dtypes.float.ptr(128), arg=ShapeTracker(views=(View(shape=(1, 1, 128, 512, 512, 1, 1, 1), strides=(0, 0, 1, 0, 0, 0, 0, 0), offset=0, mask=None, contiguous=False),)), src=( # noqa: E501 - UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(128), arg=3, src=()),)),)),)),)),)) - - for expected, opts in [ - (1, [Opt(op=OptOps.UPCAST, axis=2, arg=4)]), - (4, [Opt(op=OptOps.UPCAST, axis=2, arg=4), Opt(op=OptOps.UPCAST, axis=0, arg=4)]), - ]: - ast = ast.replace(arg=KernelInfo(opts_to_apply=tuple(opts))) - program = get_program(ast, Device[Device.DEFAULT].renderer) - count = len([uop for uop in program.uops if uop.op is Ops.DEFINE_REG and uop.dtype == dtypes.float.vec(4)]) - assert count == expected, f"{count=}, {expected=}" - - @unittest.skip("this doesn't happen anymore") - def test_float2_acc(self): - # from resnet - ast = UOp(Ops.SINK, dtypes.void, arg=None, src=( - UOp(Ops.STORE, dtypes.void, arg=None, src=( - UOp(Ops.VIEW, dtypes.half.ptr(212926464), arg=ShapeTracker(views=(View(shape=(1, 256, 1, 64, 1, 114, 1, 114), strides=(0, 831744, 0, 12996, 0, 114, 0, 1), offset=0, mask=None, contiguous=True),)), src=( # noqa: E501 - UOp(Ops.DEFINE_GLOBAL, dtypes.half.ptr(212926464), arg=0, src=()),)), - UOp(Ops.CAST, dtypes.half, arg=None, src=( - UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(Ops.ADD, (4, 6)), src=( - UOp(Ops.CAST, dtypes.float, arg=None, src=( - UOp(Ops.LOAD, dtypes.half, arg=None, src=( - UOp(Ops.VIEW, dtypes.half.ptr(462422016), arg=ShapeTracker(views=(View(shape=(256, 64, 3, 56, 2, 3, 56, 2), strides=(1806336, 28224, 3, 504, 0, 1, 9, 0), offset=0, mask=((0, 256), (0, 64), (0, 3), (0, 56), (0, 1), (0, 3), (0, 56), (0, 1)), contiguous=False), View(shape=(256, 64, 3, 115, 3, 115), strides=(7225344, 112896, 37632, 336, 112, 1), offset=0, mask=((0, 256), (0, 64), (0, 3), (0, 112), (0, 3), (0, 112)), contiguous=False), View(shape=(256, 64, 456, 456), strides=(7617600, 119025, 345, 1), offset=0, mask=((0, 256), (0, 64), (0, 345), (0, 345)), contiguous=False), View(shape=(1, 256, 1, 64, 4, 114, 4, 114), strides=(0, 13307904, 0, 207936, 51984, 456, 114, 1), offset=0, mask=None, contiguous=True))), src=( # noqa: E501 - UOp(Ops.DEFINE_GLOBAL, dtypes.half.ptr(462422016), arg=1, src=()),)),)),)),)),)),)),)) - for expected, opts in [ - (16, [Opt(op=OptOps.LOCAL, axis=1, arg=16), Opt(op=OptOps.UPCAST, axis=1, arg=0), Opt(op=OptOps.UPCAST, axis=2, arg=2), Opt(op=OptOps.LOCAL, axis=2, arg=3), Opt(op=OptOps.UPCAST, axis=3, arg=4)]), # noqa: E501 - (4, [Opt(op=OptOps.LOCAL, axis=1, arg=16), Opt(op=OptOps.UPCAST, axis=1, arg=0), Opt(op=OptOps.UPCAST, axis=2, arg=2)]), - ]: - ast = ast.replace(arg=KernelInfo(opts_to_apply=tuple(opts))) - program = get_program(ast, Device[Device.DEFAULT].renderer) - count = len([uop for uop in program.uops if uop.op is Ops.DEFINE_REG and uop.dtype == dtypes.float.vec(2)]) - assert count == expected, f"{count=}, {expected=}" - class TestHandCodedOpts(unittest.TestCase): def test_masked_upcast(self): layer_1 = Tensor.cat(*[Tensor.empty(5) for _ in range(4)]) @@ -1088,321 +864,5 @@ def _helper_linearizer_opt_ast(realized_ast:UOp, real_bufs:list[Buffer], opts=[] check_opt(x, lambda: Kernel(realized_ast), color_sizes[i] if i < len(color_sizes) else None) return lins -class TestKernelOpts(unittest.TestCase): - @unittest.skipUnless(Device[Device.DEFAULT].renderer.has_local, "test requires locals") - @unittest.skipUnless(Device[Device.DEFAULT].renderer.has_shared, "test requires shared") - def test_local_and_grouped_reduce(self): - N = 128 - Tensor.manual_seed(1882) - a = Tensor.rand(4, 4, N, N) - b = Tensor.rand(4, 4, N) - r = (b.sqrt() + ((a+1).sum(axis=3).exp())) - helper_linearizer_opt(r, [ - [Opt(OptOps.LOCAL, 0, 2)], - [Opt(OptOps.LOCAL, 0, 8)], - [Opt(OptOps.LOCAL, 0, 16)], # Checking how it works with locals - [Opt(OptOps.GROUPTOP, 0, 2)], - [Opt(OptOps.GROUPTOP, 0, 32)], - [Opt(OptOps.GROUPTOP, 0, 64)], # Checking how it works with grouped reduce - [Opt(OptOps.LOCAL, 0, 2), Opt(OptOps.GROUPTOP, 0, 2)], - [Opt(OptOps.LOCAL, 0, 16), Opt(OptOps.GROUPTOP, 0, 16)], - [Opt(OptOps.LOCAL, 0, 32), Opt(OptOps.GROUPTOP, 0, 2)], - # Checking how it works with locals + grouped reduce - [Opt(OptOps.LOCAL, 0, 2), Opt(OptOps.GROUPTOP, 0, 64)], - # Checking how it works with locals + grouped reduce + upcasts - [Opt(OptOps.LOCAL, 0, 2), Opt(OptOps.GROUPTOP, 0, 2), Opt(OptOps.UPCAST, 0, 8), Opt(OptOps.UNROLL, 1, 4)], - # many local + many group - [Opt(OptOps.GROUP, 0, 2)] * 4, - [Opt(OptOps.LOCAL, 0, 2)] * 4, - [Opt(OptOps.LOCAL, 0, 2), Opt(OptOps.GROUP, 0, 2)] * 4, - ]) - - def test_upcasts(self): - N = 16 - Tensor.manual_seed(1772) - a = Tensor.rand(N, N) - b = Tensor.rand(N, N) - r = (a+b).sqrt() * ((a+1).exp()) - helper_linearizer_opt(r, [ - [Opt(OptOps.UPCAST, 0, 2)], - [Opt(OptOps.UPCAST, 0, 4)], - [Opt(OptOps.UPCAST, 0, 8)], # Checking how it works with upcasts - ]) - - def test_full_upcast(self): - Tensor.manual_seed(1772) - a = Tensor.rand(4) - b = Tensor.rand(4) - r = (a+b).sqrt() * ((a+1).exp()) - helper_linearizer_opt(r, [ - [Opt(OptOps.UPCAST, 0, 4)], # Checking how it works with upcasts - ]) - - @unittest.skipUnless(Device[Device.DEFAULT].renderer.has_local, "test requires locals") - @unittest.skipUnless(Device[Device.DEFAULT].renderer.has_shared, "test requires shared") - def test_matmul(self): - N = 128 - Tensor.manual_seed(1552) - a = Tensor.rand(N, N) - b = Tensor.rand(N, N) - r = a@b - helper_linearizer_opt(r, [ - [Opt(OptOps.UPCAST, 0, 2)], - [Opt(OptOps.UPCAST, 0, 4), Opt(OptOps.UPCAST, 1, 4)], # Checking how it works with upcasts - [Opt(OptOps.LOCAL, 0, 2)], - [Opt(OptOps.LOCAL, 1, 32)], - [Opt(OptOps.LOCAL, 0, 4), Opt(OptOps.LOCAL, 1, 4)], - [Opt(OptOps.LOCAL, 0, 4), Opt(OptOps.LOCAL, 1, 32)], - [Opt(OptOps.LOCAL, 0, 16), Opt(OptOps.LOCAL, 1, 8)], # Checking how it works with locals - [Opt(OptOps.GROUPTOP, 0, 2)], - [Opt(OptOps.GROUPTOP, 0, 32)], - [Opt(OptOps.GROUPTOP, 0, 32), Opt(OptOps.UNROLL, 0, 4)], # Checking how it works with grouped_reduce - [Opt(OptOps.LOCAL, 0, 2), Opt(OptOps.LOCAL, 1, 2), Opt(OptOps.GROUPTOP, 0, 32)], - [Opt(OptOps.LOCAL, 0, 8), Opt(OptOps.GROUPTOP, 0, 32)], - [Opt(OptOps.LOCAL, 0, 4), Opt(OptOps.LOCAL, 0, 8), Opt(OptOps.GROUPTOP, 0, 4)], # Checking how it works with local+grouped_reduce - # Checking all together - [Opt(OptOps.LOCAL, 0, 4), Opt(OptOps.LOCAL, 0, 4), Opt(OptOps.GROUPTOP, 0, 8), Opt(OptOps.UNROLL, 0, 4), Opt(OptOps.UPCAST, 0, 4), - Opt(OptOps.UPCAST, 1, 2)], - # Full global upcast + local - [Opt(OptOps.LOCAL, 0, 4), Opt(OptOps.LOCAL, 0, 4), Opt(OptOps.GROUPTOP, 0, 8), Opt(OptOps.UNROLL, 0, 4), Opt(OptOps.UPCAST, 0, 8)], - ]) - - @unittest.skipUnless(Device[Device.DEFAULT].renderer.has_local, "test requires locals") - @unittest.skipUnless(Device[Device.DEFAULT].renderer.has_shared, "test requires shared") - def test_double_reduce(self): - N = 128 - Tensor.manual_seed(1552) - a = Tensor.rand(8, N, 8, N) - r = a.sum(axis=(1,3)) - helper_linearizer_opt(r, [ - # openCL / GPU=1 is 256 max threads - [Opt(OptOps.GROUPTOP, 0, 2)], [Opt(OptOps.GROUPTOP, 0, 32)], - [Opt(OptOps.GROUPTOP, 1, 2)], [Opt(OptOps.GROUPTOP, 1, 32)], # Checking how it works with 1 grouped_reduce. - [Opt(OptOps.GROUPTOP, 0, 2), Opt(OptOps.GROUPTOP, 1, 2)], - [Opt(OptOps.GROUPTOP, 0, 16), Opt(OptOps.GROUPTOP, 1, 2)], - [Opt(OptOps.GROUPTOP, 0, 4), Opt(OptOps.GROUPTOP, 1, 64)], # Checking how it works with 2 grouped_reduces. - [Opt(OptOps.GROUPTOP, 0, 16), Opt(OptOps.GROUPTOP, 1, 2), Opt(OptOps.UNROLL, 0, 4)], - [Opt(OptOps.GROUPTOP, 0, 2), Opt(OptOps.GROUPTOP, 1, 32), Opt(OptOps.UNROLL, 2, 4)], # Checking how it works with 2 grouped_reduces + upcasts. - [Opt(OptOps.LOCAL, 0, 4), Opt(OptOps.LOCAL, 1, 4), Opt(OptOps.GROUPTOP, 0, 4), Opt(OptOps.GROUPTOP, 1, 4)], - # Checking how it works with 2 grouped_reduces + upcasts + locals. - [Opt(OptOps.LOCAL, 0, 4), Opt(OptOps.LOCAL, 1, 4), Opt(OptOps.GROUPTOP, 0, 2), Opt(OptOps.GROUPTOP, 1, 32), Opt(OptOps.UNROLL, 1, 4)], - [Opt(OptOps.LOCAL, 0, 2), Opt(OptOps.LOCAL, 1, 2), Opt(OptOps.GROUPTOP, 0, 8), Opt(OptOps.GROUPTOP, 1, 4), Opt(OptOps.UPCAST, 0, 2)], - [Opt(OptOps.LOCAL, 0, 2), Opt(OptOps.LOCAL, 1, 2), Opt(OptOps.GROUPTOP, 0, 8), Opt(OptOps.GROUPTOP, 1, 4), Opt(OptOps.UPCAST, 0, 2), - Opt(OptOps.UNROLL, 0, 4), Opt(OptOps.UNROLL, 1, 4)], # Checking how it works with 2 grouped_reduces + upcasts + locals. - [Opt(OptOps.LOCAL, 0, 4), Opt(OptOps.LOCAL, 1, 4), Opt(OptOps.GROUPTOP, 0, 4), Opt(OptOps.GROUPTOP, 1, 4), Opt(OptOps.UPCAST, 0, 2), - Opt(OptOps.UPCAST, 0, 2)], # No globals - ]) - - @unittest.skipUnless(Device[Device.DEFAULT].renderer.tensor_cores, "test requires tensor cores") - @unittest.skipUnless(any(tc.dtype_in == tc.dtype_out == dtypes.half for tc in Device[Device.DEFAULT].renderer.tensor_cores), - "test requires tensor cores with accumulation in half") # testing with half suffices. - def test_tensor_core_opts(self): - N = 128 - Tensor.manual_seed(1552) - a, b = Tensor.rand(N, N, dtype=dtypes.half), Tensor.rand(N, N, dtype=dtypes.half) - r = a.matmul(b, dtype=dtypes.half) - atol, rtol = 0.25, 0.01 - helper_linearizer_opt(r, [ - [], - [Opt(OptOps.UPCAST, 0, 4)], - [Opt(OptOps.UPCAST, 1, 4)], - [Opt(OptOps.UPCAST, 0, 4), Opt(OptOps.UPCAST, 1, 4)], # check upcasts - [Opt(OptOps.UNROLL, 0, 2)], # check unroll - [Opt(OptOps.UPCAST, 0, 4), Opt(OptOps.UNROLL, 0, 2)], # check combo of unroll and local - [Opt(OptOps.UPCAST, 0, 4), Opt(OptOps.UPCAST, 1, 4), Opt(OptOps.UNROLL, 0, 2)], - [Opt(OptOps.UPCAST, 0, 4), Opt(OptOps.UPCAST, 1, 4), Opt(OptOps.UNROLL, 0, 4)], - [Opt(OptOps.UPCAST, 1, 4), Opt(OptOps.UPCAST, 0, 4)], # check permutations - [Opt(OptOps.UNROLL, 0, 2), Opt(OptOps.UPCAST, 0, 4)], - [Opt(OptOps.UPCAST, 0, 4), Opt(OptOps.UNROLL, 0, 2), Opt(OptOps.UPCAST, 1, 4)], - [Opt(OptOps.UNROLL, 0, 2), Opt(OptOps.UPCAST, 1, 4), Opt(OptOps.UPCAST, 0, 4), Opt(OptOps.UNROLL, 0, 4)], - ], apply_tc=True, atol=atol, rtol=rtol) - - @unittest.skipUnless(Device[Device.DEFAULT].renderer.tensor_cores, "test requires tensor cores") - @unittest.skipUnless(any(tc.dtype_in == tc.dtype_out == dtypes.half for tc in Device[Device.DEFAULT].renderer.tensor_cores), - "test requires tensor cores with accumulation in half") # testing with half suffices. - @unittest.skipUnless(Device[Device.DEFAULT].renderer.has_local, "test requires locals") - def test_tensor_core_opts_locals(self): - N = 128 - Tensor.manual_seed(1552) - a, b = Tensor.rand(N, N, dtype=dtypes.half), Tensor.rand(N, N, dtype=dtypes.half) - r = a.matmul(b, dtype=dtypes.half) - atol, rtol = 0.25, 0.01 - helper_linearizer_opt(r, [ - [Opt(OptOps.UNROLL, 0, 0)], # check full unroll of reduce with locals - [Opt(OptOps.LOCAL, 0, 4)], # check local - [Opt(OptOps.UPCAST, 0, 4), Opt(OptOps.UPCAST, 1, 4), Opt(OptOps.UNROLL, 0, 4), Opt(OptOps.LOCAL, 0, 2)], - [Opt(OptOps.LOCAL, 0, 2), Opt(OptOps.UPCAST, 1, 4), Opt(OptOps.UNROLL, 0, 2), Opt(OptOps.UPCAST, 0, 4)], - ], apply_tc=True, atol=atol, rtol=rtol) - - @unittest.skipUnless(Device[Device.DEFAULT].renderer.tensor_cores, "test requires tensor cores") - @unittest.skipUnless(Device[Device.DEFAULT].renderer.has_local, "test requires locals") - @unittest.skipUnless(Device[Device.DEFAULT].renderer.has_shared, "test requires shared memory") - @unittest.skipUnless(any(tc.dtype_in == tc.dtype_out == dtypes.half for tc in Device[Device.DEFAULT].renderer.tensor_cores), - "test requires tensor cores with accumulation in half") # testing with half suffices. - # NOTE: the METAL test is broken, likely due to a compiler bug. passes on CI with -O0 and with default opt level locally on M3 - @unittest.skipIf(Device.DEFAULT == "METAL", "broken for METAL") - @unittest.skip("feature was removed") - def test_tensor_core_opts_group(self): - N = 128 - Tensor.manual_seed(1552) - a, b = Tensor.rand(N, N, dtype=dtypes.half), Tensor.rand(N, N, dtype=dtypes.half) - r = a.matmul(b, dtype=dtypes.half) - atol, rtol = 0.25, 0.01 - helper_linearizer_opt(r, [ - [Opt(OptOps.GROUP, 0, 2)], - [Opt(OptOps.GROUPTOP, 0, 4)], - [Opt(OptOps.UPCAST, 0, 4), Opt(OptOps.GROUP, 0, 2)], - [Opt(OptOps.LOCAL, 0, 4), Opt(OptOps.GROUP, 0, 2)], - [Opt(OptOps.UNROLL, 0, 4), Opt(OptOps.GROUP, 0, 2)], - [Opt(OptOps.UPCAST, 0, 2), Opt(OptOps.LOCAL, 0, 2), Opt(OptOps.GROUP, 0, 2)], - [Opt(OptOps.LOCAL, 0, 2), Opt(OptOps.GROUPTOP, 0, 8), Opt(OptOps.UNROLL, 0, 2), Opt(OptOps.UPCAST, 1, 2)], - ], apply_tc=True, atol=atol, rtol=rtol) - - def test_padto_matmul(self): - if (CI and Device.DEFAULT in ["AMD", "NV", "CUDA"]): - self.skipTest("super slow on CUDA and AMD because of the big grid dims") - N = 17 * 17 - Tensor.manual_seed(289) - a = Tensor.rand(N, N) - b = Tensor.rand(N, N) - helper_linearizer_opt(a@b, [ - [Opt(OptOps.PADTO, 0, 32)], - [Opt(OptOps.PADTO, 1, 32)], - [Opt(OptOps.PADTO, 2, 32)], - [Opt(OptOps.PADTO, 0, 32), Opt(OptOps.PADTO, 1, 32)], - [Opt(OptOps.PADTO, 0, 32), Opt(OptOps.PADTO, 1, 32), Opt(OptOps.PADTO, 2, 32)], - # can optimize further post PADTO - [Opt(OptOps.PADTO, 0, 32), Opt(OptOps.PADTO, 1, 32), Opt(OptOps.UPCAST, 0, 2), Opt(OptOps.UPCAST, 1, 2),], - ]) - - def test_padto_upcasted_not_ok(self): - N = 4 - a = Tensor.rand(N, N) - b = Tensor.rand(N, N) - helper_linearizer_opt(a@b, [ - [Opt(OptOps.UPCAST, 0, 0)], - [Opt(OptOps.UPCAST, 1, 0)], - [Opt(OptOps.UNROLL, 0, 0)], - [Opt(OptOps.PADTO, 0, 8)], - [Opt(OptOps.PADTO, 1, 8)], - [Opt(OptOps.PADTO, 2, 8)], - ]) - with self.assertRaises(KernelOptError): - helper_linearizer_opt(a@b, [[Opt(OptOps.UPCAST, 0, 0), Opt(OptOps.PADTO, 1, 8)]]) - with self.assertRaises(KernelOptError): - helper_linearizer_opt(a@b, [[Opt(OptOps.UPCAST, 1, 0), Opt(OptOps.PADTO, 1, 8)]]) - with self.assertRaises(KernelOptError): - helper_linearizer_opt(a@b, [[Opt(OptOps.UNROLL, 0, 0), Opt(OptOps.PADTO, 2, 8)]]) - - def test_padto_sum_ok(self): - N = 18 * 18 - # NOTE: this setup prevents 17 * 17 contiguous merged into one dimension - a = Tensor.rand(N, N).realize().shrink(((0, 17), (0, 17))) * 100 - b = (Tensor.rand(N, N) < 0.5).realize().shrink(((0, 17), (0, 17))) - - helper_linearizer_opt(a.sum(0), [ - [Opt(OptOps.PADTO, 0, 32)], - [Opt(OptOps.PADTO, 0, 32), Opt(OptOps.UPCAST, 0, 8),], - ]) - helper_linearizer_opt(a.sum(1), [ - [Opt(OptOps.PADTO, 0, 32)], - [Opt(OptOps.PADTO, 0, 32), Opt(OptOps.UPCAST, 0, 8),], - ]) - - # can pad sum reduce axis if there's no unsafe ops prior to sum - for axis in (0, 1): - helper_linearizer_opt(a.sum(), [[Opt(OptOps.PADTO, axis, 32)],]) - helper_linearizer_opt(a.sum(0), [[Opt(OptOps.PADTO, axis, 32)],]) - helper_linearizer_opt(b.sum(), [[Opt(OptOps.PADTO, axis, 32)],]) - helper_linearizer_opt(b.sum(0), [[Opt(OptOps.PADTO, axis, 32)],]) - helper_linearizer_opt(b.sum(dtype=dtypes.bool), [[Opt(OptOps.PADTO, axis, 32)],]) - # TODO: why? - if Device.DEFAULT != "WEBGPU": - helper_linearizer_opt(b.sum(0, dtype=dtypes.bool), [[Opt(OptOps.PADTO, axis, 32)],]) - helper_linearizer_opt(b.sum(1, dtype=dtypes.bool), [[Opt(OptOps.PADTO, axis, 32)],]) - - # having unsafe ops after sum is fine - helper_linearizer_opt(a.sum().exp(), [[Opt(OptOps.PADTO, 0, 32)],]) - helper_linearizer_opt(a.sum(0).exp(), [[Opt(OptOps.PADTO, 1, 32)],]) - - def test_padto_sum_not_ok(self): - N = 18 * 18 - # NOTE: this setup prevents 17 * 17 contiguous merged into one dimension - a = Tensor.rand(N, N).shrink(((0, 17), (0, 17))).exp() - # exp is not safe to pad - with self.assertRaises(KernelOptError): - helper_linearizer_opt(a.exp().sum(), [[Opt(OptOps.PADTO, 0, 32)],]) - with self.assertRaises(KernelOptError): - helper_linearizer_opt(a.exp().sum(0), [[Opt(OptOps.PADTO, 1, 32)],]) - - b = a < 1 - # lt is not safe to pad - with self.assertRaises(KernelOptError): - helper_linearizer_opt(b.sum(), [[Opt(OptOps.PADTO, 0, 32)],]) - with self.assertRaises(KernelOptError): - helper_linearizer_opt(b.sum(0), [[Opt(OptOps.PADTO, 1, 32)],]) - - def test_padto_max(self): - N = 18 * 18 - # NOTE: this setup prevents 17 * 17 contiguous merged into one axis - a = -Tensor.rand(N, N).shrink(((0, 17), (0, 17))) * 100 - - helper_linearizer_opt(a.max(0), [ - [Opt(OptOps.PADTO, 0, 32)], - [Opt(OptOps.PADTO, 0, 32), Opt(OptOps.UPCAST, 0, 8),], - ]) - helper_linearizer_opt(a.max(1), [ - [Opt(OptOps.PADTO, 0, 32)], - [Opt(OptOps.PADTO, 0, 32), Opt(OptOps.UPCAST, 0, 8),], - ]) - - # cannot pad max kernel on reduce - with self.assertRaises(KernelOptError): - helper_linearizer_opt(a.max(), [[Opt(OptOps.PADTO, 0, 32)],]) - with self.assertRaises(KernelOptError): - helper_linearizer_opt(a.max(0), [[Opt(OptOps.PADTO, 1, 32)],]) - - def test_padto_where(self): - Tensor.manual_seed(0) - N = 17 * 17 - a = (Tensor.randn(N, N).realize().max(axis=0, keepdim=True) > 1).where(1, 0) - helper_linearizer_opt(a.max(0), [ - [Opt(OptOps.PADTO, 0, 32)], - [Opt(OptOps.PADTO, 0, 32), Opt(OptOps.UPCAST, 0, 8),], - ]) - - def test_padto_where_multioutput(self): - Tensor.manual_seed(0) - N = 17 * 17 - r = Tensor.randn(N, N).realize().max(axis=0, keepdim=True) > 1 - a0 = r.where(1, 0) - a1 = r.where(2, 0) - helper_linearizer_opt([a0.max(0), a1.max(0)], [ - [Opt(OptOps.PADTO, 0, 32)], - [Opt(OptOps.PADTO, 0, 32), Opt(OptOps.UPCAST, 0, 8),], - ]) - - @unittest.skipUnless(Device[Device.DEFAULT].renderer.has_local, "test requires locals") - @unittest.skipUnless(Device[Device.DEFAULT].renderer.has_shared, "test requires shared") - def test_color_shapes_with_local(self): - N = 32 - Tensor.manual_seed(1552) - a = Tensor.rand(N, N) - b = Tensor.rand(N, N) - r = a@b - opts_shapes = [ - ([Opt(OptOps.LOCAL, 0, 2)], [("blue",16),("blue",32),("cyan",2),("red",32)]), - ([Opt(OptOps.LOCAL, 0, 2),Opt(OptOps.GROUP, 0, 2)], [("blue",16),("blue",32),("cyan",2),("green",2),("red",16)]), - # check to ensure local_dims are stable for full UNROLL of the first reduce - ([Opt(OptOps.LOCAL, 0, 2),Opt(OptOps.UNROLL, 0, 0)], [("blue",16),("blue",32),("cyan",2),("magenta",32)]), - ([Opt(OptOps.UNROLL, 0, 0),Opt(OptOps.LOCAL, 0, 2)], [("blue",16),("blue",32),("cyan",2),("magenta",32)]), - # check behavior for full UNROLL on an existing GROUP - ([Opt(OptOps.LOCAL, 0, 2),Opt(OptOps.GROUP, 0, 0),Opt(OptOps.UNROLL, 0, 2)], [("blue",16),("blue",32),("cyan",2),("green",16),("magenta",2)]), - ([Opt(OptOps.LOCAL, 0, 2),Opt(OptOps.GROUP, 0, 0),Opt(OptOps.UNROLL, 0, 0)], [("blue",16),("blue",32),("cyan",2),("magenta",32)]), - ([Opt(OptOps.GROUP, 0, 0),Opt(OptOps.LOCAL, 0, 2),Opt(OptOps.UNROLL, 0, 0)], [("blue",16),("blue",32),("cyan",2),("magenta",32)]), - ([Opt(OptOps.GROUP, 0, 2),Opt(OptOps.UNROLL, 0, 0)], [("blue",32),("blue",32),("red",16),("magenta",2)]), - ] - helper_linearizer_opt(r, [x[0] for x in opts_shapes], color_sizes=[x[1] for x in opts_shapes]) - if __name__ == '__main__': unittest.main() diff --git a/test/unit/test_linearizer_rewrite.py b/test/unit/test_linearizer_rewrite.py index 01a3884f8a..46b3b1aba1 100644 --- a/test/unit/test_linearizer_rewrite.py +++ b/test/unit/test_linearizer_rewrite.py @@ -1,7 +1,7 @@ import unittest from tinygrad import Tensor, Context, Device from tinygrad.engine.realize import get_program -from tinygrad.codegen.opt.kernel import Opt, OptOps +from tinygrad.codegen.opt import Opt, OptOps from tinygrad.uop.ops import KernelInfo class TestLinearizerRewrite(unittest.TestCase): diff --git a/tinygrad/codegen/__init__.py b/tinygrad/codegen/__init__.py index 950981f826..0783c7bc76 100644 --- a/tinygrad/codegen/__init__.py +++ b/tinygrad/codegen/__init__.py @@ -16,7 +16,7 @@ from tinygrad.codegen.late.expander import migrate_indexing, expander, pm_pre_ex from tinygrad.codegen.late.devectorizer import load_store_folding, load_store_indexing, devectorize, pm_reduce, \ ReduceContext, correct_load_store, pm_render from tinygrad.codegen.late.linearize import block_create, pm_blockend_merge, block_merge, pm_finalize, BlockContext -from tinygrad.codegen.opt import pm_get_optimization, pm_do_optimize +from tinygrad.codegen.opt.kernel import pm_get_optimization, pm_do_optimize from tinygrad.codegen.opt.swizzler import view_left, view_right, fix_kernel_ops from tinygrad.codegen.opt.postrange import pm_postrange_opt from tinygrad.schedule.rangeify import pm_add_buffers_local, rangeify_codegen diff --git a/tinygrad/codegen/opt/__init__.py b/tinygrad/codegen/opt/__init__.py index 2cfca3f640..c113287d6e 100644 --- a/tinygrad/codegen/opt/__init__.py +++ b/tinygrad/codegen/opt/__init__.py @@ -1,51 +1,26 @@ # opt opinionatedly transforms an ast into an optimized ast using either heuristics or beam search +from __future__ import annotations +from enum import Enum, auto +from dataclasses import dataclass +from tinygrad.uop.ops import AxisType -from tinygrad.codegen.opt.kernel import Kernel -from tinygrad.codegen.opt.heuristic import hand_coded_optimizations -from tinygrad.uop.ops import UOp, PatternMatcher, UPat, Ops, KernelInfo -from tinygrad.helpers import NOOPT, BEAM, getenv, POSTOPT -from tinygrad.renderer import Renderer -from tinygrad.uop.spec import type_verify +class OptOps(Enum): + TC = auto(); UPCAST = auto(); UNROLL = auto(); LOCAL = auto() # noqa: E702 + GROUP = auto(); GROUPTOP = auto(); NOLOCALS = auto(); PADTO = auto(); SWAP = auto() # noqa: E702 + def __lt__(self, x:OptOps): return self.value < x.value -def get_optimized_ast(ast:UOp, renderer:Renderer) -> UOp|None: - """ - Optimize an AST based on heuristics or BEAM search. +@dataclass(frozen=True, order=True) +class Opt: + op: OptOps + axis: int|None = None + arg: int|tuple|None = None + def __repr__(self): return f"Opt(op={self.op}, axis={self.axis}, arg={self.arg})" - Args: - ast: The Ops.SINK rooted AST - renderer: The renderer used to generate the code +axis_letters = {AxisType.GLOBAL: "g", AxisType.LOCAL: "l", AxisType.LOOP: "L", AxisType.UPCAST: "u", + AxisType.GROUP_REDUCE: "G", AxisType.REDUCE: "R", AxisType.UNROLL: "r"} +axis_colors = {AxisType.GLOBAL: "blue", AxisType.LOCAL: "cyan", AxisType.LOOP: "WHITE", AxisType.UPCAST: "yellow", + AxisType.GROUP_REDUCE: "green", AxisType.REDUCE: "red", AxisType.UNROLL: "magenta"} - Returns: - The Ops.SINK rooted AST transformed to apply the opts and with a KernelInfo in the arg. - """ - - # no shape, no opt - if ast.src[0].st is None: return None - new_arg = ast.arg - if new_arg is None: - k = Kernel(ast, opts=renderer) - if not NOOPT: - k.apply_opts(hand_coded_optimizations(k)) - if not POSTOPT and BEAM >= 1: - from tinygrad.codegen.opt.search import beam_search, bufs_from_lin - kb = Kernel(ast, opts=renderer) - rawbufs = bufs_from_lin(kb, allocate=False) - k = beam_search(kb, rawbufs, BEAM.value, bool(getenv("BEAM_ESTIMATE", 1))) - new_arg = KernelInfo(opts_to_apply=tuple(k.applied_opts)) - elif len(new_arg.applied_opts): return None - return Kernel(ast.replace(arg=None), opts=renderer).get_optimized_ast().replace(arg=new_arg) - -pm_get_optimization = PatternMatcher([ - (UPat(Ops.SINK, name="ast"), lambda ctx,ast: get_optimized_ast(ast, ctx)), -]) - -def apply_opt(ast:UOp, renderer:Renderer): - k = Kernel(ast, opts=renderer) - k.apply_opts(ast.arg.opts_to_apply) - ret = k.get_optimized_ast() - if __debug__: type_verify(list(ret.toposort())) - return ret - -pm_do_optimize = PatternMatcher([ - (UPat(Ops.SINK, name="ast"), lambda ctx,ast: apply_opt(ast, ctx) if ast.arg is not None and ast.arg.opts_to_apply is not None else None), -]) +class KernelOptError(Exception): pass +def check(cond:bool, msg:str=""): + if not cond: raise KernelOptError(msg) diff --git a/tinygrad/codegen/opt/heuristic.py b/tinygrad/codegen/opt/heuristic.py index 80f12d0ca2..aa862a0142 100644 --- a/tinygrad/codegen/opt/heuristic.py +++ b/tinygrad/codegen/opt/heuristic.py @@ -1,9 +1,12 @@ import itertools -from tinygrad.codegen.opt.kernel import Kernel, Opt, OptOps, KernelOptError, AxisType -from tinygrad.codegen.opt.postrange import Scheduler +from tinygrad.codegen.opt import Opt, OptOps, KernelOptError from tinygrad.helpers import getenv, DEBUG, prod, NOLOCALS, TC_OPT, TC_SELECT, USE_TC, AMX from tinygrad.dtype import ImageDType -from tinygrad.uop.ops import Ops, resolve +from tinygrad.uop.ops import Ops, resolve, AxisType + +# both versions +from tinygrad.codegen.opt.kernel import Kernel +from tinygrad.codegen.opt.postrange import Scheduler def hand_coded_optimizations(k:Kernel|Scheduler) -> list[Opt]: # first try the tensor cores diff --git a/tinygrad/codegen/opt/kernel.py b/tinygrad/codegen/opt/kernel.py index 43615d988d..02f522dce4 100644 --- a/tinygrad/codegen/opt/kernel.py +++ b/tinygrad/codegen/opt/kernel.py @@ -3,40 +3,18 @@ import itertools, functools, math from dataclasses import dataclass from collections import defaultdict from typing import cast, Final, Callable, Sequence -from enum import Enum, auto - -from tinygrad.uop.ops import GroupOp, KernelInfo, UOp, Ops, can_pad, resolve, Variable, sint, graph_rewrite, AxisType +from tinygrad.codegen.opt import OptOps, Opt, KernelOptError, check, axis_letters, axis_colors +from tinygrad.uop.ops import GroupOp, KernelInfo, UOp, Ops, can_pad, resolve, Variable, sint, graph_rewrite, AxisType, PatternMatcher, UPat from tinygrad.uop.spec import type_verify, ast_spec from tinygrad.device import Device from tinygrad.codegen.opt.tc import TensorCore from tinygrad.renderer import Renderer from tinygrad.dtype import ImageDType -from tinygrad.helpers import all_same, colored, ansilen, dedup, prod, round_up, to_function_name, unwrap, argfix, DEBUG +from tinygrad.helpers import all_same, colored, ansilen, dedup, prod, round_up, to_function_name, unwrap, argfix, DEBUG, NOOPT, BEAM, getenv, POSTOPT from tinygrad.shape.shapetracker import ShapeTracker from tinygrad.shape.view import strides_for_shape, get_contraction from tinygrad.codegen.opt.swizzler import view_left, view_left_through_load -class OptOps(Enum): - TC = auto(); UPCAST = auto(); UNROLL = auto(); LOCAL = auto() # noqa: E702 - GROUP = auto(); GROUPTOP = auto(); NOLOCALS = auto(); PADTO = auto(); SWAP = auto() # noqa: E702 - def __lt__(self, x:OptOps): return self.value < x.value - -@dataclass(frozen=True, order=True) -class Opt: - op: OptOps - axis: int|None = None - arg: int|tuple|None = None - def __repr__(self): return f"Opt(op={self.op}, axis={self.axis}, arg={self.arg})" - -axis_letters = {AxisType.GLOBAL: "g", AxisType.LOCAL: "l", AxisType.LOOP: "L", AxisType.UPCAST: "u", - AxisType.GROUP_REDUCE: "G", AxisType.REDUCE: "R", AxisType.UNROLL: "r"} -axis_colors = {AxisType.GLOBAL: "blue", AxisType.LOCAL: "cyan", AxisType.LOOP: "WHITE", AxisType.UPCAST: "yellow", - AxisType.GROUP_REDUCE: "green", AxisType.REDUCE: "red", AxisType.UNROLL: "magenta"} - -class KernelOptError(Exception): pass -def check(cond:bool, msg:str=""): - if not cond: raise KernelOptError(msg) - @dataclass class TensorCoreOptions: axes: tuple[int, ...] # the location of the original N and M axes if still in the shape @@ -455,3 +433,47 @@ class Kernel: fixed_ast = fixup_ast(self.ast) del fixup_ast return graph_rewrite(fixed_ast, view_left+view_left_through_load, name="fixup optimized AST") + +def get_optimized_ast(ast:UOp, renderer:Renderer) -> UOp|None: + """ + Optimize an AST based on heuristics or BEAM search. + + Args: + ast: The Ops.SINK rooted AST + renderer: The renderer used to generate the code + + Returns: + The Ops.SINK rooted AST transformed to apply the opts and with a KernelInfo in the arg. + """ + + # no shape, no opt + if ast.src[0].st is None: return None + new_arg = ast.arg + if new_arg is None: + k = Kernel(ast, opts=renderer) + if not NOOPT: + from tinygrad.codegen.opt.heuristic import hand_coded_optimizations + k.apply_opts(hand_coded_optimizations(k)) + if not POSTOPT and BEAM >= 1: + from tinygrad.codegen.opt.search import beam_search, bufs_from_lin + kb = Kernel(ast, opts=renderer) + rawbufs = bufs_from_lin(kb, allocate=False) + k = beam_search(kb, rawbufs, BEAM.value, bool(getenv("BEAM_ESTIMATE", 1))) + new_arg = KernelInfo(opts_to_apply=tuple(k.applied_opts)) + elif len(new_arg.applied_opts): return None + return Kernel(ast.replace(arg=None), opts=renderer).get_optimized_ast().replace(arg=new_arg) + +pm_get_optimization = PatternMatcher([ + (UPat(Ops.SINK, name="ast"), lambda ctx,ast: get_optimized_ast(ast, ctx)), +]) + +def apply_opt(ast:UOp, renderer:Renderer): + k = Kernel(ast, opts=renderer) + k.apply_opts(ast.arg.opts_to_apply) + ret = k.get_optimized_ast() + if __debug__: type_verify(list(ret.toposort())) + return ret + +pm_do_optimize = PatternMatcher([ + (UPat(Ops.SINK, name="ast"), lambda ctx,ast: apply_opt(ast, ctx) if ast.arg is not None and ast.arg.opts_to_apply is not None else None), +]) \ No newline at end of file diff --git a/tinygrad/codegen/opt/postrange.py b/tinygrad/codegen/opt/postrange.py index 554a8ea62e..d25a67cab9 100644 --- a/tinygrad/codegen/opt/postrange.py +++ b/tinygrad/codegen/opt/postrange.py @@ -6,7 +6,7 @@ from tinygrad.uop.symbolic import symbolic from tinygrad.device import Buffer from tinygrad.dtype import AddrSpace, dtypes from tinygrad.helpers import colored, BEAM, getenv, DEBUG, to_function_name, NOOPT, argsort, round_up -from tinygrad.codegen.opt.kernel import axis_colors, Opt, OptOps, KernelOptError, check, axis_letters +from tinygrad.codegen.opt import axis_colors, Opt, OptOps, KernelOptError, check, axis_letters from tinygrad.renderer import Renderer from tinygrad.schedule.rangeify import remove_tags diff --git a/tinygrad/codegen/opt/search.py b/tinygrad/codegen/opt/search.py index 7e603b454d..27efacc95b 100644 --- a/tinygrad/codegen/opt/search.py +++ b/tinygrad/codegen/opt/search.py @@ -7,12 +7,15 @@ from tinygrad.device import Device, Buffer, Compiler from tinygrad.helpers import prod, flatten, DEBUG, CACHELEVEL, diskcache_get, diskcache_put, getenv, Context, colored, time_to_str from tinygrad.helpers import IGNORE_BEAM_CACHE from tinygrad.dtype import ImageDType, PtrDType -from tinygrad.codegen.opt.kernel import Kernel, Opt, OptOps, KernelOptError -from tinygrad.codegen.opt.postrange import Scheduler +from tinygrad.codegen.opt import Opt, OptOps, KernelOptError from tinygrad.tensor import Tensor from tinygrad.engine.realize import CompiledRunner, get_program from tinygrad.renderer import ProgramSpec +# both versions +from tinygrad.codegen.opt.kernel import Kernel +from tinygrad.codegen.opt.postrange import Scheduler + actions = [Opt(op=OptOps.UPCAST, axis=axis, arg=amt) for amt in [0,2,3,4,5,7] for axis in range(8)] actions += [Opt(op=OptOps.UNROLL, axis=axis, arg=amt) for amt in [0,4,7] for axis in range(5)] actions += [Opt(op=OptOps.LOCAL, axis=axis, arg=amt) for amt in [2,3,4,8,13,16,29] for axis in range(6)] diff --git a/tinygrad/schedule/kernelize.py b/tinygrad/schedule/kernelize.py index 75ea81fe2f..cf4db55cb9 100644 --- a/tinygrad/schedule/kernelize.py +++ b/tinygrad/schedule/kernelize.py @@ -8,7 +8,7 @@ from tinygrad.dtype import ImageDType from tinygrad.schedule.multi import multi_pm from tinygrad.schedule.grouper import group_realizes, ALWAYS_CONTIGUOUS from tinygrad.codegen.opt.swizzler import merge_views, apply_swizzle, swizzle_reduceop -from tinygrad.codegen.opt.kernel import Opt +from tinygrad.codegen.opt import Opt # creation can recurse a lot import sys