mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 23:18:04 -05:00
fix llvm tc on mi350 (#11404)
This commit is contained in:
@@ -48,8 +48,9 @@ def render_wmma_amx(ctx, wmma: UOp) -> str:
|
||||
def render_wmma_amd(ctx, wmma: UOp, arch: str) -> str:
|
||||
dt_map = {dtypes.half: "f16", dtypes.float: "f32", dtypes.bfloat16: "bf16", dtypes.ushort: "bf16"}
|
||||
# https://github.com/llvm/llvm-project/blob/main/clang/test/CodeGenOpenCL/builtins-amdgcn-mfma.cl
|
||||
if arch.split(":")[0] == "gfx942": return f" {ctx[wmma]} = call {ldt(wmma.dtype)} @llvm.amdgcn.mfma.{dt_map[wmma.src[-1].dtype.scalar()]}" + \
|
||||
f".16x16x16{dt_map[wmma.src[0].dtype.scalar()]}(" + ", ".join([f"{ldt(w.dtype)} {ctx[w]}" for w in wmma.src]) + ", i32 0, i32 0, i32 0)"
|
||||
if arch.split(":")[0] in {"gfx942", "gfx950"}:
|
||||
return f" {ctx[wmma]} = call {ldt(wmma.dtype)} @llvm.amdgcn.mfma.{dt_map[wmma.src[-1].dtype.scalar()]}" + \
|
||||
f".16x16x16{dt_map[wmma.src[0].dtype.scalar()]}(" + ", ".join([f"{ldt(w.dtype)} {ctx[w]}" for w in wmma.src]) + ", i32 0, i32 0, i32 0)"
|
||||
# https://github.com/llvm/llvm-project/blob/main/llvm/test/CodeGen/AMDGPU/GlobalISel/llvm.amdgcn.wmma_32.ll
|
||||
# example: %wmma0 = call <8 x float> @llvm.amdgcn.wmma.f32.16x16x16.f16(<16 x half> %v99,<16 x half> %v100,<8 x float> %v101)
|
||||
return f" {ctx[wmma]} = call {ldt(wmma.dtype)} @llvm.amdgcn.wmma.{dt_map[wmma.src[-1].dtype.scalar()]}.16x16x16." + \
|
||||
|
||||
Reference in New Issue
Block a user