diff --git a/tinygrad/codegen/opt/tc.py b/tinygrad/codegen/opt/tc.py index b5b4dedd31..c5b1d33631 100644 --- a/tinygrad/codegen/opt/tc.py +++ b/tinygrad/codegen/opt/tc.py @@ -117,6 +117,14 @@ amd_cdna = [TensorCore(dims=(16,16,16), threads=64, elements_per_thread=(4,4,4), (('l0', 'l1', 'l2', 'l3', 'r2', 'r3'), ('r0', 'r1'), ('l4', 'l5', 'u0', 'u1')))) for di,do in [(dtypes.half,dtypes.float),(dtypes.bfloat16,dtypes.float)]] +amd_cdna_161632 = [TensorCore(dims=(16,16,32), threads=64, elements_per_thread=(8,8,4), dtype_in=di, dtype_out=do, + opts=("l0","l0","l0","l0","u1","u1","l1","l1"), + swizzle=((('u0','u1','l4','l5','r3','r4'), ('r0','r1'), ('l0','l1','l2','l3','r2')), + (('l0','l1','l2','l3','r3','r4'), ('r0','r1'), ('l4','l5','u0','u1','r2')))) + for di,do in [(dtypes.half,dtypes.float),(dtypes.bfloat16,dtypes.float)]] + +amd_cdna4 = amd_cdna_161632 + amd_cdna + # ***** Apple Metal ***** metal = [TensorCore(dims=(8,8,8), threads=32, elements_per_thread=(2,2,2), dtype_in=di, dtype_out=do, diff --git a/tinygrad/renderer/cstyle.py b/tinygrad/renderer/cstyle.py index 450ba154d5..b00b7d0784 100644 --- a/tinygrad/renderer/cstyle.py +++ b/tinygrad/renderer/cstyle.py @@ -423,7 +423,7 @@ class AMDRenderer(CStyleLanguage): @staticmethod def get_tensor_cores(arch): - return {"gfx942": tc.amd_cdna, "gfx950": tc.amd_cdna, "gfx1200": tc.amd_rdna4, "gfx1201": tc.amd_rdna4}.get(arch.split(":")[0], tc.amd_rdna3) + return {"gfx942": tc.amd_cdna, "gfx950": tc.amd_cdna4, "gfx1200": tc.amd_rdna4, "gfx1201": tc.amd_rdna4}.get(arch.split(":")[0], tc.amd_rdna3) def __init__(self, arch:str): # gfx942 => MI300, gfx1100 => RX 7900, gfx1201 => RX 9700 self.arch = arch self.tensor_cores = self.get_tensor_cores(arch) diff --git a/tinygrad/renderer/llvmir.py b/tinygrad/renderer/llvmir.py index b67bd9cb32..165b9a9485 100644 --- a/tinygrad/renderer/llvmir.py +++ b/tinygrad/renderer/llvmir.py @@ -49,9 +49,11 @@ def render_wmma_amx(ctx, wmma: UOp) -> str: def render_wmma_amd(ctx, wmma: UOp, cdna=False) -> str: dt_map = {dtypes.half: "f16", dtypes.float: "f32", dtypes.ushort: "bf16.1k" if cdna else "bf16", dtypes.bfloat16: "bf16.1k" if cdna else "bf16"} # https://github.com/llvm/llvm-project/blob/main/clang/test/CodeGenOpenCL/builtins-amdgcn-mfma.cl + N,M,K = wmma.arg[1] if cdna: + if K == 32: dt_map.update({dtypes.half: ".f16", dtypes.bfloat16: ".bf16"}) return f" {ctx[wmma]} = call {ldt(wmma.dtype)} @llvm.amdgcn.mfma.{dt_map[wmma.src[-1].dtype.scalar()]}" + \ - f".16x16x16{dt_map[wmma.src[0].dtype.scalar()]}(" + ", ".join([f"{ldt(w.dtype)} {ctx[w]}" for w in wmma.src]) + ", i32 0, i32 0, i32 0)" + f".{N}x{M}x{K}{dt_map[wmma.src[0].dtype.scalar()]}(" + ", ".join([f"{ldt(w.dtype)} {ctx[w]}" for w in wmma.src]) + ", i32 0, i32 0, i32 0)" # https://github.com/llvm/llvm-project/blob/main/llvm/test/CodeGen/AMDGPU/GlobalISel/llvm.amdgcn.wmma_32.ll # example: %wmma0 = call <8 x float> @llvm.amdgcn.wmma.f32.16x16x16.f16(<16 x half> %v99,<16 x half> %v100,<8 x float> %v101) return f" {ctx[wmma]} = call {ldt(wmma.dtype)} @llvm.amdgcn.wmma.{dt_map[wmma.src[-1].dtype.scalar()]}.16x16x16." + \ diff --git a/tinygrad/runtime/ops_python.py b/tinygrad/runtime/ops_python.py index 780762539d..2491bb41d4 100644 --- a/tinygrad/runtime/ops_python.py +++ b/tinygrad/runtime/ops_python.py @@ -150,10 +150,10 @@ class PythonProgram: def c_map(lane, elem): return (elem + ((lane%2)*2) + ((lane//8)%2)*4, ((lane//2)%4) + (lane//16)*4) ul[i] = wmma_helper(32, 8, 2, 2, 2, a_b_elem, a_b_elem, c_map) elif device == "AMD" and threads == 64: - def a_elem(x, k, row, goff): return x[k%4][goff + (k//4)*16 + row] - def b_elem(x, col, k, goff): return a_elem(x, k, col, goff) # pylint: disable=arguments-out-of-order + def a_elem(x, k, row, goff): return x[k%(dims[2]//4)][goff + (k//(dims[2]//4))*16 + row] + def b_elem(x, col, k, goff): return a_elem(x, k, col, goff) # pylint: disable=arguments-out-of-order def c_map(lane, elem): return (lane%16, (lane//16)*4 + elem) - ul[i] = wmma_helper(64, 16, 4, 4, 4, a_elem, b_elem, c_map) + ul[i] = wmma_helper(64, dims[2], len(inp[0]), len(inp[1]), len(inp[2]), a_elem, b_elem, c_map) elif device == "AMD" and len(inp[0]) == 8: # RDNA4 def a_elem(x, k, row, goff): return x[k - [0, 4, 4, 8][k//4]][goff + row + [0, 16, 0, 16][k//4]] def b_elem(x, col, k, goff): return a_elem(x, k, col, goff) @@ -221,7 +221,7 @@ class PythonRenderer(Renderer): match cast(str, EMULATE.value): case "METAL": self.device, self.tensor_cores = "METAL", tc.metal case "AMD": self.device, self.tensor_cores = "AMD", tc.amd_rdna3 - case "AMD_MFMA": self.device, self.tensor_cores = "AMD", tc.amd_cdna + case "AMD_MFMA": self.device, self.tensor_cores = "AMD", tc.amd_cdna4 case "AMD_RDNA4": self.device, self.tensor_cores = "AMD", tc.amd_rdna4 case "CUDA": self.device, self.tensor_cores = "CUDA", tc.cuda_sm80 case "CUDA_SM75": self.device, self.tensor_cores = "CUDA", tc.cuda_sm75