add AMX support to LLVM (#8957)

* init amx support for llvm

* revert elf changes

* fix attributes for AMX asm calls

* add comments

* add llvm amx job to benchmarks

* cleanup

* cleanup

* hotfix: improve comments

* comment for aux buffers

* hotfix:

* move amx_tc to ClangRenderer

* merge master

* refactor

* add docs

* add corsix docs reference

---------

Co-authored-by: George Hotz <72895+geohot@users.noreply.github.com>
This commit is contained in:
Ignacio Sica
2025-02-12 05:01:18 -03:00
committed by GitHub
parent 0c97c10814
commit aaed315fee
5 changed files with 40 additions and 15 deletions

View File

@@ -1136,7 +1136,7 @@ class TestLinearizer(unittest.TestCase):
assert u.src[-1].src[0].op != Ops.ASSIGN
@unittest.skipUnless(Device[Device.DEFAULT].renderer.tensor_cores, "test requires tensor cores")
@unittest.skipIf(Device.DEFAULT in {"CLANG"}, "CLANG does not support using a different type for accumulation")
@unittest.skipIf(Device.DEFAULT in {"CLANG", "LLVM"}, "CLANG does not support using a different type for accumulation")
def test_tensor_cores_unroll_casted_phi(self):
tc = [tc for tc in Device[Device.DEFAULT].renderer.tensor_cores if tc.dtype_in != tc.dtype_out][0]
x, y = Tensor.rand(128, 128, dtype=tc.dtype_in), Tensor.rand(128, 128, dtype=tc.dtype_in)
@@ -1148,7 +1148,7 @@ class TestLinearizer(unittest.TestCase):
assert u.src[-1].src[0].op != Ops.ASSIGN
@unittest.skipUnless(Device[Device.DEFAULT].renderer.tensor_cores, "test requires tensor cores")
@unittest.skipIf(Device.DEFAULT in {"CLANG"}, "CLANG does not support using a different type for accumulation")
@unittest.skipIf(Device.DEFAULT in {"CLANG", "LLVM"}, "CLANG does not support using a different type for accumulation")
def test_tensor_cores_unroll_casted_phi_with_children(self):
# all ASSIGN children are outside the loop
tc = [tc for tc in Device[Device.DEFAULT].renderer.tensor_cores if tc.dtype_in != tc.dtype_out][0]
@@ -1429,7 +1429,7 @@ class TestFloat4(unittest.TestCase):
assert TestFloat4.count_float4(k) == (2, 1)
@unittest.skipIf(Device.DEFAULT in {"CLANG"} and AMX, "CLANG with AMX upcasts float up to size 16")
@unittest.skipIf(Device.DEFAULT in {"CLANG", "LLVM"} and AMX, "CLANG with AMX upcasts float up to size 16")
def test_float4_multidim(self):
a = Tensor.rand(2, 8).realize()
b = Tensor.rand(2, 8).realize()
@@ -1446,7 +1446,7 @@ class TestFloat4(unittest.TestCase):
assert TestFloat4.count_float4(k) == (4, 2)
@unittest.skipUnless(Device.DEFAULT in {"CLANG"} and AMX, "Only CLANG with AMX upcasts float up to size 16")
@unittest.skipUnless(Device.DEFAULT in {"CLANG", "LLVM"} and AMX, "Only CLANG with AMX upcasts float up to size 16")
def test_float4_multidim_amx(self):
def kernel_for_shape(size, shift):
a = Tensor.rand(2, size).realize()
@@ -1471,7 +1471,7 @@ class TestFloat4(unittest.TestCase):
for i in range(len(sizes)):
assert TestFloat4.count_float4(kernel_for_shape(sizes[i], shifts[i]), excepted_upcast_size[i]) == expected_output[i]
@unittest.skipIf(Device.DEFAULT in {"CLANG"} and AMX, "CLANG with AMX upcasts float up to size 16")
@unittest.skipIf(Device.DEFAULT in {"CLANG", "LLVM"} and AMX, "CLANG with AMX upcasts float up to size 16")
def test_float4_unaligned_load(self):
a = Tensor.rand(9).realize().shrink(((1, 9),))
b = Tensor.rand(9).realize().shrink(((1, 9),))
@@ -1484,7 +1484,7 @@ class TestFloat4(unittest.TestCase):
assert TestFloat4.count_float4(k) == (0, 1)
@unittest.skipIf(Device.DEFAULT in {"CLANG"} and AMX, "CLANG with AMX upcasts float up to size 16")
@unittest.skipIf(Device.DEFAULT in {"CLANG", "LLVM"} and AMX, "CLANG with AMX upcasts float up to size 16")
def test_float4_multidim_unaligned_load(self):
a = Tensor.rand(2, 9).realize().shrink(((0, 2), (1, 9),))
b = Tensor.rand(2, 9).realize().shrink(((0, 2), (1, 9),))
@@ -1501,7 +1501,7 @@ class TestFloat4(unittest.TestCase):
assert TestFloat4.count_float4(k) == (0, 2)
@unittest.skipUnless(Device.DEFAULT in {"CLANG"} and AMX, "Only CLANG with AMX upcasts float up to size 16")
@unittest.skipUnless(Device.DEFAULT in {"CLANG", "LLVM"} and AMX, "Only CLANG with AMX upcasts float up to size 16")
def test_float4_multidim_unaligned_load_amx(self):
def kernel_for_shape(size, shift):
a = Tensor.rand(2, size).realize().shrink(((0, 2), (1, size),))