mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-10 23:48:01 -05:00
fix tensor cores for gfx1201 (#9838)
* fix tensor cores for gfx1201 * fix typo * fix python wmma * AMDLLVMRenderer with arch + AMDLLVM tensor_cores * fix ci * clean up * more tensor cores for RDNA4 * fix half/half, bfloat16/float, bfloat16/bfloat16 for amd_llvm --------- Co-authored-by: nimlgen <138685161+nimlgen@users.noreply.github.com> Co-authored-by: George Hotz <72895+geohot@users.noreply.github.com>
This commit is contained in:
8
.github/workflows/test.yml
vendored
8
.github/workflows/test.yml
vendored
@@ -250,6 +250,14 @@ jobs:
|
||||
PYTHONPATH=. DEBUG=2 EMULATE_AMD_MFMA=1 FORWARD_ONLY=1 PYTHON=1 N=64 HALF=1 ACC_HALF=0 python3 ./extra/gemm/simple_matmul.py
|
||||
PYTHONPATH=. DEBUG=2 EMULATE_AMD_MFMA=1 FORWARD_ONLY=1 PYTHON=1 python3 test/test_linearizer.py TestLinearizer.test_tensor_cores
|
||||
PYTHONPATH=. DEBUG=2 EMULATE_AMD_MFMA=1 FORWARD_ONLY=1 PYTHON=1 python3 test/test_linearizer.py TestLinearizer.test_tensor_cores_padded TestLinearizer.test_tensor_cores_padded_uops
|
||||
- name: Test emulated AMD RDNA4 tensor cores
|
||||
run: |
|
||||
PYTHONPATH=. DEBUG=2 EMULATE_AMD_RDNA4=1 FORWARD_ONLY=1 PYTHON=1 N=16 HALF=1 ACC_HALF=0 python3 ./extra/gemm/simple_matmul.py
|
||||
PYTHONPATH=. DEBUG=2 EMULATE_AMD_RDNA4=1 FORWARD_ONLY=1 PYTHON=1 N=64 HALF=1 ACC_HALF=0 python3 ./extra/gemm/simple_matmul.py
|
||||
PYTHONPATH=. DEBUG=2 EMULATE_AMD_RDNA4=1 FORWARD_ONLY=1 PYTHON=1 N=16 HALF=1 ACC_HALF=1 python3 ./extra/gemm/simple_matmul.py
|
||||
PYTHONPATH=. DEBUG=2 EMULATE_AMD_RDNA4=1 FORWARD_ONLY=1 PYTHON=1 N=64 HALF=1 ACC_HALF=1 python3 ./extra/gemm/simple_matmul.py
|
||||
PYTHONPATH=. DEBUG=2 EMULATE_AMD_RDNA4=1 FORWARD_ONLY=1 PYTHON=1 python3 test/test_linearizer.py TestLinearizer.test_tensor_cores
|
||||
PYTHONPATH=. DEBUG=2 EMULATE_AMD_RDNA4=1 FORWARD_ONLY=1 PYTHON=1 python3 test/test_linearizer.py TestLinearizer.test_tensor_cores_padded TestLinearizer.test_tensor_cores_padded_uops
|
||||
- name: Test emulated CUDA tensor cores
|
||||
run: |
|
||||
DEBUG=2 EMULATE_CUDA=1 FORWARD_ONLY=1 PYTHON=1 python3 test/test_ops.py TestOps.test_gemm_fp16
|
||||
|
||||
@@ -411,16 +411,20 @@ class AMDRenderer(CStyleLanguage):
|
||||
tensor_cores = [TensorCore(dims=(16,16,16), threads=32, elements_per_thread=(16,16,8), dtype_in=di, dtype_out=do,
|
||||
opts=("l0","l0","l0","l0","l1","u1","u1","u1"), swizzle=(((4,9,10,11,0),(1,2,3,5,6,7,8)), ((0,1,2,3,4),(9,10,11,5,6,7,8))))
|
||||
for di,do in [(dtypes.half,dtypes.float),(dtypes.half,dtypes.half),(dtypes.bfloat16,dtypes.float)]]
|
||||
tensor_cores_rdna4 = [TensorCore(dims=(16,16,16), threads=32, elements_per_thread=(8,8,8), dtype_in=di, dtype_out=do,
|
||||
opts=("l0","l0","l0","l0","u1","u1","u1","l1"), swizzle=(((9,10,11,4,7),(0,1,2,3,5,6,8)),((0,1,2,3,7),(4,9,10,11,5,6,8))))
|
||||
for di,do in [(dtypes.half,dtypes.float),(dtypes.half,dtypes.half),(dtypes.bfloat16,dtypes.float),(dtypes.bfloat16,dtypes.bfloat16)]]
|
||||
# https://gpuopen.com/learn/amd-lab-notes/amd-lab-notes-matrix-cores-readme
|
||||
tensor_cores_mfma = [TensorCore(dims=(16,16,16), threads=64, elements_per_thread=(4,4,4), dtype_in=di, dtype_out=do,
|
||||
opts=("l0","l0","l0","l0","u1","u1","l1","l1"), swizzle=(((10,11,4,5,8,9),(0,1,2,3,6,7)),((0,1,2,3,8,9),(4,5,10,11,6,7))))
|
||||
for di,do in [(dtypes.half,dtypes.float),(dtypes.bfloat16,dtypes.float)]]
|
||||
|
||||
@staticmethod
|
||||
def get_tensor_cores(arch):
|
||||
return {"gfx942": AMDRenderer.tensor_cores_mfma, "gfx1201": AMDRenderer.tensor_cores_rdna4}.get(arch.split(":")[0], AMDRenderer.tensor_cores)
|
||||
def __init__(self, arch:str): # gfx942 => MI300, gfx1100 => RX 7900, gfx1201 => RX 9700
|
||||
self.arch = arch
|
||||
# TODO: fix tensor cores for gfx1201
|
||||
self.tensor_cores = \
|
||||
AMDRenderer.tensor_cores_mfma if arch.split(":")[0] == "gfx942" else AMDRenderer.tensor_cores if arch.split(":")[0] != "gfx1201" else []
|
||||
self.tensor_cores = self.get_tensor_cores(arch)
|
||||
if self.arch.split(":")[0] == "gfx942":
|
||||
self.string_rewrite = PatternMatcher([
|
||||
(UPat(Ops.WMMA, name="x"), lambda ctx,x: f"__{x.arg[0]}({ctx[x.src[0]]}, {ctx[x.src[1]]}, {ctx[x.src[2]]}, 0, 0, 0)")]) + base_rewrite
|
||||
@@ -473,7 +477,7 @@ class AMDRenderer(CStyleLanguage):
|
||||
|
||||
def render_kernel(self, function_name, kernel, bufs, uops, prefix=None) -> str:
|
||||
prefix = ["#define INFINITY (__builtin_inff())","#define NAN (__builtin_nanf(\"\"))","typedef long unsigned int size_t;","#define half _Float16"]
|
||||
|
||||
type_map = { dtypes.bfloat16: "bf16", dtypes.float: "f32", dtypes.half: "f16" }
|
||||
used_dtypes = uops_to_dtypes(uops)
|
||||
if any(dt.scalar() == dtypes.bfloat16 for dt in used_dtypes): prefix.append("typedef unsigned short hip_bfloat16;")
|
||||
prefix += [self.render_vector_prefix(dt) for dt in used_dtypes if dt.count > 1]
|
||||
@@ -481,6 +485,9 @@ class AMDRenderer(CStyleLanguage):
|
||||
for arg in dedup([uop.arg for uop in uops if uop.op is Ops.WMMA]): # TODO: handle TCs f32_bf16 and bf16_bf16 w/ wrapper
|
||||
if self.arch.split(":")[0] == "gfx942":
|
||||
prefix.append(f"#define __{arg[0]} __builtin_amdgcn_mfma_f32_16x16x16{'f16' if arg[2] == dtypes.half else 'bf16_1k'}")
|
||||
# #define __WMMA_16_16_16_half_half __builtin_amdgcn_wmma_f16_16x16x16_f16_w32_gfx12
|
||||
elif self.arch.split(":")[0] == "gfx1201":
|
||||
prefix.append(f"#define __{arg[0]} __builtin_amdgcn_wmma_{type_map[arg[3]]}_16x16x16_{type_map[arg[2]]}_w32_gfx12")
|
||||
elif arg[3] == dtypes.float:
|
||||
prefix.append(f"#define __{arg[0]} __builtin_amdgcn_wmma_f32_16x16x16_{'f16' if arg[2] == dtypes.half else 'bf16'}_w32")
|
||||
else: prefix.append(f"static inline __attribute__((device)) half8 __{arg[0]}"+"""(half16 a, half16 b, half8 c) {
|
||||
|
||||
@@ -45,7 +45,7 @@ def render_wmma_amx(ctx, wmma: UOp) -> str:
|
||||
f' {ctx[wmma]} = load {ldt(wmma.dtype)}, ptr {ctx[wmma]}_amx2, align {wmma.dtype.itemsize}'])
|
||||
|
||||
def render_wmma_amd(ctx, wmma: UOp) -> str:
|
||||
dt_map = {dtypes.half: "f16", dtypes.float: "f32"}
|
||||
dt_map = {dtypes.half: "f16", dtypes.float: "f32", dtypes.bfloat16: "bf16", dtypes.ushort: "bf16"}
|
||||
# https://github.com/llvm/llvm-project/blob/main/llvm/test/CodeGen/AMDGPU/GlobalISel/llvm.amdgcn.wmma_32.ll
|
||||
# example: %wmma0 = call <8 x float> @llvm.amdgcn.wmma.f32.16x16x16.f16(<16 x half> %v99,<16 x half> %v100,<8 x float> %v101)
|
||||
return f" {ctx[wmma]} = call {ldt(wmma.dtype)} @llvm.amdgcn.wmma.{dt_map[wmma.src[-1].dtype.scalar()]}.16x16x16." + \
|
||||
@@ -218,7 +218,6 @@ class AMDLLVMRenderer(LLVMRenderer):
|
||||
has_shared = True
|
||||
shared_max = AMDRenderer.shared_max
|
||||
global_max = AMDRenderer.global_max
|
||||
tensor_cores = AMDRenderer.tensor_cores
|
||||
abi = "amdgpu_kernel"
|
||||
string_rewrite = PatternMatcher([
|
||||
(UPat(Ops.SPECIAL, name="x"), lambda ctx, x: f" {ctx[x]} = " + f"{ code_for_workitem[x.arg[0][0]](x.arg[0][-1])}; "),
|
||||
@@ -229,9 +228,22 @@ class AMDLLVMRenderer(LLVMRenderer):
|
||||
f" {ctx[x]}= shufflevector <16 x half> {ctx[y]}, <16 x half> undef, <8 x i32> <{', '.join([f'i32 {x}' for x in range(0, 16, 2)])}>"),
|
||||
(UPat(Ops.WMMA, name="wmma"), render_wmma_amd),
|
||||
]) + base_rewrite
|
||||
extra_matcher = PatternMatcher([
|
||||
(UPat(Ops.WMMA, name="x", dtype=dtypes.half.vec(8)),
|
||||
lambda x: UOp(Ops.WMMA, dtypes.half.vec(16), (x.src[0], x.src[1], x.src[2].cast(dtypes.half.vec(16))), (*x.arg,)).cast(dtypes.half.vec(8)))
|
||||
]) + LLVMRenderer.extra_matcher
|
||||
def __init__(self, arch:str): self.arch = arch
|
||||
extra_matcher = LLVMRenderer.extra_matcher
|
||||
def __init__(self, arch:str):
|
||||
self.arch = arch
|
||||
self.tensor_cores = AMDRenderer.get_tensor_cores(arch)
|
||||
if self.arch.split(":")[0] == "gfx1100":
|
||||
self.extra_matcher += PatternMatcher([
|
||||
(UPat(Ops.WMMA, name="x", dtype=dtypes.half.vec(8)),
|
||||
lambda x: UOp(Ops.WMMA, dtypes.half.vec(16), (x.src[0], x.src[1], x.src[2].cast(dtypes.half.vec(16))), (*x.arg,)).cast(dtypes.half.vec(8)))
|
||||
])
|
||||
if self.arch.split(":")[0] == "gfx1201":
|
||||
self.extra_matcher += PatternMatcher([
|
||||
(UPat(Ops.WMMA, name="x", dtype=dtypes.bfloat16.vec(8)), lambda x: UOp(Ops.WMMA, dtypes.uint16.vec(8),
|
||||
(x.src[0].bitcast(dtypes.uint16.vec(8)), x.src[1].bitcast(dtypes.uint16.vec(8)), x.src[2].bitcast(dtypes.uint16.vec(8))), (*x.arg,))
|
||||
.bitcast(dtypes.bfloat16.vec(8)) if x.src[0].dtype == dtypes.bfloat16.vec(8) else None),
|
||||
(UPat(Ops.WMMA, name="x", dtype=dtypes.float.vec(8)),
|
||||
lambda x: UOp(Ops.WMMA, dtypes.float.vec(8), (x.src[0].bitcast(dtypes.uint16.vec(8)), x.src[1].bitcast(dtypes.uint16.vec(8)),
|
||||
x.src[2]), (*x.arg,)) if x.src[0].dtype == dtypes.bfloat16.vec(8) else None)
|
||||
])
|
||||
def __reduce__(self): return self.__class__, (self.arch,)
|
||||
|
||||
@@ -141,6 +141,11 @@ class PythonProgram:
|
||||
def b_elem(x, col, k, goff): return a_elem(x, k, col, goff) # pylint: disable=arguments-out-of-order
|
||||
def c_map(lane, elem): return (lane%16, (lane//16)*4 + elem)
|
||||
ul[i] = wmma_helper(64, 16, 4, 4, 4, a_elem, b_elem, c_map)
|
||||
elif arg[4] == "AMD" and len(inp[0]) == 8: # RDNA4
|
||||
def a_elem(x, k, row, goff): return x[k - [0, 4, 4, 8][k//4]][goff + row + [0, 16, 0, 16][k//4]]
|
||||
def b_elem(x, col, k, goff): return a_elem(x, k, col, goff)
|
||||
def c_map(lane, elem): return (lane%16, (lane//16)*8 + elem)
|
||||
ul[i] = wmma_helper(32, 16, 8, 8, 8, a_elem, b_elem, c_map)
|
||||
elif arg[4] == "AMD":
|
||||
# A (16 elements on 32 threads): col major, lane 16-32 == lane 0-15
|
||||
def a_elem(x, k, row, goff):
|
||||
@@ -197,6 +202,7 @@ class PythonRenderer(Renderer):
|
||||
if getenv("EMULATE_METAL"): self.device, self.tensor_cores = "METAL", MetalRenderer.tensor_cores
|
||||
if getenv("EMULATE_AMD"): self.device, self.tensor_cores = "AMD", AMDRenderer.tensor_cores
|
||||
if getenv("EMULATE_AMD_MFMA"): self.device, self.tensor_cores = "AMD", AMDRenderer.tensor_cores_mfma
|
||||
if getenv("EMULATE_AMD_RDNA4"): self.device, self.tensor_cores = "AMD", AMDRenderer.tensor_cores_rdna4
|
||||
if getenv("EMULATE_CUDA"): self.device, self.tensor_cores = "CUDA", CUDARenderer.tc_sm80
|
||||
if getenv("EMULATE_CUDA_SM75"): self.device, self.tensor_cores = "CUDA", CUDARenderer.tc_sm75
|
||||
if getenv("EMULATE_INTEL"): self.device, self.suffix, self.tensor_cores = "INTEL", "INTEL", IntelRenderer.tensor_cores
|
||||
|
||||
Reference in New Issue
Block a user