diff --git a/extra/gemm/amd_matmul.py b/extra/gemm/amd_matmul.py index 3c4cfbd25c..f7f0c0193e 100644 --- a/extra/gemm/amd_matmul.py +++ b/extra/gemm/amd_matmul.py @@ -1,96 +1,39 @@ # kernel8_batched_gmem.s from https://seb-v.github.io/optimization/update/2025/01/20/Fast-GPU-Matrix-multiplication.html +# sudo PATH=/opt/homebrew/Cellar/llvm/20.1.6/bin:$PATH AMD_LLVM=0 AMD=1 DEBUG=2 python3 extra/gemm/amd_matmul.py import pathlib -import numpy as np from dataclasses import replace -from tinygrad import Tensor, Device, Context +from tinygrad import Tensor, Device, Context, GlobalCounters from tinygrad.helpers import getenv -from tinygrad.opt.kernel import Kernel, Opt, OptOps -from tinygrad.engine.realize import CompiledRunner, ExecItem -from tinygrad.uop.ops import graph_rewrite, PatternMatcher, UPat, Ops, UOp - -# TODO: on METAL for `DEBUG=4 python3 extra/gemm/amd_matmul.py` -# * fix load grouping (like float4). idk why it's not working, need new devectorizer (this is a Monday project) -# * DONE - remove extra barrier -# * DONE (moved Ops.ADD) - fix load order to be in order (the +0 one is last!) -# * explore async (fast) global load -> local store -# * why is TC=3 broken for 4096x4096? -# * write syntactic sugar for these local additions + use it in tensor core kernel.py +from tinygrad.engine.realize import CompiledRunner, ExecItem, get_program N = 4096 -LN = 16 run_count = 5 -from tinygrad.shape.shapetracker import ShapeTracker, View -def transform_load(ctx:tuple[Kernel, set[UOp]], x:UOp): - if x.src[0].op is not Ops.DEFINE_GLOBAL: return None - if x in ctx[1]: return None - print(ctx[0].colored_shape()) - ctx[1].add(x) - input_st: ShapeTracker = x.src[1].arg - #strides = input_st.real_strides() - #strides = (0,0)+strides[2:] - if input_st.real_strides()[2] == 0: - perm = (0,1,5,3,4,2) - strides = (0,0,LN*4,4,0,0,1,0) - elif input_st.real_strides()[3] == 0: - perm = (0,1,2,5,4,3) - strides = (0,0,LN*4,4,0,0,0,1) - else: - return None - if len(input_st.shape) == 8: - local_st = ShapeTracker(views=(View.create((1,1,LN,LN,1,1,4,4), strides),)) - perm = perm + (6,7) - else: - local_st = ShapeTracker(views=(View.create((1,1,LN,LN,1,1)),)) - #local_st = ShapeTracker(views=(View.create((1,1,LN,LN,1,1)),)) - load_st = local_st.permute(perm) - input_st = input_st.permute(perm) - lcl = UOp(Ops.DEFINE_LOCAL, x.dtype.ptr(local_st.real_size(), local=True), (), f"temp{x.src[0].arg}") - global_load = x.replace(src=(x.src[0], input_st.to_uop())) - ret = UOp(Ops.STORE, src=(lcl, local_st.to_uop(), global_load)) - return UOp(Ops.LOAD, x.dtype, src=(lcl, load_st.to_uop(), ret)) - -local_loads_pm = PatternMatcher([ - (UPat(Ops.LOAD, name="x"), transform_load), -]) - -def ast_transform(k, ast): - #return ast - ast = graph_rewrite(ast, local_loads_pm, ctx=(k, set())) - #ast = ast.replace(arg=replace(ast.arg, upcasted=0)) - print(ast) - return ast - if __name__ == "__main__": - rng = np.random.default_rng() - a = Tensor(na:=rng.random((4096, 4096), dtype=np.float32)).realize() - b = Tensor(nb:=rng.random((4096, 4096), dtype=np.float32)).realize() - c = a @ b - si = c.schedule()[-1] - k = Kernel(si.ast, opts=Device[Device.DEFAULT].renderer) - #opts = [Opt(op=OptOps.LOCAL, axis=1, arg=16), - # Opt(op=OptOps.LOCAL, axis=0, arg=8), - # Opt(op=OptOps.UPCAST, axis=2, arg=4), - # Opt(op=OptOps.UPCAST, axis=1, arg=4), - # Opt(op=OptOps.UPCAST, axis=0, arg=2)] - #opts = [Opt(op=OptOps.UPCAST, axis=1, arg=4), - # Opt(op=OptOps.UPCAST, axis=0, arg=4), - # Opt(op=OptOps.LOCAL, axis=1, arg=8), - # Opt(op=OptOps.LOCAL, axis=0, arg=4)] - opts = [Opt(op=OptOps.UNROLL, axis=0, arg=LN), - #Opt(op=OptOps.UPCAST, axis=0, arg=4), - #Opt(op=OptOps.UPCAST, axis=1, arg=4), - Opt(op=OptOps.LOCAL, axis=1, arg=LN), - Opt(op=OptOps.LOCAL, axis=0, arg=LN)] - k.apply_opts(opts) - prg = k.to_program(ast_transform=ast_transform) - if getenv("FAST", 1) and Device.DEFAULT == "AMD": - #src = (pathlib.Path(__file__).parent / "fp32_sgemm_amd" / "src" / "kernel8_batched_gmem.s").read_text() - src = (pathlib.Path(__file__).parent / "kernel8_batched_gmem.s").read_text() - prg = replace(prg, src=src, global_size=[N//128, N//128, 1], local_size=[128, 1, 1]) - print(prg.global_size, prg.local_size) - ei = ExecItem(CompiledRunner(prg), [x.ensure_allocated() for x in si.bufs], si.metadata) + ast = (Tensor.empty(N, N)@Tensor.empty(N, N)).schedule()[-1].ast + prg = get_program(ast, Device.default.renderer) + + if getenv("ASM") == 1: + src = (pathlib.Path(__file__).parent / "amd_seb" / "kernel8_batched_gmem.s").read_text() + prgfast = replace(prg, name="kernel", src=src, global_size=[N//128, N//128, 1], local_size=[128, 1, 1]) + elif getenv("ASM") == -1: + src = (pathlib.Path(__file__).parent / "amd_seb" / "kernel3_registers.cpp").read_text() + prgfast = replace(prg, name="kernel3_registers", src=src, global_size=[N//128, N//128, 1], local_size=[256, 1, 1]) + else: + src = (pathlib.Path(__file__).parent / "amd_seb" / "kernel5_lds_optim.cpp").read_text() + prgfast = replace(prg, name="kernel5_lds_optim", src=src, global_size=[N//128, N//128, 1], local_size=[128, 1, 1]) + runner = CompiledRunner(prgfast) + + a = Tensor.randn(N, N).realize() + b = Tensor.randn(N, N).realize() + c = Tensor.zeros(N, N).contiguous().realize() + + GlobalCounters.reset() + with Context(DEBUG=2, BEAM=4): + for _ in range(run_count): tc = (a@b).realize() + + GlobalCounters.reset() + ei = ExecItem(runner, [a.uop.buffer, b.uop.buffer, c.uop.buffer]) with Context(DEBUG=2): for _ in range(run_count): ei.run(wait=True) - nc = c.numpy() - np.testing.assert_allclose(na@nb, nc, rtol=1e-5) + print(f"custom {(c-tc).square().mean().item()}") diff --git a/extra/gemm/amd_seb/kernel3_registers.cpp b/extra/gemm/amd_seb/kernel3_registers.cpp new file mode 100644 index 0000000000..4ce5511ece --- /dev/null +++ b/extra/gemm/amd_seb/kernel3_registers.cpp @@ -0,0 +1,141 @@ +typedef long unsigned int size_t; +extern "C" __attribute__((device, const)) size_t __ockl_get_local_id(unsigned int); +extern "C" __attribute__((device, const)) size_t __ockl_get_group_id(unsigned int); +struct Dim3 { size_t x, y, z; }; +#define __shared__ __attribute__((shared, aligned(16))) +__attribute__((device)) inline void __syncthreads() { + __builtin_amdgcn_fence(__ATOMIC_RELEASE, "workgroup"); + __builtin_amdgcn_s_barrier(); + __builtin_amdgcn_fence(__ATOMIC_ACQUIRE, "workgroup"); +} + +#define BLOCK_SIZE 256 +extern "C" __attribute__((global)) void kernel3_registers(float *a, float *b, float *c) +{ + constexpr int N = 4096; + constexpr float alpha = 1.0; + constexpr float beta = 0.0; + + const Dim3 blockIdx{ __ockl_get_group_id(0), __ockl_get_group_id(1), __ockl_get_group_id(2) }; + const Dim3 threadIdx{ __ockl_get_local_id(0), __ockl_get_local_id(1), __ockl_get_local_id(2) }; + + // Block Tile size + constexpr int BN = 128; + constexpr int BM = 128; + // Number of Row or column we read per batch + constexpr int BK = 8; + + // Thread Tile size + constexpr int TN = 4; + constexpr int TM = 4; + + constexpr int nbWaves = BLOCK_SIZE / 32; + // Wave Tile size + constexpr int WN = 64; + constexpr int WM = BN * BM / nbWaves / WN; + + // Number of wave on X & Y axis in the Block tile + constexpr int nbWaveX = BN / WN; + constexpr int nbWaveY = BM / WM; + + const int waveIndex = threadIdx.x / 32; + const int waveIdx = waveIndex % nbWaveX; + const int waveIdy = waveIndex / nbWaveX; + const int indexInWave = threadIdx.x % 32; + + // A wave is a block of 8x4 of the output matrix + constexpr int nbThreadXPerWave = 8; + constexpr int nbThreadYPerWave = 4; + + // Thread coordinates in Wave + const int idxInWave = indexInWave % nbThreadXPerWave; + const int idyInWave = indexInWave / nbThreadXPerWave; + + constexpr int nbIterWaveN = WN / (nbThreadXPerWave * TN); + constexpr int nbIterWaveM = WM / (nbThreadYPerWave * TM); + + // Wave Sub-tile size + constexpr int SUBWN = WN / nbIterWaveN; + constexpr int SUBWM = WM / nbIterWaveM; + + // Thread mapping to read BKxBN block from A + int rAIdx = threadIdx.x % BK; + int rAIdy = threadIdx.x / BK; + // Thread mapping to read BNxBK block from B + int rBIdx = threadIdx.x % BN; + int rBIdy = threadIdx.x / BN; + + constexpr int strideReadB = BLOCK_SIZE / BN; + constexpr int strideReadA = BLOCK_SIZE / BK; + constexpr int nbReadsB = BN * BK / BLOCK_SIZE; + constexpr int nbReadsA = BM * BK / BLOCK_SIZE; + + float A_col[nbIterWaveM * TM]; + float B_row[nbIterWaveN * TN]; + + __shared__ float As[BK][BM]; + __shared__ float Bs[BK][BN]; + + float c_regs[TM * nbIterWaveM * TN * nbIterWaveN] = {0.0f}; + + // Iteration over BK blocks. + for (int kId = 0; kId < N; kId += BK) { + // We populate the Shared Memory with Ks row and columns + for (int i = 0; i < nbReadsB; i++) { + int index_x = BN * blockIdx.x + rBIdx; + int index_y = rBIdy + i * strideReadB + kId; + Bs[index_y % BK][index_x % BN] = b[N * index_y + index_x]; + } + + for (int i = 0; i < nbReadsA; i++) { + int index_x = rAIdx + kId; + int index_y = BM * blockIdx.y + rAIdy + i * strideReadA; + As[(index_x % BK)][(index_y % BM)] = a[N * index_y + index_x]; + } + + __syncthreads(); + for (int k = 0; k < BK; k++) { + // we cache A & B for the entire Wave tile + for (int iterWave = 0; iterWave < nbIterWaveN; iterWave++) { + for (int i = 0; i < TN; i++) { + int index = waveIdx * WN + iterWave * SUBWN + TN * idxInWave + i; + B_row[iterWave * TN + i] = Bs[k][index]; + } + } + + for (int iterWave = 0; iterWave < nbIterWaveM; iterWave++) { + for (int i = 0; i < TM; i++) { + int index = waveIdy * WM + iterWave * SUBWM + TM * idyInWave + i; + A_col[iterWave * TM + i] = As[k][index]; + } + } + + // we accumulate to C_regs + for (int iterWaveM = 0; iterWaveM < nbIterWaveM; iterWaveM++) { + for (int iterWaveN = 0; iterWaveN < nbIterWaveN; iterWaveN++) { + for (int yt = 0; yt < TM; yt++) { + for (int xt = 0; xt < TN; xt++) { + const int x = iterWaveN * TN + xt; + const int y = iterWaveM * TM + yt; + c_regs[y * TN * nbIterWaveN + x] += A_col[y] * B_row[x]; + } + } + } + } + } + __syncthreads(); + } + + for (int iterWaveM = 0; iterWaveM < nbIterWaveM; iterWaveM++) { + for (int iterWaveN = 0; iterWaveN < nbIterWaveN; iterWaveN++) { + int xOut = blockIdx.x * BN + waveIdx * WN + iterWaveN * SUBWN + TN * idxInWave; + int yOut = blockIdx.y * BM + waveIdy * WM + iterWaveM * SUBWM + TM * idyInWave; + for (int yt = 0; yt < TM; yt++) { + for (int xt = 0; xt < TN; xt++) { + int indexC = N * (yOut + yt) + xOut + xt; + c[indexC] = beta * c[indexC] + alpha * c_regs[TN * nbIterWaveN * (iterWaveM * TM + yt) + (iterWaveN * TN + xt)]; + } + } + } + } +} diff --git a/extra/gemm/amd_seb/kernel5_lds_optim.cpp b/extra/gemm/amd_seb/kernel5_lds_optim.cpp new file mode 100644 index 0000000000..965f3d54d2 --- /dev/null +++ b/extra/gemm/amd_seb/kernel5_lds_optim.cpp @@ -0,0 +1,172 @@ +typedef long unsigned int size_t; +extern "C" __attribute__((device, const)) size_t __ockl_get_local_id(unsigned int); +extern "C" __attribute__((device, const)) size_t __ockl_get_group_id(unsigned int); +struct Dim3 { size_t x, y, z; }; +#define __shared__ __attribute__((shared, aligned(16))) +__attribute__((device)) inline void __syncthreads() { + __builtin_amdgcn_fence(__ATOMIC_RELEASE, "workgroup"); + __builtin_amdgcn_s_barrier(); + __builtin_amdgcn_fence(__ATOMIC_ACQUIRE, "workgroup"); +} + +#define BLOCK_SIZE 128 +extern "C" __attribute__((global)) void __attribute__((amdgpu_flat_work_group_size(1, BLOCK_SIZE))) +kernel5_lds_optim(float *a, float *b, float *c) +{ + constexpr int N = 4096; + constexpr float alpha = 1.0; + constexpr float beta = 0.0; + + const Dim3 blockIdx{ __ockl_get_group_id(0), __ockl_get_group_id(1), __ockl_get_group_id(2) }; + const Dim3 threadIdx{ __ockl_get_local_id(0), __ockl_get_local_id(1), __ockl_get_local_id(2) }; + + // Block Tile size + constexpr int BN = 128; + constexpr int BM = 128; + // Number of Row or column we read per batch + constexpr int BK = 8; + + // Thread Tile size . 4x4 + constexpr int TN = 4; + constexpr int TM = 4; + + constexpr int nbWaves = BLOCK_SIZE / 32; + // Wave Tile size + constexpr int WN = 128; + constexpr int WM = BN * BM / nbWaves / WN; + + // Number of wave on X & Y axis in the Block tile + constexpr int nbWaveX = BN / WN; + constexpr int nbWaveY = BM / WM; + + const int waveIndex = threadIdx.x / 32; + const int waveIdx = waveIndex % nbWaveX; + const int waveIdy = waveIndex / nbWaveX; + const int indexInWave = threadIdx.x % 32; + + // A wave is a block of 8x4 of the output matrix + constexpr int nbThreadXPerWave = 8; + constexpr int nbThreadYPerWave = 4; + + // Thread coordinates in Wave + const int idxInWave = indexInWave % nbThreadXPerWave; + const int idyInWave = indexInWave / nbThreadXPerWave; + + constexpr int nbIterWaveN = WN / (nbThreadXPerWave * TN); + constexpr int nbIterWaveM = WM / (nbThreadYPerWave * TM); + + // Wave Sub-tile size + constexpr int SUBWN = WN / nbIterWaveN; + constexpr int SUBWM = WM / nbIterWaveM; + + // Thread mapping to read BKxBN block from A + int rAIdx = threadIdx.x % BK; + int rAIdy = threadIdx.x / BK; + // Thread mapping to read BNxBK block from B + int rBIdx = threadIdx.x % BN; + int rBIdy = threadIdx.x / BN; + + constexpr int strideReadB = BLOCK_SIZE / BN; + constexpr int strideReadA = BLOCK_SIZE / BK; + constexpr int nbReadsB = BN * BK / BLOCK_SIZE; + constexpr int nbReadsA = BM * BK / BLOCK_SIZE; + + float A_col[nbIterWaveM * TM]; + float B_row[nbIterWaveN * TN]; + + __shared__ float As[BK][BM+4]; // 4 padding to avoid bank conflicts + __shared__ float Bs[BK][BN]; + + float c_regs[TM * nbIterWaveM * TN * nbIterWaveN] = {0.0f}; + + // initial copy into shared memory + for (int i = 0; i < nbReadsB; i++) { + int index_x = BN * blockIdx.x + rBIdx; + int index_y = rBIdy + i * strideReadB; + Bs[index_y % BK][index_x % BN] = b[N * index_y + index_x]; + } + for (int i = 0; i < nbReadsA; i++) { + int index_x = rAIdx; + int index_y = BM * blockIdx.y + rAIdy + i * strideReadA; + As[(index_x % BK)][(index_y % BM)] = a[N * index_y + index_x]; + } + + __syncthreads(); + // Iteration over BK blocks. + for (int kId = 0; kId < N; kId += BK) { + float regA[nbReadsA]; + float regB[nbReadsB]; + if (kId < N - BK) { + // We populate the Shared Memory with Ks row and columns + for (int i = 0; i < nbReadsB; i++) { + int index_x = BN * blockIdx.x + rBIdx; + int index_y = rBIdy + i * strideReadB + kId + BK; + regB[i] = b[N * index_y + index_x]; + } + + for (int i = 0; i < nbReadsA; i++) { + int index_x = rAIdx + kId + BK; + int index_y = BM * blockIdx.y + rAIdy + i * strideReadA; + regA[i] = a[N * index_y + index_x]; + } + } + + for (int k = 0; k < BK; k++) { + // we cache A & B for the entire Wave tile + for (int iterWave = 0; iterWave < nbIterWaveN; iterWave++) { + for (int i = 0; i < TN; i++) { + int index = waveIdx * WN + iterWave * SUBWN + TN * idxInWave + i; + B_row[iterWave * TN + i] = Bs[k][index]; + } + } + + for (int iterWave = 0; iterWave < nbIterWaveM; iterWave++) { + for (int i = 0; i < TM; i++) { + int index = waveIdy * WM + iterWave * SUBWM + TM * idyInWave + i; + A_col[iterWave * TM + i] = As[k][index]; + } + } + + // we accumulate to C_regs + for (int iterWaveM = 0; iterWaveM < nbIterWaveM; iterWaveM++) { + for (int iterWaveN = 0; iterWaveN < nbIterWaveN; iterWaveN++) { + for (int yt = 0; yt < TM; yt++) { + for (int xt = 0; xt < TN; xt++) { + const int x = iterWaveN * TN + xt; + const int y = iterWaveM * TM + yt; + c_regs[y * TN * nbIterWaveN + x] += A_col[y] * B_row[x]; + } + } + } + } + } + __syncthreads(); + if (kId < N - BK) { + for (int i = 0; i < nbReadsB; i++) { + int index_x = BN * blockIdx.x + rBIdx; + int index_y = rBIdy + i * strideReadB + kId + BK; + Bs[index_y % BK][index_x % BN] = regB[i]; // row + } + + for (int i = 0; i < nbReadsA; i++) { + int index_x = rAIdx + kId + BK; + int index_y = BM * blockIdx.y + rAIdy + i * strideReadA; + As[(index_x % BK)][(index_y % BM)] = regA[i]; + } + __syncthreads(); + } + } + + for (int iterWaveM = 0; iterWaveM < nbIterWaveM; iterWaveM++) { + for (int iterWaveN = 0; iterWaveN < nbIterWaveN; iterWaveN++) { + int xOut = blockIdx.x * BN + waveIdx * WN + iterWaveN * SUBWN + TN * idxInWave; + int yOut = blockIdx.y * BM + waveIdy * WM + iterWaveM * SUBWM + TM * idyInWave; + for (int yt = 0; yt < TM; yt++) { + for (int xt = 0; xt < TN; xt++) { + int indexC = N * (yOut + yt) + xOut + xt; + c[indexC] = beta * c[indexC] + alpha * c_regs[TN * nbIterWaveN * (iterWaveM * TM + yt) + (iterWaveN * TN + xt)]; + } + } + } + } +} diff --git a/extra/gemm/kernel8_batched_gmem.s b/extra/gemm/amd_seb/kernel8_batched_gmem.s similarity index 99% rename from extra/gemm/kernel8_batched_gmem.s rename to extra/gemm/amd_seb/kernel8_batched_gmem.s index b37c050d74..00c321d5e7 100644 --- a/extra/gemm/kernel8_batched_gmem.s +++ b/extra/gemm/amd_seb/kernel8_batched_gmem.s @@ -1,6 +1,5 @@ .text .amdgcn_target "amdgcn-amd-amdhsa--gfx1100" - ;.amdhsa_code_object_version 5 .protected kernel ; -- Begin function kernel .globl kernel .p2align 8 @@ -9,7 +8,7 @@ kernel: ; @kernel ; %bb.0: ; %.preheader193 ;; Init code for matrix A and B buffer Loads - START - s_load_b128 s[20:23], s[0:1], 0x8 ; Matrix A and B + s_load_b128 s[20:23], s[0:1], 0x0 ; Matrix A and B s_waitcnt lgkmcnt(0) ; Matrix B offsets: @@ -76,14 +75,12 @@ kernel: ; @kernel s_clause 0x1 - ;s_load_b128 s[4:7], s[0:1], 0x18 ; N, alpha, beta, ??? - s_load_b128 s[8:11], s[0:1], 0x8 ; Matrix A and B - - s_mov_b32 s4, 4096 ; hardcode 4096 - s_mov_b32 s5, 0x3f800000 ; alpha - s_mov_b32 s6, 0 ; beta - s_mov_b32 s7, 0 - + ; s_load_b128 s[4:7], s[0:1], 0x18 + ; N=4096, alpha=1.0, beta=0.0 + s_mov_b32 s4, 4096 + s_mov_b32 s5, 0x3F800000 + s_mov_b32 s6, 0 + s_load_b128 s[8:11], s[0:1], 0x0 s_lshl_b32 s2, s14, 7 v_lshrrev_b32_e32 v4, 3, v0 v_or_b32_e32 v1, s2, v0 @@ -93,7 +90,7 @@ kernel: ; @kernel v_or_b32_e32 v22, s3, v4 v_ashrrev_i32_e32 v2, 31, v1 s_lshr_b32 s12, s12, 25 - s_load_b64 s[0:1], s[0:1], 0 ; Matrix C + s_load_b64 s[0:1], s[0:1], 0x10 v_lshlrev_b32_e32 v135, 2, v118 s_delay_alu instid0(VALU_DEP_2) | instskip(SKIP_3) | instid1(VALU_DEP_3) v_lshlrev_b64 v[5:6], 2, v[1:2] @@ -463,7 +460,7 @@ kernel: ; @kernel v_mov_b32_e32 v5, 0 v_mov_b32_e32 v3, 0 - s_add_i32 s7, s4, -1 + s_add_i32 s7, s4, -8 s_add_u32 s8, s8, 32 s_addc_u32 s9, s9, 0 s_mov_b32 s12, 0 @@ -2398,18 +2395,9 @@ amdhsa.kernels: .offset: 16 .size: 8 .value_kind: global_buffer - - .offset: 24 - .size: 4 - .value_kind: by_value - - .offset: 28 - .size: 4 - .value_kind: by_value - - .offset: 32 - .size: 4 - .value_kind: by_value .group_segment_fixed_size: 8320 .kernarg_segment_align: 8 - .kernarg_segment_size: 36 + .kernarg_segment_size: 24 .language: OpenCL C .language_version: - 2 diff --git a/extra/gemm/amd_uop_matmul.py b/extra/gemm/amd_uop_matmul.py new file mode 100644 index 0000000000..0833dcb112 --- /dev/null +++ b/extra/gemm/amd_uop_matmul.py @@ -0,0 +1,167 @@ +from tinygrad import Tensor, Device, Context, GlobalCounters, dtypes +from tinygrad.helpers import prod, unwrap +from tinygrad.uop.ops import UOp, Ops, KernelInfo +from tinygrad.opt.kernel import AxisType +from tinygrad.engine.realize import CompiledRunner, ExecItem, get_program +from tinygrad.uop.ops import graph_rewrite, PatternMatcher, UPat, Ops, UOp, GroupOp +from tinygrad.shape.shapetracker import ShapeTracker, strides_for_shape +from tinygrad.kernelize.kernelize import merge_views +from tinygrad.shape.view import View + +N = 4096 +run_count = 5 + +# change reduceop axes and input ShapeTrackers, view gets replaced with a reshape. +# src->r->view --> src->view->r +def swizzle_reduceop(src:UOp, r:UOp, view:UOp): + if r.tag is not None: return None + # confirm the input is in order + # TODO: replace this with a UOp that allows for nothing else then remove this + permute = tuple(i for i in range(len(src.shape)) if i not in r.axis_arg)+r.axis_arg + assert permute == tuple(range(len(permute))), f"reduce axis must already be in order, {permute} isn't" + + # append the reduce shape to each of the views + reduce_count = len(r.axis_arg) + prshape = prod(rshape:=src.shape[-reduce_count:]) + rstrides = strides_for_shape(rshape) + nv = [View.create(v.shape[:-reduce_count]+rshape, tuple(x*prshape for x in v.strides[:-reduce_count])+rstrides, v.offset*prshape, + v.mask[:-reduce_count]+tuple((0,s) for s in rshape) if v.mask is not None else None) for v in unwrap(view.st).views] + + # no reshape required with shrinking REDUCE_AXIS + return UOp(Ops.REDUCE_AXIS, r.dtype, (src.view(ShapeTracker(tuple(nv))),), + (r.arg[0], tuple(range(len(view.shape)-reduce_count, len(view.shape))))) + +early_view_left = merge_views+PatternMatcher([ + # view before elementwise and buffer ops + (UPat(Ops.VIEW, src=(UPat({*GroupOp.ALU, Ops.CAST, Ops.BITCAST, Ops.BIND, Ops.VALID, Ops.STORE, Ops.LOAD}, name="e"),), name="view"), + lambda e,view: e.replace(src=tuple(s.view(view.st) for s in e.src)) if e.tag is None else None), + # push a non contiguous ShapeTracker through reduceop + (UPat(Ops.VIEW, src=(UPat(Ops.REDUCE_AXIS, src=(UPat.var("src"),), name="r"),), name="view"), swizzle_reduceop), +]) + +def hand_spec(): + # Block Tile size . 128x128 + # Thread Tile size . 4x4 + # Wave Tile size . 128x32 + # A wave is . 8x4 + # ────── problem size and tiling params (mirror the C kernel) ─────────────────── + BK = 8 # depth of K-tile + BN = BM = 128 # block-tile (output) sizes + # the real thread is 16x8 = 128 regs + TM = 4 + nbIterWaveM = 2 + TN = 4 + nbIterWaveN = 4 + + # ────── shared-memory tile sizes (unchanged) ─────────────────────────────────── + LDS_A_SZ = BK * BM # 1024 floats + LDS_B_SZ = BK * BN # 1024 floats + + bC = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(N*N), arg=0) # output C + bA = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(N*N), arg=1) # input A + bB = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(N*N), arg=2) # input B + + # TODO: this should not be a string, just a number + lAs = UOp(Ops.DEFINE_LOCAL, dtypes.float.ptr(LDS_A_SZ, local=True), arg="As") + lBs = UOp(Ops.DEFINE_LOCAL, dtypes.float.ptr(LDS_B_SZ, local=True), arg="Bs") + + s0 = ShapeTracker.from_shape((N, N, N), (N, 0, 1)) + s1 = ShapeTracker.from_shape((N, N, N), (0, 1, N)) + s2 = ShapeTracker.from_shape((N, N, 1), (N, 1, 0)) + + ls0 = ShapeTracker.from_shape((BM, BK)) + ls1 = ShapeTracker.from_shape((BN, BK)) + + buf_at = [AxisType.GLOBAL, AxisType.UPCAST, AxisType.LOCAL, AxisType.LOCAL, AxisType.LOCAL, AxisType.LOCAL, AxisType.UPCAST, AxisType.UPCAST] + buf_bt = [AxisType.GLOBAL, AxisType.UPCAST, AxisType.LOCAL, AxisType.LOCAL, AxisType.LOCAL, AxisType.LOCAL, AxisType.UPCAST, AxisType.UPCAST] + axis_types = buf_at + buf_bt + [AxisType.REDUCE, AxisType.UNROLL, AxisType.UNROLL, AxisType.UNROLL] + + # 128 x 128 x 8 + full_shape = (N//BM, 2, 2, 2, 2, 2, 2, 2, N//BN, 2, 2, 2, 2, 2, 2, 2, N//BK, 2, 2, 2) + + s0 = s0.reshape(full_shape) + s1 = s1.reshape(full_shape) + s2 = s2.reshape(full_shape[:-4] + (1,)*4) + + ls0 = ls0.reshape((1, 2, 2, 2, 2, 2, 2, 2, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2)).expand(s0.shape) + ls1 = ls1.reshape((1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 1, 2, 2, 2)).expand(s1.shape) + assert ls0.real_size() == LDS_A_SZ + assert ls1.real_size() == LDS_B_SZ + + # BK is a loop of 8 + # each loop reads 8 in A, 16 in B + + print(ls0) + print(ls1) + + permaxis = [] + for axis_order in [AxisType.GLOBAL, AxisType.LOCAL, AxisType.LOOP, AxisType.UPCAST, AxisType.GROUP_REDUCE, AxisType.REDUCE, AxisType.UNROLL]: + permaxis += [i for i,a in enumerate(axis_types) if a == axis_order] + axis_types = [axis_types[x] for x in permaxis] + s0, s1, s2, ls0, ls1 = [x.permute(tuple(permaxis)) for x in [s0, s1, s2, ls0, ls1]] + print(axis_types) + + lw0, lr0 = ls0, ls0 + lw1, lr1 = ls1, ls1 + + # first round of permutes + + permaxis = (0, 1, 19, 18, 17, 12, 11, 10, 5, 4, 3, 2, 6, 7, 8, 9, 16, 13, 14, 15) + s0 = s0.permute(permaxis) + lw0 = lw0.permute(permaxis) + + permaxis = (0, 1, 15, 14, 9, 8, 7, 6, 13, 19, 18, 17, 5, 4, 3, 2, 16, 12, 11, 10) + s1 = s1.permute(permaxis) + lw1 = lw1.permute(permaxis) + + # second round of permutes + #permaxis = (0, 1, 12, 11, 5, 4, 3, 2, 10, 6, 7, 8, 9, 13, 14, 15, 16, 17, 18, 19) + #lw0 = lw0.permute(permaxis) + #lr0 = lr0.permute(permaxis) + + from tinygrad.opt.kernel import axis_colors, colored + print('_'.join([colored(f"{s}({st})", axis_colors[x]) for s,st,x in zip(s0.shape, s0.views[0].strides, axis_types)])) + print('_'.join([colored(f"{s}({st})", axis_colors[x]) for s,st,x in zip(s1.shape, s1.views[0].strides, axis_types)])) + print('_'.join([colored(f"{s}({st})", axis_colors[x]) for s,st,x in zip(s2.shape, s2.views[0].strides, axis_types)])) + print("lw") + print('_'.join([colored(f"{s}({st})", axis_colors[x]) for s,st,x in zip(lw0.shape, lw0.views[0].strides, axis_types)])) + print('_'.join([colored(f"{s}({st})", axis_colors[x]) for s,st,x in zip(lw1.shape, lw1.views[0].strides, axis_types)])) + print("lr") + print('_'.join([colored(f"{s}({st})", axis_colors[x]) for s,st,x in zip(lr0.shape, lr0.views[0].strides, axis_types)])) + print('_'.join([colored(f"{s}({st})", axis_colors[x]) for s,st,x in zip(lr1.shape, lr1.views[0].strides, axis_types)])) + + # loads and stores + bs0 = bA.view(s0).load() + bs1 = bB.view(s1).load() + bs0 = lAs.view(lr0).load(lAs.view(lw0).store(bs0)) + bs1 = lBs.view(lr1).load(lBs.view(lw1).store(bs1)) + + mat = (bs0 * bs1).r(Ops.ADD, tuple([i for i,a in enumerate(axis_types) if a in (AxisType.REDUCE, AxisType.UNROLL)]), permute=False) + st = bC.view(s2).store(mat) + + ast = st.sink(arg=KernelInfo(axis_types=tuple(axis_types), name="tinygemm")) + ast = graph_rewrite(ast, merge_views) + prg = get_program(ast, Device.default.renderer) + print(prg.src) + return prg + + +if __name__ == "__main__": + hprg = hand_spec() + hrunner = CompiledRunner(hprg) + + a = Tensor.randn(N, N).realize() + b = Tensor.randn(N, N).realize() + hc = Tensor.zeros(N, N).contiguous().realize() + + GlobalCounters.reset() + with Context(DEBUG=2, BEAM=4): + for _ in range(run_count): tc = (a@b).realize() + + GlobalCounters.reset() + ei = ExecItem(hrunner, [hc.uop.buffer, a.uop.buffer, b.uop.buffer]) + with Context(DEBUG=2): + for _ in range(run_count): ei.run(wait=True) + err = (hc-tc).square().mean().item() + print(f"hrunner {err}") + assert err < 1e-06