add AMX support to LLVM (#8957)

* init amx support for llvm

* revert elf changes

* fix attributes for AMX asm calls

* add comments

* add llvm amx job to benchmarks

* cleanup

* cleanup

* hotfix: improve comments

* comment for aux buffers

* hotfix:

* move amx_tc to ClangRenderer

* merge master

* refactor

* add docs

* add corsix docs reference

---------

Co-authored-by: George Hotz <72895+geohot@users.noreply.github.com>
This commit is contained in:
Ignacio Sica
2025-02-12 05:01:18 -03:00
committed by GitHub
parent 0c97c10814
commit aaed315fee
5 changed files with 40 additions and 15 deletions

View File

@@ -63,7 +63,9 @@ jobs:
- name: Test tensor cores
run: METAL=1 python3.11 test/test_linearizer.py TestLinearizer.test_tensor_cores TestLinearizer.test_tensor_cores_padded
- name: Test AMX tensor cores
run: DEBUG=2 CLANG=1 AMX=1 python3.11 test/test_linearizer.py TestLinearizer.test_tensor_cores TestLinearizer.test_tensor_cores_padded
run: |
DEBUG=2 CLANG=1 AMX=1 python3.11 test/test_linearizer.py TestLinearizer.test_tensor_cores TestLinearizer.test_tensor_cores_padded
DEBUG=2 LLVM=1 AMX=1 python3.11 test/test_linearizer.py TestLinearizer.test_tensor_cores TestLinearizer.test_tensor_cores_padded
- name: Run Tensor Core GEMM (float)
run: DEBUG=2 python3.11 extra/gemm/simple_matmul.py | tee matmul.txt
- name: Run Tensor Core GEMM (half)

View File

@@ -1136,7 +1136,7 @@ class TestLinearizer(unittest.TestCase):
assert u.src[-1].src[0].op != Ops.ASSIGN
@unittest.skipUnless(Device[Device.DEFAULT].renderer.tensor_cores, "test requires tensor cores")
@unittest.skipIf(Device.DEFAULT in {"CLANG"}, "CLANG does not support using a different type for accumulation")
@unittest.skipIf(Device.DEFAULT in {"CLANG", "LLVM"}, "CLANG does not support using a different type for accumulation")
def test_tensor_cores_unroll_casted_phi(self):
tc = [tc for tc in Device[Device.DEFAULT].renderer.tensor_cores if tc.dtype_in != tc.dtype_out][0]
x, y = Tensor.rand(128, 128, dtype=tc.dtype_in), Tensor.rand(128, 128, dtype=tc.dtype_in)
@@ -1148,7 +1148,7 @@ class TestLinearizer(unittest.TestCase):
assert u.src[-1].src[0].op != Ops.ASSIGN
@unittest.skipUnless(Device[Device.DEFAULT].renderer.tensor_cores, "test requires tensor cores")
@unittest.skipIf(Device.DEFAULT in {"CLANG"}, "CLANG does not support using a different type for accumulation")
@unittest.skipIf(Device.DEFAULT in {"CLANG", "LLVM"}, "CLANG does not support using a different type for accumulation")
def test_tensor_cores_unroll_casted_phi_with_children(self):
# all ASSIGN children are outside the loop
tc = [tc for tc in Device[Device.DEFAULT].renderer.tensor_cores if tc.dtype_in != tc.dtype_out][0]
@@ -1429,7 +1429,7 @@ class TestFloat4(unittest.TestCase):
assert TestFloat4.count_float4(k) == (2, 1)
@unittest.skipIf(Device.DEFAULT in {"CLANG"} and AMX, "CLANG with AMX upcasts float up to size 16")
@unittest.skipIf(Device.DEFAULT in {"CLANG", "LLVM"} and AMX, "CLANG with AMX upcasts float up to size 16")
def test_float4_multidim(self):
a = Tensor.rand(2, 8).realize()
b = Tensor.rand(2, 8).realize()
@@ -1446,7 +1446,7 @@ class TestFloat4(unittest.TestCase):
assert TestFloat4.count_float4(k) == (4, 2)
@unittest.skipUnless(Device.DEFAULT in {"CLANG"} and AMX, "Only CLANG with AMX upcasts float up to size 16")
@unittest.skipUnless(Device.DEFAULT in {"CLANG", "LLVM"} and AMX, "Only CLANG with AMX upcasts float up to size 16")
def test_float4_multidim_amx(self):
def kernel_for_shape(size, shift):
a = Tensor.rand(2, size).realize()
@@ -1471,7 +1471,7 @@ class TestFloat4(unittest.TestCase):
for i in range(len(sizes)):
assert TestFloat4.count_float4(kernel_for_shape(sizes[i], shifts[i]), excepted_upcast_size[i]) == expected_output[i]
@unittest.skipIf(Device.DEFAULT in {"CLANG"} and AMX, "CLANG with AMX upcasts float up to size 16")
@unittest.skipIf(Device.DEFAULT in {"CLANG", "LLVM"} and AMX, "CLANG with AMX upcasts float up to size 16")
def test_float4_unaligned_load(self):
a = Tensor.rand(9).realize().shrink(((1, 9),))
b = Tensor.rand(9).realize().shrink(((1, 9),))
@@ -1484,7 +1484,7 @@ class TestFloat4(unittest.TestCase):
assert TestFloat4.count_float4(k) == (0, 1)
@unittest.skipIf(Device.DEFAULT in {"CLANG"} and AMX, "CLANG with AMX upcasts float up to size 16")
@unittest.skipIf(Device.DEFAULT in {"CLANG", "LLVM"} and AMX, "CLANG with AMX upcasts float up to size 16")
def test_float4_multidim_unaligned_load(self):
a = Tensor.rand(2, 9).realize().shrink(((0, 2), (1, 9),))
b = Tensor.rand(2, 9).realize().shrink(((0, 2), (1, 9),))
@@ -1501,7 +1501,7 @@ class TestFloat4(unittest.TestCase):
assert TestFloat4.count_float4(k) == (0, 2)
@unittest.skipUnless(Device.DEFAULT in {"CLANG"} and AMX, "Only CLANG with AMX upcasts float up to size 16")
@unittest.skipUnless(Device.DEFAULT in {"CLANG", "LLVM"} and AMX, "Only CLANG with AMX upcasts float up to size 16")
def test_float4_multidim_unaligned_load_amx(self):
def kernel_for_shape(size, shift):
a = Tensor.rand(2, size).realize().shrink(((0, 2), (1, size),))

View File

@@ -339,7 +339,7 @@ class Kernel:
if extra_opts is not None:
for opt in extra_opts: self.apply_opt(opt)
else:
if (self.opts.device == "CLANG" and AMX): return True # skip hand-coded TC opts if AMX, upcasting will make kernel slower
if AMX: return True # skip hand-coded TC opts if AMX, upcasting will make kernel slower
# hand-coded TC opts
for tc_dim in [tc_dim for tc_dim in [1,0] if tc_opts.axes_exist[tc_dim]]: # attempt to upcast M and N
szs = [sz for sz in [5,4,3,2] if self.full_shape[tc_opts.axes[tc_dim]] % sz == 0]

View File

@@ -173,6 +173,9 @@ class ClangRenderer(CStyleLanguage):
global_max = None
infinity = "__builtin_inff()"
nan = '__builtin_nanf("")'
amx_tc = [TensorCore(dims=(sz,sz,1), threads=1, elements_per_thread=(sz,sz,sz*sz), dtype_in=dt, dtype_out=dt, swizzle=(None,((),(4,5,6,7,0,1,2,3))),
opts=("u0","u0","u0","u0","u1","u1","u1","u1")) for dt,sz in [(dt, 64 // dt.itemsize) for dt in [dtypes.float]]]
if AMX: tensor_cores = amx_tc
# language options
buffer_suffix = " restrict"
@@ -183,10 +186,6 @@ class ClangRenderer(CStyleLanguage):
extra_matcher = PatternMatcher([(UPat.var("x", dtypes.float64).cast(dtypes.float16), lambda x: x.cast(dtypes.float32).cast(dtypes.float16))]) + \
CStyleLanguage.extra_matcher
if AMX:
tensor_cores = [TensorCore(dims=(sz,sz,1), threads=1, elements_per_thread=(sz,sz,sz*sz), dtype_in=dt, dtype_out=dt,
swizzle=(None, ((),(4,5,6,7,0,1,2,3))), opts=("u0","u0","u0","u0","u1","u1","u1","u1"))
for dt,sz in [(dt, 64 // dt.itemsize) for dt in [dtypes.float]]]
if sys.platform == 'win32':
kernel_prefix = "__attribute__((ms_abi)) "
def render_vector_prefix(self, dt:DType) -> str:

View File

@@ -1,8 +1,10 @@
from typing import cast
import math, struct
from tinygrad.renderer import Renderer
from tinygrad.renderer.cstyle import ClangRenderer
from tinygrad.ops import UOp, PatternMatcher, UPat, Ops, GroupOp
from tinygrad.dtype import dtypes, DType, PtrDType, truncate
from tinygrad.helpers import prod, AMX
def ldt(dt:DType):
if dt.vcount > 1: return f"<{dt.vcount} x {ldt(dt.scalar())}>"
@@ -29,6 +31,19 @@ def lcast(input_type:DType, output_type:DType):
if dtypes.is_int(output_type): return 'trunc' if output_type.itemsize < input_type.itemsize else 'sext'
raise NotImplementedError(f"cast from {input_type} -> {output_type} not implemented")
# https://github.com/corsix/amx
def render_wmma(ctx, wmma: UOp) -> str:
def AMX(op, gpr): return f'call void asm sideeffect ".word (0x201000+($0<<5)+0$1-((0$1>>4)*6))", "i,r,~{{memory}}"(i32 {op}, i64 {gpr}) #0; AMX'
return "\n".join([
*[f' store {ldt(src.dtype)} {ctx[src]}, {ldt(src.dtype.ptr())} {ctx[wmma]}_amx{i}, align {src.dtype.itemsize}' for i,src in enumerate(wmma.src)],
f' call void asm sideeffect "nop\\0Anop\\0Anop\\0A.word ({0x201000 + (17 << 5) + 0})", "~{{memory}}"() #0; AMX set', # set
*[f' {ctx[wmma]}_ld{i} = add i64 {ctx[wmma]}_ptr_amx2, {i*4<<56 | i*64}\n {AMX(4,f"{ctx[wmma]}_ld{i}")} ldz' for i in range(16)], # ldz
f' {AMX(0, f"{ctx[wmma]}_ptr_amx1")} ldx\n {AMX(1, f"{ctx[wmma]}_ptr_amx0")} ldy\n {AMX(12, 0)} fma32', # ldx ldy fma
*[f' {ctx[wmma]}_st{i} = add i64 {ctx[wmma]}_ptr_amx2, {i*4<<56 | i*64}\n {AMX(5,f"{ctx[wmma]}_st{i}")} stz' for i in range(16)], # stz
f' call void asm sideeffect "nop\\0Anop\\0Anop\\0A.word ({0x201000 + (17 << 5) + 1})", "~{{memory}}"() #0; AMX clr', # clr
f' {ctx[wmma]} = load {ldt(wmma.dtype)}, ptr {ctx[wmma]}_amx2, align {wmma.dtype.itemsize}'])
# llvm ops, lop[<dtype>][<op>]
unsigned_lop = { Ops.ADD: "add", Ops.MUL: "mul", Ops.IDIV: "udiv", Ops.MOD: "urem",
Ops.CMPLT: "icmp ult", Ops.CMPNE: "icmp ne", Ops.OR: "or", Ops.AND: "and", Ops.XOR: "xor", }
@@ -84,6 +99,9 @@ llvm_rewrite = PatternMatcher([
# if
(UPat(Ops.IF, name="x"), lambda ctx,x: f" br i1 {ctx[x.src[0]]}, label %ifbody_{ctx[x][1:]}, label %ifskip_{ctx[x][1:]}\nifbody_{ctx[x][1:]}:"),
(UPat(Ops.ENDIF, name="x"), lambda ctx,x: f" br label %ifskip_{ctx[x.src[0]][1:]}\nifskip_{ctx[x.src[0]][1:]}:"),
# wmma
(UPat(Ops.WMMA, name="wmma"), render_wmma),
])
def llvm_bf16_cast(buf:UOp, idx:UOp, root:UOp):
@@ -96,6 +114,7 @@ class LLVMRenderer(Renderer):
has_local = False
has_shared = False
global_max = None
if AMX: tensor_cores = ClangRenderer.amx_tc
extra_matcher = PatternMatcher([
# rewrite RECIP with FDIV
@@ -118,14 +137,19 @@ class LLVMRenderer(Renderer):
end_lines: dict[str, None] = {}
vc = -1
# prealloc all assigns
acc_to_assign: dict[UOp, UOp] = {}
for u in uops:
if u.op is Ops.ASSIGN:
if u.op is Ops.ASSIGN: # prealloc all assigns
vc += 1
r[u] = r[u.src[1]] = f"%assign{vc}"
assert u.src[0] not in acc_to_assign, "can't assign to DEFINE_ACC twice"
acc_to_assign[u.src[0]] = u.src[1]
if u.op is Ops.WMMA: # prealloc aux buffers as AMX can only load from memory
vc += 1
r[u] = f"%wmma{vc}"
for i, dtype in enumerate(u.arg[2].vec(sz) for sz in [prod(size for _, size in upcast) for upcast in u.arg[6]]):
kernel += [f" {r[u]}_amx{i} = alloca {ldt(dtype)}, align {dtype.itemsize}",
f" {r[u]}_ptr_amx{i} = ptrtoint {ldt(dtype.ptr())} {r[u]}_amx{i} to i64"]
for u in uops:
# hack for defining sqrt function (TODO: can we get a transcendental for this?)