From 512513c403605fa58be86296e6494b291b961565 Mon Sep 17 00:00:00 2001 From: George Hotz <72895+geohot@users.noreply.github.com> Date: Fri, 31 Oct 2025 10:04:45 +0800 Subject: [PATCH] cleanup amd uop matmul (#13025) * cleanup amd uop matmul * remove mod * move that out * better variable names * var names * more * render fallback * colors --- extra/gemm/amd_uop_matmul.py | 165 +++++++++++++++++------------------ tinygrad/uop/ops.py | 5 +- tinygrad/viz/serve.py | 4 +- 3 files changed, 84 insertions(+), 90 deletions(-) diff --git a/extra/gemm/amd_uop_matmul.py b/extra/gemm/amd_uop_matmul.py index febc6b2098..1528ef17f1 100644 --- a/extra/gemm/amd_uop_matmul.py +++ b/extra/gemm/amd_uop_matmul.py @@ -7,106 +7,100 @@ from tinygrad.helpers import getenv N = 4096 run_count = 5 -# block for locals -BN = 128 -BM = 128 -BK = 8 +# --------------------------- +# launch/config constants +# --------------------------- -# t for registers -TN = 4 -TM = 4 +WARP_SIZE = 32 +# Threadblock tile sizes (block-level tile of C that a block computes) +BLOCK_N = 128 # columns of C (N-dim) per block +BLOCK_M = 128 # rows of C (M-dim) per block +BLOCK_K = 8 # K-slice per block iteration -def hand_spec_kernel3(kernel5=getenv("K5", 0)): - # --------------------------- - # launch/config constants - # --------------------------- +# Register tile sizes (per-thread accumulator tile of C) +TN = 4 # columns per thread +TM = 4 # rows per thread - BLOCK_SIZE = 128 if kernel5 else 256 +is_kernel5 = getenv("K5", 0) +THREADS_PER_BLOCK = 128 if is_kernel5 else 256 +assert THREADS_PER_BLOCK % BLOCK_N == 0, "THREADS_PER_BLOCK must be divisible by BLOCK_N" +assert THREADS_PER_BLOCK % BLOCK_K == 0, "THREADS_PER_BLOCK must be divisible by BLOCK_K" +assert (BLOCK_N * BLOCK_K) % THREADS_PER_BLOCK == 0 +assert (BLOCK_M * BLOCK_K) % THREADS_PER_BLOCK == 0 - nbWaves = BLOCK_SIZE // 32 - WN = 128 if kernel5 else 64 - WM = BN * BM // nbWaves // WN +WARPS_PER_BLOCK = THREADS_PER_BLOCK // WARP_SIZE +WAVE_TILE_N = 128 if is_kernel5 else 64 +WAVE_TILE_M = BLOCK_N * BLOCK_M // WARPS_PER_BLOCK // WAVE_TILE_N +assert BLOCK_N % WAVE_TILE_N == 0, "BN must be a multiple of WN" +assert BLOCK_M % WAVE_TILE_M == 0, "BM must be a multiple of WM" +WAVES_IN_BLOCK_X = BLOCK_N // WAVE_TILE_N +WAVES_IN_BLOCK_Y = BLOCK_M // WAVE_TILE_M +assert WAVES_IN_BLOCK_X * WAVES_IN_BLOCK_Y == WARPS_PER_BLOCK, "wave grid must match warps/block" - # Sanity checks (fail fast if shapes/tiles misalign) - assert BN % WN == 0, "BN must be a multiple of WN" - assert BM % WM == 0, "BM must be a multiple of WM" - nbWaveX = BN // WN - nbWaveY = BM // WM - - assert BLOCK_SIZE % BN == 0, "BLOCK_SIZE must be divisible by BN" - assert BLOCK_SIZE % BK == 0, "BLOCK_SIZE must be divisible by BK" - - assert (BN * BK) % BLOCK_SIZE == 0 - assert (BM * BK) % BLOCK_SIZE == 0 +LANES_PER_WAVE_X = 8 +LANES_PER_WAVE_Y = 4 +ITERS_PER_WAVE_N = WAVE_TILE_N // (LANES_PER_WAVE_X * TN) +ITERS_PER_WAVE_M = WAVE_TILE_M // (LANES_PER_WAVE_Y * TM) +N_PER_ITER = WAVE_TILE_N // ITERS_PER_WAVE_N +M_PER_ITER = WAVE_TILE_M // ITERS_PER_WAVE_M +assert WAVE_TILE_N % (LANES_PER_WAVE_X * TN) == 0, "WAVE_TILE_N must be divisible by LANES_PER_WAVE_X*TN" +assert WAVE_TILE_M % (LANES_PER_WAVE_Y * TM) == 0, "WAVE_TILE_M must be divisible by LANES_PER_WAVE_Y*TM" +def hand_spec_kernel3(): # --------------------------- # per-thread read mapping # --------------------------- # A: read BK x BN tiles; B: read BN x BK tiles + tid = UOp.special(THREADS_PER_BLOCK, "lidx0") - threadIdx_x = UOp.special(BLOCK_SIZE, "lidx0") - waveIndex = threadIdx_x // 32 - waveIdx = waveIndex % nbWaveX - waveIdy = waveIndex // nbWaveX - indexInWave = threadIdx_x % 32 + waveIdx = (tid // WARP_SIZE) % WAVES_IN_BLOCK_X + waveIdy = (tid // WARP_SIZE) // WAVES_IN_BLOCK_X + assert waveIdy.vmax+1 == WAVES_IN_BLOCK_Y - nbThreadXPerWave = 8 - nbThreadYPerWave = 4 - - idxInWave = indexInWave % nbThreadXPerWave - idyInWave = indexInWave // nbThreadXPerWave - - nbIterWaveN = WN // (nbThreadXPerWave * TN) - nbIterWaveM = WM // (nbThreadYPerWave * TM) - - SUBWN = WN // nbIterWaveN - SUBWM = WM // nbIterWaveM + idxInWave = (tid % WARP_SIZE) % LANES_PER_WAVE_X + idyInWave = (tid % WARP_SIZE) // LANES_PER_WAVE_X + assert idyInWave.vmax+1 == LANES_PER_WAVE_Y # --------------------------- # block indices & placeholders # --------------------------- - blockIdx_x = UOp.special(N // BN, "gidx0") - blockIdx_y = UOp.special(N // BM, "gidx1") + blockIdx_x = UOp.special(N // BLOCK_N, "gidx0") + blockIdx_y = UOp.special(N // BLOCK_M, "gidx1") a = UOp.placeholder(dtypes.float, (N, N), slot=1) b = UOp.placeholder(dtypes.float, (N, N), slot=2) c = UOp.placeholder(dtypes.float, (N, N), slot=0) - BM_As_stride = (BM + 4) if kernel5 else BM - As = UOp.placeholder(dtypes.float, (BK, BM_As_stride), slot=0, addrspace=AddrSpace.LOCAL) - Bs = UOp.placeholder(dtypes.float, (BK, BN), slot=1, addrspace=AddrSpace.LOCAL) + BM_As_stride = (BLOCK_M + 4) if is_kernel5 else BLOCK_M + As = UOp.placeholder(dtypes.float, (BLOCK_K, BM_As_stride), slot=0, addrspace=AddrSpace.LOCAL) + Bs = UOp.placeholder(dtypes.float, (BLOCK_K, BLOCK_N), slot=1, addrspace=AddrSpace.LOCAL) - A_col = UOp.placeholder(dtypes.float, (nbIterWaveM, TM), slot=0, addrspace=AddrSpace.REG) - B_row = UOp.placeholder(dtypes.float, (nbIterWaveN, TN), slot=1, addrspace=AddrSpace.REG) - c_regs = UOp.placeholder(dtypes.float, (nbIterWaveM, TM, nbIterWaveN, TN), slot=2, addrspace=AddrSpace.REG) + A_col = UOp.placeholder(dtypes.float, (ITERS_PER_WAVE_M, TM), slot=0, addrspace=AddrSpace.REG) + B_row = UOp.placeholder(dtypes.float, (ITERS_PER_WAVE_N, TN), slot=1, addrspace=AddrSpace.REG) + c_regs = UOp.placeholder(dtypes.float, (ITERS_PER_WAVE_M, TM, ITERS_PER_WAVE_N, TN), slot=2, addrspace=AddrSpace.REG) - i = UOp.range(c_regs.dtype.size, 16) + i = UOp.range(c_regs.size, 16) c_regs = c_regs[i].set(0.0, end=i) - kId_range = UOp.range(N // BK, 0) - kId = kId_range * BK + k_tile_range = UOp.range(N // BLOCK_K, 0) # --------------------------- # GLOBAL -> LOCAL (As, Bs) # --------------------------- - nbReadsB = BN * BK // BLOCK_SIZE - i = UOp.range(nbReadsB, 1) - rBIdx = threadIdx_x % BN - rBIdy = threadIdx_x // BN - strideReadB = BLOCK_SIZE // BN - index_x = BN * blockIdx_x + rBIdx - index_y = rBIdy + i * strideReadB + kId - Bs_store = Bs[index_y % BK, index_x % BN].store(b[index_y, index_x]).end(i) + b = b.reshape((N // BLOCK_K, BLOCK_K, + N // BLOCK_N, BLOCK_N)) + i = UOp.range(BLOCK_N * BLOCK_K // THREADS_PER_BLOCK, 1) + index_x = tid % BLOCK_N + index_y = (tid // BLOCK_N) + (THREADS_PER_BLOCK // BLOCK_N) * i + Bs_store = Bs[index_y, index_x].store(b[k_tile_range, index_y, blockIdx_x, index_x]).end(i) - nbReadsA = BM * BK // BLOCK_SIZE - i = UOp.range(nbReadsA, 2) - rAIdx = threadIdx_x % BK - rAIdy = threadIdx_x // BK - strideReadA = BLOCK_SIZE // BK - index_x = rAIdx + kId - index_y = BM * blockIdx_y + rAIdy + i * strideReadA - As_store = As[index_x % BK, index_y % BM].store(a[index_y, index_x]).end(i) + a = a.reshape((N // BLOCK_M, BLOCK_M, + N // BLOCK_K, BLOCK_K)) + i = UOp.range(BLOCK_M * BLOCK_K // THREADS_PER_BLOCK, 2) + index_x = tid % BLOCK_K + index_y = (tid // BLOCK_K) + (THREADS_PER_BLOCK // BLOCK_K) * i + As_store = As[index_x, index_y].store(a[blockIdx_y, index_y, k_tile_range, index_x]).end(i) # TODO: can we automate barrier? barrier = UOp.barrier(As_store, Bs_store) @@ -114,44 +108,45 @@ def hand_spec_kernel3(kernel5=getenv("K5", 0)): As = As.after(barrier) # open inner k range - k = UOp.range(BK, 3) + k = UOp.range(BLOCK_K, 3) # --------------------------- # LOCAL -> REG (per-wave tiles) # --------------------------- - iterWave = UOp.range(nbIterWaveN, 4) + iterWaveN = UOp.range(ITERS_PER_WAVE_N, 4) i = UOp.range(TN, 5) - index = waveIdx * WN + iterWave * SUBWN + TN * idxInWave + i - B_row = B_row[iterWave, i].set(Bs[k, index], end=(iterWave, i)) + index = waveIdx * WAVE_TILE_N + iterWaveN * N_PER_ITER + idxInWave * TN + i + B_row = B_row[iterWaveN, i].set(Bs[k, index], end=(iterWaveN, i)) - iterWave = UOp.range(nbIterWaveM, 6) + iterWaveM = UOp.range(ITERS_PER_WAVE_M, 6) i = UOp.range(TM, 7) - index = waveIdy * WM + iterWave * SUBWM + TM * idyInWave + i - A_col = A_col[iterWave, i].set(As[k, index], end=(iterWave, i)) + index = waveIdy * WAVE_TILE_M + iterWaveM * M_PER_ITER + idyInWave * TM + i + A_col = A_col[iterWaveM, i].set(As[k, index], end=(iterWaveM, i)) # --------------------------- # FMA: c_regs += A_col * B_row # --------------------------- - iterWaveM = UOp.range(nbIterWaveM, 8) + iterWaveM = UOp.range(ITERS_PER_WAVE_M, 8) yt = UOp.range(TM, 9) - iterWaveN = UOp.range(nbIterWaveN, 10) + iterWaveN = UOp.range(ITERS_PER_WAVE_N, 10) xt = UOp.range(TN, 12) - c_idx = c_regs.after(k, kId_range)[iterWaveM, yt, iterWaveN, xt] + c_idx = c_regs.after(k, k_tile_range)[iterWaveM, yt, iterWaveN, xt] sink = c_idx.store(c_idx + A_col[iterWaveM, yt] * B_row[iterWaveN, xt]).end(iterWaveM, iterWaveN, yt, xt) # Close k, sync, and close K tiles - sink = sink.end(k).barrier().end(kId_range) + sink = sink.end(k).barrier().end(k_tile_range) # --------------------------- # REG -> GLOBAL (epilogue) # --------------------------- - iterWaveM = UOp.range(nbIterWaveM, 1000) + c = c.reshape((N//BLOCK_M, WAVES_IN_BLOCK_Y, ITERS_PER_WAVE_M, LANES_PER_WAVE_Y, TM, + N//BLOCK_N, WAVES_IN_BLOCK_X, ITERS_PER_WAVE_N, LANES_PER_WAVE_X, TN)) + iterWaveM = UOp.range(ITERS_PER_WAVE_M, 1000) yt = UOp.range(TM, 1001) - iterWaveN = UOp.range(nbIterWaveN, 1002) + iterWaveN = UOp.range(ITERS_PER_WAVE_N, 1002) xt = UOp.range(TN, 1003) - xOut = blockIdx_x * BN + waveIdx * WN + iterWaveN * SUBWN + TN * idxInWave - yOut = blockIdx_y * BM + waveIdy * WM + iterWaveM * SUBWM + TM * idyInWave - sink = c[yOut + yt, xOut + xt].store(c_regs.after(sink)[iterWaveM, yt, iterWaveN, xt]) + c_glbl_idx = c[blockIdx_y, waveIdy, iterWaveM, idyInWave, yt, blockIdx_x, waveIdx, iterWaveN, idxInWave, xt] + sink = c_glbl_idx.store(c_regs.after(sink)[iterWaveM, yt, iterWaveN, xt]) sink = sink.end(iterWaveM, iterWaveN, yt, xt) return sink.sink(arg=KernelInfo(opts_to_apply=())) diff --git a/tinygrad/uop/ops.py b/tinygrad/uop/ops.py index a7bd23d659..11b62135d3 100644 --- a/tinygrad/uop/ops.py +++ b/tinygrad/uop/ops.py @@ -750,9 +750,7 @@ class UOp(MathTrait, metaclass=UOpMetaClass): ctx: dict[UOp, str] = {} pm = renderer if pm is None else pm for u in (s:=self.simplify() if simplify else self).toposort(): - # if there is any node in the toposort we can't render, we just render the whole thing using UOp pretty printer - if (u_str:=pm.rewrite(u, ctx=ctx)) is None: return str(s) - ctx[u] = cast(str, u_str) + ctx[u] = cast(str, pm.rewrite(u, ctx=ctx)) return ctx[s] def pyrender(self): return pyrender(self) @@ -1277,6 +1275,7 @@ renderer = PatternMatcher([ (UPat((Ops.INDEX, Ops.BUFFERIZE), name="x"), lambda x, ctx: ''.join([f"[{strip_parens(ctx[y])}]" for y in x.src[1:]])), (UPat(Ops.VECTORIZE, name="x"), lambda ctx,x: f"{{{','.join([ctx[y] for y in x.src])}}}" if not all_same(x.src) else f"{{{ctx[x.src[0]]}, ...}}"), + (UPat(GroupOp.All, name="x"), lambda x: str(x)), ]) renderer_infer = PatternMatcher([ diff --git a/tinygrad/viz/serve.py b/tinygrad/viz/serve.py index c734701cfe..e6bb9fedf0 100755 --- a/tinygrad/viz/serve.py +++ b/tinygrad/viz/serve.py @@ -14,9 +14,9 @@ from tinygrad.renderer import ProgramSpec from tinygrad.dtype import dtypes uops_colors = {Ops.LOAD: "#ffc0c0", Ops.STORE: "#87CEEB", Ops.CONST: "#e0e0e0", Ops.VCONST: "#e0e0e0", Ops.REDUCE: "#FF5B5B", - Ops.DEFINE_GLOBAL: "#ffe0b0", Ops.DEFINE_LOCAL: "#ffe0d0", Ops.DEFINE_REG: "#f0ffe0", Ops.REDUCE_AXIS: "#FF6B6B", + **{x:"#f2cb91" for x in GroupOp.Defines}, Ops.REDUCE_AXIS: "#FF6B6B", Ops.RANGE: "#c8a0e0", Ops.ASSIGN: "#909090", Ops.BARRIER: "#ff8080", Ops.IF: "#c8b0c0", Ops.SPECIAL: "#c0c0ff", - Ops.INDEX: "#e8ffa0", Ops.WMMA: "#efefc0", Ops.MULTI: "#f6ccff", Ops.KERNEL: "#3e7f55", + Ops.INDEX: "#cef263", Ops.WMMA: "#efefc0", Ops.MULTI: "#f6ccff", Ops.KERNEL: "#3e7f55", **{x:"#D8F9E4" for x in GroupOp.Movement}, **{x:"#ffffc0" for x in GroupOp.ALU}, Ops.THREEFRY:"#ffff80", Ops.BUFFER_VIEW: "#E5EAFF", Ops.BUFFER: "#B0BDFF", Ops.COPY: "#a040a0", Ops.FUSE: "#FFa500", Ops.ALLREDUCE: "#ff40a0", Ops.MSELECT: "#d040a0", Ops.MSTACK: "#d040a0", Ops.CONTIGUOUS: "#FFC14D",