add RDNA4 support to copy WMMA (#15663)

* add RDNA4 supportt to copy WMMA

* simpler

* simpler

* comment

* assert
This commit is contained in:
George Hotz
2026-04-09 22:48:20 +08:00
committed by GitHub
parent 6837881b06
commit 48a7627b04

View File

@@ -1,4 +1,4 @@
from tinygrad import UOp, getenv
from tinygrad import Device, UOp, getenv
from tinygrad.uop.ops import AxisType, KernelInfo, Ops
from tinygrad.dtype import AddrSpace, dtypes
@@ -13,18 +13,23 @@ assert N % BLOCK_N == 0 and M % BLOCK_M == 0 and K % BLOCK_K == 0
use_wmma = getenv("WMMA")
if use_wmma:
is_rdna4 = Device[Device.DEFAULT].renderer.target.arch.startswith("gfx12")
WAVES_M, WAVES_N = 2, 2
LANES_PER_WAVE_M, LANES_PER_WAVE_N = 2, 16
UNROLL_M, UNROLL_N = 1, 1
# wmma params
WMMA_M, WMMA_N, WMMA_K = 16, 16, 16
WMMA_ACC = WMMA_M // LANES_PER_WAVE_M
UNROLL_M, UNROLL_N = (WMMA_ACC, 1) if is_rdna4 else (1, 1)
else:
WAVES_M, WAVES_N = 4, 1
LANES_PER_WAVE_M, LANES_PER_WAVE_N = 4, 8
UNROLL_M, UNROLL_N = 4, 4
# total lanes must be the warp size
assert LANES_PER_WAVE_M*LANES_PER_WAVE_N == WARP_SIZE
# WARP_SIZE * total waves
THREADS_PER_BLOCK = WARP_SIZE * WAVES_M * WAVES_N
@@ -71,7 +76,10 @@ def block_128x128_gemm(c:UOp, a:UOp, b:UOp) -> UOp:
acc_frag = acc.reshape(TM // WMMA_ACC, WMMA_ACC, TN).permute(0,2,1)[tile_m, tile_n]
a_frag = A_local.reshape(WAVES_M, TM // WMMA_ACC, WMMA_M, BLOCK_K // WMMA_K, WMMA_K)[wave_m, tile_m, lane_n, k]
b_frag = B_local.reshape(WAVES_N, TN, WMMA_N, BLOCK_K // WMMA_K, WMMA_K)[wave_n, tile_n, lane_n, k]
if is_rdna4:
# NOTE: since this is part of K, these 2 can be anywhere in the frags and long as a and b match
a_frag = a_frag.reshape(2, 8)[lane_m, :]
b_frag = b_frag.reshape(2, 8)[lane_m, :]
wmma = UOp(Ops.SHAPED_WMMA, dtypes.float, (a_frag, b_frag, acc_frag.after(k)), arg=((16, 16, 16), 'AMD', 32))
acc_store = acc_frag.store(wmma).end(tile_m, tile_n)
else: