fix tensor cores for gfx1201 (#9838)

* fix tensor cores for gfx1201

* fix typo

* fix python wmma

* AMDLLVMRenderer with arch + AMDLLVM tensor_cores

* fix ci

* clean up

* more tensor cores for RDNA4

* fix half/half, bfloat16/float, bfloat16/bfloat16 for amd_llvm

---------

Co-authored-by: nimlgen <138685161+nimlgen@users.noreply.github.com>
Co-authored-by: George Hotz <72895+geohot@users.noreply.github.com>
This commit is contained in:
b1tg
2025-04-25 02:57:41 +08:00
committed by GitHub
parent 779aa1e2e9
commit 914d89fa0b
4 changed files with 44 additions and 11 deletions

View File

@@ -250,6 +250,14 @@ jobs:
PYTHONPATH=. DEBUG=2 EMULATE_AMD_MFMA=1 FORWARD_ONLY=1 PYTHON=1 N=64 HALF=1 ACC_HALF=0 python3 ./extra/gemm/simple_matmul.py
PYTHONPATH=. DEBUG=2 EMULATE_AMD_MFMA=1 FORWARD_ONLY=1 PYTHON=1 python3 test/test_linearizer.py TestLinearizer.test_tensor_cores
PYTHONPATH=. DEBUG=2 EMULATE_AMD_MFMA=1 FORWARD_ONLY=1 PYTHON=1 python3 test/test_linearizer.py TestLinearizer.test_tensor_cores_padded TestLinearizer.test_tensor_cores_padded_uops
- name: Test emulated AMD RDNA4 tensor cores
run: |
PYTHONPATH=. DEBUG=2 EMULATE_AMD_RDNA4=1 FORWARD_ONLY=1 PYTHON=1 N=16 HALF=1 ACC_HALF=0 python3 ./extra/gemm/simple_matmul.py
PYTHONPATH=. DEBUG=2 EMULATE_AMD_RDNA4=1 FORWARD_ONLY=1 PYTHON=1 N=64 HALF=1 ACC_HALF=0 python3 ./extra/gemm/simple_matmul.py
PYTHONPATH=. DEBUG=2 EMULATE_AMD_RDNA4=1 FORWARD_ONLY=1 PYTHON=1 N=16 HALF=1 ACC_HALF=1 python3 ./extra/gemm/simple_matmul.py
PYTHONPATH=. DEBUG=2 EMULATE_AMD_RDNA4=1 FORWARD_ONLY=1 PYTHON=1 N=64 HALF=1 ACC_HALF=1 python3 ./extra/gemm/simple_matmul.py
PYTHONPATH=. DEBUG=2 EMULATE_AMD_RDNA4=1 FORWARD_ONLY=1 PYTHON=1 python3 test/test_linearizer.py TestLinearizer.test_tensor_cores
PYTHONPATH=. DEBUG=2 EMULATE_AMD_RDNA4=1 FORWARD_ONLY=1 PYTHON=1 python3 test/test_linearizer.py TestLinearizer.test_tensor_cores_padded TestLinearizer.test_tensor_cores_padded_uops
- name: Test emulated CUDA tensor cores
run: |
DEBUG=2 EMULATE_CUDA=1 FORWARD_ONLY=1 PYTHON=1 python3 test/test_ops.py TestOps.test_gemm_fp16

View File

@@ -411,16 +411,20 @@ class AMDRenderer(CStyleLanguage):
tensor_cores = [TensorCore(dims=(16,16,16), threads=32, elements_per_thread=(16,16,8), dtype_in=di, dtype_out=do,
opts=("l0","l0","l0","l0","l1","u1","u1","u1"), swizzle=(((4,9,10,11,0),(1,2,3,5,6,7,8)), ((0,1,2,3,4),(9,10,11,5,6,7,8))))
for di,do in [(dtypes.half,dtypes.float),(dtypes.half,dtypes.half),(dtypes.bfloat16,dtypes.float)]]
tensor_cores_rdna4 = [TensorCore(dims=(16,16,16), threads=32, elements_per_thread=(8,8,8), dtype_in=di, dtype_out=do,
opts=("l0","l0","l0","l0","u1","u1","u1","l1"), swizzle=(((9,10,11,4,7),(0,1,2,3,5,6,8)),((0,1,2,3,7),(4,9,10,11,5,6,8))))
for di,do in [(dtypes.half,dtypes.float),(dtypes.half,dtypes.half),(dtypes.bfloat16,dtypes.float),(dtypes.bfloat16,dtypes.bfloat16)]]
# https://gpuopen.com/learn/amd-lab-notes/amd-lab-notes-matrix-cores-readme
tensor_cores_mfma = [TensorCore(dims=(16,16,16), threads=64, elements_per_thread=(4,4,4), dtype_in=di, dtype_out=do,
opts=("l0","l0","l0","l0","u1","u1","l1","l1"), swizzle=(((10,11,4,5,8,9),(0,1,2,3,6,7)),((0,1,2,3,8,9),(4,5,10,11,6,7))))
for di,do in [(dtypes.half,dtypes.float),(dtypes.bfloat16,dtypes.float)]]
@staticmethod
def get_tensor_cores(arch):
return {"gfx942": AMDRenderer.tensor_cores_mfma, "gfx1201": AMDRenderer.tensor_cores_rdna4}.get(arch.split(":")[0], AMDRenderer.tensor_cores)
def __init__(self, arch:str): # gfx942 => MI300, gfx1100 => RX 7900, gfx1201 => RX 9700
self.arch = arch
# TODO: fix tensor cores for gfx1201
self.tensor_cores = \
AMDRenderer.tensor_cores_mfma if arch.split(":")[0] == "gfx942" else AMDRenderer.tensor_cores if arch.split(":")[0] != "gfx1201" else []
self.tensor_cores = self.get_tensor_cores(arch)
if self.arch.split(":")[0] == "gfx942":
self.string_rewrite = PatternMatcher([
(UPat(Ops.WMMA, name="x"), lambda ctx,x: f"__{x.arg[0]}({ctx[x.src[0]]}, {ctx[x.src[1]]}, {ctx[x.src[2]]}, 0, 0, 0)")]) + base_rewrite
@@ -473,7 +477,7 @@ class AMDRenderer(CStyleLanguage):
def render_kernel(self, function_name, kernel, bufs, uops, prefix=None) -> str:
prefix = ["#define INFINITY (__builtin_inff())","#define NAN (__builtin_nanf(\"\"))","typedef long unsigned int size_t;","#define half _Float16"]
type_map = { dtypes.bfloat16: "bf16", dtypes.float: "f32", dtypes.half: "f16" }
used_dtypes = uops_to_dtypes(uops)
if any(dt.scalar() == dtypes.bfloat16 for dt in used_dtypes): prefix.append("typedef unsigned short hip_bfloat16;")
prefix += [self.render_vector_prefix(dt) for dt in used_dtypes if dt.count > 1]
@@ -481,6 +485,9 @@ class AMDRenderer(CStyleLanguage):
for arg in dedup([uop.arg for uop in uops if uop.op is Ops.WMMA]): # TODO: handle TCs f32_bf16 and bf16_bf16 w/ wrapper
if self.arch.split(":")[0] == "gfx942":
prefix.append(f"#define __{arg[0]} __builtin_amdgcn_mfma_f32_16x16x16{'f16' if arg[2] == dtypes.half else 'bf16_1k'}")
# #define __WMMA_16_16_16_half_half __builtin_amdgcn_wmma_f16_16x16x16_f16_w32_gfx12
elif self.arch.split(":")[0] == "gfx1201":
prefix.append(f"#define __{arg[0]} __builtin_amdgcn_wmma_{type_map[arg[3]]}_16x16x16_{type_map[arg[2]]}_w32_gfx12")
elif arg[3] == dtypes.float:
prefix.append(f"#define __{arg[0]} __builtin_amdgcn_wmma_f32_16x16x16_{'f16' if arg[2] == dtypes.half else 'bf16'}_w32")
else: prefix.append(f"static inline __attribute__((device)) half8 __{arg[0]}"+"""(half16 a, half16 b, half8 c) {

View File

@@ -45,7 +45,7 @@ def render_wmma_amx(ctx, wmma: UOp) -> str:
f' {ctx[wmma]} = load {ldt(wmma.dtype)}, ptr {ctx[wmma]}_amx2, align {wmma.dtype.itemsize}'])
def render_wmma_amd(ctx, wmma: UOp) -> str:
dt_map = {dtypes.half: "f16", dtypes.float: "f32"}
dt_map = {dtypes.half: "f16", dtypes.float: "f32", dtypes.bfloat16: "bf16", dtypes.ushort: "bf16"}
# https://github.com/llvm/llvm-project/blob/main/llvm/test/CodeGen/AMDGPU/GlobalISel/llvm.amdgcn.wmma_32.ll
# example: %wmma0 = call <8 x float> @llvm.amdgcn.wmma.f32.16x16x16.f16(<16 x half> %v99,<16 x half> %v100,<8 x float> %v101)
return f" {ctx[wmma]} = call {ldt(wmma.dtype)} @llvm.amdgcn.wmma.{dt_map[wmma.src[-1].dtype.scalar()]}.16x16x16." + \
@@ -218,7 +218,6 @@ class AMDLLVMRenderer(LLVMRenderer):
has_shared = True
shared_max = AMDRenderer.shared_max
global_max = AMDRenderer.global_max
tensor_cores = AMDRenderer.tensor_cores
abi = "amdgpu_kernel"
string_rewrite = PatternMatcher([
(UPat(Ops.SPECIAL, name="x"), lambda ctx, x: f" {ctx[x]} = " + f"{ code_for_workitem[x.arg[0][0]](x.arg[0][-1])}; "),
@@ -229,9 +228,22 @@ class AMDLLVMRenderer(LLVMRenderer):
f" {ctx[x]}= shufflevector <16 x half> {ctx[y]}, <16 x half> undef, <8 x i32> <{', '.join([f'i32 {x}' for x in range(0, 16, 2)])}>"),
(UPat(Ops.WMMA, name="wmma"), render_wmma_amd),
]) + base_rewrite
extra_matcher = PatternMatcher([
(UPat(Ops.WMMA, name="x", dtype=dtypes.half.vec(8)),
lambda x: UOp(Ops.WMMA, dtypes.half.vec(16), (x.src[0], x.src[1], x.src[2].cast(dtypes.half.vec(16))), (*x.arg,)).cast(dtypes.half.vec(8)))
]) + LLVMRenderer.extra_matcher
def __init__(self, arch:str): self.arch = arch
extra_matcher = LLVMRenderer.extra_matcher
def __init__(self, arch:str):
self.arch = arch
self.tensor_cores = AMDRenderer.get_tensor_cores(arch)
if self.arch.split(":")[0] == "gfx1100":
self.extra_matcher += PatternMatcher([
(UPat(Ops.WMMA, name="x", dtype=dtypes.half.vec(8)),
lambda x: UOp(Ops.WMMA, dtypes.half.vec(16), (x.src[0], x.src[1], x.src[2].cast(dtypes.half.vec(16))), (*x.arg,)).cast(dtypes.half.vec(8)))
])
if self.arch.split(":")[0] == "gfx1201":
self.extra_matcher += PatternMatcher([
(UPat(Ops.WMMA, name="x", dtype=dtypes.bfloat16.vec(8)), lambda x: UOp(Ops.WMMA, dtypes.uint16.vec(8),
(x.src[0].bitcast(dtypes.uint16.vec(8)), x.src[1].bitcast(dtypes.uint16.vec(8)), x.src[2].bitcast(dtypes.uint16.vec(8))), (*x.arg,))
.bitcast(dtypes.bfloat16.vec(8)) if x.src[0].dtype == dtypes.bfloat16.vec(8) else None),
(UPat(Ops.WMMA, name="x", dtype=dtypes.float.vec(8)),
lambda x: UOp(Ops.WMMA, dtypes.float.vec(8), (x.src[0].bitcast(dtypes.uint16.vec(8)), x.src[1].bitcast(dtypes.uint16.vec(8)),
x.src[2]), (*x.arg,)) if x.src[0].dtype == dtypes.bfloat16.vec(8) else None)
])
def __reduce__(self): return self.__class__, (self.arch,)

View File

@@ -141,6 +141,11 @@ class PythonProgram:
def b_elem(x, col, k, goff): return a_elem(x, k, col, goff) # pylint: disable=arguments-out-of-order
def c_map(lane, elem): return (lane%16, (lane//16)*4 + elem)
ul[i] = wmma_helper(64, 16, 4, 4, 4, a_elem, b_elem, c_map)
elif arg[4] == "AMD" and len(inp[0]) == 8: # RDNA4
def a_elem(x, k, row, goff): return x[k - [0, 4, 4, 8][k//4]][goff + row + [0, 16, 0, 16][k//4]]
def b_elem(x, col, k, goff): return a_elem(x, k, col, goff)
def c_map(lane, elem): return (lane%16, (lane//16)*8 + elem)
ul[i] = wmma_helper(32, 16, 8, 8, 8, a_elem, b_elem, c_map)
elif arg[4] == "AMD":
# A (16 elements on 32 threads): col major, lane 16-32 == lane 0-15
def a_elem(x, k, row, goff):
@@ -197,6 +202,7 @@ class PythonRenderer(Renderer):
if getenv("EMULATE_METAL"): self.device, self.tensor_cores = "METAL", MetalRenderer.tensor_cores
if getenv("EMULATE_AMD"): self.device, self.tensor_cores = "AMD", AMDRenderer.tensor_cores
if getenv("EMULATE_AMD_MFMA"): self.device, self.tensor_cores = "AMD", AMDRenderer.tensor_cores_mfma
if getenv("EMULATE_AMD_RDNA4"): self.device, self.tensor_cores = "AMD", AMDRenderer.tensor_cores_rdna4
if getenv("EMULATE_CUDA"): self.device, self.tensor_cores = "CUDA", CUDARenderer.tc_sm80
if getenv("EMULATE_CUDA_SM75"): self.device, self.tensor_cores = "CUDA", CUDARenderer.tc_sm75
if getenv("EMULATE_INTEL"): self.device, self.suffix, self.tensor_cores = "INTEL", "INTEL", IntelRenderer.tensor_cores