diff --git a/extra/thunder/cuda/matmul.cu b/extra/thunder/cuda/matmul.cu index 29cab02292..3f0ea766d1 100644 --- a/extra/thunder/cuda/matmul.cu +++ b/extra/thunder/cuda/matmul.cu @@ -5,11 +5,11 @@ using namespace kittens; constexpr int g_N = 8192; constexpr int BLOCK_SIZE = 32; #define NUM_WORKERS (1) -#define NUM_THREADS (NUM_WORKERS*kittens::WARP_THREADS) using sub_tile = st_bf; -using tile_gl = gl; +using tile_gl = gl; +__launch_bounds__(NUM_WORKERS*WARP_THREADS, 1) __global__ void kernel(bf16 *c_ptr, bf16 *a_ptr, bf16 *b_ptr) { tile_gl g_C{c_ptr, nullptr, nullptr, nullptr, nullptr}; tile_gl g_A{a_ptr, nullptr, nullptr, nullptr, nullptr}; diff --git a/extra/thunder/cuda/matmul.py b/extra/thunder/cuda/matmul.py index ea0454edcb..ace4cd9adf 100644 --- a/extra/thunder/cuda/matmul.py +++ b/extra/thunder/cuda/matmul.py @@ -1,10 +1,14 @@ import pathlib from tinygrad import Device, Tensor -from tinygrad.helpers import Context +from tinygrad.helpers import Context, getenv from tinygrad.runtime.support.compiler_cuda import pretty_ptx, NVCCCompiler if __name__ == "__main__": - code = (pathlib.Path(__file__).parent / "matmul.cu").read_text() + if getenv("MATMUL2"): + code = (pathlib.Path(__file__).parent / "matmul2.cu").read_text() + else: + code = (pathlib.Path(__file__).parent / "matmul.cu").read_text() + device = Device["CUDA"] kitten_args = [f"-I{(pathlib.Path(__file__).parent / 'include').as_posix()}", "-std=c++20", "--expt-relaxed-constexpr"] lib = NVCCCompiler(device.compiler.arch, kitten_args).compile(code) @@ -13,7 +17,10 @@ if __name__ == "__main__": print(pretty_ptx(lib.decode())) prg = device.runtime(kernel_name, lib) - prg.smem = 10000 + if getenv("MATMUL2"): + prg.smem = 16384 * 2 + else: + prg.smem = 10000 N = 8192 a = Tensor.randn(N, N, device='CUDA', dtype="bfloat16") @@ -21,14 +28,25 @@ if __name__ == "__main__": c = Tensor.empty(N, N, device='CUDA', dtype="bfloat16") Tensor.realize(a, b, c) - BLOCK_SIZE = 32 + WARP_THREADS = 32 + if getenv("MATMUL2"): + SUPER_N = 2 + SUPER_M = 2 + NUM_WORKERS = SUPER_N * SUPER_M + BLOCK_SIZE = 32 + gsz = (N // (BLOCK_SIZE * SUPER_N), N // (BLOCK_SIZE * SUPER_M), 1) + else: + NUM_WORKERS = 1 + BLOCK_SIZE = 32 + gsz = (N // (BLOCK_SIZE), N // (BLOCK_SIZE), 1) - gsz = (N // BLOCK_SIZE, N // BLOCK_SIZE, 1) for _ in range(5): et = prg(c.uop.buffer.ensure_allocated()._buf, a.uop.buffer._buf, b.uop.buffer._buf, - global_size=gsz, local_size=(32,1,1), wait=True) + global_size=gsz, local_size=(NUM_WORKERS*WARP_THREADS,1,1), wait=True) print(f"{N*N*N*2/(et*1e9):2f} GFLOPS") + # print(c.tolist()) + for _ in range(5): with Context(DEBUG=2): ref = (a@b).realize() diff --git a/extra/thunder/cuda/matmul2.cu b/extra/thunder/cuda/matmul2.cu new file mode 100644 index 0000000000..6c31cf5766 --- /dev/null +++ b/extra/thunder/cuda/matmul2.cu @@ -0,0 +1,105 @@ +#include "kittens.cuh" +using namespace kittens; + +constexpr int g_N = 8192; + +constexpr int SUPER_N = 2; +constexpr int SUPER_M = 2; +constexpr int NUM_WORKERS = SUPER_N * SUPER_M; +constexpr int LOAD_TASKS = SUPER_N + SUPER_M; + +constexpr int WORKER_M = 32; +constexpr int WORKER_N = 32; + +constexpr int BLOCK_K = 32; +constexpr int BLOCK_M = WORKER_M * SUPER_M; +constexpr int BLOCK_N = WORKER_N * SUPER_N; + +constexpr int PIPE_STAGES = 2; + +using reg_tile_A = rt_bf; +using reg_tile_B_col = rt_bf; +using reg_tile_C = rt_fl; + +using shared_tile_A = st_bf; +using shared_tile_B = st_bf; +using shared_tile_C = st_bf; + +using gl_tile_A = gl; +using gl_tile_B = gl; +using gl_tile_C = gl; + +__launch_bounds__(NUM_WORKERS *WARP_THREADS, 1) __global__ + void kernel(bf16 *c_ptr, bf16 *a_ptr, bf16 *b_ptr) { + gl_tile_C g_C{c_ptr, nullptr, nullptr, nullptr, nullptr}; + gl_tile_A g_A{a_ptr, nullptr, nullptr, nullptr, nullptr}; + gl_tile_B g_B{b_ptr, nullptr, nullptr, nullptr, nullptr}; + + extern __shared__ alignment_dummy __shm[]; + shared_allocator al((int *)&__shm[0]); + + shared_tile_A(&As)[SUPER_M][PIPE_STAGES] = + al.allocate(); + shared_tile_B(&Bs)[SUPER_N][PIPE_STAGES] = + al.allocate(); + + reg_tile_A A_reg; + reg_tile_B_col B_reg_col; + reg_tile_C C_accum; + + int warpid = kittens::warpid(); + int warp_m = warpid % SUPER_M; + int warp_n = warpid / SUPER_M; + + int load_group_id = warpgroup::groupid(); + + int block_row = blockIdx.y * SUPER_M; + int block_col = blockIdx.x * SUPER_N; + + warp::zero(C_accum); + int num_tiles = (g_N + BLOCK_K - 1) / BLOCK_K; + + for (int load_tile = 0; load_tile < (PIPE_STAGES - 1); load_tile++) { + if (load_tile < num_tiles) { + int load_smem_idx = load_tile % PIPE_STAGES; + for (int task_id = warpid; task_id < LOAD_TASKS; task_id += NUM_WORKERS) { + if (task_id < SUPER_M) { + warp::load_async(As[task_id][load_smem_idx], g_A, {0, 0, block_row + task_id, load_tile}); + } else { + int n_index = task_id - SUPER_M; + warp::load_async(Bs[n_index][load_smem_idx], g_B, {0, 0, load_tile, block_col + n_index}); + } + } + } + } + + for (int tile = 0; tile < num_tiles; tile++) { + int compute_smem_idx = tile % PIPE_STAGES; + + int load_tile = tile + PIPE_STAGES - 1; + int load_smem_idx = load_tile % PIPE_STAGES; + + if (load_tile < num_tiles) { + for (int task_id = warpid; task_id < LOAD_TASKS; task_id += NUM_WORKERS) { + if (task_id < SUPER_M) { + warp::load_async(As[task_id][load_smem_idx], g_A, + {0, 0, block_row + task_id, load_tile}); + } else { + int n_index = task_id - SUPER_M; + warp::load_async(Bs[n_index][load_smem_idx], g_B, + {0, 0, load_tile, block_col + n_index}); + } + } + load_async_wait<1>(); + } else + load_async_wait(); + __syncthreads(); + + warp::load(A_reg, As[warp_m][compute_smem_idx]); + warp::load(B_reg_col, Bs[warp_n][compute_smem_idx]); + + warp::mma_AB(C_accum, A_reg, B_reg_col, C_accum); + __syncthreads(); + } + warp::store(g_C, C_accum, {0, 0, block_row + warp_m, block_col + warp_n}); +}