mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 15:08:02 -05:00
add AMX support to LLVM (#8957)
* init amx support for llvm * revert elf changes * fix attributes for AMX asm calls * add comments * add llvm amx job to benchmarks * cleanup * cleanup * hotfix: improve comments * comment for aux buffers * hotfix: * move amx_tc to ClangRenderer * merge master * refactor * add docs * add corsix docs reference --------- Co-authored-by: George Hotz <72895+geohot@users.noreply.github.com>
This commit is contained in:
4
.github/workflows/benchmark.yml
vendored
4
.github/workflows/benchmark.yml
vendored
@@ -63,7 +63,9 @@ jobs:
|
||||
- name: Test tensor cores
|
||||
run: METAL=1 python3.11 test/test_linearizer.py TestLinearizer.test_tensor_cores TestLinearizer.test_tensor_cores_padded
|
||||
- name: Test AMX tensor cores
|
||||
run: DEBUG=2 CLANG=1 AMX=1 python3.11 test/test_linearizer.py TestLinearizer.test_tensor_cores TestLinearizer.test_tensor_cores_padded
|
||||
run: |
|
||||
DEBUG=2 CLANG=1 AMX=1 python3.11 test/test_linearizer.py TestLinearizer.test_tensor_cores TestLinearizer.test_tensor_cores_padded
|
||||
DEBUG=2 LLVM=1 AMX=1 python3.11 test/test_linearizer.py TestLinearizer.test_tensor_cores TestLinearizer.test_tensor_cores_padded
|
||||
- name: Run Tensor Core GEMM (float)
|
||||
run: DEBUG=2 python3.11 extra/gemm/simple_matmul.py | tee matmul.txt
|
||||
- name: Run Tensor Core GEMM (half)
|
||||
|
||||
@@ -1136,7 +1136,7 @@ class TestLinearizer(unittest.TestCase):
|
||||
assert u.src[-1].src[0].op != Ops.ASSIGN
|
||||
|
||||
@unittest.skipUnless(Device[Device.DEFAULT].renderer.tensor_cores, "test requires tensor cores")
|
||||
@unittest.skipIf(Device.DEFAULT in {"CLANG"}, "CLANG does not support using a different type for accumulation")
|
||||
@unittest.skipIf(Device.DEFAULT in {"CLANG", "LLVM"}, "CLANG does not support using a different type for accumulation")
|
||||
def test_tensor_cores_unroll_casted_phi(self):
|
||||
tc = [tc for tc in Device[Device.DEFAULT].renderer.tensor_cores if tc.dtype_in != tc.dtype_out][0]
|
||||
x, y = Tensor.rand(128, 128, dtype=tc.dtype_in), Tensor.rand(128, 128, dtype=tc.dtype_in)
|
||||
@@ -1148,7 +1148,7 @@ class TestLinearizer(unittest.TestCase):
|
||||
assert u.src[-1].src[0].op != Ops.ASSIGN
|
||||
|
||||
@unittest.skipUnless(Device[Device.DEFAULT].renderer.tensor_cores, "test requires tensor cores")
|
||||
@unittest.skipIf(Device.DEFAULT in {"CLANG"}, "CLANG does not support using a different type for accumulation")
|
||||
@unittest.skipIf(Device.DEFAULT in {"CLANG", "LLVM"}, "CLANG does not support using a different type for accumulation")
|
||||
def test_tensor_cores_unroll_casted_phi_with_children(self):
|
||||
# all ASSIGN children are outside the loop
|
||||
tc = [tc for tc in Device[Device.DEFAULT].renderer.tensor_cores if tc.dtype_in != tc.dtype_out][0]
|
||||
@@ -1429,7 +1429,7 @@ class TestFloat4(unittest.TestCase):
|
||||
|
||||
assert TestFloat4.count_float4(k) == (2, 1)
|
||||
|
||||
@unittest.skipIf(Device.DEFAULT in {"CLANG"} and AMX, "CLANG with AMX upcasts float up to size 16")
|
||||
@unittest.skipIf(Device.DEFAULT in {"CLANG", "LLVM"} and AMX, "CLANG with AMX upcasts float up to size 16")
|
||||
def test_float4_multidim(self):
|
||||
a = Tensor.rand(2, 8).realize()
|
||||
b = Tensor.rand(2, 8).realize()
|
||||
@@ -1446,7 +1446,7 @@ class TestFloat4(unittest.TestCase):
|
||||
|
||||
assert TestFloat4.count_float4(k) == (4, 2)
|
||||
|
||||
@unittest.skipUnless(Device.DEFAULT in {"CLANG"} and AMX, "Only CLANG with AMX upcasts float up to size 16")
|
||||
@unittest.skipUnless(Device.DEFAULT in {"CLANG", "LLVM"} and AMX, "Only CLANG with AMX upcasts float up to size 16")
|
||||
def test_float4_multidim_amx(self):
|
||||
def kernel_for_shape(size, shift):
|
||||
a = Tensor.rand(2, size).realize()
|
||||
@@ -1471,7 +1471,7 @@ class TestFloat4(unittest.TestCase):
|
||||
for i in range(len(sizes)):
|
||||
assert TestFloat4.count_float4(kernel_for_shape(sizes[i], shifts[i]), excepted_upcast_size[i]) == expected_output[i]
|
||||
|
||||
@unittest.skipIf(Device.DEFAULT in {"CLANG"} and AMX, "CLANG with AMX upcasts float up to size 16")
|
||||
@unittest.skipIf(Device.DEFAULT in {"CLANG", "LLVM"} and AMX, "CLANG with AMX upcasts float up to size 16")
|
||||
def test_float4_unaligned_load(self):
|
||||
a = Tensor.rand(9).realize().shrink(((1, 9),))
|
||||
b = Tensor.rand(9).realize().shrink(((1, 9),))
|
||||
@@ -1484,7 +1484,7 @@ class TestFloat4(unittest.TestCase):
|
||||
|
||||
assert TestFloat4.count_float4(k) == (0, 1)
|
||||
|
||||
@unittest.skipIf(Device.DEFAULT in {"CLANG"} and AMX, "CLANG with AMX upcasts float up to size 16")
|
||||
@unittest.skipIf(Device.DEFAULT in {"CLANG", "LLVM"} and AMX, "CLANG with AMX upcasts float up to size 16")
|
||||
def test_float4_multidim_unaligned_load(self):
|
||||
a = Tensor.rand(2, 9).realize().shrink(((0, 2), (1, 9),))
|
||||
b = Tensor.rand(2, 9).realize().shrink(((0, 2), (1, 9),))
|
||||
@@ -1501,7 +1501,7 @@ class TestFloat4(unittest.TestCase):
|
||||
|
||||
assert TestFloat4.count_float4(k) == (0, 2)
|
||||
|
||||
@unittest.skipUnless(Device.DEFAULT in {"CLANG"} and AMX, "Only CLANG with AMX upcasts float up to size 16")
|
||||
@unittest.skipUnless(Device.DEFAULT in {"CLANG", "LLVM"} and AMX, "Only CLANG with AMX upcasts float up to size 16")
|
||||
def test_float4_multidim_unaligned_load_amx(self):
|
||||
def kernel_for_shape(size, shift):
|
||||
a = Tensor.rand(2, size).realize().shrink(((0, 2), (1, size),))
|
||||
|
||||
@@ -339,7 +339,7 @@ class Kernel:
|
||||
if extra_opts is not None:
|
||||
for opt in extra_opts: self.apply_opt(opt)
|
||||
else:
|
||||
if (self.opts.device == "CLANG" and AMX): return True # skip hand-coded TC opts if AMX, upcasting will make kernel slower
|
||||
if AMX: return True # skip hand-coded TC opts if AMX, upcasting will make kernel slower
|
||||
# hand-coded TC opts
|
||||
for tc_dim in [tc_dim for tc_dim in [1,0] if tc_opts.axes_exist[tc_dim]]: # attempt to upcast M and N
|
||||
szs = [sz for sz in [5,4,3,2] if self.full_shape[tc_opts.axes[tc_dim]] % sz == 0]
|
||||
|
||||
@@ -173,6 +173,9 @@ class ClangRenderer(CStyleLanguage):
|
||||
global_max = None
|
||||
infinity = "__builtin_inff()"
|
||||
nan = '__builtin_nanf("")'
|
||||
amx_tc = [TensorCore(dims=(sz,sz,1), threads=1, elements_per_thread=(sz,sz,sz*sz), dtype_in=dt, dtype_out=dt, swizzle=(None,((),(4,5,6,7,0,1,2,3))),
|
||||
opts=("u0","u0","u0","u0","u1","u1","u1","u1")) for dt,sz in [(dt, 64 // dt.itemsize) for dt in [dtypes.float]]]
|
||||
if AMX: tensor_cores = amx_tc
|
||||
|
||||
# language options
|
||||
buffer_suffix = " restrict"
|
||||
@@ -183,10 +186,6 @@ class ClangRenderer(CStyleLanguage):
|
||||
extra_matcher = PatternMatcher([(UPat.var("x", dtypes.float64).cast(dtypes.float16), lambda x: x.cast(dtypes.float32).cast(dtypes.float16))]) + \
|
||||
CStyleLanguage.extra_matcher
|
||||
|
||||
if AMX:
|
||||
tensor_cores = [TensorCore(dims=(sz,sz,1), threads=1, elements_per_thread=(sz,sz,sz*sz), dtype_in=dt, dtype_out=dt,
|
||||
swizzle=(None, ((),(4,5,6,7,0,1,2,3))), opts=("u0","u0","u0","u0","u1","u1","u1","u1"))
|
||||
for dt,sz in [(dt, 64 // dt.itemsize) for dt in [dtypes.float]]]
|
||||
if sys.platform == 'win32':
|
||||
kernel_prefix = "__attribute__((ms_abi)) "
|
||||
def render_vector_prefix(self, dt:DType) -> str:
|
||||
|
||||
@@ -1,8 +1,10 @@
|
||||
from typing import cast
|
||||
import math, struct
|
||||
from tinygrad.renderer import Renderer
|
||||
from tinygrad.renderer.cstyle import ClangRenderer
|
||||
from tinygrad.ops import UOp, PatternMatcher, UPat, Ops, GroupOp
|
||||
from tinygrad.dtype import dtypes, DType, PtrDType, truncate
|
||||
from tinygrad.helpers import prod, AMX
|
||||
|
||||
def ldt(dt:DType):
|
||||
if dt.vcount > 1: return f"<{dt.vcount} x {ldt(dt.scalar())}>"
|
||||
@@ -29,6 +31,19 @@ def lcast(input_type:DType, output_type:DType):
|
||||
if dtypes.is_int(output_type): return 'trunc' if output_type.itemsize < input_type.itemsize else 'sext'
|
||||
raise NotImplementedError(f"cast from {input_type} -> {output_type} not implemented")
|
||||
|
||||
# https://github.com/corsix/amx
|
||||
def render_wmma(ctx, wmma: UOp) -> str:
|
||||
def AMX(op, gpr): return f'call void asm sideeffect ".word (0x201000+($0<<5)+0$1-((0$1>>4)*6))", "i,r,~{{memory}}"(i32 {op}, i64 {gpr}) #0; AMX'
|
||||
|
||||
return "\n".join([
|
||||
*[f' store {ldt(src.dtype)} {ctx[src]}, {ldt(src.dtype.ptr())} {ctx[wmma]}_amx{i}, align {src.dtype.itemsize}' for i,src in enumerate(wmma.src)],
|
||||
f' call void asm sideeffect "nop\\0Anop\\0Anop\\0A.word ({0x201000 + (17 << 5) + 0})", "~{{memory}}"() #0; AMX set', # set
|
||||
*[f' {ctx[wmma]}_ld{i} = add i64 {ctx[wmma]}_ptr_amx2, {i*4<<56 | i*64}\n {AMX(4,f"{ctx[wmma]}_ld{i}")} ldz' for i in range(16)], # ldz
|
||||
f' {AMX(0, f"{ctx[wmma]}_ptr_amx1")} ldx\n {AMX(1, f"{ctx[wmma]}_ptr_amx0")} ldy\n {AMX(12, 0)} fma32', # ldx ldy fma
|
||||
*[f' {ctx[wmma]}_st{i} = add i64 {ctx[wmma]}_ptr_amx2, {i*4<<56 | i*64}\n {AMX(5,f"{ctx[wmma]}_st{i}")} stz' for i in range(16)], # stz
|
||||
f' call void asm sideeffect "nop\\0Anop\\0Anop\\0A.word ({0x201000 + (17 << 5) + 1})", "~{{memory}}"() #0; AMX clr', # clr
|
||||
f' {ctx[wmma]} = load {ldt(wmma.dtype)}, ptr {ctx[wmma]}_amx2, align {wmma.dtype.itemsize}'])
|
||||
|
||||
# llvm ops, lop[<dtype>][<op>]
|
||||
unsigned_lop = { Ops.ADD: "add", Ops.MUL: "mul", Ops.IDIV: "udiv", Ops.MOD: "urem",
|
||||
Ops.CMPLT: "icmp ult", Ops.CMPNE: "icmp ne", Ops.OR: "or", Ops.AND: "and", Ops.XOR: "xor", }
|
||||
@@ -84,6 +99,9 @@ llvm_rewrite = PatternMatcher([
|
||||
# if
|
||||
(UPat(Ops.IF, name="x"), lambda ctx,x: f" br i1 {ctx[x.src[0]]}, label %ifbody_{ctx[x][1:]}, label %ifskip_{ctx[x][1:]}\nifbody_{ctx[x][1:]}:"),
|
||||
(UPat(Ops.ENDIF, name="x"), lambda ctx,x: f" br label %ifskip_{ctx[x.src[0]][1:]}\nifskip_{ctx[x.src[0]][1:]}:"),
|
||||
|
||||
# wmma
|
||||
(UPat(Ops.WMMA, name="wmma"), render_wmma),
|
||||
])
|
||||
|
||||
def llvm_bf16_cast(buf:UOp, idx:UOp, root:UOp):
|
||||
@@ -96,6 +114,7 @@ class LLVMRenderer(Renderer):
|
||||
has_local = False
|
||||
has_shared = False
|
||||
global_max = None
|
||||
if AMX: tensor_cores = ClangRenderer.amx_tc
|
||||
|
||||
extra_matcher = PatternMatcher([
|
||||
# rewrite RECIP with FDIV
|
||||
@@ -118,14 +137,19 @@ class LLVMRenderer(Renderer):
|
||||
end_lines: dict[str, None] = {}
|
||||
vc = -1
|
||||
|
||||
# prealloc all assigns
|
||||
acc_to_assign: dict[UOp, UOp] = {}
|
||||
for u in uops:
|
||||
if u.op is Ops.ASSIGN:
|
||||
if u.op is Ops.ASSIGN: # prealloc all assigns
|
||||
vc += 1
|
||||
r[u] = r[u.src[1]] = f"%assign{vc}"
|
||||
assert u.src[0] not in acc_to_assign, "can't assign to DEFINE_ACC twice"
|
||||
acc_to_assign[u.src[0]] = u.src[1]
|
||||
if u.op is Ops.WMMA: # prealloc aux buffers as AMX can only load from memory
|
||||
vc += 1
|
||||
r[u] = f"%wmma{vc}"
|
||||
for i, dtype in enumerate(u.arg[2].vec(sz) for sz in [prod(size for _, size in upcast) for upcast in u.arg[6]]):
|
||||
kernel += [f" {r[u]}_amx{i} = alloca {ldt(dtype)}, align {dtype.itemsize}",
|
||||
f" {r[u]}_ptr_amx{i} = ptrtoint {ldt(dtype.ptr())} {r[u]}_amx{i} to i64"]
|
||||
|
||||
for u in uops:
|
||||
# hack for defining sqrt function (TODO: can we get a transcendental for this?)
|
||||
|
||||
Reference in New Issue
Block a user