diff --git a/tinygrad/renderer/llvmir.py b/tinygrad/renderer/llvmir.py index 69b645256c..9be32fd2a7 100644 --- a/tinygrad/renderer/llvmir.py +++ b/tinygrad/renderer/llvmir.py @@ -44,8 +44,11 @@ def render_wmma_amx(ctx, wmma: UOp) -> str: f' call void asm sideeffect "nop\\0Anop\\0Anop\\0A.word ({0x201000 + (17 << 5) + 1})", "~{{memory}}"() #0; AMX clr', # clr f' {ctx[wmma]} = load {ldt(wmma.dtype)}, ptr {ctx[wmma]}_amx2, align {wmma.dtype.itemsize}']) -def render_wmma_amd(ctx, wmma: UOp) -> str: +def render_wmma_amd(ctx, wmma: UOp, arch: str) -> str: dt_map = {dtypes.half: "f16", dtypes.float: "f32", dtypes.bfloat16: "bf16", dtypes.ushort: "bf16"} + # https://github.com/llvm/llvm-project/blob/main/clang/test/CodeGenOpenCL/builtins-amdgcn-mfma.cl + if arch.split(":")[0] == "gfx942": return f" {ctx[wmma]} = call {ldt(wmma.dtype)} @llvm.amdgcn.mfma.{dt_map[wmma.src[-1].dtype.scalar()]}" + \ + f".16x16x16{dt_map[wmma.src[0].dtype.scalar()]}(" + ", ".join([f"{ldt(w.dtype)} {ctx[w]}" for w in wmma.src]) + ", i32 0, i32 0, i32 0)" # https://github.com/llvm/llvm-project/blob/main/llvm/test/CodeGen/AMDGPU/GlobalISel/llvm.amdgcn.wmma_32.ll # example: %wmma0 = call <8 x float> @llvm.amdgcn.wmma.f32.16x16x16.f16(<16 x half> %v99,<16 x half> %v100,<8 x float> %v101) return f" {ctx[wmma]} = call {ldt(wmma.dtype)} @llvm.amdgcn.wmma.{dt_map[wmma.src[-1].dtype.scalar()]}.16x16x16." + \ @@ -216,7 +219,6 @@ class AMDLLVMRenderer(LLVMRenderer): f"<8 x half> {ctx[y]}, <8 x half> zeroinitializer, <16 x i32> <{', '.join([f'i32 {i}, i32 {j}' for i, j in zip(range(0, 8), range(8, 16))])}>"), (UPat(Ops.CAST, name="x", dtype=dtypes.half.vec(8), src=UPat.var("y", dtypes.half.vec(16))), lambda ctx, x, y: f" {ctx[x]}= shufflevector <16 x half> {ctx[y]}, <16 x half> undef, <8 x i32> <{', '.join([f'i32 {x}' for x in range(0, 16, 2)])}>"), - (UPat(Ops.WMMA, name="wmma"), render_wmma_amd), ]) + base_rewrite extra_matcher = LLVMRenderer.extra_matcher def _render_footer(self, uops: list[UOp]) -> str: @@ -228,6 +230,7 @@ class AMDLLVMRenderer(LLVMRenderer): def __init__(self, arch:str): self.arch = arch self.tensor_cores = AMDRenderer.get_tensor_cores(arch) + self.string_rewrite += PatternMatcher([(UPat(Ops.WMMA, name="wmma"), lambda ctx, wmma, arch=arch: render_wmma_amd(ctx, wmma, arch))]) if self.arch.split(":")[0] == "gfx1100": self.extra_matcher += PatternMatcher([ (UPat(Ops.WMMA, name="x", dtype=dtypes.half.vec(8)),