working kitten matmul (#12935)

This commit is contained in:
wozeparrot
2025-10-26 23:40:49 -07:00
committed by GitHub
parent 189582db5e
commit 6b54378eba
4 changed files with 88 additions and 6 deletions

View File

@@ -65,8 +65,8 @@ template<typename _T, int _axis=-9999, bool _swizzle_flag=true> struct descripto
namespace detail {
template<typename... Args>
struct descriptor_dict {
__host__ descriptor_dict() {}
template<typename T> __host__ descriptor_dict(T _, int b, int d, int r, int c) {}
__host__ __device__ descriptor_dict() {}
template<typename T> __host__ __device__ descriptor_dict(T _, int b, int d, int r, int c) {}
__host__ __device__ descriptor_dict(const descriptor_dict &other) {}
#ifdef KITTENS_HOPPER
template<typename T, int U> __device__ const CUtensorMap* get() const {
@@ -85,8 +85,8 @@ struct descriptor_dict<_T, Args...> {
using DESC = kittens::tma::descriptor<_T>; // copy or initialize with a default value
CUtensorMap tma_desc;
descriptor_dict<Args...> other_descs;
__host__ descriptor_dict() {}
__host__ descriptor_dict(typename DESC::T::dtype *data, int b, int d, int r, int c): other_descs(data, b, d, r, c) {
__host__ __device__ descriptor_dict() {}
__host__ __device__ descriptor_dict(typename DESC::T::dtype *data, int b, int d, int r, int c): other_descs(data, b, d, r, c) {
kittens::detail::tma::create_tensor_map<typename DESC::T, DESC::axis, DESC::swizzle_flag>(&tma_desc, data, b, d, r, c);
}
__host__ __device__ inline descriptor_dict(const descriptor_dict &other) :
@@ -135,7 +135,7 @@ struct gl {
detail::descriptor_dict<TMA_Types...> tma_descs;
__host__ inline gl(T *_data,
__host__ __device__ inline gl(T *_data,
ducks::gl::make_arg_t<b> _batch,
ducks::gl::make_arg_t<d> _depth,
ducks::gl::make_arg_t<r> _rows,

View File

@@ -425,4 +425,4 @@ __host__ static inline CUtensorMap* allocate_and_create_tensor_map(const typenam
} // namespace tma
} // namespace detail
} // namespace kittens
} // namespace kittens

View File

@@ -0,0 +1,45 @@
// https://github.com/HazyResearch/ThunderKittens/blob/main/kernels/matmul/educational/level_04.cu
#include "kittens.cuh"
using namespace kittens;
constexpr int g_N = 8192;
constexpr int BLOCK_SIZE = 32;
#define NUM_WORKERS (1)
#define NUM_THREADS (NUM_WORKERS*kittens::WARP_THREADS)
using sub_tile = st_bf<BLOCK_SIZE,BLOCK_SIZE>;
using tile_gl = gl<bf16, 1, 1, g_N, g_N, sub_tile>;
__global__ void kernel(bf16 *c_ptr, bf16 *a_ptr, bf16 *b_ptr) {
tile_gl g_C{c_ptr, nullptr, nullptr, nullptr, nullptr};
tile_gl g_A{a_ptr, nullptr, nullptr, nullptr, nullptr};
tile_gl g_B{b_ptr, nullptr, nullptr, nullptr, nullptr};
extern __shared__ alignment_dummy __shm[];
shared_allocator al((int*)&__shm[0]);
st_bf<BLOCK_SIZE,BLOCK_SIZE> &As = al.allocate<st_bf<BLOCK_SIZE,BLOCK_SIZE>>();
st_bf<BLOCK_SIZE,BLOCK_SIZE> &Bs = al.allocate<st_bf<BLOCK_SIZE,BLOCK_SIZE>>();
rt_bf<BLOCK_SIZE,BLOCK_SIZE> A_reg;
rt_bf<BLOCK_SIZE,BLOCK_SIZE> B_reg;
rt_bf<BLOCK_SIZE,BLOCK_SIZE, ducks::rt_layout::col> B_reg_col;
rt_fl<BLOCK_SIZE,BLOCK_SIZE> C_accum;
int col = blockIdx.x;
int row = blockIdx.y;
warp::zero(C_accum);
int num_tiles = (g_N + BLOCK_SIZE - 1) / BLOCK_SIZE;
for (int tile = 0; tile < num_tiles; ++tile) {
warp::load(As, g_A, {0, 0, row, tile});
warp::load(Bs, g_B, {0, 0, tile, col});
__syncthreads();
warp::load(A_reg, As);
warp::load(B_reg, Bs);
warp::swap_layout(B_reg_col, B_reg);
__syncthreads();
warp::mma_AB(C_accum, A_reg, B_reg_col, C_accum);
__syncthreads();
}
warp::store(g_C, C_accum, {0, 0, row, col});
}

View File

@@ -0,0 +1,37 @@
import pathlib
from tinygrad import Device, Tensor
from tinygrad.helpers import Context
from tinygrad.runtime.support.compiler_cuda import pretty_ptx, NVCCCompiler
if __name__ == "__main__":
code = (pathlib.Path(__file__).parent / "matmul.cu").read_text()
device = Device["CUDA"]
kitten_args = [f"-I{(pathlib.Path(__file__).parent / 'include').as_posix()}", "-std=c++20", "--expt-relaxed-constexpr", "-DKITTENS_HOPPER"]
lib = NVCCCompiler(device.compiler.arch, kitten_args).compile(code)
kernel_name = lib.decode().split(".globl\t")[1].split("\n")[0]
print("kernel name", kernel_name)
print(pretty_ptx(lib.decode()))
prg = device.runtime(kernel_name, lib)
prg.smem = 10000
N = 8192
a = Tensor.randn(N, N, device='CUDA', dtype="bfloat16")
b = Tensor.randn(N, N, device='CUDA', dtype="bfloat16")
c = Tensor.empty(N, N, device='CUDA', dtype="bfloat16")
Tensor.realize(a, b, c)
BLOCK_SIZE = 32
gsz = (N // BLOCK_SIZE, N // BLOCK_SIZE, 1)
for _ in range(5):
et = prg(c.uop.buffer.ensure_allocated()._buf, a.uop.buffer._buf, b.uop.buffer._buf,
global_size=gsz, local_size=(32,1,1), wait=True)
print(f"{N*N*N*2/(et*1e9):2f} GFLOPS")
for _ in range(5):
with Context(DEBUG=2):
ref = (a@b).realize()
ref, c = ref.float(), c.float()
print((ref-c).mean().item(), (ref-c).max().item())