update amd_uop_matmul (#15236)

* update amd_uop_matmul

* use custom kernel

* simpler

* ignore
This commit is contained in:
George Hotz
2026-03-12 17:33:12 +08:00
committed by GitHub
parent 90b7f4341d
commit e560a46f59

View File

@@ -1,51 +1,32 @@
import numpy as np
from tinygrad import Tensor, Device, Context, GlobalCounters, dtypes
from tinygrad import Tensor, Context, GlobalCounters, dtypes
from tinygrad.uop.ops import UOp, KernelInfo, sint, AxisType
from tinygrad.engine.realize import ExecItem, get_runner
from tinygrad.dtype import AddrSpace
from tinygrad.helpers import getenv
from tinygrad.helpers import DEBUG, getenv
N = getenv("N", 4096)
M = K = N
run_count = getenv("CNT", 5)
M = getenv("M", N)
K = getenv("K", N)
NUM_RUNS = getenv("CNT", 5)
# ---------------------------
# launch/config constants
# ---------------------------
WARP_SIZE = 32
# Threadblock tile sizes (block-level tile of C that a block computes)
BLOCK_N = 128 # columns of C (N-dim) per block
BLOCK_M = 128 # rows of C (M-dim) per block
BLOCK_K = 8 # K-slice per block iteration
# Register tile sizes (per-thread accumulator tile of C)
TN = 4 # columns per thread
TM = 4 # rows per thread
BLOCK_M, BLOCK_N, BLOCK_K = 128, 128, 8
TM, TN = 4, 4
LANES_PER_WAVE_M, LANES_PER_WAVE_N = 4, 8
assert N % BLOCK_N == 0 and M % BLOCK_M == 0 and K % BLOCK_K == 0
is_kernel5 = getenv("K5", 0)
THREADS_PER_BLOCK = 128 if is_kernel5 else 256
assert THREADS_PER_BLOCK % BLOCK_N == 0, "THREADS_PER_BLOCK must be divisible by BLOCK_N"
assert THREADS_PER_BLOCK % BLOCK_K == 0, "THREADS_PER_BLOCK must be divisible by BLOCK_K"
assert (BLOCK_N * BLOCK_K) % THREADS_PER_BLOCK == 0
assert (BLOCK_M * BLOCK_K) % THREADS_PER_BLOCK == 0
WAVES_PER_BLOCK_N = 1 if is_kernel5 else 2
WAVES_PER_BLOCK_M = THREADS_PER_BLOCK // WARP_SIZE // WAVES_PER_BLOCK_N
REG_TILES_PER_WAVE_N = BLOCK_N // (WAVES_PER_BLOCK_N * LANES_PER_WAVE_N * TN)
REG_TILES_PER_WAVE_M = BLOCK_M // (WAVES_PER_BLOCK_M * LANES_PER_WAVE_M * TM)
WARPS_PER_BLOCK = THREADS_PER_BLOCK // WARP_SIZE
WAVE_TILE_N = 128 if is_kernel5 else 64
WAVE_TILE_M = BLOCK_N * BLOCK_M // WARPS_PER_BLOCK // WAVE_TILE_N
assert BLOCK_N % WAVE_TILE_N == 0, "BN must be a multiple of WN"
assert BLOCK_M % WAVE_TILE_M == 0, "BM must be a multiple of WM"
WAVES_IN_BLOCK_X = BLOCK_N // WAVE_TILE_N
WAVES_IN_BLOCK_Y = BLOCK_M // WAVE_TILE_M
assert WAVES_IN_BLOCK_X * WAVES_IN_BLOCK_Y == WARPS_PER_BLOCK, "wave grid must match warps/block"
LANES_PER_WAVE_X = 8
LANES_PER_WAVE_Y = 4
ITERS_PER_WAVE_N = WAVE_TILE_N // (LANES_PER_WAVE_X * TN)
ITERS_PER_WAVE_M = WAVE_TILE_M // (LANES_PER_WAVE_Y * TM)
assert WAVE_TILE_N % (LANES_PER_WAVE_X * TN) == 0, "WAVE_TILE_N must be divisible by LANES_PER_WAVE_X*TN"
assert WAVE_TILE_M % (LANES_PER_WAVE_Y * TM) == 0, "WAVE_TILE_M must be divisible by LANES_PER_WAVE_Y*TM"
assert WAVES_PER_BLOCK_M*REG_TILES_PER_WAVE_M*LANES_PER_WAVE_M*TM == BLOCK_M, "M reshape is wrong"
assert WAVES_PER_BLOCK_N*REG_TILES_PER_WAVE_N*LANES_PER_WAVE_N*TN == BLOCK_N, "N reshape is wrong"
def rngs_for_shape(shape:tuple[sint, ...], rng:int, axis_type=AxisType.LOOP): return [UOp.range(s, rng+i, axis_type) for i,s in enumerate(shape)]
def copy(dest:UOp, src:UOp, rng:int, set=False, upcast=False):
@@ -54,45 +35,41 @@ def copy(dest:UOp, src:UOp, rng:int, set=False, upcast=False):
copy = dest[*rngs].store(src[*rngs]).end(*rngs)
return dest.after(copy) if set else copy
def hand_spec_kernel3():
def hand_spec_kernel3(c:UOp, a:UOp, b:UOp) -> UOp:
# ---------------------------
# block indices & placeholders
# block indices
# ---------------------------
blockIdx_x = UOp.special(N // BLOCK_N, "gidx0")
blockIdx_y = UOp.special(N // BLOCK_M, "gidx1")
a = UOp.placeholder((N, N), dtypes.float, slot=1)
b = UOp.placeholder((N, N), dtypes.float, slot=2)
c = UOp.placeholder((N, N), dtypes.float, slot=0)
block_id_n = UOp.special(N // BLOCK_N, "gidx0")
block_id_m = UOp.special(M // BLOCK_M, "gidx1")
# index the output with the globals
c = c.reshape(M // BLOCK_M, BLOCK_M, N // BLOCK_N, BLOCK_N)[blockIdx_y, :, blockIdx_x, :]
c = c.reshape(M // BLOCK_M, BLOCK_M, N // BLOCK_N, BLOCK_N)[block_id_m, :, block_id_n, :]
# open the main reduction range
k_tile_range = UOp.range(N // BLOCK_K, 0, AxisType.REDUCE)
a = a.reshape(M // BLOCK_M, BLOCK_M, N // BLOCK_K, BLOCK_K)[blockIdx_y, :, k_tile_range, :]
b = b.reshape(N // BLOCK_K, BLOCK_K, N // BLOCK_N, BLOCK_N)[k_tile_range, :, blockIdx_x, :]
k_tile_range = UOp.range(K // BLOCK_K, 0, AxisType.REDUCE)
a = a.reshape(M // BLOCK_M, BLOCK_M, K // BLOCK_K, BLOCK_K)[block_id_m, :, k_tile_range, :]
b = b.reshape(K // BLOCK_K, BLOCK_K, N // BLOCK_N, BLOCK_N)[k_tile_range, :, block_id_n, :]
# globals are no longer used, they are already in the indexes
del blockIdx_y, blockIdx_x
del block_id_m, block_id_n
# ---------------------------
# GLOBAL -> LOCAL (As, Bs)
# GLOBAL -> LOCAL (A_local, B_local)
# ---------------------------
tid = UOp.special(THREADS_PER_BLOCK, "lidx0")
# A: read BM x BK tiles (permute on store into locals)
BM_As_stride = (BLOCK_M + 4) if is_kernel5 else BLOCK_M
As = UOp.placeholder((BLOCK_K, BM_As_stride), dtypes.float, slot=0, addrspace=AddrSpace.LOCAL).shrink_to((BLOCK_K, BLOCK_M))
As_store = copy(As.permute((1,0)).reshape(-1, THREADS_PER_BLOCK)[:, tid], a.reshape(-1, THREADS_PER_BLOCK)[:, tid], rng=100)
BM_A_local_stride = (BLOCK_M + 4) if is_kernel5 else BLOCK_M
A_local = UOp.placeholder((BLOCK_K, BM_A_local_stride), dtypes.float, slot=0, addrspace=AddrSpace.LOCAL).shrink_to((BLOCK_K, BLOCK_M))
A_local_store = copy(A_local.permute((1,0)).reshape(-1, THREADS_PER_BLOCK)[:, tid], a.reshape(-1, THREADS_PER_BLOCK)[:, tid], rng=100)
# B: read BK x BN tiles
Bs = UOp.placeholder((BLOCK_K, BLOCK_N), dtypes.float, slot=1, addrspace=AddrSpace.LOCAL)
Bs_store = copy(Bs.reshape(-1, THREADS_PER_BLOCK)[:, tid], b.reshape(-1, THREADS_PER_BLOCK)[:, tid], rng=200)
B_local = UOp.placeholder((BLOCK_K, BLOCK_N), dtypes.float, slot=1, addrspace=AddrSpace.LOCAL)
B_local_store = copy(B_local.reshape(-1, THREADS_PER_BLOCK)[:, tid], b.reshape(-1, THREADS_PER_BLOCK)[:, tid], rng=200)
# TODO: can we automate barrier?
barrier = UOp.barrier(As_store, Bs_store)
As, Bs = As.after(barrier), Bs.after(barrier)
barrier = UOp.barrier(A_local_store, B_local_store)
A_local, B_local = A_local.after(barrier), B_local.after(barrier)
# open inner k range
k = UOp.range(BLOCK_K, 3, AxisType.REDUCE)
@@ -100,31 +77,33 @@ def hand_spec_kernel3():
# ---------------------------
# LOCAL -> REG (per-wave tiles)
# ---------------------------
waveIdx = (tid // WARP_SIZE) % WAVES_IN_BLOCK_X
waveIdy = (tid // WARP_SIZE) // WAVES_IN_BLOCK_X
assert waveIdy.vmax+1 == WAVES_IN_BLOCK_Y
waveIdx = (tid // WARP_SIZE) % WAVES_PER_BLOCK_N
waveIdy = (tid // WARP_SIZE) // WAVES_PER_BLOCK_N
assert waveIdy.vmax+1 == WAVES_PER_BLOCK_M
laneIdx = (tid % WARP_SIZE) % LANES_PER_WAVE_X
laneIdy = (tid % WARP_SIZE) // LANES_PER_WAVE_X
assert laneIdy.vmax+1 == LANES_PER_WAVE_Y
laneIdx = (tid % WARP_SIZE) % LANES_PER_WAVE_N
laneIdy = (tid % WARP_SIZE) // LANES_PER_WAVE_N
assert laneIdy.vmax+1 == LANES_PER_WAVE_M
A_col = UOp.placeholder((ITERS_PER_WAVE_M, TM), dtypes.float, slot=0, addrspace=AddrSpace.REG)
A_col = copy(A_col, As[k, :].reshape(WAVES_IN_BLOCK_Y, ITERS_PER_WAVE_M, LANES_PER_WAVE_Y, TM)[waveIdy, :, laneIdy, :], 300, set=True, upcast=True)
A_col = UOp.placeholder((REG_TILES_PER_WAVE_M, TM), dtypes.float, slot=0, addrspace=AddrSpace.REG)
A_local_slice = A_local[k, :].reshape(WAVES_PER_BLOCK_M, REG_TILES_PER_WAVE_M, LANES_PER_WAVE_M, TM)[waveIdy, :, laneIdy, :]
A_col = copy(A_col, A_local_slice, 300, set=True, upcast=True)
B_row = UOp.placeholder((ITERS_PER_WAVE_N, TN), dtypes.float, slot=1, addrspace=AddrSpace.REG)
B_row = copy(B_row, Bs[k, :].reshape(WAVES_IN_BLOCK_X, ITERS_PER_WAVE_N, LANES_PER_WAVE_X, TN)[waveIdx, :, laneIdx, :], 400, set=True, upcast=True)
B_row = UOp.placeholder((REG_TILES_PER_WAVE_N, TN), dtypes.float, slot=1, addrspace=AddrSpace.REG)
B_local_slice = B_local[k, :].reshape(WAVES_PER_BLOCK_N, REG_TILES_PER_WAVE_N, LANES_PER_WAVE_N, TN)[waveIdx, :, laneIdx, :]
B_row = copy(B_row, B_local_slice, 400, set=True, upcast=True)
# ---------------------------
# FMA: c_regs += A_col * B_row
# ---------------------------
c_regs = UOp.placeholder((ITERS_PER_WAVE_M, TM, ITERS_PER_WAVE_N, TN), dtypes.float, slot=2, addrspace=AddrSpace.REG)
c_regs = UOp.placeholder((REG_TILES_PER_WAVE_M, TM, REG_TILES_PER_WAVE_N, TN), dtypes.float, slot=2, addrspace=AddrSpace.REG)
i = UOp.range(c_regs.size, 16)
c_regs = c_regs.after(c_regs.flatten()[i].store(0.0).end(i))
# TODO: why don't these work as upcast?
# why if the ranges merge is it slow?!? (if you change the order on end, they will merge. big slowdown on METAL)
iterWaveM, yt, iterWaveN, xt = rngs = rngs_for_shape(c_regs.shape, 500)
sink = c_regs[*rngs].store(c_regs.after(k)[*rngs] + A_col[iterWaveM, yt] * B_row[iterWaveN, xt]).end(iterWaveM, iterWaveN, yt, xt)
iter_m, t_m, iter_n, t_n = rngs = rngs_for_shape(c_regs.shape, 500)
sink = c_regs[*rngs].store(c_regs.after(k)[*rngs] + A_col[iter_m, t_m] * B_row[iter_n, t_n]).end(iter_m, iter_n, t_m, t_n)
# Close k, sync, and close K tiles
sink = sink.end(k).barrier().end(k_tile_range)
@@ -132,38 +111,34 @@ def hand_spec_kernel3():
# ---------------------------
# REG -> GLOBAL (epilogue)
# ---------------------------
c = c.reshape(WAVES_IN_BLOCK_Y, ITERS_PER_WAVE_M, LANES_PER_WAVE_Y, TM,
WAVES_IN_BLOCK_X, ITERS_PER_WAVE_N, LANES_PER_WAVE_X, TN)
c = c.reshape(WAVES_PER_BLOCK_M, REG_TILES_PER_WAVE_M, LANES_PER_WAVE_M, TM,
WAVES_PER_BLOCK_N, REG_TILES_PER_WAVE_N, LANES_PER_WAVE_N, TN)
c = c[waveIdy, :, laneIdy, :,
waveIdx, :, laneIdx, :]
sink = copy(c, c_regs.after(sink), rng=600)
return sink.sink(arg=KernelInfo(opts_to_apply=())).simplify()
def test_matmul(sink:UOp, dtype=dtypes.float32, N=N):
rng = np.random.default_rng()
a = Tensor(rng.random((N, N), dtype=np.float32)-0.5, dtype=dtype)
b = Tensor(rng.random((N, N), dtype=np.float32)-0.5, dtype=dtype)
hc = Tensor.empty(N, N, dtype=dtype)
Tensor.realize(a, b, hc)
ei = ExecItem(sink, [t.uop.buffer for t in [hc, a, b]], prg=get_runner(Device.DEFAULT, sink))
if __name__ == "__main__":
a = Tensor.randn(M, K, dtype=dtypes.float)
b = Tensor.randn(K, N, dtype=dtypes.float)
c = Tensor.empty(M, N, dtype=dtypes.float)
with Context(DEBUG=0): Tensor.realize(a, b)
ets = []
with Context(DEBUG=2):
for _ in range(run_count):
ets.append(ei.run(wait=True))
print(f"REAL TFLOPS {N * N * N * 2 / min(ets) * 1e-12:.2f}")
with Context(DEBUG=max(2, DEBUG.value)):
for _ in range(NUM_RUNS):
GlobalCounters.reset()
tst = Tensor.custom_kernel(c, a, b, fxn=hand_spec_kernel3)[0].realize()
ets.append(GlobalCounters.time_sum_s)
print(f"REAL TFLOPS {M * N * K * 2 / min(ets) * 1e-12:.2f}")
if getenv("VERIFY", 1):
GlobalCounters.reset()
with Context(DEBUG=2):
tc = (a @ b).realize()
with Context(DEBUG=0):
err = (hc - tc).square().mean().item()
err = (tc - tst).square().mean().item()
print(f"mean squared error {err}")
if err > 1e-06:
raise RuntimeError("matmul is wrong!")
if __name__ == "__main__":
test_matmul(hand_spec_kernel3(), N=N)