mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
add RDNA4 support to copy WMMA (#15663)
* add RDNA4 supportt to copy WMMA * simpler * simpler * comment * assert
This commit is contained in:
@@ -1,4 +1,4 @@
|
||||
from tinygrad import UOp, getenv
|
||||
from tinygrad import Device, UOp, getenv
|
||||
from tinygrad.uop.ops import AxisType, KernelInfo, Ops
|
||||
from tinygrad.dtype import AddrSpace, dtypes
|
||||
|
||||
@@ -13,18 +13,23 @@ assert N % BLOCK_N == 0 and M % BLOCK_M == 0 and K % BLOCK_K == 0
|
||||
|
||||
use_wmma = getenv("WMMA")
|
||||
if use_wmma:
|
||||
is_rdna4 = Device[Device.DEFAULT].renderer.target.arch.startswith("gfx12")
|
||||
|
||||
WAVES_M, WAVES_N = 2, 2
|
||||
LANES_PER_WAVE_M, LANES_PER_WAVE_N = 2, 16
|
||||
UNROLL_M, UNROLL_N = 1, 1
|
||||
|
||||
# wmma params
|
||||
WMMA_M, WMMA_N, WMMA_K = 16, 16, 16
|
||||
WMMA_ACC = WMMA_M // LANES_PER_WAVE_M
|
||||
UNROLL_M, UNROLL_N = (WMMA_ACC, 1) if is_rdna4 else (1, 1)
|
||||
else:
|
||||
WAVES_M, WAVES_N = 4, 1
|
||||
LANES_PER_WAVE_M, LANES_PER_WAVE_N = 4, 8
|
||||
UNROLL_M, UNROLL_N = 4, 4
|
||||
|
||||
# total lanes must be the warp size
|
||||
assert LANES_PER_WAVE_M*LANES_PER_WAVE_N == WARP_SIZE
|
||||
|
||||
# WARP_SIZE * total waves
|
||||
THREADS_PER_BLOCK = WARP_SIZE * WAVES_M * WAVES_N
|
||||
|
||||
@@ -71,7 +76,10 @@ def block_128x128_gemm(c:UOp, a:UOp, b:UOp) -> UOp:
|
||||
acc_frag = acc.reshape(TM // WMMA_ACC, WMMA_ACC, TN).permute(0,2,1)[tile_m, tile_n]
|
||||
a_frag = A_local.reshape(WAVES_M, TM // WMMA_ACC, WMMA_M, BLOCK_K // WMMA_K, WMMA_K)[wave_m, tile_m, lane_n, k]
|
||||
b_frag = B_local.reshape(WAVES_N, TN, WMMA_N, BLOCK_K // WMMA_K, WMMA_K)[wave_n, tile_n, lane_n, k]
|
||||
|
||||
if is_rdna4:
|
||||
# NOTE: since this is part of K, these 2 can be anywhere in the frags and long as a and b match
|
||||
a_frag = a_frag.reshape(2, 8)[lane_m, :]
|
||||
b_frag = b_frag.reshape(2, 8)[lane_m, :]
|
||||
wmma = UOp(Ops.SHAPED_WMMA, dtypes.float, (a_frag, b_frag, acc_frag.after(k)), arg=((16, 16, 16), 'AMD', 32))
|
||||
acc_store = acc_frag.store(wmma).end(tile_m, tile_n)
|
||||
else:
|
||||
|
||||
Reference in New Issue
Block a user