cleanup amd uop matmul (#13025)

* cleanup amd uop matmul

* remove mod

* move that out

* better variable names

* var names

* more

* render fallback

* colors
This commit is contained in:
George Hotz
2025-10-31 10:04:45 +08:00
committed by GitHub
parent f6430a0559
commit 512513c403
3 changed files with 84 additions and 90 deletions

View File

@@ -7,106 +7,100 @@ from tinygrad.helpers import getenv
N = 4096
run_count = 5
# block for locals
BN = 128
BM = 128
BK = 8
# ---------------------------
# launch/config constants
# ---------------------------
# t for registers
TN = 4
TM = 4
WARP_SIZE = 32
# Threadblock tile sizes (block-level tile of C that a block computes)
BLOCK_N = 128 # columns of C (N-dim) per block
BLOCK_M = 128 # rows of C (M-dim) per block
BLOCK_K = 8 # K-slice per block iteration
def hand_spec_kernel3(kernel5=getenv("K5", 0)):
# ---------------------------
# launch/config constants
# ---------------------------
# Register tile sizes (per-thread accumulator tile of C)
TN = 4 # columns per thread
TM = 4 # rows per thread
BLOCK_SIZE = 128 if kernel5 else 256
is_kernel5 = getenv("K5", 0)
THREADS_PER_BLOCK = 128 if is_kernel5 else 256
assert THREADS_PER_BLOCK % BLOCK_N == 0, "THREADS_PER_BLOCK must be divisible by BLOCK_N"
assert THREADS_PER_BLOCK % BLOCK_K == 0, "THREADS_PER_BLOCK must be divisible by BLOCK_K"
assert (BLOCK_N * BLOCK_K) % THREADS_PER_BLOCK == 0
assert (BLOCK_M * BLOCK_K) % THREADS_PER_BLOCK == 0
nbWaves = BLOCK_SIZE // 32
WN = 128 if kernel5 else 64
WM = BN * BM // nbWaves // WN
WARPS_PER_BLOCK = THREADS_PER_BLOCK // WARP_SIZE
WAVE_TILE_N = 128 if is_kernel5 else 64
WAVE_TILE_M = BLOCK_N * BLOCK_M // WARPS_PER_BLOCK // WAVE_TILE_N
assert BLOCK_N % WAVE_TILE_N == 0, "BN must be a multiple of WN"
assert BLOCK_M % WAVE_TILE_M == 0, "BM must be a multiple of WM"
WAVES_IN_BLOCK_X = BLOCK_N // WAVE_TILE_N
WAVES_IN_BLOCK_Y = BLOCK_M // WAVE_TILE_M
assert WAVES_IN_BLOCK_X * WAVES_IN_BLOCK_Y == WARPS_PER_BLOCK, "wave grid must match warps/block"
# Sanity checks (fail fast if shapes/tiles misalign)
assert BN % WN == 0, "BN must be a multiple of WN"
assert BM % WM == 0, "BM must be a multiple of WM"
nbWaveX = BN // WN
nbWaveY = BM // WM
assert BLOCK_SIZE % BN == 0, "BLOCK_SIZE must be divisible by BN"
assert BLOCK_SIZE % BK == 0, "BLOCK_SIZE must be divisible by BK"
assert (BN * BK) % BLOCK_SIZE == 0
assert (BM * BK) % BLOCK_SIZE == 0
LANES_PER_WAVE_X = 8
LANES_PER_WAVE_Y = 4
ITERS_PER_WAVE_N = WAVE_TILE_N // (LANES_PER_WAVE_X * TN)
ITERS_PER_WAVE_M = WAVE_TILE_M // (LANES_PER_WAVE_Y * TM)
N_PER_ITER = WAVE_TILE_N // ITERS_PER_WAVE_N
M_PER_ITER = WAVE_TILE_M // ITERS_PER_WAVE_M
assert WAVE_TILE_N % (LANES_PER_WAVE_X * TN) == 0, "WAVE_TILE_N must be divisible by LANES_PER_WAVE_X*TN"
assert WAVE_TILE_M % (LANES_PER_WAVE_Y * TM) == 0, "WAVE_TILE_M must be divisible by LANES_PER_WAVE_Y*TM"
def hand_spec_kernel3():
# ---------------------------
# per-thread read mapping
# ---------------------------
# A: read BK x BN tiles; B: read BN x BK tiles
tid = UOp.special(THREADS_PER_BLOCK, "lidx0")
threadIdx_x = UOp.special(BLOCK_SIZE, "lidx0")
waveIndex = threadIdx_x // 32
waveIdx = waveIndex % nbWaveX
waveIdy = waveIndex // nbWaveX
indexInWave = threadIdx_x % 32
waveIdx = (tid // WARP_SIZE) % WAVES_IN_BLOCK_X
waveIdy = (tid // WARP_SIZE) // WAVES_IN_BLOCK_X
assert waveIdy.vmax+1 == WAVES_IN_BLOCK_Y
nbThreadXPerWave = 8
nbThreadYPerWave = 4
idxInWave = indexInWave % nbThreadXPerWave
idyInWave = indexInWave // nbThreadXPerWave
nbIterWaveN = WN // (nbThreadXPerWave * TN)
nbIterWaveM = WM // (nbThreadYPerWave * TM)
SUBWN = WN // nbIterWaveN
SUBWM = WM // nbIterWaveM
idxInWave = (tid % WARP_SIZE) % LANES_PER_WAVE_X
idyInWave = (tid % WARP_SIZE) // LANES_PER_WAVE_X
assert idyInWave.vmax+1 == LANES_PER_WAVE_Y
# ---------------------------
# block indices & placeholders
# ---------------------------
blockIdx_x = UOp.special(N // BN, "gidx0")
blockIdx_y = UOp.special(N // BM, "gidx1")
blockIdx_x = UOp.special(N // BLOCK_N, "gidx0")
blockIdx_y = UOp.special(N // BLOCK_M, "gidx1")
a = UOp.placeholder(dtypes.float, (N, N), slot=1)
b = UOp.placeholder(dtypes.float, (N, N), slot=2)
c = UOp.placeholder(dtypes.float, (N, N), slot=0)
BM_As_stride = (BM + 4) if kernel5 else BM
As = UOp.placeholder(dtypes.float, (BK, BM_As_stride), slot=0, addrspace=AddrSpace.LOCAL)
Bs = UOp.placeholder(dtypes.float, (BK, BN), slot=1, addrspace=AddrSpace.LOCAL)
BM_As_stride = (BLOCK_M + 4) if is_kernel5 else BLOCK_M
As = UOp.placeholder(dtypes.float, (BLOCK_K, BM_As_stride), slot=0, addrspace=AddrSpace.LOCAL)
Bs = UOp.placeholder(dtypes.float, (BLOCK_K, BLOCK_N), slot=1, addrspace=AddrSpace.LOCAL)
A_col = UOp.placeholder(dtypes.float, (nbIterWaveM, TM), slot=0, addrspace=AddrSpace.REG)
B_row = UOp.placeholder(dtypes.float, (nbIterWaveN, TN), slot=1, addrspace=AddrSpace.REG)
c_regs = UOp.placeholder(dtypes.float, (nbIterWaveM, TM, nbIterWaveN, TN), slot=2, addrspace=AddrSpace.REG)
A_col = UOp.placeholder(dtypes.float, (ITERS_PER_WAVE_M, TM), slot=0, addrspace=AddrSpace.REG)
B_row = UOp.placeholder(dtypes.float, (ITERS_PER_WAVE_N, TN), slot=1, addrspace=AddrSpace.REG)
c_regs = UOp.placeholder(dtypes.float, (ITERS_PER_WAVE_M, TM, ITERS_PER_WAVE_N, TN), slot=2, addrspace=AddrSpace.REG)
i = UOp.range(c_regs.dtype.size, 16)
i = UOp.range(c_regs.size, 16)
c_regs = c_regs[i].set(0.0, end=i)
kId_range = UOp.range(N // BK, 0)
kId = kId_range * BK
k_tile_range = UOp.range(N // BLOCK_K, 0)
# ---------------------------
# GLOBAL -> LOCAL (As, Bs)
# ---------------------------
nbReadsB = BN * BK // BLOCK_SIZE
i = UOp.range(nbReadsB, 1)
rBIdx = threadIdx_x % BN
rBIdy = threadIdx_x // BN
strideReadB = BLOCK_SIZE // BN
index_x = BN * blockIdx_x + rBIdx
index_y = rBIdy + i * strideReadB + kId
Bs_store = Bs[index_y % BK, index_x % BN].store(b[index_y, index_x]).end(i)
b = b.reshape((N // BLOCK_K, BLOCK_K,
N // BLOCK_N, BLOCK_N))
i = UOp.range(BLOCK_N * BLOCK_K // THREADS_PER_BLOCK, 1)
index_x = tid % BLOCK_N
index_y = (tid // BLOCK_N) + (THREADS_PER_BLOCK // BLOCK_N) * i
Bs_store = Bs[index_y, index_x].store(b[k_tile_range, index_y, blockIdx_x, index_x]).end(i)
nbReadsA = BM * BK // BLOCK_SIZE
i = UOp.range(nbReadsA, 2)
rAIdx = threadIdx_x % BK
rAIdy = threadIdx_x // BK
strideReadA = BLOCK_SIZE // BK
index_x = rAIdx + kId
index_y = BM * blockIdx_y + rAIdy + i * strideReadA
As_store = As[index_x % BK, index_y % BM].store(a[index_y, index_x]).end(i)
a = a.reshape((N // BLOCK_M, BLOCK_M,
N // BLOCK_K, BLOCK_K))
i = UOp.range(BLOCK_M * BLOCK_K // THREADS_PER_BLOCK, 2)
index_x = tid % BLOCK_K
index_y = (tid // BLOCK_K) + (THREADS_PER_BLOCK // BLOCK_K) * i
As_store = As[index_x, index_y].store(a[blockIdx_y, index_y, k_tile_range, index_x]).end(i)
# TODO: can we automate barrier?
barrier = UOp.barrier(As_store, Bs_store)
@@ -114,44 +108,45 @@ def hand_spec_kernel3(kernel5=getenv("K5", 0)):
As = As.after(barrier)
# open inner k range
k = UOp.range(BK, 3)
k = UOp.range(BLOCK_K, 3)
# ---------------------------
# LOCAL -> REG (per-wave tiles)
# ---------------------------
iterWave = UOp.range(nbIterWaveN, 4)
iterWaveN = UOp.range(ITERS_PER_WAVE_N, 4)
i = UOp.range(TN, 5)
index = waveIdx * WN + iterWave * SUBWN + TN * idxInWave + i
B_row = B_row[iterWave, i].set(Bs[k, index], end=(iterWave, i))
index = waveIdx * WAVE_TILE_N + iterWaveN * N_PER_ITER + idxInWave * TN + i
B_row = B_row[iterWaveN, i].set(Bs[k, index], end=(iterWaveN, i))
iterWave = UOp.range(nbIterWaveM, 6)
iterWaveM = UOp.range(ITERS_PER_WAVE_M, 6)
i = UOp.range(TM, 7)
index = waveIdy * WM + iterWave * SUBWM + TM * idyInWave + i
A_col = A_col[iterWave, i].set(As[k, index], end=(iterWave, i))
index = waveIdy * WAVE_TILE_M + iterWaveM * M_PER_ITER + idyInWave * TM + i
A_col = A_col[iterWaveM, i].set(As[k, index], end=(iterWaveM, i))
# ---------------------------
# FMA: c_regs += A_col * B_row
# ---------------------------
iterWaveM = UOp.range(nbIterWaveM, 8)
iterWaveM = UOp.range(ITERS_PER_WAVE_M, 8)
yt = UOp.range(TM, 9)
iterWaveN = UOp.range(nbIterWaveN, 10)
iterWaveN = UOp.range(ITERS_PER_WAVE_N, 10)
xt = UOp.range(TN, 12)
c_idx = c_regs.after(k, kId_range)[iterWaveM, yt, iterWaveN, xt]
c_idx = c_regs.after(k, k_tile_range)[iterWaveM, yt, iterWaveN, xt]
sink = c_idx.store(c_idx + A_col[iterWaveM, yt] * B_row[iterWaveN, xt]).end(iterWaveM, iterWaveN, yt, xt)
# Close k, sync, and close K tiles
sink = sink.end(k).barrier().end(kId_range)
sink = sink.end(k).barrier().end(k_tile_range)
# ---------------------------
# REG -> GLOBAL (epilogue)
# ---------------------------
iterWaveM = UOp.range(nbIterWaveM, 1000)
c = c.reshape((N//BLOCK_M, WAVES_IN_BLOCK_Y, ITERS_PER_WAVE_M, LANES_PER_WAVE_Y, TM,
N//BLOCK_N, WAVES_IN_BLOCK_X, ITERS_PER_WAVE_N, LANES_PER_WAVE_X, TN))
iterWaveM = UOp.range(ITERS_PER_WAVE_M, 1000)
yt = UOp.range(TM, 1001)
iterWaveN = UOp.range(nbIterWaveN, 1002)
iterWaveN = UOp.range(ITERS_PER_WAVE_N, 1002)
xt = UOp.range(TN, 1003)
xOut = blockIdx_x * BN + waveIdx * WN + iterWaveN * SUBWN + TN * idxInWave
yOut = blockIdx_y * BM + waveIdy * WM + iterWaveM * SUBWM + TM * idyInWave
sink = c[yOut + yt, xOut + xt].store(c_regs.after(sink)[iterWaveM, yt, iterWaveN, xt])
c_glbl_idx = c[blockIdx_y, waveIdy, iterWaveM, idyInWave, yt, blockIdx_x, waveIdx, iterWaveN, idxInWave, xt]
sink = c_glbl_idx.store(c_regs.after(sink)[iterWaveM, yt, iterWaveN, xt])
sink = sink.end(iterWaveM, iterWaveN, yt, xt)
return sink.sink(arg=KernelInfo(opts_to_apply=()))

View File

@@ -750,9 +750,7 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
ctx: dict[UOp, str] = {}
pm = renderer if pm is None else pm
for u in (s:=self.simplify() if simplify else self).toposort():
# if there is any node in the toposort we can't render, we just render the whole thing using UOp pretty printer
if (u_str:=pm.rewrite(u, ctx=ctx)) is None: return str(s)
ctx[u] = cast(str, u_str)
ctx[u] = cast(str, pm.rewrite(u, ctx=ctx))
return ctx[s]
def pyrender(self): return pyrender(self)
@@ -1277,6 +1275,7 @@ renderer = PatternMatcher([
(UPat((Ops.INDEX, Ops.BUFFERIZE), name="x"), lambda x, ctx: ''.join([f"[{strip_parens(ctx[y])}]" for y in x.src[1:]])),
(UPat(Ops.VECTORIZE, name="x"),
lambda ctx,x: f"{{{','.join([ctx[y] for y in x.src])}}}" if not all_same(x.src) else f"{{{ctx[x.src[0]]}, ...}}"),
(UPat(GroupOp.All, name="x"), lambda x: str(x)),
])
renderer_infer = PatternMatcher([

View File

@@ -14,9 +14,9 @@ from tinygrad.renderer import ProgramSpec
from tinygrad.dtype import dtypes
uops_colors = {Ops.LOAD: "#ffc0c0", Ops.STORE: "#87CEEB", Ops.CONST: "#e0e0e0", Ops.VCONST: "#e0e0e0", Ops.REDUCE: "#FF5B5B",
Ops.DEFINE_GLOBAL: "#ffe0b0", Ops.DEFINE_LOCAL: "#ffe0d0", Ops.DEFINE_REG: "#f0ffe0", Ops.REDUCE_AXIS: "#FF6B6B",
**{x:"#f2cb91" for x in GroupOp.Defines}, Ops.REDUCE_AXIS: "#FF6B6B",
Ops.RANGE: "#c8a0e0", Ops.ASSIGN: "#909090", Ops.BARRIER: "#ff8080", Ops.IF: "#c8b0c0", Ops.SPECIAL: "#c0c0ff",
Ops.INDEX: "#e8ffa0", Ops.WMMA: "#efefc0", Ops.MULTI: "#f6ccff", Ops.KERNEL: "#3e7f55",
Ops.INDEX: "#cef263", Ops.WMMA: "#efefc0", Ops.MULTI: "#f6ccff", Ops.KERNEL: "#3e7f55",
**{x:"#D8F9E4" for x in GroupOp.Movement}, **{x:"#ffffc0" for x in GroupOp.ALU}, Ops.THREEFRY:"#ffff80",
Ops.BUFFER_VIEW: "#E5EAFF", Ops.BUFFER: "#B0BDFF", Ops.COPY: "#a040a0", Ops.FUSE: "#FFa500",
Ops.ALLREDUCE: "#ff40a0", Ops.MSELECT: "#d040a0", Ops.MSTACK: "#d040a0", Ops.CONTIGUOUS: "#FFC14D",