From a78fcc55a4dca1ab87f0b7457557eaa0e5ef1a6a Mon Sep 17 00:00:00 2001 From: b1tg <33436708+b1tg@users.noreply.github.com> Date: Fri, 2 Jan 2026 22:01:05 +0800 Subject: [PATCH] amd tc 1616128 (#13439) * amd tc 1616128 * fix test * remove hardcoded check in test --- test/opt/test_tensor_cores.py | 3 ++- tinygrad/codegen/opt/tc.py | 8 +++++++- tinygrad/renderer/cstyle.py | 13 +++++++++---- 3 files changed, 18 insertions(+), 6 deletions(-) diff --git a/test/opt/test_tensor_cores.py b/test/opt/test_tensor_cores.py index 0b6bad6e9f..7cb5bb7d39 100644 --- a/test/opt/test_tensor_cores.py +++ b/test/opt/test_tensor_cores.py @@ -11,6 +11,7 @@ from tinygrad.helpers import AMX, AMD_LLVM, CPU_LLVM, Context from test.helpers import slow from tinygrad.engine.realize import CompiledRunner, get_program from tinygrad.codegen.opt import Opt, OptOps, KernelOptError +from tinygrad.codegen.opt.tc import amd_cdna_1616128 # TODO: write a clean version of this from test.test_linearizer import helper_realized_ast, helper_linearizer_opt @@ -120,7 +121,7 @@ class TestTensorCores(unittest.TestCase): # check excessive padding doesn't trigger padded TC in TC_OPT=2 helper_tc_ensure_uops_and_opts_count(tc.dims[0]//4, tc.dims[1], tc.dims[2], tc.dtype_in, tc.dtype_out, tc_opt=2, ensure_triggered=False) helper_tc_ensure_uops_and_opts_count(tc.dims[0], tc.dims[1]//4, tc.dims[2], tc.dtype_in, tc.dtype_out, tc_opt=2, ensure_triggered=False) - if not AMX: # AMX tc.dims[2] == 1 + if not AMX and tc not in amd_cdna_1616128: # AMX tc.dims[2] == 1 helper_tc_ensure_uops_and_opts_count(tc.dims[0], tc.dims[1], tc.dims[2]//8, tc.dtype_in, tc.dtype_out, tc_opt=2, ensure_triggered=False) @Context(ALLOW_TF32=1) diff --git a/tinygrad/codegen/opt/tc.py b/tinygrad/codegen/opt/tc.py index 6a05f0bd16..fadad45296 100644 --- a/tinygrad/codegen/opt/tc.py +++ b/tinygrad/codegen/opt/tc.py @@ -121,9 +121,15 @@ amd_cdna_161632 = [TensorCore(dims=(16,16,32), threads=64, elements_per_thread=( (('l0', 'l1', 'l2', 'l3', 'r3', 'r4'), ('r0', 'r1'), ('l4', 'l5', 'u0', 'u1', 'r2')))) for di,do in [(dtypes.fp8e5m2,dtypes.float),(dtypes.fp8e4m3,dtypes.float),(dtypes.half,dtypes.float),(dtypes.bfloat16,dtypes.float)]] +amd_cdna_1616128 = [TensorCore(dims=(16,16,128), threads=64, elements_per_thread=(32,32,4), dtype_in=di, dtype_out=do, + opts=("l0","l0","l0","l0","u1","u1","l1","l1"), + swizzle=((('u0', 'u1', 'l4', 'l5', 'r5', 'r6'), ('r0', 'r1'), ('l0', 'l1', 'l2', 'l3', 'r2', 'r3', 'r4')), + (('l0', 'l1', 'l2', 'l3', 'r5', 'r6'), ('r0', 'r1'), ('l4', 'l5', 'u0', 'u1', 'r2', 'r3', 'r4')))) + for di,do in [(dtypes.fp8e5m2,dtypes.float),(dtypes.fp8e4m3,dtypes.float)]] + amd_cdna3 = amd_cdna_161632[:2] + amd_cdna_161616 -amd_cdna4 = amd_cdna_161632 + amd_cdna_161616 +amd_cdna4 = amd_cdna_1616128 + amd_cdna_161632 + amd_cdna_161616 # ***** Apple Metal ***** diff --git a/tinygrad/renderer/cstyle.py b/tinygrad/renderer/cstyle.py index c5716c2370..fd078eb535 100644 --- a/tinygrad/renderer/cstyle.py +++ b/tinygrad/renderer/cstyle.py @@ -373,7 +373,7 @@ class MetalRenderer(CStyleLanguage): simdgroup_multiply_accumulate(mat_c, mat_a, mat_b, mat_c);\n return {dstr_out}(mat_c.thread_elements()[0], mat_c.thread_elements()[1]);\n}}""") return super().render_kernel(function_name, kernel, bufs, uops, prefix) -_nms = "xyzwabcdefghijkl" +_nms = list("xyzwabcdefghijkl") + [f'v{i}' for i in range(16, 32)] class CUDARenderer(CStyleLanguage): device = "CUDA" @@ -446,6 +446,8 @@ class CUDARenderer(CStyleLanguage): return super().render_kernel(function_name, kernel, bufs, uops, prefix=prefix) +def fp8_index(dtype: DType): return (dtypes.fp8e4m3, dtypes.fp8e5m2).index(dtype.scalar()) + class AMDHIPRenderer(CStyleLanguage): device = "AMD" shared_max = 65536 @@ -463,11 +465,13 @@ class AMDHIPRenderer(CStyleLanguage): self.tensor_cores = self.get_tensor_cores(arch) if self.is_cdna(self.arch): self.string_rewrite = PatternMatcher([ + (UPat(Ops.WMMA, name="x"), lambda ctx,x: f"__{x.arg[0]}({ctx[x.src[0]]}, {ctx[x.src[1]]}, {ctx[x.src[2]]}," + f" {fp8_index(x.src[0].dtype)}, {fp8_index(x.src[0].dtype)}, 0, 0, 0, 0)" if x.arg[1][2] == 128 else None), (UPat(Ops.WMMA, name="x"), lambda ctx,x: f"__{x.arg[0]}({ctx[x.src[0]]}, {ctx[x.src[1]]}, {ctx[x.src[2]]}, 0, 0, 0)"), (UPat(Ops.CAST, dtypes.fp8s, (UPat.var("y", dtypes.float),), name="x",), - lambda ctx,x, y: f"f32_to_fp8({ctx[x.src[0]]}, {'1' if x.dtype == dtypes.fp8e5m2 else '0'})"), + lambda ctx,x,y: f"f32_to_fp8({ctx[x.src[0]]}, {fp8_index(x.dtype)})"), (UPat(Ops.CAST, dtypes.float, (UPat.var("y", dtypes.fp8s),), name="x",), - lambda ctx,x, y: f"__builtin_amdgcn_cvt_f32_{'bf8' if y.dtype == dtypes.fp8e5m2 else 'fp8'}((unsigned int){ctx[x.src[0]]}, 0)"), + lambda ctx,x,y: f"__builtin_amdgcn_cvt_f32_{('fp8', 'bf8')[fp8_index(y.dtype)]}((unsigned int){ctx[x.src[0]]}, 0)"), ]) + base_rewrite def __reduce__(self): return self.__class__, (self.arch,) @@ -527,7 +531,8 @@ class AMDHIPRenderer(CStyleLanguage): if self.is_cdna(self.arch): if (N, M, K) == (16, 16, 16): type_map[dtypes.bfloat16] = 'bf16_1k' elif (N, M, K) == (16, 16, 32): type_map = {**type_map, dtypes.bfloat16: "_bf16", dtypes.half: "_f16"} - prefix.append(f"#define __{name} __builtin_amdgcn_mfma_f32_{N}x{M}x{K}{type_map[dtype_in]}") + elif (N, M, K) == (16, 16, 128): type_map = {**type_map, dtypes.fp8e4m3: "_f8f6f4", dtypes.fp8e5m2: "_f8f6f4"} + prefix.append(f"#define __{name} __builtin_amdgcn_mfma_{'scale_' if K == 128 else ''}f32_{N}x{M}x{K}{type_map[dtype_in]}") # #define __WMMA_16_16_16_half_half __builtin_amdgcn_wmma_f16_16x16x16_f16_w32_gfx12 elif self.tensor_cores == tc.amd_rdna4: prefix.append(f"#define __{name} __builtin_amdgcn_wmma_{type_map[dtype_out]}_16x16x16_{type_map[dtype_in]}_w32_gfx12")