mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 15:08:02 -05:00
faster tk matmul (#13006)
This commit is contained in:
@@ -5,11 +5,11 @@ using namespace kittens;
|
||||
constexpr int g_N = 8192;
|
||||
constexpr int BLOCK_SIZE = 32;
|
||||
#define NUM_WORKERS (1)
|
||||
#define NUM_THREADS (NUM_WORKERS*kittens::WARP_THREADS)
|
||||
|
||||
using sub_tile = st_bf<BLOCK_SIZE,BLOCK_SIZE>;
|
||||
using tile_gl = gl<bf16, 1, 1, g_N, g_N, sub_tile>;
|
||||
using tile_gl = gl<bf16, 1, 1, g_N, g_N>;
|
||||
|
||||
__launch_bounds__(NUM_WORKERS*WARP_THREADS, 1)
|
||||
__global__ void kernel(bf16 *c_ptr, bf16 *a_ptr, bf16 *b_ptr) {
|
||||
tile_gl g_C{c_ptr, nullptr, nullptr, nullptr, nullptr};
|
||||
tile_gl g_A{a_ptr, nullptr, nullptr, nullptr, nullptr};
|
||||
|
||||
@@ -1,10 +1,14 @@
|
||||
import pathlib
|
||||
from tinygrad import Device, Tensor
|
||||
from tinygrad.helpers import Context
|
||||
from tinygrad.helpers import Context, getenv
|
||||
from tinygrad.runtime.support.compiler_cuda import pretty_ptx, NVCCCompiler
|
||||
|
||||
if __name__ == "__main__":
|
||||
code = (pathlib.Path(__file__).parent / "matmul.cu").read_text()
|
||||
if getenv("MATMUL2"):
|
||||
code = (pathlib.Path(__file__).parent / "matmul2.cu").read_text()
|
||||
else:
|
||||
code = (pathlib.Path(__file__).parent / "matmul.cu").read_text()
|
||||
|
||||
device = Device["CUDA"]
|
||||
kitten_args = [f"-I{(pathlib.Path(__file__).parent / 'include').as_posix()}", "-std=c++20", "--expt-relaxed-constexpr"]
|
||||
lib = NVCCCompiler(device.compiler.arch, kitten_args).compile(code)
|
||||
@@ -13,7 +17,10 @@ if __name__ == "__main__":
|
||||
print(pretty_ptx(lib.decode()))
|
||||
|
||||
prg = device.runtime(kernel_name, lib)
|
||||
prg.smem = 10000
|
||||
if getenv("MATMUL2"):
|
||||
prg.smem = 16384 * 2
|
||||
else:
|
||||
prg.smem = 10000
|
||||
|
||||
N = 8192
|
||||
a = Tensor.randn(N, N, device='CUDA', dtype="bfloat16")
|
||||
@@ -21,14 +28,25 @@ if __name__ == "__main__":
|
||||
c = Tensor.empty(N, N, device='CUDA', dtype="bfloat16")
|
||||
Tensor.realize(a, b, c)
|
||||
|
||||
BLOCK_SIZE = 32
|
||||
WARP_THREADS = 32
|
||||
if getenv("MATMUL2"):
|
||||
SUPER_N = 2
|
||||
SUPER_M = 2
|
||||
NUM_WORKERS = SUPER_N * SUPER_M
|
||||
BLOCK_SIZE = 32
|
||||
gsz = (N // (BLOCK_SIZE * SUPER_N), N // (BLOCK_SIZE * SUPER_M), 1)
|
||||
else:
|
||||
NUM_WORKERS = 1
|
||||
BLOCK_SIZE = 32
|
||||
gsz = (N // (BLOCK_SIZE), N // (BLOCK_SIZE), 1)
|
||||
|
||||
gsz = (N // BLOCK_SIZE, N // BLOCK_SIZE, 1)
|
||||
for _ in range(5):
|
||||
et = prg(c.uop.buffer.ensure_allocated()._buf, a.uop.buffer._buf, b.uop.buffer._buf,
|
||||
global_size=gsz, local_size=(32,1,1), wait=True)
|
||||
global_size=gsz, local_size=(NUM_WORKERS*WARP_THREADS,1,1), wait=True)
|
||||
print(f"{N*N*N*2/(et*1e9):2f} GFLOPS")
|
||||
|
||||
# print(c.tolist())
|
||||
|
||||
for _ in range(5):
|
||||
with Context(DEBUG=2):
|
||||
ref = (a@b).realize()
|
||||
|
||||
105
extra/thunder/cuda/matmul2.cu
Normal file
105
extra/thunder/cuda/matmul2.cu
Normal file
@@ -0,0 +1,105 @@
|
||||
#include "kittens.cuh"
|
||||
using namespace kittens;
|
||||
|
||||
constexpr int g_N = 8192;
|
||||
|
||||
constexpr int SUPER_N = 2;
|
||||
constexpr int SUPER_M = 2;
|
||||
constexpr int NUM_WORKERS = SUPER_N * SUPER_M;
|
||||
constexpr int LOAD_TASKS = SUPER_N + SUPER_M;
|
||||
|
||||
constexpr int WORKER_M = 32;
|
||||
constexpr int WORKER_N = 32;
|
||||
|
||||
constexpr int BLOCK_K = 32;
|
||||
constexpr int BLOCK_M = WORKER_M * SUPER_M;
|
||||
constexpr int BLOCK_N = WORKER_N * SUPER_N;
|
||||
|
||||
constexpr int PIPE_STAGES = 2;
|
||||
|
||||
using reg_tile_A = rt_bf<WORKER_M, BLOCK_K>;
|
||||
using reg_tile_B_col = rt_bf<BLOCK_K, WORKER_N, ducks::rt_layout::col>;
|
||||
using reg_tile_C = rt_fl<WORKER_M, WORKER_N>;
|
||||
|
||||
using shared_tile_A = st_bf<WORKER_M, BLOCK_K>;
|
||||
using shared_tile_B = st_bf<BLOCK_K, WORKER_N>;
|
||||
using shared_tile_C = st_bf<WORKER_M, WORKER_N>;
|
||||
|
||||
using gl_tile_A = gl<bf16, 1, 1, g_N, g_N, shared_tile_A>;
|
||||
using gl_tile_B = gl<bf16, 1, 1, g_N, g_N, shared_tile_B>;
|
||||
using gl_tile_C = gl<bf16, 1, 1, g_N, g_N, shared_tile_C>;
|
||||
|
||||
__launch_bounds__(NUM_WORKERS *WARP_THREADS, 1) __global__
|
||||
void kernel(bf16 *c_ptr, bf16 *a_ptr, bf16 *b_ptr) {
|
||||
gl_tile_C g_C{c_ptr, nullptr, nullptr, nullptr, nullptr};
|
||||
gl_tile_A g_A{a_ptr, nullptr, nullptr, nullptr, nullptr};
|
||||
gl_tile_B g_B{b_ptr, nullptr, nullptr, nullptr, nullptr};
|
||||
|
||||
extern __shared__ alignment_dummy __shm[];
|
||||
shared_allocator al((int *)&__shm[0]);
|
||||
|
||||
shared_tile_A(&As)[SUPER_M][PIPE_STAGES] =
|
||||
al.allocate<shared_tile_A, SUPER_M, PIPE_STAGES>();
|
||||
shared_tile_B(&Bs)[SUPER_N][PIPE_STAGES] =
|
||||
al.allocate<shared_tile_B, SUPER_N, PIPE_STAGES>();
|
||||
|
||||
reg_tile_A A_reg;
|
||||
reg_tile_B_col B_reg_col;
|
||||
reg_tile_C C_accum;
|
||||
|
||||
int warpid = kittens::warpid();
|
||||
int warp_m = warpid % SUPER_M;
|
||||
int warp_n = warpid / SUPER_M;
|
||||
|
||||
int load_group_id = warpgroup::groupid();
|
||||
|
||||
int block_row = blockIdx.y * SUPER_M;
|
||||
int block_col = blockIdx.x * SUPER_N;
|
||||
|
||||
warp::zero(C_accum);
|
||||
int num_tiles = (g_N + BLOCK_K - 1) / BLOCK_K;
|
||||
|
||||
for (int load_tile = 0; load_tile < (PIPE_STAGES - 1); load_tile++) {
|
||||
if (load_tile < num_tiles) {
|
||||
int load_smem_idx = load_tile % PIPE_STAGES;
|
||||
for (int task_id = warpid; task_id < LOAD_TASKS; task_id += NUM_WORKERS) {
|
||||
if (task_id < SUPER_M) {
|
||||
warp::load_async(As[task_id][load_smem_idx], g_A, {0, 0, block_row + task_id, load_tile});
|
||||
} else {
|
||||
int n_index = task_id - SUPER_M;
|
||||
warp::load_async(Bs[n_index][load_smem_idx], g_B, {0, 0, load_tile, block_col + n_index});
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for (int tile = 0; tile < num_tiles; tile++) {
|
||||
int compute_smem_idx = tile % PIPE_STAGES;
|
||||
|
||||
int load_tile = tile + PIPE_STAGES - 1;
|
||||
int load_smem_idx = load_tile % PIPE_STAGES;
|
||||
|
||||
if (load_tile < num_tiles) {
|
||||
for (int task_id = warpid; task_id < LOAD_TASKS; task_id += NUM_WORKERS) {
|
||||
if (task_id < SUPER_M) {
|
||||
warp::load_async(As[task_id][load_smem_idx], g_A,
|
||||
{0, 0, block_row + task_id, load_tile});
|
||||
} else {
|
||||
int n_index = task_id - SUPER_M;
|
||||
warp::load_async(Bs[n_index][load_smem_idx], g_B,
|
||||
{0, 0, load_tile, block_col + n_index});
|
||||
}
|
||||
}
|
||||
load_async_wait<1>();
|
||||
} else
|
||||
load_async_wait();
|
||||
__syncthreads();
|
||||
|
||||
warp::load(A_reg, As[warp_m][compute_smem_idx]);
|
||||
warp::load(B_reg_col, Bs[warp_n][compute_smem_idx]);
|
||||
|
||||
warp::mma_AB(C_accum, A_reg, B_reg_col, C_accum);
|
||||
__syncthreads();
|
||||
}
|
||||
warp::store(g_C, C_accum, {0, 0, block_row + warp_m, block_col + warp_n});
|
||||
}
|
||||
Reference in New Issue
Block a user