mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
amd tc: 16x16x32 (#12874)
* amd tc: 16x16x32 * test * clean, test amd_cdna4
This commit is contained in:
@@ -117,6 +117,14 @@ amd_cdna = [TensorCore(dims=(16,16,16), threads=64, elements_per_thread=(4,4,4),
|
||||
(('l0', 'l1', 'l2', 'l3', 'r2', 'r3'), ('r0', 'r1'), ('l4', 'l5', 'u0', 'u1'))))
|
||||
for di,do in [(dtypes.half,dtypes.float),(dtypes.bfloat16,dtypes.float)]]
|
||||
|
||||
amd_cdna_161632 = [TensorCore(dims=(16,16,32), threads=64, elements_per_thread=(8,8,4), dtype_in=di, dtype_out=do,
|
||||
opts=("l0","l0","l0","l0","u1","u1","l1","l1"),
|
||||
swizzle=((('u0','u1','l4','l5','r3','r4'), ('r0','r1'), ('l0','l1','l2','l3','r2')),
|
||||
(('l0','l1','l2','l3','r3','r4'), ('r0','r1'), ('l4','l5','u0','u1','r2'))))
|
||||
for di,do in [(dtypes.half,dtypes.float),(dtypes.bfloat16,dtypes.float)]]
|
||||
|
||||
amd_cdna4 = amd_cdna_161632 + amd_cdna
|
||||
|
||||
# ***** Apple Metal *****
|
||||
|
||||
metal = [TensorCore(dims=(8,8,8), threads=32, elements_per_thread=(2,2,2), dtype_in=di, dtype_out=do,
|
||||
|
||||
@@ -423,7 +423,7 @@ class AMDRenderer(CStyleLanguage):
|
||||
|
||||
@staticmethod
|
||||
def get_tensor_cores(arch):
|
||||
return {"gfx942": tc.amd_cdna, "gfx950": tc.amd_cdna, "gfx1200": tc.amd_rdna4, "gfx1201": tc.amd_rdna4}.get(arch.split(":")[0], tc.amd_rdna3)
|
||||
return {"gfx942": tc.amd_cdna, "gfx950": tc.amd_cdna4, "gfx1200": tc.amd_rdna4, "gfx1201": tc.amd_rdna4}.get(arch.split(":")[0], tc.amd_rdna3)
|
||||
def __init__(self, arch:str): # gfx942 => MI300, gfx1100 => RX 7900, gfx1201 => RX 9700
|
||||
self.arch = arch
|
||||
self.tensor_cores = self.get_tensor_cores(arch)
|
||||
|
||||
@@ -49,9 +49,11 @@ def render_wmma_amx(ctx, wmma: UOp) -> str:
|
||||
def render_wmma_amd(ctx, wmma: UOp, cdna=False) -> str:
|
||||
dt_map = {dtypes.half: "f16", dtypes.float: "f32", dtypes.ushort: "bf16.1k" if cdna else "bf16", dtypes.bfloat16: "bf16.1k" if cdna else "bf16"}
|
||||
# https://github.com/llvm/llvm-project/blob/main/clang/test/CodeGenOpenCL/builtins-amdgcn-mfma.cl
|
||||
N,M,K = wmma.arg[1]
|
||||
if cdna:
|
||||
if K == 32: dt_map.update({dtypes.half: ".f16", dtypes.bfloat16: ".bf16"})
|
||||
return f" {ctx[wmma]} = call {ldt(wmma.dtype)} @llvm.amdgcn.mfma.{dt_map[wmma.src[-1].dtype.scalar()]}" + \
|
||||
f".16x16x16{dt_map[wmma.src[0].dtype.scalar()]}(" + ", ".join([f"{ldt(w.dtype)} {ctx[w]}" for w in wmma.src]) + ", i32 0, i32 0, i32 0)"
|
||||
f".{N}x{M}x{K}{dt_map[wmma.src[0].dtype.scalar()]}(" + ", ".join([f"{ldt(w.dtype)} {ctx[w]}" for w in wmma.src]) + ", i32 0, i32 0, i32 0)"
|
||||
# https://github.com/llvm/llvm-project/blob/main/llvm/test/CodeGen/AMDGPU/GlobalISel/llvm.amdgcn.wmma_32.ll
|
||||
# example: %wmma0 = call <8 x float> @llvm.amdgcn.wmma.f32.16x16x16.f16(<16 x half> %v99,<16 x half> %v100,<8 x float> %v101)
|
||||
return f" {ctx[wmma]} = call {ldt(wmma.dtype)} @llvm.amdgcn.wmma.{dt_map[wmma.src[-1].dtype.scalar()]}.16x16x16." + \
|
||||
|
||||
@@ -150,10 +150,10 @@ class PythonProgram:
|
||||
def c_map(lane, elem): return (elem + ((lane%2)*2) + ((lane//8)%2)*4, ((lane//2)%4) + (lane//16)*4)
|
||||
ul[i] = wmma_helper(32, 8, 2, 2, 2, a_b_elem, a_b_elem, c_map)
|
||||
elif device == "AMD" and threads == 64:
|
||||
def a_elem(x, k, row, goff): return x[k%4][goff + (k//4)*16 + row]
|
||||
def b_elem(x, col, k, goff): return a_elem(x, k, col, goff) # pylint: disable=arguments-out-of-order
|
||||
def a_elem(x, k, row, goff): return x[k%(dims[2]//4)][goff + (k//(dims[2]//4))*16 + row]
|
||||
def b_elem(x, col, k, goff): return a_elem(x, k, col, goff) # pylint: disable=arguments-out-of-order
|
||||
def c_map(lane, elem): return (lane%16, (lane//16)*4 + elem)
|
||||
ul[i] = wmma_helper(64, 16, 4, 4, 4, a_elem, b_elem, c_map)
|
||||
ul[i] = wmma_helper(64, dims[2], len(inp[0]), len(inp[1]), len(inp[2]), a_elem, b_elem, c_map)
|
||||
elif device == "AMD" and len(inp[0]) == 8: # RDNA4
|
||||
def a_elem(x, k, row, goff): return x[k - [0, 4, 4, 8][k//4]][goff + row + [0, 16, 0, 16][k//4]]
|
||||
def b_elem(x, col, k, goff): return a_elem(x, k, col, goff)
|
||||
@@ -221,7 +221,7 @@ class PythonRenderer(Renderer):
|
||||
match cast(str, EMULATE.value):
|
||||
case "METAL": self.device, self.tensor_cores = "METAL", tc.metal
|
||||
case "AMD": self.device, self.tensor_cores = "AMD", tc.amd_rdna3
|
||||
case "AMD_MFMA": self.device, self.tensor_cores = "AMD", tc.amd_cdna
|
||||
case "AMD_MFMA": self.device, self.tensor_cores = "AMD", tc.amd_cdna4
|
||||
case "AMD_RDNA4": self.device, self.tensor_cores = "AMD", tc.amd_rdna4
|
||||
case "CUDA": self.device, self.tensor_cores = "CUDA", tc.cuda_sm80
|
||||
case "CUDA_SM75": self.device, self.tensor_cores = "CUDA", tc.cuda_sm75
|
||||
|
||||
Reference in New Issue
Block a user