diff --git a/extra/gemm/amd_matmul.py b/extra/gemm/amd_matmul.py index 97c081aa77..6d704766b8 100644 --- a/extra/gemm/amd_matmul.py +++ b/extra/gemm/amd_matmul.py @@ -19,6 +19,9 @@ if __name__ == "__main__": elif getenv("ASM") == -1: src = (pathlib.Path(__file__).parent / "amd_seb" / "kernel3_registers.cpp").read_text() prgfast = replace(prg, name="kernel3_registers", src=src, global_size=[N//128, N//128, 1], local_size=[256, 1, 1]) + elif getenv("ASM") == -2: + src = (pathlib.Path(__file__).parent / "amd_seb" / "kernel4_gmem_df.cpp").read_text() + prgfast = replace(prg, name="kernel4_gmem_db", src=src, global_size=[N//128, N//128, 1], local_size=[256, 1, 1]) else: src = (pathlib.Path(__file__).parent / "amd_seb" / "kernel5_lds_optim.cpp").read_text() prgfast = replace(prg, name="kernel5_lds_optim", src=src, global_size=[N//128, N//128, 1], local_size=[128, 1, 1]) diff --git a/extra/gemm/amd_seb/kernel3_registers.cpp b/extra/gemm/amd_seb/kernel3_registers.cpp index 85bb62b86b..f4cdf21e29 100644 --- a/extra/gemm/amd_seb/kernel3_registers.cpp +++ b/extra/gemm/amd_seb/kernel3_registers.cpp @@ -10,7 +10,8 @@ __attribute__((device)) inline void __syncthreads() { } #define BLOCK_SIZE 256 -extern "C" __attribute__((global)) void kernel3_registers(float *a, float *b, float *c) +extern "C" __attribute__((global)) void __attribute__((amdgpu_flat_work_group_size(1, BLOCK_SIZE))) +kernel3_registers(float *a, float *b, float *c) { constexpr int N = 4096; constexpr float alpha = 1.0; diff --git a/extra/gemm/amd_seb/kernel4_gmem_df.cpp b/extra/gemm/amd_seb/kernel4_gmem_df.cpp new file mode 100644 index 0000000000..1d12654313 --- /dev/null +++ b/extra/gemm/amd_seb/kernel4_gmem_df.cpp @@ -0,0 +1,172 @@ +typedef long unsigned int size_t; +extern "C" __attribute__((device, const)) size_t __ockl_get_local_id(unsigned int); +extern "C" __attribute__((device, const)) size_t __ockl_get_group_id(unsigned int); +struct Dim3 { size_t x, y, z; }; +#define __shared__ __attribute__((shared, aligned(16))) +__attribute__((device)) inline void __syncthreads() { + __builtin_amdgcn_fence(__ATOMIC_RELEASE, "workgroup"); + __builtin_amdgcn_s_barrier(); + __builtin_amdgcn_fence(__ATOMIC_ACQUIRE, "workgroup"); +} + +#define BLOCK_SIZE 256 +extern "C" __attribute__((global)) void __attribute__((amdgpu_flat_work_group_size(1, BLOCK_SIZE))) +kernel4_gmem_db(float *a, float *b, float *c) +{ + constexpr int N = 4096; + constexpr float alpha = 1.0; + constexpr float beta = 0.0; + + const Dim3 blockIdx{ __ockl_get_group_id(0), __ockl_get_group_id(1), __ockl_get_group_id(2) }; + const Dim3 threadIdx{ __ockl_get_local_id(0), __ockl_get_local_id(1), __ockl_get_local_id(2) }; + + // Block Tile size + constexpr int BN = 128; + constexpr int BM = 128; + // Number of Row or column we read per batch + constexpr int BK = 8; + + // Thread Tile size + constexpr int TN = 4; + constexpr int TM = 4; + + constexpr int nbWaves = BLOCK_SIZE / 32; + // Wave Tile size + constexpr int WN = 64; + constexpr int WM = BN * BM / nbWaves / WN; + + // Number of wave on X & Y axis in the Block tile + constexpr int nbWaveX = BN / WN; + constexpr int nbWaveY = BM / WM; + + const int waveIndex = threadIdx.x / 32; + const int waveIdx = waveIndex % nbWaveX; + const int waveIdy = waveIndex / nbWaveX; + const int indexInWave = threadIdx.x % 32; + + // A wave is a block of 8x4 of the output matrix + constexpr int nbThreadXPerWave = 8; + constexpr int nbThreadYPerWave = 4; + + // Thread coordinates in Wave + const int idxInWave = indexInWave % nbThreadXPerWave; + const int idyInWave = indexInWave / nbThreadXPerWave; + + constexpr int nbIterWaveN = WN / (nbThreadXPerWave * TN); + constexpr int nbIterWaveM = WM / (nbThreadYPerWave * TM); + + // Wave Sub-tile size + constexpr int SUBWN = WN / nbIterWaveN; + constexpr int SUBWM = WM / nbIterWaveM; + + // Thread mapping to read BKxBN block from A + int rAIdx = threadIdx.x % BK; + int rAIdy = threadIdx.x / BK; + // Thread mapping to read BNxBK block from B + int rBIdx = threadIdx.x % BN; + int rBIdy = threadIdx.x / BN; + + constexpr int strideReadB = BLOCK_SIZE / BN; + constexpr int strideReadA = BLOCK_SIZE / BK; + constexpr int nbReadsB = BN * BK / BLOCK_SIZE; + constexpr int nbReadsA = BM * BK / BLOCK_SIZE; + + float A_col[nbIterWaveM * TM]; + float B_row[nbIterWaveN * TN]; + + __shared__ float As[BK][BM]; + __shared__ float Bs[BK][BN]; + + float c_regs[TM * nbIterWaveM * TN * nbIterWaveN] = {0.0f}; + + for (int i = 0; i < nbReadsB; i++) { + int index_x = BN * blockIdx.x + rBIdx; + int index_y = rBIdy + i * strideReadB; + Bs[index_y % BK][index_x % BN] = b[N * index_y + index_x]; + } + + for (int i = 0; i < nbReadsA; i++) { + int index_x = rAIdx; + int index_y = BM * blockIdx.y + rAIdy + i * strideReadA; + As[(index_x % BK)][(index_y % BM)] = a[N * index_y + index_x]; + } + + __syncthreads(); + // Iteration over BK blocks. + for (int kId = 0; kId < N; kId += BK) { + float regA[nbReadsA]; + float regB[nbReadsB]; + if (kId < N - BK) { + // We populate the Shared Memory with Ks row and columns + for (int i = 0; i < nbReadsB; i++) { + int index_x = BN * blockIdx.x + rBIdx; + int index_y = rBIdy + i * strideReadB + kId + BK; + regB[i] = b[N * index_y + index_x]; + } + + for (int i = 0; i < nbReadsA; i++) { + int index_x = rAIdx + kId + BK; + int index_y = BM * blockIdx.y + rAIdy + i * strideReadA; + regA[i] = a[N * index_y + index_x]; + } + } + + for (int k = 0; k < BK; k++) { + // we cache A & B for the entire Wave tile + for (int iterWave = 0; iterWave < nbIterWaveN; iterWave++) { + for (int i = 0; i < TN; i++) { + int index = waveIdx * WN + iterWave * SUBWN + TN * idxInWave + i; + B_row[iterWave * TN + i] = Bs[k][index]; + } + } + + for (int iterWave = 0; iterWave < nbIterWaveM; iterWave++) { + for (int i = 0; i < TM; i++) { + int index = waveIdy * WM + iterWave * SUBWM + TM * idyInWave + i; + A_col[iterWave * TM + i] = As[k][index]; + } + } + + // we accumulate to C_regs + for (int iterWaveM = 0; iterWaveM < nbIterWaveM; iterWaveM++) { + for (int iterWaveN = 0; iterWaveN < nbIterWaveN; iterWaveN++) { + for (int yt = 0; yt < TM; yt++) { + for (int xt = 0; xt < TN; xt++) { + const int x = iterWaveN * TN + xt; + const int y = iterWaveM * TM + yt; + c_regs[y * TN * nbIterWaveN + x] += A_col[y] * B_row[x]; + } + } + } + } + } + __syncthreads(); + if (kId < N - BK) { + for (int i = 0; i < nbReadsB; i++) { + int index_x = BN * blockIdx.x + rBIdx; + int index_y = rBIdy + i * strideReadB + kId + BK; + Bs[index_y % BK][index_x % BN] = regB[i]; // row + } + + for (int i = 0; i < nbReadsA; i++) { + int index_x = rAIdx + kId + BK; + int index_y = BM * blockIdx.y + rAIdy + i * strideReadA; + As[(index_x % BK)][(index_y % BM)] = regA[i]; + } + __syncthreads(); + } + } + + for (int iterWaveM = 0; iterWaveM < nbIterWaveM; iterWaveM++) { + for (int iterWaveN = 0; iterWaveN < nbIterWaveN; iterWaveN++) { + int xOut = blockIdx.x * BN + waveIdx * WN + iterWaveN * SUBWN + TN * idxInWave; + int yOut = blockIdx.y * BM + waveIdy * WM + iterWaveM * SUBWM + TM * idyInWave; + for (int yt = 0; yt < TM; yt++) { + for (int xt = 0; xt < TN; xt++) { + int indexC = N * (yOut + yt) + xOut + xt; + c[indexC] = beta * c[indexC] + alpha * c_regs[TN * nbIterWaveN * (iterWaveM * TM + yt) + (iterWaveN * TN + xt)]; + } + } + } + } +} \ No newline at end of file diff --git a/extra/gemm/amd_seb/kernel5_lds_optim.cpp b/extra/gemm/amd_seb/kernel5_lds_optim.cpp index 965f3d54d2..bf5b34e9c2 100644 --- a/extra/gemm/amd_seb/kernel5_lds_optim.cpp +++ b/extra/gemm/amd_seb/kernel5_lds_optim.cpp @@ -26,7 +26,7 @@ kernel5_lds_optim(float *a, float *b, float *c) // Number of Row or column we read per batch constexpr int BK = 8; - // Thread Tile size . 4x4 + // Thread Tile size constexpr int TN = 4; constexpr int TM = 4; diff --git a/extra/gemm/amd_uop_matmul.py b/extra/gemm/amd_uop_matmul.py index 6cbc85c478..22246c31f4 100644 --- a/extra/gemm/amd_uop_matmul.py +++ b/extra/gemm/amd_uop_matmul.py @@ -91,11 +91,11 @@ def hl_spec_kernel3(): sink = graph_rewrite(sink, merge_views) return sink -def hand_spec_kernel3(): - BLOCK_SIZE = 256 +def hand_spec_kernel3(kernel4=getenv("K4", 0), kernel5=getenv("K5", 0)): + BLOCK_SIZE = 128 if kernel5 else 256 nbWaves = BLOCK_SIZE // 32 - WN = 64 + WN = 128 if kernel5 else 64 WM = BN * BM // nbWaves // WN nbWaveX = BN // WN @@ -141,7 +141,8 @@ def hand_spec_kernel3(): A_col = UOp(Ops.DEFINE_REG, dtypes.float.ptr(nbIterWaveM * TM, AddrSpace.REG), arg=0) B_row = UOp(Ops.DEFINE_REG, dtypes.float.ptr(nbIterWaveN * TN, AddrSpace.REG), arg=1) - As = UOp(Ops.DEFINE_LOCAL, dtypes.float.ptr(BK*BM, AddrSpace.LOCAL), arg=0) + BM_As_stride = (BM+4) if kernel5 else BM + As = UOp(Ops.DEFINE_LOCAL, dtypes.float.ptr(BK*BM_As_stride, AddrSpace.LOCAL), arg=0) Bs = UOp(Ops.DEFINE_LOCAL, dtypes.float.ptr(BK*BN, AddrSpace.LOCAL), arg=1) c_regs = UOp(Ops.DEFINE_REG, dtypes.float.ptr(TM * nbIterWaveM * TN * nbIterWaveN), arg=2) @@ -149,51 +150,131 @@ def hand_spec_kernel3(): i = UOp.range(dtypes.int, c_regs.dtype.size, 16) init_store = c_regs[i].store(UOp.const(dtypes.float, 0.0), i) - kId_range = UOp.range(dtypes.int, N//BK, 0) - kId = kId_range*BK + if kernel4: + regA = UOp(Ops.DEFINE_REG, dtypes.float.ptr(nbReadsA, AddrSpace.REG), arg=3) + regB = UOp(Ops.DEFINE_REG, dtypes.float.ptr(nbReadsB, AddrSpace.REG), arg=4) - # load from globals into locals - i = UOp.range(dtypes.int, nbReadsB, 1) - index_x = BN * blockIdx_x + rBIdx - index_y = rBIdy + i * strideReadB + kId - Bs_store = Bs[(index_y % BK) * BN + index_x % BN].store(b[N * index_y + index_x].load(), i) + # initial load from globals into locals (0) + kId = 0 - i = UOp.range(dtypes.int, nbReadsA, 2) - index_x = rAIdx + kId - index_y = BM * blockIdx_y + rAIdy + i * strideReadA - As_store = As[(index_x % BK) * BM + index_y % BM].store(a[N * index_y + index_x].load(), i) + # load from globals into locals + i = UOp.range(dtypes.int, nbReadsB, 0) + index_x = BN * blockIdx_x + rBIdx + index_y = rBIdy + i * strideReadB + kId + Bs_store = Bs[(index_y % BK) * BN + index_x % BN].store(b[N * index_y + index_x].load(), i) - barrier = UOp(Ops.BARRIER, src=(As_store, Bs_store)) + i = UOp.range(dtypes.int, nbReadsA, 1) + index_x = rAIdx + kId + index_y = BM * blockIdx_y + rAIdy + i * strideReadA + As_store = As[(index_x % BK) * BM_As_stride + index_y % BM].store(a[N * index_y + index_x].load(), i) - k = UOp.range(dtypes.int, BK, 3) + # iterate over the middle chunk + kId_range = UOp.range(dtypes.int, N//BK-1, 2) + kId = kId_range*BK - # load from locals into registers - iterWave = UOp.range(dtypes.int, nbIterWaveN, 4) - i = UOp.range(dtypes.int, TN, 5) - index = waveIdx * WN + iterWave * SUBWN + TN * idxInWave + i - B_row_store = B_row[iterWave*TN + i].store(Bs[k*BN + index].load(barrier), iterWave, i) + barrier = UOp.barrier(As_store, Bs_store) - iterWave = UOp.range(dtypes.int, nbIterWaveM, 6) - i = UOp.range(dtypes.int, TM, 7) - index = waveIdy * WM + iterWave * SUBWM + TM * idyInWave + i - A_col_store = A_col[iterWave*TM + i].store(As[k*BM + index].load(barrier), iterWave, i) + # load from globals into registers (next round) + i = UOp.range(dtypes.int, nbReadsB, 3) + index_x = BN * blockIdx_x + rBIdx + index_y = rBIdy + i * strideReadB + kId + BK + regB_store = regB[i].store(b[N * index_y + index_x].load(), i) - # do the GEMM math - iterWaveM = UOp.range(dtypes.int, nbIterWaveM, 8) - iterWaveN = UOp.range(dtypes.int, nbIterWaveN, 9) - yt = UOp.range(dtypes.int, TM, 10) - xt = UOp.range(dtypes.int, TN, 11) - x = iterWaveN * TN + xt - y = iterWaveM * TM + yt - c_regs_idx = c_regs[y * TN * nbIterWaveN + x] - sink = c_regs_idx.store(c_regs_idx.load(init_store) + A_col[y].load(A_col_store) * B_row[x].load(B_row_store), - iterWaveM, iterWaveN, yt, xt, k, kId_range) + i = UOp.range(dtypes.int, nbReadsA, 4) + index_x = rAIdx + kId + BK + index_y = BM * blockIdx_y + rAIdy + i * strideReadA + regA_store = regA[i].store(a[N * index_y + index_x].load(), i) + + def inner_loop(first_range, inp_dep=()): + # inner unroll + k = UOp.range(dtypes.int, BK, first_range+0) + + # load from locals into registers + iterWave = UOp.range(dtypes.int, nbIterWaveN, first_range+1) + i = UOp.range(dtypes.int, TN, first_range+2) + index = waveIdx * WN + iterWave * SUBWN + TN * idxInWave + i + B_row_store = B_row[iterWave*TN + i].store(Bs[k*BN + index].load(*inp_dep), iterWave, i) + + iterWave = UOp.range(dtypes.int, nbIterWaveM, first_range+3) + i = UOp.range(dtypes.int, TM, first_range+4) + index = waveIdy * WM + iterWave * SUBWM + TM * idyInWave + i + A_col_store = A_col[iterWave*TM + i].store(As[k*BM_As_stride + index].load(*inp_dep), iterWave, i) + + # do the GEMM math + iterWaveM = UOp.range(dtypes.int, nbIterWaveM, first_range+5) + yt = UOp.range(dtypes.int, TM, first_range+6) + iterWaveN = UOp.range(dtypes.int, nbIterWaveN, first_range+7) + xt = UOp.range(dtypes.int, TN, first_range+8) + x = iterWaveN * TN + xt + y = iterWaveM * TM + yt + c_regs_idx = c_regs[y * TN * nbIterWaveN + x] + # sketchy, this should end the kId_range but it doesn't + sink = c_regs_idx.store(c_regs_idx.load(init_store) + A_col[y].load(A_col_store) * B_row[x].load(B_row_store), + iterWaveM, iterWaveN, yt, xt, k) + return sink + + # TODO: kId_range should endrange after a barrier + sink = inner_loop(5, (barrier, regB_store, regA_store)).barrier() + + # load from registers into locals + i = UOp.range(dtypes.int, nbReadsB, 14) + index_x = BN * blockIdx_x + rBIdx + index_y = rBIdy + i * strideReadB + kId + BK + Bs_store = Bs[(index_y % BK) * BN + index_x % BN].store(regB[i].load(sink), i, kId_range) + + i = UOp.range(dtypes.int, nbReadsA, 15) + index_x = rAIdx + kId + BK + index_y = BM * blockIdx_y + rAIdy + i * strideReadA + As_store = As[(index_x % BK) * BM_As_stride + index_y % BM].store(regA[i].load(sink), i, kId_range) + + # final iteration without the copy + sink = inner_loop(16, (UOp.barrier(Bs_store, As_store),)) + else: + kId_range = UOp.range(dtypes.int, N//BK, 0) + kId = kId_range*BK + + # load from globals into locals + i = UOp.range(dtypes.int, nbReadsB, 1) + index_x = BN * blockIdx_x + rBIdx + index_y = rBIdy + i * strideReadB + kId + Bs_store = Bs[(index_y % BK) * BN + index_x % BN].store(b[N * index_y + index_x].load(), i) + + i = UOp.range(dtypes.int, nbReadsA, 2) + index_x = rAIdx + kId + index_y = BM * blockIdx_y + rAIdy + i * strideReadA + As_store = As[(index_x % BK) * BM_As_stride + index_y % BM].store(a[N * index_y + index_x].load(), i) + + barrier = UOp.barrier(As_store, Bs_store) + + k = UOp.range(dtypes.int, BK, 3) + + # load from locals into registers + iterWave = UOp.range(dtypes.int, nbIterWaveN, 4) + i = UOp.range(dtypes.int, TN, 5) + index = waveIdx * WN + iterWave * SUBWN + TN * idxInWave + i + B_row_store = B_row[iterWave*TN + i].store(Bs[k*BN + index].load(barrier), iterWave, i) + + iterWave = UOp.range(dtypes.int, nbIterWaveM, 6) + i = UOp.range(dtypes.int, TM, 7) + index = waveIdy * WM + iterWave * SUBWM + TM * idyInWave + i + A_col_store = A_col[iterWave*TM + i].store(As[k*BM_As_stride + index].load(barrier), iterWave, i) + + # do the GEMM math + iterWaveM = UOp.range(dtypes.int, nbIterWaveM, 8) + yt = UOp.range(dtypes.int, TM, 9) + iterWaveN = UOp.range(dtypes.int, nbIterWaveN, 10) + xt = UOp.range(dtypes.int, TN, 12) + x = iterWaveN * TN + xt + y = iterWaveM * TM + yt + c_regs_idx = c_regs[y * TN * nbIterWaveN + x] + sink = c_regs_idx.store(c_regs_idx.load(init_store) + A_col[y].load(A_col_store) * B_row[x].load(B_row_store), + iterWaveM, iterWaveN, yt, xt, k, kId_range) # store c_regs into c - iterWaveM = UOp.range(dtypes.int, nbIterWaveM, 12) - iterWaveN = UOp.range(dtypes.int, nbIterWaveN, 13) - yt = UOp.range(dtypes.int, TM, 14) - xt = UOp.range(dtypes.int, TN, 15) + iterWaveM = UOp.range(dtypes.int, nbIterWaveM, 1000) + yt = UOp.range(dtypes.int, TM, 1001) + iterWaveN = UOp.range(dtypes.int, nbIterWaveN, 1002) + xt = UOp.range(dtypes.int, TN, 1003) xOut = blockIdx_x * BN + waveIdx * WN + iterWaveN * SUBWN + TN * idxInWave yOut = blockIdx_y * BM + waveIdy * WM + iterWaveM * SUBWM + TM * idyInWave indexC = N * (yOut + yt) + xOut + xt diff --git a/tinygrad/codegen/linearize.py b/tinygrad/codegen/linearize.py index 9aade2d260..067e12c1f0 100644 --- a/tinygrad/codegen/linearize.py +++ b/tinygrad/codegen/linearize.py @@ -97,7 +97,7 @@ class BlockContext: # ***** make blocks ***** -DONT_PLACE_IN_BLOCK = {Ops.DEFINE_GLOBAL, Ops.DEFINE_LOCAL, Ops.DEFINE_VAR, Ops.SPECIAL, Ops.CONST} +DONT_PLACE_IN_BLOCK = {Ops.DEFINE_GLOBAL, Ops.DEFINE_LOCAL, Ops.DEFINE_REG, Ops.DEFINE_VAR, Ops.SPECIAL, Ops.CONST} def add_blockends(base_block:UOp, new_ctx:tuple[UOp, ...], current_ctx:tuple[UOp, ...], cnt:int=1) -> UOp: ends_to_add = [z for z in new_ctx if z not in current_ctx] diff --git a/tinygrad/uop/ops.py b/tinygrad/uop/ops.py index 9a3a29968e..1a8ca202b0 100644 --- a/tinygrad/uop/ops.py +++ b/tinygrad/uop/ops.py @@ -238,7 +238,7 @@ class UOp(MathTrait, metaclass=UOpMetaClass): def load(self, *src:UOp, **kwargs): return UOp(Ops.LOAD, dtype=kwargs.pop("dtype", self.dtype.base), src=(self,)+src, **kwargs) def store(self, *src:UOp, **kwargs): return UOp(Ops.STORE, dtypes.void, (self,)+src, **kwargs) def assign(self, x:UOp): return UOp(Ops.ASSIGN, self.dtype, (self, x)) - def barrier(self): return UOp(Ops.BARRIER, src=(self,)) + def barrier(self, *src:UOp): return UOp(Ops.BARRIER, src=(self,)+src) def alu(self, op, *src:UOp, **kwargs): out_dtype = (self, *src)[-1].dtype if op in {Ops.CMPLT, Ops.CMPNE}: out_dtype = dtypes.bool.vec(out_dtype.count) if out_dtype.count > 1 else dtypes.bool