mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-10 07:28:15 -05:00
cleanup amd uop matmul (#13025)
* cleanup amd uop matmul * remove mod * move that out * better variable names * var names * more * render fallback * colors
This commit is contained in:
@@ -7,106 +7,100 @@ from tinygrad.helpers import getenv
|
||||
N = 4096
|
||||
run_count = 5
|
||||
|
||||
# block for locals
|
||||
BN = 128
|
||||
BM = 128
|
||||
BK = 8
|
||||
# ---------------------------
|
||||
# launch/config constants
|
||||
# ---------------------------
|
||||
|
||||
# t for registers
|
||||
TN = 4
|
||||
TM = 4
|
||||
WARP_SIZE = 32
|
||||
|
||||
# Threadblock tile sizes (block-level tile of C that a block computes)
|
||||
BLOCK_N = 128 # columns of C (N-dim) per block
|
||||
BLOCK_M = 128 # rows of C (M-dim) per block
|
||||
BLOCK_K = 8 # K-slice per block iteration
|
||||
|
||||
def hand_spec_kernel3(kernel5=getenv("K5", 0)):
|
||||
# ---------------------------
|
||||
# launch/config constants
|
||||
# ---------------------------
|
||||
# Register tile sizes (per-thread accumulator tile of C)
|
||||
TN = 4 # columns per thread
|
||||
TM = 4 # rows per thread
|
||||
|
||||
BLOCK_SIZE = 128 if kernel5 else 256
|
||||
is_kernel5 = getenv("K5", 0)
|
||||
THREADS_PER_BLOCK = 128 if is_kernel5 else 256
|
||||
assert THREADS_PER_BLOCK % BLOCK_N == 0, "THREADS_PER_BLOCK must be divisible by BLOCK_N"
|
||||
assert THREADS_PER_BLOCK % BLOCK_K == 0, "THREADS_PER_BLOCK must be divisible by BLOCK_K"
|
||||
assert (BLOCK_N * BLOCK_K) % THREADS_PER_BLOCK == 0
|
||||
assert (BLOCK_M * BLOCK_K) % THREADS_PER_BLOCK == 0
|
||||
|
||||
nbWaves = BLOCK_SIZE // 32
|
||||
WN = 128 if kernel5 else 64
|
||||
WM = BN * BM // nbWaves // WN
|
||||
WARPS_PER_BLOCK = THREADS_PER_BLOCK // WARP_SIZE
|
||||
WAVE_TILE_N = 128 if is_kernel5 else 64
|
||||
WAVE_TILE_M = BLOCK_N * BLOCK_M // WARPS_PER_BLOCK // WAVE_TILE_N
|
||||
assert BLOCK_N % WAVE_TILE_N == 0, "BN must be a multiple of WN"
|
||||
assert BLOCK_M % WAVE_TILE_M == 0, "BM must be a multiple of WM"
|
||||
WAVES_IN_BLOCK_X = BLOCK_N // WAVE_TILE_N
|
||||
WAVES_IN_BLOCK_Y = BLOCK_M // WAVE_TILE_M
|
||||
assert WAVES_IN_BLOCK_X * WAVES_IN_BLOCK_Y == WARPS_PER_BLOCK, "wave grid must match warps/block"
|
||||
|
||||
# Sanity checks (fail fast if shapes/tiles misalign)
|
||||
assert BN % WN == 0, "BN must be a multiple of WN"
|
||||
assert BM % WM == 0, "BM must be a multiple of WM"
|
||||
nbWaveX = BN // WN
|
||||
nbWaveY = BM // WM
|
||||
|
||||
assert BLOCK_SIZE % BN == 0, "BLOCK_SIZE must be divisible by BN"
|
||||
assert BLOCK_SIZE % BK == 0, "BLOCK_SIZE must be divisible by BK"
|
||||
|
||||
assert (BN * BK) % BLOCK_SIZE == 0
|
||||
assert (BM * BK) % BLOCK_SIZE == 0
|
||||
LANES_PER_WAVE_X = 8
|
||||
LANES_PER_WAVE_Y = 4
|
||||
ITERS_PER_WAVE_N = WAVE_TILE_N // (LANES_PER_WAVE_X * TN)
|
||||
ITERS_PER_WAVE_M = WAVE_TILE_M // (LANES_PER_WAVE_Y * TM)
|
||||
N_PER_ITER = WAVE_TILE_N // ITERS_PER_WAVE_N
|
||||
M_PER_ITER = WAVE_TILE_M // ITERS_PER_WAVE_M
|
||||
assert WAVE_TILE_N % (LANES_PER_WAVE_X * TN) == 0, "WAVE_TILE_N must be divisible by LANES_PER_WAVE_X*TN"
|
||||
assert WAVE_TILE_M % (LANES_PER_WAVE_Y * TM) == 0, "WAVE_TILE_M must be divisible by LANES_PER_WAVE_Y*TM"
|
||||
|
||||
def hand_spec_kernel3():
|
||||
# ---------------------------
|
||||
# per-thread read mapping
|
||||
# ---------------------------
|
||||
# A: read BK x BN tiles; B: read BN x BK tiles
|
||||
tid = UOp.special(THREADS_PER_BLOCK, "lidx0")
|
||||
|
||||
threadIdx_x = UOp.special(BLOCK_SIZE, "lidx0")
|
||||
waveIndex = threadIdx_x // 32
|
||||
waveIdx = waveIndex % nbWaveX
|
||||
waveIdy = waveIndex // nbWaveX
|
||||
indexInWave = threadIdx_x % 32
|
||||
waveIdx = (tid // WARP_SIZE) % WAVES_IN_BLOCK_X
|
||||
waveIdy = (tid // WARP_SIZE) // WAVES_IN_BLOCK_X
|
||||
assert waveIdy.vmax+1 == WAVES_IN_BLOCK_Y
|
||||
|
||||
nbThreadXPerWave = 8
|
||||
nbThreadYPerWave = 4
|
||||
|
||||
idxInWave = indexInWave % nbThreadXPerWave
|
||||
idyInWave = indexInWave // nbThreadXPerWave
|
||||
|
||||
nbIterWaveN = WN // (nbThreadXPerWave * TN)
|
||||
nbIterWaveM = WM // (nbThreadYPerWave * TM)
|
||||
|
||||
SUBWN = WN // nbIterWaveN
|
||||
SUBWM = WM // nbIterWaveM
|
||||
idxInWave = (tid % WARP_SIZE) % LANES_PER_WAVE_X
|
||||
idyInWave = (tid % WARP_SIZE) // LANES_PER_WAVE_X
|
||||
assert idyInWave.vmax+1 == LANES_PER_WAVE_Y
|
||||
|
||||
# ---------------------------
|
||||
# block indices & placeholders
|
||||
# ---------------------------
|
||||
blockIdx_x = UOp.special(N // BN, "gidx0")
|
||||
blockIdx_y = UOp.special(N // BM, "gidx1")
|
||||
blockIdx_x = UOp.special(N // BLOCK_N, "gidx0")
|
||||
blockIdx_y = UOp.special(N // BLOCK_M, "gidx1")
|
||||
|
||||
a = UOp.placeholder(dtypes.float, (N, N), slot=1)
|
||||
b = UOp.placeholder(dtypes.float, (N, N), slot=2)
|
||||
c = UOp.placeholder(dtypes.float, (N, N), slot=0)
|
||||
|
||||
BM_As_stride = (BM + 4) if kernel5 else BM
|
||||
As = UOp.placeholder(dtypes.float, (BK, BM_As_stride), slot=0, addrspace=AddrSpace.LOCAL)
|
||||
Bs = UOp.placeholder(dtypes.float, (BK, BN), slot=1, addrspace=AddrSpace.LOCAL)
|
||||
BM_As_stride = (BLOCK_M + 4) if is_kernel5 else BLOCK_M
|
||||
As = UOp.placeholder(dtypes.float, (BLOCK_K, BM_As_stride), slot=0, addrspace=AddrSpace.LOCAL)
|
||||
Bs = UOp.placeholder(dtypes.float, (BLOCK_K, BLOCK_N), slot=1, addrspace=AddrSpace.LOCAL)
|
||||
|
||||
A_col = UOp.placeholder(dtypes.float, (nbIterWaveM, TM), slot=0, addrspace=AddrSpace.REG)
|
||||
B_row = UOp.placeholder(dtypes.float, (nbIterWaveN, TN), slot=1, addrspace=AddrSpace.REG)
|
||||
c_regs = UOp.placeholder(dtypes.float, (nbIterWaveM, TM, nbIterWaveN, TN), slot=2, addrspace=AddrSpace.REG)
|
||||
A_col = UOp.placeholder(dtypes.float, (ITERS_PER_WAVE_M, TM), slot=0, addrspace=AddrSpace.REG)
|
||||
B_row = UOp.placeholder(dtypes.float, (ITERS_PER_WAVE_N, TN), slot=1, addrspace=AddrSpace.REG)
|
||||
c_regs = UOp.placeholder(dtypes.float, (ITERS_PER_WAVE_M, TM, ITERS_PER_WAVE_N, TN), slot=2, addrspace=AddrSpace.REG)
|
||||
|
||||
i = UOp.range(c_regs.dtype.size, 16)
|
||||
i = UOp.range(c_regs.size, 16)
|
||||
c_regs = c_regs[i].set(0.0, end=i)
|
||||
|
||||
kId_range = UOp.range(N // BK, 0)
|
||||
kId = kId_range * BK
|
||||
k_tile_range = UOp.range(N // BLOCK_K, 0)
|
||||
|
||||
# ---------------------------
|
||||
# GLOBAL -> LOCAL (As, Bs)
|
||||
# ---------------------------
|
||||
nbReadsB = BN * BK // BLOCK_SIZE
|
||||
i = UOp.range(nbReadsB, 1)
|
||||
rBIdx = threadIdx_x % BN
|
||||
rBIdy = threadIdx_x // BN
|
||||
strideReadB = BLOCK_SIZE // BN
|
||||
index_x = BN * blockIdx_x + rBIdx
|
||||
index_y = rBIdy + i * strideReadB + kId
|
||||
Bs_store = Bs[index_y % BK, index_x % BN].store(b[index_y, index_x]).end(i)
|
||||
b = b.reshape((N // BLOCK_K, BLOCK_K,
|
||||
N // BLOCK_N, BLOCK_N))
|
||||
i = UOp.range(BLOCK_N * BLOCK_K // THREADS_PER_BLOCK, 1)
|
||||
index_x = tid % BLOCK_N
|
||||
index_y = (tid // BLOCK_N) + (THREADS_PER_BLOCK // BLOCK_N) * i
|
||||
Bs_store = Bs[index_y, index_x].store(b[k_tile_range, index_y, blockIdx_x, index_x]).end(i)
|
||||
|
||||
nbReadsA = BM * BK // BLOCK_SIZE
|
||||
i = UOp.range(nbReadsA, 2)
|
||||
rAIdx = threadIdx_x % BK
|
||||
rAIdy = threadIdx_x // BK
|
||||
strideReadA = BLOCK_SIZE // BK
|
||||
index_x = rAIdx + kId
|
||||
index_y = BM * blockIdx_y + rAIdy + i * strideReadA
|
||||
As_store = As[index_x % BK, index_y % BM].store(a[index_y, index_x]).end(i)
|
||||
a = a.reshape((N // BLOCK_M, BLOCK_M,
|
||||
N // BLOCK_K, BLOCK_K))
|
||||
i = UOp.range(BLOCK_M * BLOCK_K // THREADS_PER_BLOCK, 2)
|
||||
index_x = tid % BLOCK_K
|
||||
index_y = (tid // BLOCK_K) + (THREADS_PER_BLOCK // BLOCK_K) * i
|
||||
As_store = As[index_x, index_y].store(a[blockIdx_y, index_y, k_tile_range, index_x]).end(i)
|
||||
|
||||
# TODO: can we automate barrier?
|
||||
barrier = UOp.barrier(As_store, Bs_store)
|
||||
@@ -114,44 +108,45 @@ def hand_spec_kernel3(kernel5=getenv("K5", 0)):
|
||||
As = As.after(barrier)
|
||||
|
||||
# open inner k range
|
||||
k = UOp.range(BK, 3)
|
||||
k = UOp.range(BLOCK_K, 3)
|
||||
|
||||
# ---------------------------
|
||||
# LOCAL -> REG (per-wave tiles)
|
||||
# ---------------------------
|
||||
iterWave = UOp.range(nbIterWaveN, 4)
|
||||
iterWaveN = UOp.range(ITERS_PER_WAVE_N, 4)
|
||||
i = UOp.range(TN, 5)
|
||||
index = waveIdx * WN + iterWave * SUBWN + TN * idxInWave + i
|
||||
B_row = B_row[iterWave, i].set(Bs[k, index], end=(iterWave, i))
|
||||
index = waveIdx * WAVE_TILE_N + iterWaveN * N_PER_ITER + idxInWave * TN + i
|
||||
B_row = B_row[iterWaveN, i].set(Bs[k, index], end=(iterWaveN, i))
|
||||
|
||||
iterWave = UOp.range(nbIterWaveM, 6)
|
||||
iterWaveM = UOp.range(ITERS_PER_WAVE_M, 6)
|
||||
i = UOp.range(TM, 7)
|
||||
index = waveIdy * WM + iterWave * SUBWM + TM * idyInWave + i
|
||||
A_col = A_col[iterWave, i].set(As[k, index], end=(iterWave, i))
|
||||
index = waveIdy * WAVE_TILE_M + iterWaveM * M_PER_ITER + idyInWave * TM + i
|
||||
A_col = A_col[iterWaveM, i].set(As[k, index], end=(iterWaveM, i))
|
||||
|
||||
# ---------------------------
|
||||
# FMA: c_regs += A_col * B_row
|
||||
# ---------------------------
|
||||
iterWaveM = UOp.range(nbIterWaveM, 8)
|
||||
iterWaveM = UOp.range(ITERS_PER_WAVE_M, 8)
|
||||
yt = UOp.range(TM, 9)
|
||||
iterWaveN = UOp.range(nbIterWaveN, 10)
|
||||
iterWaveN = UOp.range(ITERS_PER_WAVE_N, 10)
|
||||
xt = UOp.range(TN, 12)
|
||||
c_idx = c_regs.after(k, kId_range)[iterWaveM, yt, iterWaveN, xt]
|
||||
c_idx = c_regs.after(k, k_tile_range)[iterWaveM, yt, iterWaveN, xt]
|
||||
sink = c_idx.store(c_idx + A_col[iterWaveM, yt] * B_row[iterWaveN, xt]).end(iterWaveM, iterWaveN, yt, xt)
|
||||
|
||||
# Close k, sync, and close K tiles
|
||||
sink = sink.end(k).barrier().end(kId_range)
|
||||
sink = sink.end(k).barrier().end(k_tile_range)
|
||||
|
||||
# ---------------------------
|
||||
# REG -> GLOBAL (epilogue)
|
||||
# ---------------------------
|
||||
iterWaveM = UOp.range(nbIterWaveM, 1000)
|
||||
c = c.reshape((N//BLOCK_M, WAVES_IN_BLOCK_Y, ITERS_PER_WAVE_M, LANES_PER_WAVE_Y, TM,
|
||||
N//BLOCK_N, WAVES_IN_BLOCK_X, ITERS_PER_WAVE_N, LANES_PER_WAVE_X, TN))
|
||||
iterWaveM = UOp.range(ITERS_PER_WAVE_M, 1000)
|
||||
yt = UOp.range(TM, 1001)
|
||||
iterWaveN = UOp.range(nbIterWaveN, 1002)
|
||||
iterWaveN = UOp.range(ITERS_PER_WAVE_N, 1002)
|
||||
xt = UOp.range(TN, 1003)
|
||||
xOut = blockIdx_x * BN + waveIdx * WN + iterWaveN * SUBWN + TN * idxInWave
|
||||
yOut = blockIdx_y * BM + waveIdy * WM + iterWaveM * SUBWM + TM * idyInWave
|
||||
sink = c[yOut + yt, xOut + xt].store(c_regs.after(sink)[iterWaveM, yt, iterWaveN, xt])
|
||||
c_glbl_idx = c[blockIdx_y, waveIdy, iterWaveM, idyInWave, yt, blockIdx_x, waveIdx, iterWaveN, idxInWave, xt]
|
||||
sink = c_glbl_idx.store(c_regs.after(sink)[iterWaveM, yt, iterWaveN, xt])
|
||||
sink = sink.end(iterWaveM, iterWaveN, yt, xt)
|
||||
|
||||
return sink.sink(arg=KernelInfo(opts_to_apply=()))
|
||||
|
||||
@@ -750,9 +750,7 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
|
||||
ctx: dict[UOp, str] = {}
|
||||
pm = renderer if pm is None else pm
|
||||
for u in (s:=self.simplify() if simplify else self).toposort():
|
||||
# if there is any node in the toposort we can't render, we just render the whole thing using UOp pretty printer
|
||||
if (u_str:=pm.rewrite(u, ctx=ctx)) is None: return str(s)
|
||||
ctx[u] = cast(str, u_str)
|
||||
ctx[u] = cast(str, pm.rewrite(u, ctx=ctx))
|
||||
return ctx[s]
|
||||
|
||||
def pyrender(self): return pyrender(self)
|
||||
@@ -1277,6 +1275,7 @@ renderer = PatternMatcher([
|
||||
(UPat((Ops.INDEX, Ops.BUFFERIZE), name="x"), lambda x, ctx: ''.join([f"[{strip_parens(ctx[y])}]" for y in x.src[1:]])),
|
||||
(UPat(Ops.VECTORIZE, name="x"),
|
||||
lambda ctx,x: f"{{{','.join([ctx[y] for y in x.src])}}}" if not all_same(x.src) else f"{{{ctx[x.src[0]]}, ...}}"),
|
||||
(UPat(GroupOp.All, name="x"), lambda x: str(x)),
|
||||
])
|
||||
|
||||
renderer_infer = PatternMatcher([
|
||||
|
||||
@@ -14,9 +14,9 @@ from tinygrad.renderer import ProgramSpec
|
||||
from tinygrad.dtype import dtypes
|
||||
|
||||
uops_colors = {Ops.LOAD: "#ffc0c0", Ops.STORE: "#87CEEB", Ops.CONST: "#e0e0e0", Ops.VCONST: "#e0e0e0", Ops.REDUCE: "#FF5B5B",
|
||||
Ops.DEFINE_GLOBAL: "#ffe0b0", Ops.DEFINE_LOCAL: "#ffe0d0", Ops.DEFINE_REG: "#f0ffe0", Ops.REDUCE_AXIS: "#FF6B6B",
|
||||
**{x:"#f2cb91" for x in GroupOp.Defines}, Ops.REDUCE_AXIS: "#FF6B6B",
|
||||
Ops.RANGE: "#c8a0e0", Ops.ASSIGN: "#909090", Ops.BARRIER: "#ff8080", Ops.IF: "#c8b0c0", Ops.SPECIAL: "#c0c0ff",
|
||||
Ops.INDEX: "#e8ffa0", Ops.WMMA: "#efefc0", Ops.MULTI: "#f6ccff", Ops.KERNEL: "#3e7f55",
|
||||
Ops.INDEX: "#cef263", Ops.WMMA: "#efefc0", Ops.MULTI: "#f6ccff", Ops.KERNEL: "#3e7f55",
|
||||
**{x:"#D8F9E4" for x in GroupOp.Movement}, **{x:"#ffffc0" for x in GroupOp.ALU}, Ops.THREEFRY:"#ffff80",
|
||||
Ops.BUFFER_VIEW: "#E5EAFF", Ops.BUFFER: "#B0BDFF", Ops.COPY: "#a040a0", Ops.FUSE: "#FFa500",
|
||||
Ops.ALLREDUCE: "#ff40a0", Ops.MSELECT: "#d040a0", Ops.MSTACK: "#d040a0", Ops.CONTIGUOUS: "#FFC14D",
|
||||
|
||||
Reference in New Issue
Block a user