faster tk matmul (#13006)

This commit is contained in:
wozeparrot
2025-10-30 19:09:27 -07:00
committed by GitHub
parent 512513c403
commit 78f7650eec
3 changed files with 131 additions and 8 deletions

View File

@@ -5,11 +5,11 @@ using namespace kittens;
constexpr int g_N = 8192;
constexpr int BLOCK_SIZE = 32;
#define NUM_WORKERS (1)
#define NUM_THREADS (NUM_WORKERS*kittens::WARP_THREADS)
using sub_tile = st_bf<BLOCK_SIZE,BLOCK_SIZE>;
using tile_gl = gl<bf16, 1, 1, g_N, g_N, sub_tile>;
using tile_gl = gl<bf16, 1, 1, g_N, g_N>;
__launch_bounds__(NUM_WORKERS*WARP_THREADS, 1)
__global__ void kernel(bf16 *c_ptr, bf16 *a_ptr, bf16 *b_ptr) {
tile_gl g_C{c_ptr, nullptr, nullptr, nullptr, nullptr};
tile_gl g_A{a_ptr, nullptr, nullptr, nullptr, nullptr};

View File

@@ -1,10 +1,14 @@
import pathlib
from tinygrad import Device, Tensor
from tinygrad.helpers import Context
from tinygrad.helpers import Context, getenv
from tinygrad.runtime.support.compiler_cuda import pretty_ptx, NVCCCompiler
if __name__ == "__main__":
code = (pathlib.Path(__file__).parent / "matmul.cu").read_text()
if getenv("MATMUL2"):
code = (pathlib.Path(__file__).parent / "matmul2.cu").read_text()
else:
code = (pathlib.Path(__file__).parent / "matmul.cu").read_text()
device = Device["CUDA"]
kitten_args = [f"-I{(pathlib.Path(__file__).parent / 'include').as_posix()}", "-std=c++20", "--expt-relaxed-constexpr"]
lib = NVCCCompiler(device.compiler.arch, kitten_args).compile(code)
@@ -13,7 +17,10 @@ if __name__ == "__main__":
print(pretty_ptx(lib.decode()))
prg = device.runtime(kernel_name, lib)
prg.smem = 10000
if getenv("MATMUL2"):
prg.smem = 16384 * 2
else:
prg.smem = 10000
N = 8192
a = Tensor.randn(N, N, device='CUDA', dtype="bfloat16")
@@ -21,14 +28,25 @@ if __name__ == "__main__":
c = Tensor.empty(N, N, device='CUDA', dtype="bfloat16")
Tensor.realize(a, b, c)
BLOCK_SIZE = 32
WARP_THREADS = 32
if getenv("MATMUL2"):
SUPER_N = 2
SUPER_M = 2
NUM_WORKERS = SUPER_N * SUPER_M
BLOCK_SIZE = 32
gsz = (N // (BLOCK_SIZE * SUPER_N), N // (BLOCK_SIZE * SUPER_M), 1)
else:
NUM_WORKERS = 1
BLOCK_SIZE = 32
gsz = (N // (BLOCK_SIZE), N // (BLOCK_SIZE), 1)
gsz = (N // BLOCK_SIZE, N // BLOCK_SIZE, 1)
for _ in range(5):
et = prg(c.uop.buffer.ensure_allocated()._buf, a.uop.buffer._buf, b.uop.buffer._buf,
global_size=gsz, local_size=(32,1,1), wait=True)
global_size=gsz, local_size=(NUM_WORKERS*WARP_THREADS,1,1), wait=True)
print(f"{N*N*N*2/(et*1e9):2f} GFLOPS")
# print(c.tolist())
for _ in range(5):
with Context(DEBUG=2):
ref = (a@b).realize()

View File

@@ -0,0 +1,105 @@
#include "kittens.cuh"
using namespace kittens;
constexpr int g_N = 8192;
constexpr int SUPER_N = 2;
constexpr int SUPER_M = 2;
constexpr int NUM_WORKERS = SUPER_N * SUPER_M;
constexpr int LOAD_TASKS = SUPER_N + SUPER_M;
constexpr int WORKER_M = 32;
constexpr int WORKER_N = 32;
constexpr int BLOCK_K = 32;
constexpr int BLOCK_M = WORKER_M * SUPER_M;
constexpr int BLOCK_N = WORKER_N * SUPER_N;
constexpr int PIPE_STAGES = 2;
using reg_tile_A = rt_bf<WORKER_M, BLOCK_K>;
using reg_tile_B_col = rt_bf<BLOCK_K, WORKER_N, ducks::rt_layout::col>;
using reg_tile_C = rt_fl<WORKER_M, WORKER_N>;
using shared_tile_A = st_bf<WORKER_M, BLOCK_K>;
using shared_tile_B = st_bf<BLOCK_K, WORKER_N>;
using shared_tile_C = st_bf<WORKER_M, WORKER_N>;
using gl_tile_A = gl<bf16, 1, 1, g_N, g_N, shared_tile_A>;
using gl_tile_B = gl<bf16, 1, 1, g_N, g_N, shared_tile_B>;
using gl_tile_C = gl<bf16, 1, 1, g_N, g_N, shared_tile_C>;
__launch_bounds__(NUM_WORKERS *WARP_THREADS, 1) __global__
void kernel(bf16 *c_ptr, bf16 *a_ptr, bf16 *b_ptr) {
gl_tile_C g_C{c_ptr, nullptr, nullptr, nullptr, nullptr};
gl_tile_A g_A{a_ptr, nullptr, nullptr, nullptr, nullptr};
gl_tile_B g_B{b_ptr, nullptr, nullptr, nullptr, nullptr};
extern __shared__ alignment_dummy __shm[];
shared_allocator al((int *)&__shm[0]);
shared_tile_A(&As)[SUPER_M][PIPE_STAGES] =
al.allocate<shared_tile_A, SUPER_M, PIPE_STAGES>();
shared_tile_B(&Bs)[SUPER_N][PIPE_STAGES] =
al.allocate<shared_tile_B, SUPER_N, PIPE_STAGES>();
reg_tile_A A_reg;
reg_tile_B_col B_reg_col;
reg_tile_C C_accum;
int warpid = kittens::warpid();
int warp_m = warpid % SUPER_M;
int warp_n = warpid / SUPER_M;
int load_group_id = warpgroup::groupid();
int block_row = blockIdx.y * SUPER_M;
int block_col = blockIdx.x * SUPER_N;
warp::zero(C_accum);
int num_tiles = (g_N + BLOCK_K - 1) / BLOCK_K;
for (int load_tile = 0; load_tile < (PIPE_STAGES - 1); load_tile++) {
if (load_tile < num_tiles) {
int load_smem_idx = load_tile % PIPE_STAGES;
for (int task_id = warpid; task_id < LOAD_TASKS; task_id += NUM_WORKERS) {
if (task_id < SUPER_M) {
warp::load_async(As[task_id][load_smem_idx], g_A, {0, 0, block_row + task_id, load_tile});
} else {
int n_index = task_id - SUPER_M;
warp::load_async(Bs[n_index][load_smem_idx], g_B, {0, 0, load_tile, block_col + n_index});
}
}
}
}
for (int tile = 0; tile < num_tiles; tile++) {
int compute_smem_idx = tile % PIPE_STAGES;
int load_tile = tile + PIPE_STAGES - 1;
int load_smem_idx = load_tile % PIPE_STAGES;
if (load_tile < num_tiles) {
for (int task_id = warpid; task_id < LOAD_TASKS; task_id += NUM_WORKERS) {
if (task_id < SUPER_M) {
warp::load_async(As[task_id][load_smem_idx], g_A,
{0, 0, block_row + task_id, load_tile});
} else {
int n_index = task_id - SUPER_M;
warp::load_async(Bs[n_index][load_smem_idx], g_B,
{0, 0, load_tile, block_col + n_index});
}
}
load_async_wait<1>();
} else
load_async_wait();
__syncthreads();
warp::load(A_reg, As[warp_m][compute_smem_idx]);
warp::load(B_reg_col, Bs[warp_n][compute_smem_idx]);
warp::mma_AB(C_accum, A_reg, B_reg_col, C_accum);
__syncthreads();
}
warp::store(g_C, C_accum, {0, 0, block_row + warp_m, block_col + warp_n});
}