mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 06:58:11 -05:00
refactor and split test_linearizer (#12001)
* refactor and split test_linearizer * forget that file * imports * remove from docs * test gen float4
This commit is contained in:
@@ -22,12 +22,6 @@ Group UOps into kernels.
|
||||
|
||||
Transforms the ast into an optimized ast. This is where BEAM search and heuristics live.
|
||||
|
||||
::: tinygrad.codegen.opt.get_optimized_ast
|
||||
options:
|
||||
members: false
|
||||
show_labels: false
|
||||
show_source: false
|
||||
|
||||
---
|
||||
|
||||
## tinygrad/codegen
|
||||
|
||||
229
test/opt/test_gen_float4.py
Normal file
229
test/opt/test_gen_float4.py
Normal file
@@ -0,0 +1,229 @@
|
||||
import unittest
|
||||
from tinygrad import Device, Tensor, dtypes
|
||||
from tinygrad.uop.ops import UOp, Ops
|
||||
from tinygrad.codegen.opt import Opt, OptOps
|
||||
from tinygrad.shape.shapetracker import ShapeTracker, View
|
||||
from tinygrad.engine.realize import get_program
|
||||
from tinygrad.helpers import AMX
|
||||
|
||||
@unittest.skipUnless(Device[Device.DEFAULT].renderer.supports_float4, "need backends that support float4")
|
||||
class TestFloat4(unittest.TestCase):
|
||||
@staticmethod
|
||||
def count_float4(uops: list[UOp], n=4):
|
||||
return (len([uop for uop in uops if uop.op is Ops.LOAD and uop.dtype == dtypes.float.vec(n)]),
|
||||
len([uop for uop in uops if uop.op is Ops.STORE and uop.src[1].dtype == dtypes.float.vec(n)]))
|
||||
@staticmethod
|
||||
def count_half4(uops: list[UOp]):
|
||||
return (len([uop for uop in uops if uop.op is Ops.LOAD and uop.dtype == dtypes.half.vec(4)]),
|
||||
len([uop for uop in uops if uop.op is Ops.STORE and uop.src[1].dtype == dtypes.half.vec(4)]))
|
||||
|
||||
def test_float4_basic(self):
|
||||
a = Tensor.empty(2, 8).realize()
|
||||
b = Tensor.empty(2, 8).realize()
|
||||
c = a + b
|
||||
|
||||
s = c.schedule()[0]
|
||||
realized_ast = s.ast
|
||||
opts_to_apply = [Opt(op=OptOps.UPCAST, axis=0, arg=4)]
|
||||
program = get_program(realized_ast, Device[Device.DEFAULT].renderer, opts=opts_to_apply)
|
||||
|
||||
assert TestFloat4.count_float4(program.uops) == (2, 1)
|
||||
|
||||
@unittest.skipIf(Device.DEFAULT in {"CPU", "LLVM"} and AMX, "CPU with AMX upcasts float up to size 16")
|
||||
def test_float4_multidim(self):
|
||||
a = Tensor.empty(2, 8).realize()
|
||||
b = Tensor.empty(2, 8).realize()
|
||||
c = a + b
|
||||
|
||||
s = c.schedule()[0]
|
||||
uops = get_program(s.ast, opts=[Opt(op=OptOps.UPCAST, axis=0, arg=4), Opt(op=OptOps.UPCAST, axis=0, arg=2)]).uops
|
||||
assert TestFloat4.count_float4(uops) == (4, 2)
|
||||
|
||||
@unittest.skipUnless(Device.DEFAULT in {"CPU", "LLVM"} and AMX, "Only CPU with AMX upcasts float up to size 16")
|
||||
def test_float4_multidim_amx(self):
|
||||
def kernel_for_shape(size, shift):
|
||||
a = Tensor.empty(2, size).realize()
|
||||
b = Tensor.empty(2, size).realize()
|
||||
c = a + b
|
||||
|
||||
s = c.schedule()[0]
|
||||
return get_program(s.ast, opts=[Opt(op=OptOps.UPCAST, axis=0, arg=4), Opt(op=OptOps.UPCAST, axis=0, arg=shift)]).uops
|
||||
|
||||
sizes = [12, 8, 16]
|
||||
shifts = [3, 2, 4]
|
||||
expected_upcast_size = [4, 8, 16]
|
||||
expected_output = [(6,3), (2,1), (2,1)]
|
||||
|
||||
for i in range(len(sizes)):
|
||||
assert TestFloat4.count_float4(kernel_for_shape(sizes[i], shifts[i]), expected_upcast_size[i]) == expected_output[i]
|
||||
|
||||
def test_float4_unaligned_load(self):
|
||||
a = Tensor.empty(9).realize().shrink(((1, 9),))
|
||||
b = Tensor.empty(9).realize().shrink(((1, 9),))
|
||||
c = a + b
|
||||
|
||||
s = c.schedule()[0]
|
||||
realized_ast = s.ast
|
||||
opts_to_apply = [Opt(op=OptOps.UPCAST, axis=0, arg=4)]
|
||||
program = get_program(realized_ast, Device[Device.DEFAULT].renderer, opts=opts_to_apply)
|
||||
|
||||
assert TestFloat4.count_float4(program.uops) == (0, 1)
|
||||
|
||||
@unittest.skipIf(Device.DEFAULT in {"CPU", "LLVM"} and AMX, "CPU with AMX upcasts float up to size 16")
|
||||
def test_float4_multidim_unaligned_load(self):
|
||||
a = Tensor.empty(2, 9).realize().shrink(((0, 2), (1, 9),))
|
||||
b = Tensor.empty(2, 9).realize().shrink(((0, 2), (1, 9),))
|
||||
c = a + b
|
||||
|
||||
s = c.schedule()[0]
|
||||
uops = get_program(s.ast, opts=[Opt(op=OptOps.UPCAST, axis=1, arg=4), Opt(op=OptOps.UPCAST, axis=1, arg=2)]).uops
|
||||
|
||||
assert TestFloat4.count_float4(uops) == (0, 2)
|
||||
|
||||
@unittest.skipUnless(Device.DEFAULT in {"CPU", "LLVM"} and AMX, "Only CPU with AMX upcasts float up to size 16")
|
||||
def test_float4_multidim_unaligned_load_amx(self):
|
||||
def kernel_for_shape(size, shift):
|
||||
a = Tensor.empty(2, size).realize().shrink(((0, 2), (1, size),))
|
||||
b = Tensor.empty(2, size).realize().shrink(((0, 2), (1, size),))
|
||||
c = a + b
|
||||
|
||||
s = c.schedule()[0]
|
||||
return get_program(s.ast, opts=[Opt(op=OptOps.UPCAST, axis=1, arg=4), Opt(op=OptOps.UPCAST, axis=1, arg=shift)]).uops
|
||||
|
||||
sizes = [13, 9, 17]
|
||||
shifts = [3, 2, 4]
|
||||
expected_upcast_size = [4, 8, 16]
|
||||
expected_output = [(0,3), (0,1), (0,1)]
|
||||
|
||||
for i in range(len(sizes)):
|
||||
assert TestFloat4.count_float4(kernel_for_shape(sizes[i], shifts[i]), expected_upcast_size[i]) == expected_output[i]
|
||||
|
||||
def test_float4_sometimes_unaligned(self):
|
||||
a = Tensor.empty(1, 1, 8).realize()
|
||||
b = Tensor.empty(1, 1, 5).realize().shrink(((0, 1), (0, 1), (1, 5)))
|
||||
c = a.conv2d(b)
|
||||
# only the first and last conv dot products are aligned in a, and b is never aligned, so no
|
||||
# float4 should be emitted (the reduce axis of size 4 is the float4 axis here)
|
||||
|
||||
s = c.schedule()[0]
|
||||
uops = get_program(s.ast, opts=[Opt(op=OptOps.UNROLL, axis=0, arg=4)]).uops
|
||||
|
||||
assert TestFloat4.count_float4(uops) == (0, 0)
|
||||
|
||||
def test_float4_multidim_sometimes_unaligned(self):
|
||||
a = Tensor.empty(1, 1, 7).realize()
|
||||
b = Tensor.empty(1, 1, 5).realize().shrink(((0, 1), (0, 1), (1, 5)))
|
||||
c = a.conv2d(b)
|
||||
# the first conv dot product is aligned in a. If we upcast the output and reduce
|
||||
# dimension, then we could do float4 for only that one set of loads, but we currently
|
||||
# don't.
|
||||
# UPDATE: now we do this fusion
|
||||
|
||||
s = c.schedule()[0]
|
||||
uops = get_program(s.ast, opts=[Opt(op=OptOps.UPCAST, axis=0, arg=0), Opt(op=OptOps.UNROLL, axis=0, arg=0)]).uops
|
||||
|
||||
assert TestFloat4.count_float4(uops) in {(0,1), (1,1)}
|
||||
|
||||
def test_float4_expand(self):
|
||||
a = Tensor.empty(9).realize().shrink(((1, 9),))
|
||||
b = Tensor.empty(2).realize().reshape((2, 1)).expand((2,4)).reshape((8,))
|
||||
c = a + b
|
||||
|
||||
# we will upcast the top axis of sz 4. they should not be coalesced into float4,
|
||||
# since the top axis is not contiguous.
|
||||
|
||||
s = c.schedule()[0]
|
||||
uops = get_program(s.ast, opts=[Opt(op=OptOps.UPCAST, axis=0, arg=4)]).uops
|
||||
|
||||
assert TestFloat4.count_float4(uops) == (0, 1)
|
||||
|
||||
def test_float4_heterogeneous(self):
|
||||
a = Tensor.empty(8).realize()
|
||||
b = Tensor.empty(9).realize().shrink(((1, 9),))
|
||||
c = a + b
|
||||
|
||||
# should float4 b but not a
|
||||
|
||||
s = c.schedule()[0]
|
||||
uops = get_program(s.ast, opts=[Opt(op=OptOps.UPCAST, axis=0, arg=4)]).uops
|
||||
|
||||
assert TestFloat4.count_float4(uops) == (1, 1)
|
||||
|
||||
def test_half4_load_unrolled(self):
|
||||
# from llama 7B shard 4 gpus
|
||||
ast = UOp(Ops.SINK, dtypes.void, arg=None, src=(
|
||||
UOp(Ops.STORE, dtypes.void, arg=None, src=(
|
||||
UOp(Ops.VIEW, dtypes.float.ptr(96000), arg=ShapeTracker(views=(View(shape=(1, 3, 32000, 1), strides=(0, 32000, 1, 0), offset=0, mask=None, contiguous=True),)), src=( # noqa: E501
|
||||
UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(96000), arg=0, src=()),)),
|
||||
UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(Ops.ADD, (3,)), src=(
|
||||
UOp(Ops.CAST, dtypes.float, arg=None, src=(
|
||||
UOp(Ops.MUL, dtypes.half, arg=None, src=(
|
||||
UOp(Ops.LOAD, dtypes.half, arg=None, src=(
|
||||
UOp(Ops.VIEW, dtypes.half.ptr(9216), arg=ShapeTracker(views=(View(shape=(1, 3, 32000, 1024), strides=(0, 4096, 0, 1), offset=0, mask=None, contiguous=False),)), src=( # noqa: E501
|
||||
UOp(Ops.DEFINE_GLOBAL, dtypes.half.ptr(9216), arg=1, src=()),)),)),
|
||||
UOp(Ops.LOAD, dtypes.half, arg=None, src=(
|
||||
UOp(Ops.VIEW, dtypes.half.ptr(32768000), arg=ShapeTracker(views=(View(shape=(1, 3, 32000, 1024), strides=(0, 0, 1024, 1), offset=0, mask=None, contiguous=False),)), src=( # noqa: E501
|
||||
UOp(Ops.DEFINE_GLOBAL, dtypes.half.ptr(32768000), arg=2, src=()),)),)),)),)),)),)),))
|
||||
|
||||
# TODO: fix this, expected might change but should be positive
|
||||
for expected, opts in [
|
||||
((7, 0), [Opt(op=OptOps.UPCAST, axis=1, arg=4), Opt(op=OptOps.UPCAST, axis=0, arg=3), Opt(op=OptOps.UNROLL, axis=0, arg=4)]),
|
||||
((5, 0), [Opt(op=OptOps.UPCAST, axis=1, arg=4), Opt(op=OptOps.UNROLL, axis=0, arg=4)]),
|
||||
((2, 0), [Opt(op=OptOps.UNROLL, axis=0, arg=4)]),
|
||||
]:
|
||||
program = get_program(ast, Device[Device.DEFAULT].renderer, opts=opts)
|
||||
|
||||
count = TestFloat4.count_half4(program.uops)
|
||||
assert count == expected, f"{count=}, {expected=}"
|
||||
|
||||
@unittest.skip("this doesn't happen anymore")
|
||||
def test_float4_acc(self):
|
||||
# from float32 stable diffusion red tinybox
|
||||
ast = UOp(Ops.SINK, dtypes.void, arg=None, src=(
|
||||
UOp(Ops.STORE, dtypes.void, arg=None, src=(
|
||||
UOp(Ops.VIEW, dtypes.float.ptr(33554432), arg=ShapeTracker(views=(View(shape=(1, 1, 128, 512, 512, 1, 1, 1), strides=(0, 0, 262144, 512, 1, 0, 0, 0), offset=0, mask=None, contiguous=True),)), src=( # noqa: E501
|
||||
UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(33554432), arg=0, src=()),)),
|
||||
UOp(Ops.ADD, dtypes.float, arg=None, src=(
|
||||
UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(Ops.ADD, (5, 6, 7)), src=(
|
||||
UOp(Ops.MUL, dtypes.float, arg=None, src=(
|
||||
UOp(Ops.LOAD, dtypes.float, arg=None, src=(
|
||||
UOp(Ops.VIEW, dtypes.float.ptr(67108864), arg=ShapeTracker(views=(View(shape=(1, 1, 1, 256, 4, 514, 4, 514), strides=(0, 0, 0, 262144, 0, 512, 0, 1), offset=-513, mask=((0, 1), (0, 1), (0, 1), (0, 256), (0, 4), (1, 513), (0, 4), (1, 513)), contiguous=False), View(shape=(1, 1, 128, 512, 512, 256, 3, 3), strides=(0, 0, 0, 2056, 1, 4227136, 1058840, 515), offset=0, mask=None, contiguous=False))), src=( # noqa: E501
|
||||
UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(67108864), arg=1, src=()),)),)),
|
||||
UOp(Ops.LOAD, dtypes.float, arg=None, src=(
|
||||
UOp(Ops.VIEW, dtypes.float.ptr(294912), arg=ShapeTracker(views=(View(shape=(1, 1, 128, 512, 512, 256, 3, 3), strides=(0, 0, 2304, 0, 0, 9, 3, 1), offset=0, mask=None, contiguous=False),)), src=( # noqa: E501
|
||||
UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(294912), arg=2, src=()),)),)),)),)),
|
||||
UOp(Ops.LOAD, dtypes.float, arg=None, src=(
|
||||
UOp(Ops.VIEW, dtypes.float.ptr(128), arg=ShapeTracker(views=(View(shape=(1, 1, 128, 512, 512, 1, 1, 1), strides=(0, 0, 1, 0, 0, 0, 0, 0), offset=0, mask=None, contiguous=False),)), src=( # noqa: E501
|
||||
UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(128), arg=3, src=()),)),)),)),)),))
|
||||
|
||||
for expected, opts in [
|
||||
(1, [Opt(op=OptOps.UPCAST, axis=2, arg=4)]),
|
||||
(4, [Opt(op=OptOps.UPCAST, axis=2, arg=4), Opt(op=OptOps.UPCAST, axis=0, arg=4)]),
|
||||
]:
|
||||
program = get_program(ast, Device[Device.DEFAULT].renderer, opts=opts)
|
||||
count = len([uop for uop in program.uops if uop.op is Ops.DEFINE_REG and uop.dtype == dtypes.float.vec(4)])
|
||||
assert count == expected, f"{count=}, {expected=}"
|
||||
|
||||
@unittest.skip("this doesn't happen anymore")
|
||||
def test_float2_acc(self):
|
||||
# from resnet
|
||||
ast = UOp(Ops.SINK, dtypes.void, arg=None, src=(
|
||||
UOp(Ops.STORE, dtypes.void, arg=None, src=(
|
||||
UOp(Ops.VIEW, dtypes.half.ptr(212926464), arg=ShapeTracker(views=(View(shape=(1, 256, 1, 64, 1, 114, 1, 114), strides=(0, 831744, 0, 12996, 0, 114, 0, 1), offset=0, mask=None, contiguous=True),)), src=( # noqa: E501
|
||||
UOp(Ops.DEFINE_GLOBAL, dtypes.half.ptr(212926464), arg=0, src=()),)),
|
||||
UOp(Ops.CAST, dtypes.half, arg=None, src=(
|
||||
UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(Ops.ADD, (4, 6)), src=(
|
||||
UOp(Ops.CAST, dtypes.float, arg=None, src=(
|
||||
UOp(Ops.LOAD, dtypes.half, arg=None, src=(
|
||||
UOp(Ops.VIEW, dtypes.half.ptr(462422016), arg=ShapeTracker(views=(View(shape=(256, 64, 3, 56, 2, 3, 56, 2), strides=(1806336, 28224, 3, 504, 0, 1, 9, 0), offset=0, mask=((0, 256), (0, 64), (0, 3), (0, 56), (0, 1), (0, 3), (0, 56), (0, 1)), contiguous=False), View(shape=(256, 64, 3, 115, 3, 115), strides=(7225344, 112896, 37632, 336, 112, 1), offset=0, mask=((0, 256), (0, 64), (0, 3), (0, 112), (0, 3), (0, 112)), contiguous=False), View(shape=(256, 64, 456, 456), strides=(7617600, 119025, 345, 1), offset=0, mask=((0, 256), (0, 64), (0, 345), (0, 345)), contiguous=False), View(shape=(1, 256, 1, 64, 4, 114, 4, 114), strides=(0, 13307904, 0, 207936, 51984, 456, 114, 1), offset=0, mask=None, contiguous=True))), src=( # noqa: E501
|
||||
UOp(Ops.DEFINE_GLOBAL, dtypes.half.ptr(462422016), arg=1, src=()),)),)),)),)),)),)),))
|
||||
for expected, opts in [
|
||||
(16, [Opt(op=OptOps.LOCAL, axis=1, arg=16), Opt(op=OptOps.UPCAST, axis=1, arg=0), Opt(op=OptOps.UPCAST, axis=2, arg=2), Opt(op=OptOps.LOCAL, axis=2, arg=3), Opt(op=OptOps.UPCAST, axis=3, arg=4)]), # noqa: E501
|
||||
(4, [Opt(op=OptOps.LOCAL, axis=1, arg=16), Opt(op=OptOps.UPCAST, axis=1, arg=0), Opt(op=OptOps.UPCAST, axis=2, arg=2)]),
|
||||
]:
|
||||
program = get_program(ast, Device[Device.DEFAULT].renderer, opts=opts)
|
||||
count = len([uop for uop in program.uops if uop.op is Ops.DEFINE_REG and uop.dtype == dtypes.float.vec(2)])
|
||||
assert count == expected, f"{count=}, {expected=}"
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
326
test/opt/test_kernel_opts.py
Normal file
326
test/opt/test_kernel_opts.py
Normal file
@@ -0,0 +1,326 @@
|
||||
import unittest
|
||||
from tinygrad import Device, Tensor, dtypes
|
||||
from tinygrad.helpers import CI
|
||||
from tinygrad.codegen.opt import Opt, OptOps, KernelOptError
|
||||
|
||||
# TODO: write a clean version of this
|
||||
from test.test_linearizer import helper_linearizer_opt
|
||||
|
||||
class TestKernelOpts(unittest.TestCase):
|
||||
@unittest.skipUnless(Device[Device.DEFAULT].renderer.has_local, "test requires locals")
|
||||
@unittest.skipUnless(Device[Device.DEFAULT].renderer.has_shared, "test requires shared")
|
||||
def test_local_and_grouped_reduce(self):
|
||||
N = 128
|
||||
Tensor.manual_seed(1882)
|
||||
a = Tensor.rand(4, 4, N, N)
|
||||
b = Tensor.rand(4, 4, N)
|
||||
r = (b.sqrt() + ((a+1).sum(axis=3).exp()))
|
||||
helper_linearizer_opt(r, [
|
||||
[Opt(OptOps.LOCAL, 0, 2)],
|
||||
[Opt(OptOps.LOCAL, 0, 8)],
|
||||
[Opt(OptOps.LOCAL, 0, 16)], # Checking how it works with locals
|
||||
[Opt(OptOps.GROUPTOP, 0, 2)],
|
||||
[Opt(OptOps.GROUPTOP, 0, 32)],
|
||||
[Opt(OptOps.GROUPTOP, 0, 64)], # Checking how it works with grouped reduce
|
||||
[Opt(OptOps.LOCAL, 0, 2), Opt(OptOps.GROUPTOP, 0, 2)],
|
||||
[Opt(OptOps.LOCAL, 0, 16), Opt(OptOps.GROUPTOP, 0, 16)],
|
||||
[Opt(OptOps.LOCAL, 0, 32), Opt(OptOps.GROUPTOP, 0, 2)],
|
||||
# Checking how it works with locals + grouped reduce
|
||||
[Opt(OptOps.LOCAL, 0, 2), Opt(OptOps.GROUPTOP, 0, 64)],
|
||||
# Checking how it works with locals + grouped reduce + upcasts
|
||||
[Opt(OptOps.LOCAL, 0, 2), Opt(OptOps.GROUPTOP, 0, 2), Opt(OptOps.UPCAST, 0, 8), Opt(OptOps.UNROLL, 1, 4)],
|
||||
# many local + many group
|
||||
[Opt(OptOps.GROUP, 0, 2)] * 4,
|
||||
[Opt(OptOps.LOCAL, 0, 2)] * 4,
|
||||
[Opt(OptOps.LOCAL, 0, 2), Opt(OptOps.GROUP, 0, 2)] * 4,
|
||||
])
|
||||
|
||||
def test_upcasts(self):
|
||||
N = 16
|
||||
Tensor.manual_seed(1772)
|
||||
a = Tensor.rand(N, N)
|
||||
b = Tensor.rand(N, N)
|
||||
r = (a+b).sqrt() * ((a+1).exp())
|
||||
helper_linearizer_opt(r, [
|
||||
[Opt(OptOps.UPCAST, 0, 2)],
|
||||
[Opt(OptOps.UPCAST, 0, 4)],
|
||||
[Opt(OptOps.UPCAST, 0, 8)], # Checking how it works with upcasts
|
||||
])
|
||||
|
||||
def test_full_upcast(self):
|
||||
Tensor.manual_seed(1772)
|
||||
a = Tensor.rand(4)
|
||||
b = Tensor.rand(4)
|
||||
r = (a+b).sqrt() * ((a+1).exp())
|
||||
helper_linearizer_opt(r, [
|
||||
[Opt(OptOps.UPCAST, 0, 4)], # Checking how it works with upcasts
|
||||
])
|
||||
|
||||
@unittest.skipUnless(Device[Device.DEFAULT].renderer.has_local, "test requires locals")
|
||||
@unittest.skipUnless(Device[Device.DEFAULT].renderer.has_shared, "test requires shared")
|
||||
def test_matmul(self):
|
||||
N = 128
|
||||
Tensor.manual_seed(1552)
|
||||
a = Tensor.rand(N, N)
|
||||
b = Tensor.rand(N, N)
|
||||
r = a@b
|
||||
helper_linearizer_opt(r, [
|
||||
[Opt(OptOps.UPCAST, 0, 2)],
|
||||
[Opt(OptOps.UPCAST, 0, 4), Opt(OptOps.UPCAST, 1, 4)], # Checking how it works with upcasts
|
||||
[Opt(OptOps.LOCAL, 0, 2)],
|
||||
[Opt(OptOps.LOCAL, 1, 32)],
|
||||
[Opt(OptOps.LOCAL, 0, 4), Opt(OptOps.LOCAL, 1, 4)],
|
||||
[Opt(OptOps.LOCAL, 0, 4), Opt(OptOps.LOCAL, 1, 32)],
|
||||
[Opt(OptOps.LOCAL, 0, 16), Opt(OptOps.LOCAL, 1, 8)], # Checking how it works with locals
|
||||
[Opt(OptOps.GROUPTOP, 0, 2)],
|
||||
[Opt(OptOps.GROUPTOP, 0, 32)],
|
||||
[Opt(OptOps.GROUPTOP, 0, 32), Opt(OptOps.UNROLL, 0, 4)], # Checking how it works with grouped_reduce
|
||||
[Opt(OptOps.LOCAL, 0, 2), Opt(OptOps.LOCAL, 1, 2), Opt(OptOps.GROUPTOP, 0, 32)],
|
||||
[Opt(OptOps.LOCAL, 0, 8), Opt(OptOps.GROUPTOP, 0, 32)],
|
||||
[Opt(OptOps.LOCAL, 0, 4), Opt(OptOps.LOCAL, 0, 8), Opt(OptOps.GROUPTOP, 0, 4)], # Checking how it works with local+grouped_reduce
|
||||
# Checking all together
|
||||
[Opt(OptOps.LOCAL, 0, 4), Opt(OptOps.LOCAL, 0, 4), Opt(OptOps.GROUPTOP, 0, 8), Opt(OptOps.UNROLL, 0, 4), Opt(OptOps.UPCAST, 0, 4),
|
||||
Opt(OptOps.UPCAST, 1, 2)],
|
||||
# Full global upcast + local
|
||||
[Opt(OptOps.LOCAL, 0, 4), Opt(OptOps.LOCAL, 0, 4), Opt(OptOps.GROUPTOP, 0, 8), Opt(OptOps.UNROLL, 0, 4), Opt(OptOps.UPCAST, 0, 8)],
|
||||
])
|
||||
|
||||
@unittest.skipUnless(Device[Device.DEFAULT].renderer.has_local, "test requires locals")
|
||||
@unittest.skipUnless(Device[Device.DEFAULT].renderer.has_shared, "test requires shared")
|
||||
def test_double_reduce(self):
|
||||
N = 128
|
||||
Tensor.manual_seed(1552)
|
||||
a = Tensor.rand(8, N, 8, N)
|
||||
r = a.sum(axis=(1,3))
|
||||
helper_linearizer_opt(r, [
|
||||
# openCL / GPU=1 is 256 max threads
|
||||
[Opt(OptOps.GROUPTOP, 0, 2)], [Opt(OptOps.GROUPTOP, 0, 32)],
|
||||
[Opt(OptOps.GROUPTOP, 1, 2)], [Opt(OptOps.GROUPTOP, 1, 32)], # Checking how it works with 1 grouped_reduce.
|
||||
[Opt(OptOps.GROUPTOP, 0, 2), Opt(OptOps.GROUPTOP, 1, 2)],
|
||||
[Opt(OptOps.GROUPTOP, 0, 16), Opt(OptOps.GROUPTOP, 1, 2)],
|
||||
[Opt(OptOps.GROUPTOP, 0, 4), Opt(OptOps.GROUPTOP, 1, 64)], # Checking how it works with 2 grouped_reduces.
|
||||
[Opt(OptOps.GROUPTOP, 0, 16), Opt(OptOps.GROUPTOP, 1, 2), Opt(OptOps.UNROLL, 0, 4)],
|
||||
[Opt(OptOps.GROUPTOP, 0, 2), Opt(OptOps.GROUPTOP, 1, 32), Opt(OptOps.UNROLL, 2, 4)], # Checking how it works with 2 grouped_reduces + upcasts.
|
||||
[Opt(OptOps.LOCAL, 0, 4), Opt(OptOps.LOCAL, 1, 4), Opt(OptOps.GROUPTOP, 0, 4), Opt(OptOps.GROUPTOP, 1, 4)],
|
||||
# Checking how it works with 2 grouped_reduces + upcasts + locals.
|
||||
[Opt(OptOps.LOCAL, 0, 4), Opt(OptOps.LOCAL, 1, 4), Opt(OptOps.GROUPTOP, 0, 2), Opt(OptOps.GROUPTOP, 1, 32), Opt(OptOps.UNROLL, 1, 4)],
|
||||
[Opt(OptOps.LOCAL, 0, 2), Opt(OptOps.LOCAL, 1, 2), Opt(OptOps.GROUPTOP, 0, 8), Opt(OptOps.GROUPTOP, 1, 4), Opt(OptOps.UPCAST, 0, 2)],
|
||||
[Opt(OptOps.LOCAL, 0, 2), Opt(OptOps.LOCAL, 1, 2), Opt(OptOps.GROUPTOP, 0, 8), Opt(OptOps.GROUPTOP, 1, 4), Opt(OptOps.UPCAST, 0, 2),
|
||||
Opt(OptOps.UNROLL, 0, 4), Opt(OptOps.UNROLL, 1, 4)], # Checking how it works with 2 grouped_reduces + upcasts + locals.
|
||||
[Opt(OptOps.LOCAL, 0, 4), Opt(OptOps.LOCAL, 1, 4), Opt(OptOps.GROUPTOP, 0, 4), Opt(OptOps.GROUPTOP, 1, 4), Opt(OptOps.UPCAST, 0, 2),
|
||||
Opt(OptOps.UPCAST, 0, 2)], # No globals
|
||||
])
|
||||
|
||||
@unittest.skipUnless(Device[Device.DEFAULT].renderer.tensor_cores, "test requires tensor cores")
|
||||
@unittest.skipUnless(any(tc.dtype_in == tc.dtype_out == dtypes.half for tc in Device[Device.DEFAULT].renderer.tensor_cores),
|
||||
"test requires tensor cores with accumulation in half") # testing with half suffices.
|
||||
def test_tensor_core_opts(self):
|
||||
N = 128
|
||||
Tensor.manual_seed(1552)
|
||||
a, b = Tensor.rand(N, N, dtype=dtypes.half), Tensor.rand(N, N, dtype=dtypes.half)
|
||||
r = a.matmul(b, dtype=dtypes.half)
|
||||
atol, rtol = 0.25, 0.01
|
||||
helper_linearizer_opt(r, [
|
||||
[],
|
||||
[Opt(OptOps.UPCAST, 0, 4)],
|
||||
[Opt(OptOps.UPCAST, 1, 4)],
|
||||
[Opt(OptOps.UPCAST, 0, 4), Opt(OptOps.UPCAST, 1, 4)], # check upcasts
|
||||
[Opt(OptOps.UNROLL, 0, 2)], # check unroll
|
||||
[Opt(OptOps.UPCAST, 0, 4), Opt(OptOps.UNROLL, 0, 2)], # check combo of unroll and local
|
||||
[Opt(OptOps.UPCAST, 0, 4), Opt(OptOps.UPCAST, 1, 4), Opt(OptOps.UNROLL, 0, 2)],
|
||||
[Opt(OptOps.UPCAST, 0, 4), Opt(OptOps.UPCAST, 1, 4), Opt(OptOps.UNROLL, 0, 4)],
|
||||
[Opt(OptOps.UPCAST, 1, 4), Opt(OptOps.UPCAST, 0, 4)], # check permutations
|
||||
[Opt(OptOps.UNROLL, 0, 2), Opt(OptOps.UPCAST, 0, 4)],
|
||||
[Opt(OptOps.UPCAST, 0, 4), Opt(OptOps.UNROLL, 0, 2), Opt(OptOps.UPCAST, 1, 4)],
|
||||
[Opt(OptOps.UNROLL, 0, 2), Opt(OptOps.UPCAST, 1, 4), Opt(OptOps.UPCAST, 0, 4), Opt(OptOps.UNROLL, 0, 4)],
|
||||
], apply_tc=True, atol=atol, rtol=rtol)
|
||||
|
||||
@unittest.skipUnless(Device[Device.DEFAULT].renderer.tensor_cores, "test requires tensor cores")
|
||||
@unittest.skipUnless(any(tc.dtype_in == tc.dtype_out == dtypes.half for tc in Device[Device.DEFAULT].renderer.tensor_cores),
|
||||
"test requires tensor cores with accumulation in half") # testing with half suffices.
|
||||
@unittest.skipUnless(Device[Device.DEFAULT].renderer.has_local, "test requires locals")
|
||||
def test_tensor_core_opts_locals(self):
|
||||
N = 128
|
||||
Tensor.manual_seed(1552)
|
||||
a, b = Tensor.rand(N, N, dtype=dtypes.half), Tensor.rand(N, N, dtype=dtypes.half)
|
||||
r = a.matmul(b, dtype=dtypes.half)
|
||||
atol, rtol = 0.25, 0.01
|
||||
helper_linearizer_opt(r, [
|
||||
[Opt(OptOps.UNROLL, 0, 0)], # check full unroll of reduce with locals
|
||||
[Opt(OptOps.LOCAL, 0, 4)], # check local
|
||||
[Opt(OptOps.UPCAST, 0, 4), Opt(OptOps.UPCAST, 1, 4), Opt(OptOps.UNROLL, 0, 4), Opt(OptOps.LOCAL, 0, 2)],
|
||||
[Opt(OptOps.LOCAL, 0, 2), Opt(OptOps.UPCAST, 1, 4), Opt(OptOps.UNROLL, 0, 2), Opt(OptOps.UPCAST, 0, 4)],
|
||||
], apply_tc=True, atol=atol, rtol=rtol)
|
||||
|
||||
@unittest.skipUnless(Device[Device.DEFAULT].renderer.tensor_cores, "test requires tensor cores")
|
||||
@unittest.skipUnless(Device[Device.DEFAULT].renderer.has_local, "test requires locals")
|
||||
@unittest.skipUnless(Device[Device.DEFAULT].renderer.has_shared, "test requires shared memory")
|
||||
@unittest.skipUnless(any(tc.dtype_in == tc.dtype_out == dtypes.half for tc in Device[Device.DEFAULT].renderer.tensor_cores),
|
||||
"test requires tensor cores with accumulation in half") # testing with half suffices.
|
||||
# NOTE: the METAL test is broken, likely due to a compiler bug. passes on CI with -O0 and with default opt level locally on M3
|
||||
@unittest.skipIf(Device.DEFAULT == "METAL", "broken for METAL")
|
||||
@unittest.skip("feature was removed")
|
||||
def test_tensor_core_opts_group(self):
|
||||
N = 128
|
||||
Tensor.manual_seed(1552)
|
||||
a, b = Tensor.rand(N, N, dtype=dtypes.half), Tensor.rand(N, N, dtype=dtypes.half)
|
||||
r = a.matmul(b, dtype=dtypes.half)
|
||||
atol, rtol = 0.25, 0.01
|
||||
helper_linearizer_opt(r, [
|
||||
[Opt(OptOps.GROUP, 0, 2)],
|
||||
[Opt(OptOps.GROUPTOP, 0, 4)],
|
||||
[Opt(OptOps.UPCAST, 0, 4), Opt(OptOps.GROUP, 0, 2)],
|
||||
[Opt(OptOps.LOCAL, 0, 4), Opt(OptOps.GROUP, 0, 2)],
|
||||
[Opt(OptOps.UNROLL, 0, 4), Opt(OptOps.GROUP, 0, 2)],
|
||||
[Opt(OptOps.UPCAST, 0, 2), Opt(OptOps.LOCAL, 0, 2), Opt(OptOps.GROUP, 0, 2)],
|
||||
[Opt(OptOps.LOCAL, 0, 2), Opt(OptOps.GROUPTOP, 0, 8), Opt(OptOps.UNROLL, 0, 2), Opt(OptOps.UPCAST, 1, 2)],
|
||||
], apply_tc=True, atol=atol, rtol=rtol)
|
||||
|
||||
def test_padto_matmul(self):
|
||||
if (CI and Device.DEFAULT in ["AMD", "NV", "CUDA"]):
|
||||
self.skipTest("super slow on CUDA and AMD because of the big grid dims")
|
||||
N = 17 * 17
|
||||
Tensor.manual_seed(289)
|
||||
a = Tensor.rand(N, N)
|
||||
b = Tensor.rand(N, N)
|
||||
helper_linearizer_opt(a@b, [
|
||||
[Opt(OptOps.PADTO, 0, 32)],
|
||||
[Opt(OptOps.PADTO, 1, 32)],
|
||||
[Opt(OptOps.PADTO, 2, 32)],
|
||||
[Opt(OptOps.PADTO, 0, 32), Opt(OptOps.PADTO, 1, 32)],
|
||||
[Opt(OptOps.PADTO, 0, 32), Opt(OptOps.PADTO, 1, 32), Opt(OptOps.PADTO, 2, 32)],
|
||||
# can optimize further post PADTO
|
||||
[Opt(OptOps.PADTO, 0, 32), Opt(OptOps.PADTO, 1, 32), Opt(OptOps.UPCAST, 0, 2), Opt(OptOps.UPCAST, 1, 2),],
|
||||
])
|
||||
|
||||
def test_padto_upcasted_not_ok(self):
|
||||
N = 4
|
||||
a = Tensor.rand(N, N)
|
||||
b = Tensor.rand(N, N)
|
||||
helper_linearizer_opt(a@b, [
|
||||
[Opt(OptOps.UPCAST, 0, 0)],
|
||||
[Opt(OptOps.UPCAST, 1, 0)],
|
||||
[Opt(OptOps.UNROLL, 0, 0)],
|
||||
[Opt(OptOps.PADTO, 0, 8)],
|
||||
[Opt(OptOps.PADTO, 1, 8)],
|
||||
[Opt(OptOps.PADTO, 2, 8)],
|
||||
])
|
||||
with self.assertRaises(KernelOptError):
|
||||
helper_linearizer_opt(a@b, [[Opt(OptOps.UPCAST, 0, 0), Opt(OptOps.PADTO, 1, 8)]])
|
||||
with self.assertRaises(KernelOptError):
|
||||
helper_linearizer_opt(a@b, [[Opt(OptOps.UPCAST, 1, 0), Opt(OptOps.PADTO, 1, 8)]])
|
||||
with self.assertRaises(KernelOptError):
|
||||
helper_linearizer_opt(a@b, [[Opt(OptOps.UNROLL, 0, 0), Opt(OptOps.PADTO, 2, 8)]])
|
||||
|
||||
def test_padto_sum_ok(self):
|
||||
N = 18 * 18
|
||||
# NOTE: this setup prevents 17 * 17 contiguous merged into one dimension
|
||||
a = Tensor.rand(N, N).realize().shrink(((0, 17), (0, 17))) * 100
|
||||
b = (Tensor.rand(N, N) < 0.5).realize().shrink(((0, 17), (0, 17)))
|
||||
|
||||
helper_linearizer_opt(a.sum(0), [
|
||||
[Opt(OptOps.PADTO, 0, 32)],
|
||||
[Opt(OptOps.PADTO, 0, 32), Opt(OptOps.UPCAST, 0, 8),],
|
||||
])
|
||||
helper_linearizer_opt(a.sum(1), [
|
||||
[Opt(OptOps.PADTO, 0, 32)],
|
||||
[Opt(OptOps.PADTO, 0, 32), Opt(OptOps.UPCAST, 0, 8),],
|
||||
])
|
||||
|
||||
# can pad sum reduce axis if there's no unsafe ops prior to sum
|
||||
for axis in (0, 1):
|
||||
helper_linearizer_opt(a.sum(), [[Opt(OptOps.PADTO, axis, 32)],])
|
||||
helper_linearizer_opt(a.sum(0), [[Opt(OptOps.PADTO, axis, 32)],])
|
||||
helper_linearizer_opt(b.sum(), [[Opt(OptOps.PADTO, axis, 32)],])
|
||||
helper_linearizer_opt(b.sum(0), [[Opt(OptOps.PADTO, axis, 32)],])
|
||||
helper_linearizer_opt(b.sum(dtype=dtypes.bool), [[Opt(OptOps.PADTO, axis, 32)],])
|
||||
# TODO: why?
|
||||
if Device.DEFAULT != "WEBGPU":
|
||||
helper_linearizer_opt(b.sum(0, dtype=dtypes.bool), [[Opt(OptOps.PADTO, axis, 32)],])
|
||||
helper_linearizer_opt(b.sum(1, dtype=dtypes.bool), [[Opt(OptOps.PADTO, axis, 32)],])
|
||||
|
||||
# having unsafe ops after sum is fine
|
||||
helper_linearizer_opt(a.sum().exp(), [[Opt(OptOps.PADTO, 0, 32)],])
|
||||
helper_linearizer_opt(a.sum(0).exp(), [[Opt(OptOps.PADTO, 1, 32)],])
|
||||
|
||||
def test_padto_sum_not_ok(self):
|
||||
N = 18 * 18
|
||||
# NOTE: this setup prevents 17 * 17 contiguous merged into one dimension
|
||||
a = Tensor.rand(N, N).shrink(((0, 17), (0, 17))).exp()
|
||||
# exp is not safe to pad
|
||||
with self.assertRaises(KernelOptError):
|
||||
helper_linearizer_opt(a.exp().sum(), [[Opt(OptOps.PADTO, 0, 32)],])
|
||||
with self.assertRaises(KernelOptError):
|
||||
helper_linearizer_opt(a.exp().sum(0), [[Opt(OptOps.PADTO, 1, 32)],])
|
||||
|
||||
b = a < 1
|
||||
# lt is not safe to pad
|
||||
with self.assertRaises(KernelOptError):
|
||||
helper_linearizer_opt(b.sum(), [[Opt(OptOps.PADTO, 0, 32)],])
|
||||
with self.assertRaises(KernelOptError):
|
||||
helper_linearizer_opt(b.sum(0), [[Opt(OptOps.PADTO, 1, 32)],])
|
||||
|
||||
def test_padto_max(self):
|
||||
N = 18 * 18
|
||||
# NOTE: this setup prevents 17 * 17 contiguous merged into one axis
|
||||
a = -Tensor.rand(N, N).shrink(((0, 17), (0, 17))) * 100
|
||||
|
||||
helper_linearizer_opt(a.max(0), [
|
||||
[Opt(OptOps.PADTO, 0, 32)],
|
||||
[Opt(OptOps.PADTO, 0, 32), Opt(OptOps.UPCAST, 0, 8),],
|
||||
])
|
||||
helper_linearizer_opt(a.max(1), [
|
||||
[Opt(OptOps.PADTO, 0, 32)],
|
||||
[Opt(OptOps.PADTO, 0, 32), Opt(OptOps.UPCAST, 0, 8),],
|
||||
])
|
||||
|
||||
# cannot pad max kernel on reduce
|
||||
with self.assertRaises(KernelOptError):
|
||||
helper_linearizer_opt(a.max(), [[Opt(OptOps.PADTO, 0, 32)],])
|
||||
with self.assertRaises(KernelOptError):
|
||||
helper_linearizer_opt(a.max(0), [[Opt(OptOps.PADTO, 1, 32)],])
|
||||
|
||||
def test_padto_where(self):
|
||||
Tensor.manual_seed(0)
|
||||
N = 17 * 17
|
||||
a = (Tensor.randn(N, N).realize().max(axis=0, keepdim=True) > 1).where(1, 0)
|
||||
helper_linearizer_opt(a.max(0), [
|
||||
[Opt(OptOps.PADTO, 0, 32)],
|
||||
[Opt(OptOps.PADTO, 0, 32), Opt(OptOps.UPCAST, 0, 8),],
|
||||
])
|
||||
|
||||
def test_padto_where_multioutput(self):
|
||||
Tensor.manual_seed(0)
|
||||
N = 17 * 17
|
||||
r = Tensor.randn(N, N).realize().max(axis=0, keepdim=True) > 1
|
||||
a0 = r.where(1, 0)
|
||||
a1 = r.where(2, 0)
|
||||
helper_linearizer_opt([a0.max(0), a1.max(0)], [
|
||||
[Opt(OptOps.PADTO, 0, 32)],
|
||||
[Opt(OptOps.PADTO, 0, 32), Opt(OptOps.UPCAST, 0, 8),],
|
||||
])
|
||||
|
||||
@unittest.skipUnless(Device[Device.DEFAULT].renderer.has_local, "test requires locals")
|
||||
@unittest.skipUnless(Device[Device.DEFAULT].renderer.has_shared, "test requires shared")
|
||||
def test_color_shapes_with_local(self):
|
||||
N = 32
|
||||
Tensor.manual_seed(1552)
|
||||
a = Tensor.rand(N, N)
|
||||
b = Tensor.rand(N, N)
|
||||
r = a@b
|
||||
opts_shapes = [
|
||||
([Opt(OptOps.LOCAL, 0, 2)], [("blue",16),("blue",32),("cyan",2),("red",32)]),
|
||||
([Opt(OptOps.LOCAL, 0, 2),Opt(OptOps.GROUP, 0, 2)], [("blue",16),("blue",32),("cyan",2),("green",2),("red",16)]),
|
||||
# check to ensure local_dims are stable for full UNROLL of the first reduce
|
||||
([Opt(OptOps.LOCAL, 0, 2),Opt(OptOps.UNROLL, 0, 0)], [("blue",16),("blue",32),("cyan",2),("magenta",32)]),
|
||||
([Opt(OptOps.UNROLL, 0, 0),Opt(OptOps.LOCAL, 0, 2)], [("blue",16),("blue",32),("cyan",2),("magenta",32)]),
|
||||
# check behavior for full UNROLL on an existing GROUP
|
||||
([Opt(OptOps.LOCAL, 0, 2),Opt(OptOps.GROUP, 0, 0),Opt(OptOps.UNROLL, 0, 2)], [("blue",16),("blue",32),("cyan",2),("green",16),("magenta",2)]),
|
||||
([Opt(OptOps.LOCAL, 0, 2),Opt(OptOps.GROUP, 0, 0),Opt(OptOps.UNROLL, 0, 0)], [("blue",16),("blue",32),("cyan",2),("magenta",32)]),
|
||||
([Opt(OptOps.GROUP, 0, 0),Opt(OptOps.LOCAL, 0, 2),Opt(OptOps.UNROLL, 0, 0)], [("blue",16),("blue",32),("cyan",2),("magenta",32)]),
|
||||
([Opt(OptOps.GROUP, 0, 2),Opt(OptOps.UNROLL, 0, 0)], [("blue",32),("blue",32),("red",16),("magenta",2)]),
|
||||
]
|
||||
helper_linearizer_opt(r, [x[0] for x in opts_shapes], color_sizes=[x[1] for x in opts_shapes])
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
@@ -723,230 +723,6 @@ class TestLinearizer(unittest.TestCase):
|
||||
out = [u for u in get_program(k.ast, k.opts, k.applied_opts).uops if u.op is Ops.STORE][0]
|
||||
assert out.src[1].op is Ops.VECTORIZE and out.src[1].dtype.count != 1
|
||||
|
||||
@unittest.skipUnless(Device[Device.DEFAULT].renderer.supports_float4, "need backends that support float4")
|
||||
class TestFloat4(unittest.TestCase):
|
||||
@staticmethod
|
||||
def count_float4(uops: list[UOp], n=4):
|
||||
return (len([uop for uop in uops if uop.op is Ops.LOAD and uop.dtype == dtypes.float.vec(n)]),
|
||||
len([uop for uop in uops if uop.op is Ops.STORE and uop.src[1].dtype == dtypes.float.vec(n)]))
|
||||
@staticmethod
|
||||
def count_half4(uops: list[UOp]):
|
||||
return (len([uop for uop in uops if uop.op is Ops.LOAD and uop.dtype == dtypes.half.vec(4)]),
|
||||
len([uop for uop in uops if uop.op is Ops.STORE and uop.src[1].dtype == dtypes.half.vec(4)]))
|
||||
|
||||
def test_float4_basic(self):
|
||||
a = Tensor.empty(2, 8).realize()
|
||||
b = Tensor.empty(2, 8).realize()
|
||||
c = a + b
|
||||
|
||||
s = c.schedule()[0]
|
||||
realized_ast = s.ast
|
||||
opts_to_apply = [Opt(op=OptOps.UPCAST, axis=0, arg=4)]
|
||||
realized_ast = realized_ast.replace(arg=KernelInfo(opts_to_apply=tuple(opts_to_apply)))
|
||||
program = get_program(realized_ast, Device[Device.DEFAULT].renderer)
|
||||
|
||||
assert TestFloat4.count_float4(program.uops) == (2, 1)
|
||||
|
||||
@unittest.skipIf(Device.DEFAULT in {"CPU", "LLVM"} and AMX, "CPU with AMX upcasts float up to size 16")
|
||||
def test_float4_multidim(self):
|
||||
a = Tensor.empty(2, 8).realize()
|
||||
b = Tensor.empty(2, 8).realize()
|
||||
c = a + b
|
||||
|
||||
s = c.schedule()[0]
|
||||
uops = get_program(s.ast, opts=[Opt(op=OptOps.UPCAST, axis=0, arg=4), Opt(op=OptOps.UPCAST, axis=0, arg=2)]).uops
|
||||
assert TestFloat4.count_float4(uops) == (4, 2)
|
||||
|
||||
@unittest.skipUnless(Device.DEFAULT in {"CPU", "LLVM"} and AMX, "Only CPU with AMX upcasts float up to size 16")
|
||||
def test_float4_multidim_amx(self):
|
||||
def kernel_for_shape(size, shift):
|
||||
a = Tensor.empty(2, size).realize()
|
||||
b = Tensor.empty(2, size).realize()
|
||||
c = a + b
|
||||
|
||||
s = c.schedule()[0]
|
||||
return get_program(s.ast, opts=[Opt(op=OptOps.UPCAST, axis=0, arg=4), Opt(op=OptOps.UPCAST, axis=0, arg=shift)]).uops
|
||||
|
||||
sizes = [12, 8, 16]
|
||||
shifts = [3, 2, 4]
|
||||
expected_upcast_size = [4, 8, 16]
|
||||
expected_output = [(6,3), (2,1), (2,1)]
|
||||
|
||||
for i in range(len(sizes)):
|
||||
assert TestFloat4.count_float4(kernel_for_shape(sizes[i], shifts[i]), expected_upcast_size[i]) == expected_output[i]
|
||||
|
||||
def test_float4_unaligned_load(self):
|
||||
a = Tensor.empty(9).realize().shrink(((1, 9),))
|
||||
b = Tensor.empty(9).realize().shrink(((1, 9),))
|
||||
c = a + b
|
||||
|
||||
s = c.schedule()[0]
|
||||
realized_ast = s.ast
|
||||
opts_to_apply = [Opt(op=OptOps.UPCAST, axis=0, arg=4)]
|
||||
realized_ast = realized_ast.replace(arg=KernelInfo(opts_to_apply=tuple(opts_to_apply)))
|
||||
program = get_program(realized_ast, Device[Device.DEFAULT].renderer)
|
||||
|
||||
assert TestFloat4.count_float4(program.uops) == (0, 1)
|
||||
|
||||
@unittest.skipIf(Device.DEFAULT in {"CPU", "LLVM"} and AMX, "CPU with AMX upcasts float up to size 16")
|
||||
def test_float4_multidim_unaligned_load(self):
|
||||
a = Tensor.empty(2, 9).realize().shrink(((0, 2), (1, 9),))
|
||||
b = Tensor.empty(2, 9).realize().shrink(((0, 2), (1, 9),))
|
||||
c = a + b
|
||||
|
||||
s = c.schedule()[0]
|
||||
uops = get_program(s.ast, opts=[Opt(op=OptOps.UPCAST, axis=1, arg=4), Opt(op=OptOps.UPCAST, axis=1, arg=2)]).uops
|
||||
|
||||
assert TestFloat4.count_float4(uops) == (0, 2)
|
||||
|
||||
@unittest.skipUnless(Device.DEFAULT in {"CPU", "LLVM"} and AMX, "Only CPU with AMX upcasts float up to size 16")
|
||||
def test_float4_multidim_unaligned_load_amx(self):
|
||||
def kernel_for_shape(size, shift):
|
||||
a = Tensor.empty(2, size).realize().shrink(((0, 2), (1, size),))
|
||||
b = Tensor.empty(2, size).realize().shrink(((0, 2), (1, size),))
|
||||
c = a + b
|
||||
|
||||
s = c.schedule()[0]
|
||||
return get_program(s.ast, opts=[Opt(op=OptOps.UPCAST, axis=1, arg=4), Opt(op=OptOps.UPCAST, axis=1, arg=shift)]).uops
|
||||
|
||||
sizes = [13, 9, 17]
|
||||
shifts = [3, 2, 4]
|
||||
expected_upcast_size = [4, 8, 16]
|
||||
expected_output = [(0,3), (0,1), (0,1)]
|
||||
|
||||
for i in range(len(sizes)):
|
||||
assert TestFloat4.count_float4(kernel_for_shape(sizes[i], shifts[i]), expected_upcast_size[i]) == expected_output[i]
|
||||
|
||||
def test_float4_sometimes_unaligned(self):
|
||||
a = Tensor.empty(1, 1, 8).realize()
|
||||
b = Tensor.empty(1, 1, 5).realize().shrink(((0, 1), (0, 1), (1, 5)))
|
||||
c = a.conv2d(b)
|
||||
# only the first and last conv dot products are aligned in a, and b is never aligned, so no
|
||||
# float4 should be emitted (the reduce axis of size 4 is the float4 axis here)
|
||||
|
||||
s = c.schedule()[0]
|
||||
uops = get_program(s.ast, opts=[Opt(op=OptOps.UNROLL, axis=0, arg=4)]).uops
|
||||
|
||||
assert TestFloat4.count_float4(uops) == (0, 0)
|
||||
|
||||
def test_float4_multidim_sometimes_unaligned(self):
|
||||
a = Tensor.empty(1, 1, 7).realize()
|
||||
b = Tensor.empty(1, 1, 5).realize().shrink(((0, 1), (0, 1), (1, 5)))
|
||||
c = a.conv2d(b)
|
||||
# the first conv dot product is aligned in a. If we upcast the output and reduce
|
||||
# dimension, then we could do float4 for only that one set of loads, but we currently
|
||||
# don't.
|
||||
# UPDATE: now we do this fusion
|
||||
|
||||
s = c.schedule()[0]
|
||||
uops = get_program(s.ast, opts=[Opt(op=OptOps.UPCAST, axis=0, arg=0), Opt(op=OptOps.UNROLL, axis=0, arg=0)]).uops
|
||||
|
||||
assert TestFloat4.count_float4(uops) in {(0,1), (1,1)}
|
||||
|
||||
def test_float4_expand(self):
|
||||
a = Tensor.empty(9).realize().shrink(((1, 9),))
|
||||
b = Tensor.empty(2).realize().reshape((2, 1)).expand((2,4)).reshape((8,))
|
||||
c = a + b
|
||||
|
||||
# we will upcast the top axis of sz 4. they should not be coalesced into float4,
|
||||
# since the top axis is not contiguous.
|
||||
|
||||
s = c.schedule()[0]
|
||||
uops = get_program(s.ast, opts=[Opt(op=OptOps.UPCAST, axis=0, arg=4)]).uops
|
||||
|
||||
assert TestFloat4.count_float4(uops) == (0, 1)
|
||||
|
||||
def test_float4_heterogeneous(self):
|
||||
a = Tensor.empty(8).realize()
|
||||
b = Tensor.empty(9).realize().shrink(((1, 9),))
|
||||
c = a + b
|
||||
|
||||
# should float4 b but not a
|
||||
|
||||
s = c.schedule()[0]
|
||||
uops = get_program(s.ast, opts=[Opt(op=OptOps.UPCAST, axis=0, arg=4)]).uops
|
||||
|
||||
assert TestFloat4.count_float4(uops) == (1, 1)
|
||||
|
||||
def test_half4_load_unrolled(self):
|
||||
# from llama 7B shard 4 gpus
|
||||
ast = UOp(Ops.SINK, dtypes.void, arg=None, src=(
|
||||
UOp(Ops.STORE, dtypes.void, arg=None, src=(
|
||||
UOp(Ops.VIEW, dtypes.float.ptr(96000), arg=ShapeTracker(views=(View(shape=(1, 3, 32000, 1), strides=(0, 32000, 1, 0), offset=0, mask=None, contiguous=True),)), src=( # noqa: E501
|
||||
UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(96000), arg=0, src=()),)),
|
||||
UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(Ops.ADD, (3,)), src=(
|
||||
UOp(Ops.CAST, dtypes.float, arg=None, src=(
|
||||
UOp(Ops.MUL, dtypes.half, arg=None, src=(
|
||||
UOp(Ops.LOAD, dtypes.half, arg=None, src=(
|
||||
UOp(Ops.VIEW, dtypes.half.ptr(9216), arg=ShapeTracker(views=(View(shape=(1, 3, 32000, 1024), strides=(0, 4096, 0, 1), offset=0, mask=None, contiguous=False),)), src=( # noqa: E501
|
||||
UOp(Ops.DEFINE_GLOBAL, dtypes.half.ptr(9216), arg=1, src=()),)),)),
|
||||
UOp(Ops.LOAD, dtypes.half, arg=None, src=(
|
||||
UOp(Ops.VIEW, dtypes.half.ptr(32768000), arg=ShapeTracker(views=(View(shape=(1, 3, 32000, 1024), strides=(0, 0, 1024, 1), offset=0, mask=None, contiguous=False),)), src=( # noqa: E501
|
||||
UOp(Ops.DEFINE_GLOBAL, dtypes.half.ptr(32768000), arg=2, src=()),)),)),)),)),)),)),))
|
||||
|
||||
# TODO: fix this, expected might change but should be positive
|
||||
for expected, opts in [
|
||||
((7, 0), [Opt(op=OptOps.UPCAST, axis=1, arg=4), Opt(op=OptOps.UPCAST, axis=0, arg=3), Opt(op=OptOps.UNROLL, axis=0, arg=4)]),
|
||||
((5, 0), [Opt(op=OptOps.UPCAST, axis=1, arg=4), Opt(op=OptOps.UNROLL, axis=0, arg=4)]),
|
||||
((2, 0), [Opt(op=OptOps.UNROLL, axis=0, arg=4)]),
|
||||
]:
|
||||
ast = ast.replace(arg=KernelInfo(opts_to_apply=tuple(opts)))
|
||||
program = get_program(ast, Device[Device.DEFAULT].renderer)
|
||||
|
||||
count = TestFloat4.count_half4(program.uops)
|
||||
assert count == expected, f"{count=}, {expected=}"
|
||||
|
||||
@unittest.skip("this doesn't happen anymore")
|
||||
def test_float4_acc(self):
|
||||
# from float32 stable diffusion red tinybox
|
||||
ast = UOp(Ops.SINK, dtypes.void, arg=None, src=(
|
||||
UOp(Ops.STORE, dtypes.void, arg=None, src=(
|
||||
UOp(Ops.VIEW, dtypes.float.ptr(33554432), arg=ShapeTracker(views=(View(shape=(1, 1, 128, 512, 512, 1, 1, 1), strides=(0, 0, 262144, 512, 1, 0, 0, 0), offset=0, mask=None, contiguous=True),)), src=( # noqa: E501
|
||||
UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(33554432), arg=0, src=()),)),
|
||||
UOp(Ops.ADD, dtypes.float, arg=None, src=(
|
||||
UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(Ops.ADD, (5, 6, 7)), src=(
|
||||
UOp(Ops.MUL, dtypes.float, arg=None, src=(
|
||||
UOp(Ops.LOAD, dtypes.float, arg=None, src=(
|
||||
UOp(Ops.VIEW, dtypes.float.ptr(67108864), arg=ShapeTracker(views=(View(shape=(1, 1, 1, 256, 4, 514, 4, 514), strides=(0, 0, 0, 262144, 0, 512, 0, 1), offset=-513, mask=((0, 1), (0, 1), (0, 1), (0, 256), (0, 4), (1, 513), (0, 4), (1, 513)), contiguous=False), View(shape=(1, 1, 128, 512, 512, 256, 3, 3), strides=(0, 0, 0, 2056, 1, 4227136, 1058840, 515), offset=0, mask=None, contiguous=False))), src=( # noqa: E501
|
||||
UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(67108864), arg=1, src=()),)),)),
|
||||
UOp(Ops.LOAD, dtypes.float, arg=None, src=(
|
||||
UOp(Ops.VIEW, dtypes.float.ptr(294912), arg=ShapeTracker(views=(View(shape=(1, 1, 128, 512, 512, 256, 3, 3), strides=(0, 0, 2304, 0, 0, 9, 3, 1), offset=0, mask=None, contiguous=False),)), src=( # noqa: E501
|
||||
UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(294912), arg=2, src=()),)),)),)),)),
|
||||
UOp(Ops.LOAD, dtypes.float, arg=None, src=(
|
||||
UOp(Ops.VIEW, dtypes.float.ptr(128), arg=ShapeTracker(views=(View(shape=(1, 1, 128, 512, 512, 1, 1, 1), strides=(0, 0, 1, 0, 0, 0, 0, 0), offset=0, mask=None, contiguous=False),)), src=( # noqa: E501
|
||||
UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(128), arg=3, src=()),)),)),)),)),))
|
||||
|
||||
for expected, opts in [
|
||||
(1, [Opt(op=OptOps.UPCAST, axis=2, arg=4)]),
|
||||
(4, [Opt(op=OptOps.UPCAST, axis=2, arg=4), Opt(op=OptOps.UPCAST, axis=0, arg=4)]),
|
||||
]:
|
||||
ast = ast.replace(arg=KernelInfo(opts_to_apply=tuple(opts)))
|
||||
program = get_program(ast, Device[Device.DEFAULT].renderer)
|
||||
count = len([uop for uop in program.uops if uop.op is Ops.DEFINE_REG and uop.dtype == dtypes.float.vec(4)])
|
||||
assert count == expected, f"{count=}, {expected=}"
|
||||
|
||||
@unittest.skip("this doesn't happen anymore")
|
||||
def test_float2_acc(self):
|
||||
# from resnet
|
||||
ast = UOp(Ops.SINK, dtypes.void, arg=None, src=(
|
||||
UOp(Ops.STORE, dtypes.void, arg=None, src=(
|
||||
UOp(Ops.VIEW, dtypes.half.ptr(212926464), arg=ShapeTracker(views=(View(shape=(1, 256, 1, 64, 1, 114, 1, 114), strides=(0, 831744, 0, 12996, 0, 114, 0, 1), offset=0, mask=None, contiguous=True),)), src=( # noqa: E501
|
||||
UOp(Ops.DEFINE_GLOBAL, dtypes.half.ptr(212926464), arg=0, src=()),)),
|
||||
UOp(Ops.CAST, dtypes.half, arg=None, src=(
|
||||
UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(Ops.ADD, (4, 6)), src=(
|
||||
UOp(Ops.CAST, dtypes.float, arg=None, src=(
|
||||
UOp(Ops.LOAD, dtypes.half, arg=None, src=(
|
||||
UOp(Ops.VIEW, dtypes.half.ptr(462422016), arg=ShapeTracker(views=(View(shape=(256, 64, 3, 56, 2, 3, 56, 2), strides=(1806336, 28224, 3, 504, 0, 1, 9, 0), offset=0, mask=((0, 256), (0, 64), (0, 3), (0, 56), (0, 1), (0, 3), (0, 56), (0, 1)), contiguous=False), View(shape=(256, 64, 3, 115, 3, 115), strides=(7225344, 112896, 37632, 336, 112, 1), offset=0, mask=((0, 256), (0, 64), (0, 3), (0, 112), (0, 3), (0, 112)), contiguous=False), View(shape=(256, 64, 456, 456), strides=(7617600, 119025, 345, 1), offset=0, mask=((0, 256), (0, 64), (0, 345), (0, 345)), contiguous=False), View(shape=(1, 256, 1, 64, 4, 114, 4, 114), strides=(0, 13307904, 0, 207936, 51984, 456, 114, 1), offset=0, mask=None, contiguous=True))), src=( # noqa: E501
|
||||
UOp(Ops.DEFINE_GLOBAL, dtypes.half.ptr(462422016), arg=1, src=()),)),)),)),)),)),)),))
|
||||
for expected, opts in [
|
||||
(16, [Opt(op=OptOps.LOCAL, axis=1, arg=16), Opt(op=OptOps.UPCAST, axis=1, arg=0), Opt(op=OptOps.UPCAST, axis=2, arg=2), Opt(op=OptOps.LOCAL, axis=2, arg=3), Opt(op=OptOps.UPCAST, axis=3, arg=4)]), # noqa: E501
|
||||
(4, [Opt(op=OptOps.LOCAL, axis=1, arg=16), Opt(op=OptOps.UPCAST, axis=1, arg=0), Opt(op=OptOps.UPCAST, axis=2, arg=2)]),
|
||||
]:
|
||||
ast = ast.replace(arg=KernelInfo(opts_to_apply=tuple(opts)))
|
||||
program = get_program(ast, Device[Device.DEFAULT].renderer)
|
||||
count = len([uop for uop in program.uops if uop.op is Ops.DEFINE_REG and uop.dtype == dtypes.float.vec(2)])
|
||||
assert count == expected, f"{count=}, {expected=}"
|
||||
|
||||
class TestHandCodedOpts(unittest.TestCase):
|
||||
def test_masked_upcast(self):
|
||||
layer_1 = Tensor.cat(*[Tensor.empty(5) for _ in range(4)])
|
||||
@@ -1088,321 +864,5 @@ def _helper_linearizer_opt_ast(realized_ast:UOp, real_bufs:list[Buffer], opts=[]
|
||||
check_opt(x, lambda: Kernel(realized_ast), color_sizes[i] if i < len(color_sizes) else None)
|
||||
return lins
|
||||
|
||||
class TestKernelOpts(unittest.TestCase):
|
||||
@unittest.skipUnless(Device[Device.DEFAULT].renderer.has_local, "test requires locals")
|
||||
@unittest.skipUnless(Device[Device.DEFAULT].renderer.has_shared, "test requires shared")
|
||||
def test_local_and_grouped_reduce(self):
|
||||
N = 128
|
||||
Tensor.manual_seed(1882)
|
||||
a = Tensor.rand(4, 4, N, N)
|
||||
b = Tensor.rand(4, 4, N)
|
||||
r = (b.sqrt() + ((a+1).sum(axis=3).exp()))
|
||||
helper_linearizer_opt(r, [
|
||||
[Opt(OptOps.LOCAL, 0, 2)],
|
||||
[Opt(OptOps.LOCAL, 0, 8)],
|
||||
[Opt(OptOps.LOCAL, 0, 16)], # Checking how it works with locals
|
||||
[Opt(OptOps.GROUPTOP, 0, 2)],
|
||||
[Opt(OptOps.GROUPTOP, 0, 32)],
|
||||
[Opt(OptOps.GROUPTOP, 0, 64)], # Checking how it works with grouped reduce
|
||||
[Opt(OptOps.LOCAL, 0, 2), Opt(OptOps.GROUPTOP, 0, 2)],
|
||||
[Opt(OptOps.LOCAL, 0, 16), Opt(OptOps.GROUPTOP, 0, 16)],
|
||||
[Opt(OptOps.LOCAL, 0, 32), Opt(OptOps.GROUPTOP, 0, 2)],
|
||||
# Checking how it works with locals + grouped reduce
|
||||
[Opt(OptOps.LOCAL, 0, 2), Opt(OptOps.GROUPTOP, 0, 64)],
|
||||
# Checking how it works with locals + grouped reduce + upcasts
|
||||
[Opt(OptOps.LOCAL, 0, 2), Opt(OptOps.GROUPTOP, 0, 2), Opt(OptOps.UPCAST, 0, 8), Opt(OptOps.UNROLL, 1, 4)],
|
||||
# many local + many group
|
||||
[Opt(OptOps.GROUP, 0, 2)] * 4,
|
||||
[Opt(OptOps.LOCAL, 0, 2)] * 4,
|
||||
[Opt(OptOps.LOCAL, 0, 2), Opt(OptOps.GROUP, 0, 2)] * 4,
|
||||
])
|
||||
|
||||
def test_upcasts(self):
|
||||
N = 16
|
||||
Tensor.manual_seed(1772)
|
||||
a = Tensor.rand(N, N)
|
||||
b = Tensor.rand(N, N)
|
||||
r = (a+b).sqrt() * ((a+1).exp())
|
||||
helper_linearizer_opt(r, [
|
||||
[Opt(OptOps.UPCAST, 0, 2)],
|
||||
[Opt(OptOps.UPCAST, 0, 4)],
|
||||
[Opt(OptOps.UPCAST, 0, 8)], # Checking how it works with upcasts
|
||||
])
|
||||
|
||||
def test_full_upcast(self):
|
||||
Tensor.manual_seed(1772)
|
||||
a = Tensor.rand(4)
|
||||
b = Tensor.rand(4)
|
||||
r = (a+b).sqrt() * ((a+1).exp())
|
||||
helper_linearizer_opt(r, [
|
||||
[Opt(OptOps.UPCAST, 0, 4)], # Checking how it works with upcasts
|
||||
])
|
||||
|
||||
@unittest.skipUnless(Device[Device.DEFAULT].renderer.has_local, "test requires locals")
|
||||
@unittest.skipUnless(Device[Device.DEFAULT].renderer.has_shared, "test requires shared")
|
||||
def test_matmul(self):
|
||||
N = 128
|
||||
Tensor.manual_seed(1552)
|
||||
a = Tensor.rand(N, N)
|
||||
b = Tensor.rand(N, N)
|
||||
r = a@b
|
||||
helper_linearizer_opt(r, [
|
||||
[Opt(OptOps.UPCAST, 0, 2)],
|
||||
[Opt(OptOps.UPCAST, 0, 4), Opt(OptOps.UPCAST, 1, 4)], # Checking how it works with upcasts
|
||||
[Opt(OptOps.LOCAL, 0, 2)],
|
||||
[Opt(OptOps.LOCAL, 1, 32)],
|
||||
[Opt(OptOps.LOCAL, 0, 4), Opt(OptOps.LOCAL, 1, 4)],
|
||||
[Opt(OptOps.LOCAL, 0, 4), Opt(OptOps.LOCAL, 1, 32)],
|
||||
[Opt(OptOps.LOCAL, 0, 16), Opt(OptOps.LOCAL, 1, 8)], # Checking how it works with locals
|
||||
[Opt(OptOps.GROUPTOP, 0, 2)],
|
||||
[Opt(OptOps.GROUPTOP, 0, 32)],
|
||||
[Opt(OptOps.GROUPTOP, 0, 32), Opt(OptOps.UNROLL, 0, 4)], # Checking how it works with grouped_reduce
|
||||
[Opt(OptOps.LOCAL, 0, 2), Opt(OptOps.LOCAL, 1, 2), Opt(OptOps.GROUPTOP, 0, 32)],
|
||||
[Opt(OptOps.LOCAL, 0, 8), Opt(OptOps.GROUPTOP, 0, 32)],
|
||||
[Opt(OptOps.LOCAL, 0, 4), Opt(OptOps.LOCAL, 0, 8), Opt(OptOps.GROUPTOP, 0, 4)], # Checking how it works with local+grouped_reduce
|
||||
# Checking all together
|
||||
[Opt(OptOps.LOCAL, 0, 4), Opt(OptOps.LOCAL, 0, 4), Opt(OptOps.GROUPTOP, 0, 8), Opt(OptOps.UNROLL, 0, 4), Opt(OptOps.UPCAST, 0, 4),
|
||||
Opt(OptOps.UPCAST, 1, 2)],
|
||||
# Full global upcast + local
|
||||
[Opt(OptOps.LOCAL, 0, 4), Opt(OptOps.LOCAL, 0, 4), Opt(OptOps.GROUPTOP, 0, 8), Opt(OptOps.UNROLL, 0, 4), Opt(OptOps.UPCAST, 0, 8)],
|
||||
])
|
||||
|
||||
@unittest.skipUnless(Device[Device.DEFAULT].renderer.has_local, "test requires locals")
|
||||
@unittest.skipUnless(Device[Device.DEFAULT].renderer.has_shared, "test requires shared")
|
||||
def test_double_reduce(self):
|
||||
N = 128
|
||||
Tensor.manual_seed(1552)
|
||||
a = Tensor.rand(8, N, 8, N)
|
||||
r = a.sum(axis=(1,3))
|
||||
helper_linearizer_opt(r, [
|
||||
# openCL / GPU=1 is 256 max threads
|
||||
[Opt(OptOps.GROUPTOP, 0, 2)], [Opt(OptOps.GROUPTOP, 0, 32)],
|
||||
[Opt(OptOps.GROUPTOP, 1, 2)], [Opt(OptOps.GROUPTOP, 1, 32)], # Checking how it works with 1 grouped_reduce.
|
||||
[Opt(OptOps.GROUPTOP, 0, 2), Opt(OptOps.GROUPTOP, 1, 2)],
|
||||
[Opt(OptOps.GROUPTOP, 0, 16), Opt(OptOps.GROUPTOP, 1, 2)],
|
||||
[Opt(OptOps.GROUPTOP, 0, 4), Opt(OptOps.GROUPTOP, 1, 64)], # Checking how it works with 2 grouped_reduces.
|
||||
[Opt(OptOps.GROUPTOP, 0, 16), Opt(OptOps.GROUPTOP, 1, 2), Opt(OptOps.UNROLL, 0, 4)],
|
||||
[Opt(OptOps.GROUPTOP, 0, 2), Opt(OptOps.GROUPTOP, 1, 32), Opt(OptOps.UNROLL, 2, 4)], # Checking how it works with 2 grouped_reduces + upcasts.
|
||||
[Opt(OptOps.LOCAL, 0, 4), Opt(OptOps.LOCAL, 1, 4), Opt(OptOps.GROUPTOP, 0, 4), Opt(OptOps.GROUPTOP, 1, 4)],
|
||||
# Checking how it works with 2 grouped_reduces + upcasts + locals.
|
||||
[Opt(OptOps.LOCAL, 0, 4), Opt(OptOps.LOCAL, 1, 4), Opt(OptOps.GROUPTOP, 0, 2), Opt(OptOps.GROUPTOP, 1, 32), Opt(OptOps.UNROLL, 1, 4)],
|
||||
[Opt(OptOps.LOCAL, 0, 2), Opt(OptOps.LOCAL, 1, 2), Opt(OptOps.GROUPTOP, 0, 8), Opt(OptOps.GROUPTOP, 1, 4), Opt(OptOps.UPCAST, 0, 2)],
|
||||
[Opt(OptOps.LOCAL, 0, 2), Opt(OptOps.LOCAL, 1, 2), Opt(OptOps.GROUPTOP, 0, 8), Opt(OptOps.GROUPTOP, 1, 4), Opt(OptOps.UPCAST, 0, 2),
|
||||
Opt(OptOps.UNROLL, 0, 4), Opt(OptOps.UNROLL, 1, 4)], # Checking how it works with 2 grouped_reduces + upcasts + locals.
|
||||
[Opt(OptOps.LOCAL, 0, 4), Opt(OptOps.LOCAL, 1, 4), Opt(OptOps.GROUPTOP, 0, 4), Opt(OptOps.GROUPTOP, 1, 4), Opt(OptOps.UPCAST, 0, 2),
|
||||
Opt(OptOps.UPCAST, 0, 2)], # No globals
|
||||
])
|
||||
|
||||
@unittest.skipUnless(Device[Device.DEFAULT].renderer.tensor_cores, "test requires tensor cores")
|
||||
@unittest.skipUnless(any(tc.dtype_in == tc.dtype_out == dtypes.half for tc in Device[Device.DEFAULT].renderer.tensor_cores),
|
||||
"test requires tensor cores with accumulation in half") # testing with half suffices.
|
||||
def test_tensor_core_opts(self):
|
||||
N = 128
|
||||
Tensor.manual_seed(1552)
|
||||
a, b = Tensor.rand(N, N, dtype=dtypes.half), Tensor.rand(N, N, dtype=dtypes.half)
|
||||
r = a.matmul(b, dtype=dtypes.half)
|
||||
atol, rtol = 0.25, 0.01
|
||||
helper_linearizer_opt(r, [
|
||||
[],
|
||||
[Opt(OptOps.UPCAST, 0, 4)],
|
||||
[Opt(OptOps.UPCAST, 1, 4)],
|
||||
[Opt(OptOps.UPCAST, 0, 4), Opt(OptOps.UPCAST, 1, 4)], # check upcasts
|
||||
[Opt(OptOps.UNROLL, 0, 2)], # check unroll
|
||||
[Opt(OptOps.UPCAST, 0, 4), Opt(OptOps.UNROLL, 0, 2)], # check combo of unroll and local
|
||||
[Opt(OptOps.UPCAST, 0, 4), Opt(OptOps.UPCAST, 1, 4), Opt(OptOps.UNROLL, 0, 2)],
|
||||
[Opt(OptOps.UPCAST, 0, 4), Opt(OptOps.UPCAST, 1, 4), Opt(OptOps.UNROLL, 0, 4)],
|
||||
[Opt(OptOps.UPCAST, 1, 4), Opt(OptOps.UPCAST, 0, 4)], # check permutations
|
||||
[Opt(OptOps.UNROLL, 0, 2), Opt(OptOps.UPCAST, 0, 4)],
|
||||
[Opt(OptOps.UPCAST, 0, 4), Opt(OptOps.UNROLL, 0, 2), Opt(OptOps.UPCAST, 1, 4)],
|
||||
[Opt(OptOps.UNROLL, 0, 2), Opt(OptOps.UPCAST, 1, 4), Opt(OptOps.UPCAST, 0, 4), Opt(OptOps.UNROLL, 0, 4)],
|
||||
], apply_tc=True, atol=atol, rtol=rtol)
|
||||
|
||||
@unittest.skipUnless(Device[Device.DEFAULT].renderer.tensor_cores, "test requires tensor cores")
|
||||
@unittest.skipUnless(any(tc.dtype_in == tc.dtype_out == dtypes.half for tc in Device[Device.DEFAULT].renderer.tensor_cores),
|
||||
"test requires tensor cores with accumulation in half") # testing with half suffices.
|
||||
@unittest.skipUnless(Device[Device.DEFAULT].renderer.has_local, "test requires locals")
|
||||
def test_tensor_core_opts_locals(self):
|
||||
N = 128
|
||||
Tensor.manual_seed(1552)
|
||||
a, b = Tensor.rand(N, N, dtype=dtypes.half), Tensor.rand(N, N, dtype=dtypes.half)
|
||||
r = a.matmul(b, dtype=dtypes.half)
|
||||
atol, rtol = 0.25, 0.01
|
||||
helper_linearizer_opt(r, [
|
||||
[Opt(OptOps.UNROLL, 0, 0)], # check full unroll of reduce with locals
|
||||
[Opt(OptOps.LOCAL, 0, 4)], # check local
|
||||
[Opt(OptOps.UPCAST, 0, 4), Opt(OptOps.UPCAST, 1, 4), Opt(OptOps.UNROLL, 0, 4), Opt(OptOps.LOCAL, 0, 2)],
|
||||
[Opt(OptOps.LOCAL, 0, 2), Opt(OptOps.UPCAST, 1, 4), Opt(OptOps.UNROLL, 0, 2), Opt(OptOps.UPCAST, 0, 4)],
|
||||
], apply_tc=True, atol=atol, rtol=rtol)
|
||||
|
||||
@unittest.skipUnless(Device[Device.DEFAULT].renderer.tensor_cores, "test requires tensor cores")
|
||||
@unittest.skipUnless(Device[Device.DEFAULT].renderer.has_local, "test requires locals")
|
||||
@unittest.skipUnless(Device[Device.DEFAULT].renderer.has_shared, "test requires shared memory")
|
||||
@unittest.skipUnless(any(tc.dtype_in == tc.dtype_out == dtypes.half for tc in Device[Device.DEFAULT].renderer.tensor_cores),
|
||||
"test requires tensor cores with accumulation in half") # testing with half suffices.
|
||||
# NOTE: the METAL test is broken, likely due to a compiler bug. passes on CI with -O0 and with default opt level locally on M3
|
||||
@unittest.skipIf(Device.DEFAULT == "METAL", "broken for METAL")
|
||||
@unittest.skip("feature was removed")
|
||||
def test_tensor_core_opts_group(self):
|
||||
N = 128
|
||||
Tensor.manual_seed(1552)
|
||||
a, b = Tensor.rand(N, N, dtype=dtypes.half), Tensor.rand(N, N, dtype=dtypes.half)
|
||||
r = a.matmul(b, dtype=dtypes.half)
|
||||
atol, rtol = 0.25, 0.01
|
||||
helper_linearizer_opt(r, [
|
||||
[Opt(OptOps.GROUP, 0, 2)],
|
||||
[Opt(OptOps.GROUPTOP, 0, 4)],
|
||||
[Opt(OptOps.UPCAST, 0, 4), Opt(OptOps.GROUP, 0, 2)],
|
||||
[Opt(OptOps.LOCAL, 0, 4), Opt(OptOps.GROUP, 0, 2)],
|
||||
[Opt(OptOps.UNROLL, 0, 4), Opt(OptOps.GROUP, 0, 2)],
|
||||
[Opt(OptOps.UPCAST, 0, 2), Opt(OptOps.LOCAL, 0, 2), Opt(OptOps.GROUP, 0, 2)],
|
||||
[Opt(OptOps.LOCAL, 0, 2), Opt(OptOps.GROUPTOP, 0, 8), Opt(OptOps.UNROLL, 0, 2), Opt(OptOps.UPCAST, 1, 2)],
|
||||
], apply_tc=True, atol=atol, rtol=rtol)
|
||||
|
||||
def test_padto_matmul(self):
|
||||
if (CI and Device.DEFAULT in ["AMD", "NV", "CUDA"]):
|
||||
self.skipTest("super slow on CUDA and AMD because of the big grid dims")
|
||||
N = 17 * 17
|
||||
Tensor.manual_seed(289)
|
||||
a = Tensor.rand(N, N)
|
||||
b = Tensor.rand(N, N)
|
||||
helper_linearizer_opt(a@b, [
|
||||
[Opt(OptOps.PADTO, 0, 32)],
|
||||
[Opt(OptOps.PADTO, 1, 32)],
|
||||
[Opt(OptOps.PADTO, 2, 32)],
|
||||
[Opt(OptOps.PADTO, 0, 32), Opt(OptOps.PADTO, 1, 32)],
|
||||
[Opt(OptOps.PADTO, 0, 32), Opt(OptOps.PADTO, 1, 32), Opt(OptOps.PADTO, 2, 32)],
|
||||
# can optimize further post PADTO
|
||||
[Opt(OptOps.PADTO, 0, 32), Opt(OptOps.PADTO, 1, 32), Opt(OptOps.UPCAST, 0, 2), Opt(OptOps.UPCAST, 1, 2),],
|
||||
])
|
||||
|
||||
def test_padto_upcasted_not_ok(self):
|
||||
N = 4
|
||||
a = Tensor.rand(N, N)
|
||||
b = Tensor.rand(N, N)
|
||||
helper_linearizer_opt(a@b, [
|
||||
[Opt(OptOps.UPCAST, 0, 0)],
|
||||
[Opt(OptOps.UPCAST, 1, 0)],
|
||||
[Opt(OptOps.UNROLL, 0, 0)],
|
||||
[Opt(OptOps.PADTO, 0, 8)],
|
||||
[Opt(OptOps.PADTO, 1, 8)],
|
||||
[Opt(OptOps.PADTO, 2, 8)],
|
||||
])
|
||||
with self.assertRaises(KernelOptError):
|
||||
helper_linearizer_opt(a@b, [[Opt(OptOps.UPCAST, 0, 0), Opt(OptOps.PADTO, 1, 8)]])
|
||||
with self.assertRaises(KernelOptError):
|
||||
helper_linearizer_opt(a@b, [[Opt(OptOps.UPCAST, 1, 0), Opt(OptOps.PADTO, 1, 8)]])
|
||||
with self.assertRaises(KernelOptError):
|
||||
helper_linearizer_opt(a@b, [[Opt(OptOps.UNROLL, 0, 0), Opt(OptOps.PADTO, 2, 8)]])
|
||||
|
||||
def test_padto_sum_ok(self):
|
||||
N = 18 * 18
|
||||
# NOTE: this setup prevents 17 * 17 contiguous merged into one dimension
|
||||
a = Tensor.rand(N, N).realize().shrink(((0, 17), (0, 17))) * 100
|
||||
b = (Tensor.rand(N, N) < 0.5).realize().shrink(((0, 17), (0, 17)))
|
||||
|
||||
helper_linearizer_opt(a.sum(0), [
|
||||
[Opt(OptOps.PADTO, 0, 32)],
|
||||
[Opt(OptOps.PADTO, 0, 32), Opt(OptOps.UPCAST, 0, 8),],
|
||||
])
|
||||
helper_linearizer_opt(a.sum(1), [
|
||||
[Opt(OptOps.PADTO, 0, 32)],
|
||||
[Opt(OptOps.PADTO, 0, 32), Opt(OptOps.UPCAST, 0, 8),],
|
||||
])
|
||||
|
||||
# can pad sum reduce axis if there's no unsafe ops prior to sum
|
||||
for axis in (0, 1):
|
||||
helper_linearizer_opt(a.sum(), [[Opt(OptOps.PADTO, axis, 32)],])
|
||||
helper_linearizer_opt(a.sum(0), [[Opt(OptOps.PADTO, axis, 32)],])
|
||||
helper_linearizer_opt(b.sum(), [[Opt(OptOps.PADTO, axis, 32)],])
|
||||
helper_linearizer_opt(b.sum(0), [[Opt(OptOps.PADTO, axis, 32)],])
|
||||
helper_linearizer_opt(b.sum(dtype=dtypes.bool), [[Opt(OptOps.PADTO, axis, 32)],])
|
||||
# TODO: why?
|
||||
if Device.DEFAULT != "WEBGPU":
|
||||
helper_linearizer_opt(b.sum(0, dtype=dtypes.bool), [[Opt(OptOps.PADTO, axis, 32)],])
|
||||
helper_linearizer_opt(b.sum(1, dtype=dtypes.bool), [[Opt(OptOps.PADTO, axis, 32)],])
|
||||
|
||||
# having unsafe ops after sum is fine
|
||||
helper_linearizer_opt(a.sum().exp(), [[Opt(OptOps.PADTO, 0, 32)],])
|
||||
helper_linearizer_opt(a.sum(0).exp(), [[Opt(OptOps.PADTO, 1, 32)],])
|
||||
|
||||
def test_padto_sum_not_ok(self):
|
||||
N = 18 * 18
|
||||
# NOTE: this setup prevents 17 * 17 contiguous merged into one dimension
|
||||
a = Tensor.rand(N, N).shrink(((0, 17), (0, 17))).exp()
|
||||
# exp is not safe to pad
|
||||
with self.assertRaises(KernelOptError):
|
||||
helper_linearizer_opt(a.exp().sum(), [[Opt(OptOps.PADTO, 0, 32)],])
|
||||
with self.assertRaises(KernelOptError):
|
||||
helper_linearizer_opt(a.exp().sum(0), [[Opt(OptOps.PADTO, 1, 32)],])
|
||||
|
||||
b = a < 1
|
||||
# lt is not safe to pad
|
||||
with self.assertRaises(KernelOptError):
|
||||
helper_linearizer_opt(b.sum(), [[Opt(OptOps.PADTO, 0, 32)],])
|
||||
with self.assertRaises(KernelOptError):
|
||||
helper_linearizer_opt(b.sum(0), [[Opt(OptOps.PADTO, 1, 32)],])
|
||||
|
||||
def test_padto_max(self):
|
||||
N = 18 * 18
|
||||
# NOTE: this setup prevents 17 * 17 contiguous merged into one axis
|
||||
a = -Tensor.rand(N, N).shrink(((0, 17), (0, 17))) * 100
|
||||
|
||||
helper_linearizer_opt(a.max(0), [
|
||||
[Opt(OptOps.PADTO, 0, 32)],
|
||||
[Opt(OptOps.PADTO, 0, 32), Opt(OptOps.UPCAST, 0, 8),],
|
||||
])
|
||||
helper_linearizer_opt(a.max(1), [
|
||||
[Opt(OptOps.PADTO, 0, 32)],
|
||||
[Opt(OptOps.PADTO, 0, 32), Opt(OptOps.UPCAST, 0, 8),],
|
||||
])
|
||||
|
||||
# cannot pad max kernel on reduce
|
||||
with self.assertRaises(KernelOptError):
|
||||
helper_linearizer_opt(a.max(), [[Opt(OptOps.PADTO, 0, 32)],])
|
||||
with self.assertRaises(KernelOptError):
|
||||
helper_linearizer_opt(a.max(0), [[Opt(OptOps.PADTO, 1, 32)],])
|
||||
|
||||
def test_padto_where(self):
|
||||
Tensor.manual_seed(0)
|
||||
N = 17 * 17
|
||||
a = (Tensor.randn(N, N).realize().max(axis=0, keepdim=True) > 1).where(1, 0)
|
||||
helper_linearizer_opt(a.max(0), [
|
||||
[Opt(OptOps.PADTO, 0, 32)],
|
||||
[Opt(OptOps.PADTO, 0, 32), Opt(OptOps.UPCAST, 0, 8),],
|
||||
])
|
||||
|
||||
def test_padto_where_multioutput(self):
|
||||
Tensor.manual_seed(0)
|
||||
N = 17 * 17
|
||||
r = Tensor.randn(N, N).realize().max(axis=0, keepdim=True) > 1
|
||||
a0 = r.where(1, 0)
|
||||
a1 = r.where(2, 0)
|
||||
helper_linearizer_opt([a0.max(0), a1.max(0)], [
|
||||
[Opt(OptOps.PADTO, 0, 32)],
|
||||
[Opt(OptOps.PADTO, 0, 32), Opt(OptOps.UPCAST, 0, 8),],
|
||||
])
|
||||
|
||||
@unittest.skipUnless(Device[Device.DEFAULT].renderer.has_local, "test requires locals")
|
||||
@unittest.skipUnless(Device[Device.DEFAULT].renderer.has_shared, "test requires shared")
|
||||
def test_color_shapes_with_local(self):
|
||||
N = 32
|
||||
Tensor.manual_seed(1552)
|
||||
a = Tensor.rand(N, N)
|
||||
b = Tensor.rand(N, N)
|
||||
r = a@b
|
||||
opts_shapes = [
|
||||
([Opt(OptOps.LOCAL, 0, 2)], [("blue",16),("blue",32),("cyan",2),("red",32)]),
|
||||
([Opt(OptOps.LOCAL, 0, 2),Opt(OptOps.GROUP, 0, 2)], [("blue",16),("blue",32),("cyan",2),("green",2),("red",16)]),
|
||||
# check to ensure local_dims are stable for full UNROLL of the first reduce
|
||||
([Opt(OptOps.LOCAL, 0, 2),Opt(OptOps.UNROLL, 0, 0)], [("blue",16),("blue",32),("cyan",2),("magenta",32)]),
|
||||
([Opt(OptOps.UNROLL, 0, 0),Opt(OptOps.LOCAL, 0, 2)], [("blue",16),("blue",32),("cyan",2),("magenta",32)]),
|
||||
# check behavior for full UNROLL on an existing GROUP
|
||||
([Opt(OptOps.LOCAL, 0, 2),Opt(OptOps.GROUP, 0, 0),Opt(OptOps.UNROLL, 0, 2)], [("blue",16),("blue",32),("cyan",2),("green",16),("magenta",2)]),
|
||||
([Opt(OptOps.LOCAL, 0, 2),Opt(OptOps.GROUP, 0, 0),Opt(OptOps.UNROLL, 0, 0)], [("blue",16),("blue",32),("cyan",2),("magenta",32)]),
|
||||
([Opt(OptOps.GROUP, 0, 0),Opt(OptOps.LOCAL, 0, 2),Opt(OptOps.UNROLL, 0, 0)], [("blue",16),("blue",32),("cyan",2),("magenta",32)]),
|
||||
([Opt(OptOps.GROUP, 0, 2),Opt(OptOps.UNROLL, 0, 0)], [("blue",32),("blue",32),("red",16),("magenta",2)]),
|
||||
]
|
||||
helper_linearizer_opt(r, [x[0] for x in opts_shapes], color_sizes=[x[1] for x in opts_shapes])
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import unittest
|
||||
from tinygrad import Tensor, Context, Device
|
||||
from tinygrad.engine.realize import get_program
|
||||
from tinygrad.codegen.opt.kernel import Opt, OptOps
|
||||
from tinygrad.codegen.opt import Opt, OptOps
|
||||
from tinygrad.uop.ops import KernelInfo
|
||||
|
||||
class TestLinearizerRewrite(unittest.TestCase):
|
||||
|
||||
@@ -16,7 +16,7 @@ from tinygrad.codegen.late.expander import migrate_indexing, expander, pm_pre_ex
|
||||
from tinygrad.codegen.late.devectorizer import load_store_folding, load_store_indexing, devectorize, pm_reduce, \
|
||||
ReduceContext, correct_load_store, pm_render
|
||||
from tinygrad.codegen.late.linearize import block_create, pm_blockend_merge, block_merge, pm_finalize, BlockContext
|
||||
from tinygrad.codegen.opt import pm_get_optimization, pm_do_optimize
|
||||
from tinygrad.codegen.opt.kernel import pm_get_optimization, pm_do_optimize
|
||||
from tinygrad.codegen.opt.swizzler import view_left, view_right, fix_kernel_ops
|
||||
from tinygrad.codegen.opt.postrange import pm_postrange_opt
|
||||
from tinygrad.schedule.rangeify import pm_add_buffers_local, rangeify_codegen
|
||||
|
||||
@@ -1,51 +1,26 @@
|
||||
# opt opinionatedly transforms an ast into an optimized ast using either heuristics or beam search
|
||||
from __future__ import annotations
|
||||
from enum import Enum, auto
|
||||
from dataclasses import dataclass
|
||||
from tinygrad.uop.ops import AxisType
|
||||
|
||||
from tinygrad.codegen.opt.kernel import Kernel
|
||||
from tinygrad.codegen.opt.heuristic import hand_coded_optimizations
|
||||
from tinygrad.uop.ops import UOp, PatternMatcher, UPat, Ops, KernelInfo
|
||||
from tinygrad.helpers import NOOPT, BEAM, getenv, POSTOPT
|
||||
from tinygrad.renderer import Renderer
|
||||
from tinygrad.uop.spec import type_verify
|
||||
class OptOps(Enum):
|
||||
TC = auto(); UPCAST = auto(); UNROLL = auto(); LOCAL = auto() # noqa: E702
|
||||
GROUP = auto(); GROUPTOP = auto(); NOLOCALS = auto(); PADTO = auto(); SWAP = auto() # noqa: E702
|
||||
def __lt__(self, x:OptOps): return self.value < x.value
|
||||
|
||||
def get_optimized_ast(ast:UOp, renderer:Renderer) -> UOp|None:
|
||||
"""
|
||||
Optimize an AST based on heuristics or BEAM search.
|
||||
@dataclass(frozen=True, order=True)
|
||||
class Opt:
|
||||
op: OptOps
|
||||
axis: int|None = None
|
||||
arg: int|tuple|None = None
|
||||
def __repr__(self): return f"Opt(op={self.op}, axis={self.axis}, arg={self.arg})"
|
||||
|
||||
Args:
|
||||
ast: The Ops.SINK rooted AST
|
||||
renderer: The renderer used to generate the code
|
||||
axis_letters = {AxisType.GLOBAL: "g", AxisType.LOCAL: "l", AxisType.LOOP: "L", AxisType.UPCAST: "u",
|
||||
AxisType.GROUP_REDUCE: "G", AxisType.REDUCE: "R", AxisType.UNROLL: "r"}
|
||||
axis_colors = {AxisType.GLOBAL: "blue", AxisType.LOCAL: "cyan", AxisType.LOOP: "WHITE", AxisType.UPCAST: "yellow",
|
||||
AxisType.GROUP_REDUCE: "green", AxisType.REDUCE: "red", AxisType.UNROLL: "magenta"}
|
||||
|
||||
Returns:
|
||||
The Ops.SINK rooted AST transformed to apply the opts and with a KernelInfo in the arg.
|
||||
"""
|
||||
|
||||
# no shape, no opt
|
||||
if ast.src[0].st is None: return None
|
||||
new_arg = ast.arg
|
||||
if new_arg is None:
|
||||
k = Kernel(ast, opts=renderer)
|
||||
if not NOOPT:
|
||||
k.apply_opts(hand_coded_optimizations(k))
|
||||
if not POSTOPT and BEAM >= 1:
|
||||
from tinygrad.codegen.opt.search import beam_search, bufs_from_lin
|
||||
kb = Kernel(ast, opts=renderer)
|
||||
rawbufs = bufs_from_lin(kb, allocate=False)
|
||||
k = beam_search(kb, rawbufs, BEAM.value, bool(getenv("BEAM_ESTIMATE", 1)))
|
||||
new_arg = KernelInfo(opts_to_apply=tuple(k.applied_opts))
|
||||
elif len(new_arg.applied_opts): return None
|
||||
return Kernel(ast.replace(arg=None), opts=renderer).get_optimized_ast().replace(arg=new_arg)
|
||||
|
||||
pm_get_optimization = PatternMatcher([
|
||||
(UPat(Ops.SINK, name="ast"), lambda ctx,ast: get_optimized_ast(ast, ctx)),
|
||||
])
|
||||
|
||||
def apply_opt(ast:UOp, renderer:Renderer):
|
||||
k = Kernel(ast, opts=renderer)
|
||||
k.apply_opts(ast.arg.opts_to_apply)
|
||||
ret = k.get_optimized_ast()
|
||||
if __debug__: type_verify(list(ret.toposort()))
|
||||
return ret
|
||||
|
||||
pm_do_optimize = PatternMatcher([
|
||||
(UPat(Ops.SINK, name="ast"), lambda ctx,ast: apply_opt(ast, ctx) if ast.arg is not None and ast.arg.opts_to_apply is not None else None),
|
||||
])
|
||||
class KernelOptError(Exception): pass
|
||||
def check(cond:bool, msg:str=""):
|
||||
if not cond: raise KernelOptError(msg)
|
||||
|
||||
@@ -1,9 +1,12 @@
|
||||
import itertools
|
||||
from tinygrad.codegen.opt.kernel import Kernel, Opt, OptOps, KernelOptError, AxisType
|
||||
from tinygrad.codegen.opt.postrange import Scheduler
|
||||
from tinygrad.codegen.opt import Opt, OptOps, KernelOptError
|
||||
from tinygrad.helpers import getenv, DEBUG, prod, NOLOCALS, TC_OPT, TC_SELECT, USE_TC, AMX
|
||||
from tinygrad.dtype import ImageDType
|
||||
from tinygrad.uop.ops import Ops, resolve
|
||||
from tinygrad.uop.ops import Ops, resolve, AxisType
|
||||
|
||||
# both versions
|
||||
from tinygrad.codegen.opt.kernel import Kernel
|
||||
from tinygrad.codegen.opt.postrange import Scheduler
|
||||
|
||||
def hand_coded_optimizations(k:Kernel|Scheduler) -> list[Opt]:
|
||||
# first try the tensor cores
|
||||
|
||||
@@ -3,40 +3,18 @@ import itertools, functools, math
|
||||
from dataclasses import dataclass
|
||||
from collections import defaultdict
|
||||
from typing import cast, Final, Callable, Sequence
|
||||
from enum import Enum, auto
|
||||
|
||||
from tinygrad.uop.ops import GroupOp, KernelInfo, UOp, Ops, can_pad, resolve, Variable, sint, graph_rewrite, AxisType
|
||||
from tinygrad.codegen.opt import OptOps, Opt, KernelOptError, check, axis_letters, axis_colors
|
||||
from tinygrad.uop.ops import GroupOp, KernelInfo, UOp, Ops, can_pad, resolve, Variable, sint, graph_rewrite, AxisType, PatternMatcher, UPat
|
||||
from tinygrad.uop.spec import type_verify, ast_spec
|
||||
from tinygrad.device import Device
|
||||
from tinygrad.codegen.opt.tc import TensorCore
|
||||
from tinygrad.renderer import Renderer
|
||||
from tinygrad.dtype import ImageDType
|
||||
from tinygrad.helpers import all_same, colored, ansilen, dedup, prod, round_up, to_function_name, unwrap, argfix, DEBUG
|
||||
from tinygrad.helpers import all_same, colored, ansilen, dedup, prod, round_up, to_function_name, unwrap, argfix, DEBUG, NOOPT, BEAM, getenv, POSTOPT
|
||||
from tinygrad.shape.shapetracker import ShapeTracker
|
||||
from tinygrad.shape.view import strides_for_shape, get_contraction
|
||||
from tinygrad.codegen.opt.swizzler import view_left, view_left_through_load
|
||||
|
||||
class OptOps(Enum):
|
||||
TC = auto(); UPCAST = auto(); UNROLL = auto(); LOCAL = auto() # noqa: E702
|
||||
GROUP = auto(); GROUPTOP = auto(); NOLOCALS = auto(); PADTO = auto(); SWAP = auto() # noqa: E702
|
||||
def __lt__(self, x:OptOps): return self.value < x.value
|
||||
|
||||
@dataclass(frozen=True, order=True)
|
||||
class Opt:
|
||||
op: OptOps
|
||||
axis: int|None = None
|
||||
arg: int|tuple|None = None
|
||||
def __repr__(self): return f"Opt(op={self.op}, axis={self.axis}, arg={self.arg})"
|
||||
|
||||
axis_letters = {AxisType.GLOBAL: "g", AxisType.LOCAL: "l", AxisType.LOOP: "L", AxisType.UPCAST: "u",
|
||||
AxisType.GROUP_REDUCE: "G", AxisType.REDUCE: "R", AxisType.UNROLL: "r"}
|
||||
axis_colors = {AxisType.GLOBAL: "blue", AxisType.LOCAL: "cyan", AxisType.LOOP: "WHITE", AxisType.UPCAST: "yellow",
|
||||
AxisType.GROUP_REDUCE: "green", AxisType.REDUCE: "red", AxisType.UNROLL: "magenta"}
|
||||
|
||||
class KernelOptError(Exception): pass
|
||||
def check(cond:bool, msg:str=""):
|
||||
if not cond: raise KernelOptError(msg)
|
||||
|
||||
@dataclass
|
||||
class TensorCoreOptions:
|
||||
axes: tuple[int, ...] # the location of the original N and M axes if still in the shape
|
||||
@@ -455,3 +433,47 @@ class Kernel:
|
||||
fixed_ast = fixup_ast(self.ast)
|
||||
del fixup_ast
|
||||
return graph_rewrite(fixed_ast, view_left+view_left_through_load, name="fixup optimized AST")
|
||||
|
||||
def get_optimized_ast(ast:UOp, renderer:Renderer) -> UOp|None:
|
||||
"""
|
||||
Optimize an AST based on heuristics or BEAM search.
|
||||
|
||||
Args:
|
||||
ast: The Ops.SINK rooted AST
|
||||
renderer: The renderer used to generate the code
|
||||
|
||||
Returns:
|
||||
The Ops.SINK rooted AST transformed to apply the opts and with a KernelInfo in the arg.
|
||||
"""
|
||||
|
||||
# no shape, no opt
|
||||
if ast.src[0].st is None: return None
|
||||
new_arg = ast.arg
|
||||
if new_arg is None:
|
||||
k = Kernel(ast, opts=renderer)
|
||||
if not NOOPT:
|
||||
from tinygrad.codegen.opt.heuristic import hand_coded_optimizations
|
||||
k.apply_opts(hand_coded_optimizations(k))
|
||||
if not POSTOPT and BEAM >= 1:
|
||||
from tinygrad.codegen.opt.search import beam_search, bufs_from_lin
|
||||
kb = Kernel(ast, opts=renderer)
|
||||
rawbufs = bufs_from_lin(kb, allocate=False)
|
||||
k = beam_search(kb, rawbufs, BEAM.value, bool(getenv("BEAM_ESTIMATE", 1)))
|
||||
new_arg = KernelInfo(opts_to_apply=tuple(k.applied_opts))
|
||||
elif len(new_arg.applied_opts): return None
|
||||
return Kernel(ast.replace(arg=None), opts=renderer).get_optimized_ast().replace(arg=new_arg)
|
||||
|
||||
pm_get_optimization = PatternMatcher([
|
||||
(UPat(Ops.SINK, name="ast"), lambda ctx,ast: get_optimized_ast(ast, ctx)),
|
||||
])
|
||||
|
||||
def apply_opt(ast:UOp, renderer:Renderer):
|
||||
k = Kernel(ast, opts=renderer)
|
||||
k.apply_opts(ast.arg.opts_to_apply)
|
||||
ret = k.get_optimized_ast()
|
||||
if __debug__: type_verify(list(ret.toposort()))
|
||||
return ret
|
||||
|
||||
pm_do_optimize = PatternMatcher([
|
||||
(UPat(Ops.SINK, name="ast"), lambda ctx,ast: apply_opt(ast, ctx) if ast.arg is not None and ast.arg.opts_to_apply is not None else None),
|
||||
])
|
||||
@@ -6,7 +6,7 @@ from tinygrad.uop.symbolic import symbolic
|
||||
from tinygrad.device import Buffer
|
||||
from tinygrad.dtype import AddrSpace, dtypes
|
||||
from tinygrad.helpers import colored, BEAM, getenv, DEBUG, to_function_name, NOOPT, argsort, round_up
|
||||
from tinygrad.codegen.opt.kernel import axis_colors, Opt, OptOps, KernelOptError, check, axis_letters
|
||||
from tinygrad.codegen.opt import axis_colors, Opt, OptOps, KernelOptError, check, axis_letters
|
||||
from tinygrad.renderer import Renderer
|
||||
from tinygrad.schedule.rangeify import remove_tags
|
||||
|
||||
|
||||
@@ -7,12 +7,15 @@ from tinygrad.device import Device, Buffer, Compiler
|
||||
from tinygrad.helpers import prod, flatten, DEBUG, CACHELEVEL, diskcache_get, diskcache_put, getenv, Context, colored, time_to_str
|
||||
from tinygrad.helpers import IGNORE_BEAM_CACHE
|
||||
from tinygrad.dtype import ImageDType, PtrDType
|
||||
from tinygrad.codegen.opt.kernel import Kernel, Opt, OptOps, KernelOptError
|
||||
from tinygrad.codegen.opt.postrange import Scheduler
|
||||
from tinygrad.codegen.opt import Opt, OptOps, KernelOptError
|
||||
from tinygrad.tensor import Tensor
|
||||
from tinygrad.engine.realize import CompiledRunner, get_program
|
||||
from tinygrad.renderer import ProgramSpec
|
||||
|
||||
# both versions
|
||||
from tinygrad.codegen.opt.kernel import Kernel
|
||||
from tinygrad.codegen.opt.postrange import Scheduler
|
||||
|
||||
actions = [Opt(op=OptOps.UPCAST, axis=axis, arg=amt) for amt in [0,2,3,4,5,7] for axis in range(8)]
|
||||
actions += [Opt(op=OptOps.UNROLL, axis=axis, arg=amt) for amt in [0,4,7] for axis in range(5)]
|
||||
actions += [Opt(op=OptOps.LOCAL, axis=axis, arg=amt) for amt in [2,3,4,8,13,16,29] for axis in range(6)]
|
||||
|
||||
@@ -8,7 +8,7 @@ from tinygrad.dtype import ImageDType
|
||||
from tinygrad.schedule.multi import multi_pm
|
||||
from tinygrad.schedule.grouper import group_realizes, ALWAYS_CONTIGUOUS
|
||||
from tinygrad.codegen.opt.swizzler import merge_views, apply_swizzle, swizzle_reduceop
|
||||
from tinygrad.codegen.opt.kernel import Opt
|
||||
from tinygrad.codegen.opt import Opt
|
||||
|
||||
# creation can recurse a lot
|
||||
import sys
|
||||
|
||||
Reference in New Issue
Block a user