From 48a7627b040d63d8d18fa5e3e536874862afef26 Mon Sep 17 00:00:00 2001 From: George Hotz <72895+geohot@users.noreply.github.com> Date: Thu, 9 Apr 2026 22:48:20 +0800 Subject: [PATCH] add RDNA4 support to copy WMMA (#15663) * add RDNA4 supportt to copy WMMA * simpler * simpler * comment * assert --- extra/gemm/amd_copy_matmul.py | 14 +++++++++++--- 1 file changed, 11 insertions(+), 3 deletions(-) diff --git a/extra/gemm/amd_copy_matmul.py b/extra/gemm/amd_copy_matmul.py index 5ab4f4fb77..2557b19835 100644 --- a/extra/gemm/amd_copy_matmul.py +++ b/extra/gemm/amd_copy_matmul.py @@ -1,4 +1,4 @@ -from tinygrad import UOp, getenv +from tinygrad import Device, UOp, getenv from tinygrad.uop.ops import AxisType, KernelInfo, Ops from tinygrad.dtype import AddrSpace, dtypes @@ -13,18 +13,23 @@ assert N % BLOCK_N == 0 and M % BLOCK_M == 0 and K % BLOCK_K == 0 use_wmma = getenv("WMMA") if use_wmma: + is_rdna4 = Device[Device.DEFAULT].renderer.target.arch.startswith("gfx12") + WAVES_M, WAVES_N = 2, 2 LANES_PER_WAVE_M, LANES_PER_WAVE_N = 2, 16 - UNROLL_M, UNROLL_N = 1, 1 # wmma params WMMA_M, WMMA_N, WMMA_K = 16, 16, 16 WMMA_ACC = WMMA_M // LANES_PER_WAVE_M + UNROLL_M, UNROLL_N = (WMMA_ACC, 1) if is_rdna4 else (1, 1) else: WAVES_M, WAVES_N = 4, 1 LANES_PER_WAVE_M, LANES_PER_WAVE_N = 4, 8 UNROLL_M, UNROLL_N = 4, 4 +# total lanes must be the warp size +assert LANES_PER_WAVE_M*LANES_PER_WAVE_N == WARP_SIZE + # WARP_SIZE * total waves THREADS_PER_BLOCK = WARP_SIZE * WAVES_M * WAVES_N @@ -71,7 +76,10 @@ def block_128x128_gemm(c:UOp, a:UOp, b:UOp) -> UOp: acc_frag = acc.reshape(TM // WMMA_ACC, WMMA_ACC, TN).permute(0,2,1)[tile_m, tile_n] a_frag = A_local.reshape(WAVES_M, TM // WMMA_ACC, WMMA_M, BLOCK_K // WMMA_K, WMMA_K)[wave_m, tile_m, lane_n, k] b_frag = B_local.reshape(WAVES_N, TN, WMMA_N, BLOCK_K // WMMA_K, WMMA_K)[wave_n, tile_n, lane_n, k] - + if is_rdna4: + # NOTE: since this is part of K, these 2 can be anywhere in the frags and long as a and b match + a_frag = a_frag.reshape(2, 8)[lane_m, :] + b_frag = b_frag.reshape(2, 8)[lane_m, :] wmma = UOp(Ops.SHAPED_WMMA, dtypes.float, (a_frag, b_frag, acc_frag.after(k)), arg=((16, 16, 16), 'AMD', 32)) acc_store = acc_frag.store(wmma).end(tile_m, tile_n) else: