mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-10 07:28:15 -05:00
working kitten matmul (#12935)
This commit is contained in:
@@ -65,8 +65,8 @@ template<typename _T, int _axis=-9999, bool _swizzle_flag=true> struct descripto
|
||||
namespace detail {
|
||||
template<typename... Args>
|
||||
struct descriptor_dict {
|
||||
__host__ descriptor_dict() {}
|
||||
template<typename T> __host__ descriptor_dict(T _, int b, int d, int r, int c) {}
|
||||
__host__ __device__ descriptor_dict() {}
|
||||
template<typename T> __host__ __device__ descriptor_dict(T _, int b, int d, int r, int c) {}
|
||||
__host__ __device__ descriptor_dict(const descriptor_dict &other) {}
|
||||
#ifdef KITTENS_HOPPER
|
||||
template<typename T, int U> __device__ const CUtensorMap* get() const {
|
||||
@@ -85,8 +85,8 @@ struct descriptor_dict<_T, Args...> {
|
||||
using DESC = kittens::tma::descriptor<_T>; // copy or initialize with a default value
|
||||
CUtensorMap tma_desc;
|
||||
descriptor_dict<Args...> other_descs;
|
||||
__host__ descriptor_dict() {}
|
||||
__host__ descriptor_dict(typename DESC::T::dtype *data, int b, int d, int r, int c): other_descs(data, b, d, r, c) {
|
||||
__host__ __device__ descriptor_dict() {}
|
||||
__host__ __device__ descriptor_dict(typename DESC::T::dtype *data, int b, int d, int r, int c): other_descs(data, b, d, r, c) {
|
||||
kittens::detail::tma::create_tensor_map<typename DESC::T, DESC::axis, DESC::swizzle_flag>(&tma_desc, data, b, d, r, c);
|
||||
}
|
||||
__host__ __device__ inline descriptor_dict(const descriptor_dict &other) :
|
||||
@@ -135,7 +135,7 @@ struct gl {
|
||||
|
||||
detail::descriptor_dict<TMA_Types...> tma_descs;
|
||||
|
||||
__host__ inline gl(T *_data,
|
||||
__host__ __device__ inline gl(T *_data,
|
||||
ducks::gl::make_arg_t<b> _batch,
|
||||
ducks::gl::make_arg_t<d> _depth,
|
||||
ducks::gl::make_arg_t<r> _rows,
|
||||
|
||||
@@ -425,4 +425,4 @@ __host__ static inline CUtensorMap* allocate_and_create_tensor_map(const typenam
|
||||
|
||||
} // namespace tma
|
||||
} // namespace detail
|
||||
} // namespace kittens
|
||||
} // namespace kittens
|
||||
|
||||
45
extra/thunder/cuda/matmul.cu
Normal file
45
extra/thunder/cuda/matmul.cu
Normal file
@@ -0,0 +1,45 @@
|
||||
// https://github.com/HazyResearch/ThunderKittens/blob/main/kernels/matmul/educational/level_04.cu
|
||||
#include "kittens.cuh"
|
||||
using namespace kittens;
|
||||
|
||||
constexpr int g_N = 8192;
|
||||
constexpr int BLOCK_SIZE = 32;
|
||||
#define NUM_WORKERS (1)
|
||||
#define NUM_THREADS (NUM_WORKERS*kittens::WARP_THREADS)
|
||||
|
||||
using sub_tile = st_bf<BLOCK_SIZE,BLOCK_SIZE>;
|
||||
using tile_gl = gl<bf16, 1, 1, g_N, g_N, sub_tile>;
|
||||
|
||||
__global__ void kernel(bf16 *c_ptr, bf16 *a_ptr, bf16 *b_ptr) {
|
||||
tile_gl g_C{c_ptr, nullptr, nullptr, nullptr, nullptr};
|
||||
tile_gl g_A{a_ptr, nullptr, nullptr, nullptr, nullptr};
|
||||
tile_gl g_B{b_ptr, nullptr, nullptr, nullptr, nullptr};
|
||||
|
||||
extern __shared__ alignment_dummy __shm[];
|
||||
shared_allocator al((int*)&__shm[0]);
|
||||
st_bf<BLOCK_SIZE,BLOCK_SIZE> &As = al.allocate<st_bf<BLOCK_SIZE,BLOCK_SIZE>>();
|
||||
st_bf<BLOCK_SIZE,BLOCK_SIZE> &Bs = al.allocate<st_bf<BLOCK_SIZE,BLOCK_SIZE>>();
|
||||
|
||||
rt_bf<BLOCK_SIZE,BLOCK_SIZE> A_reg;
|
||||
rt_bf<BLOCK_SIZE,BLOCK_SIZE> B_reg;
|
||||
rt_bf<BLOCK_SIZE,BLOCK_SIZE, ducks::rt_layout::col> B_reg_col;
|
||||
rt_fl<BLOCK_SIZE,BLOCK_SIZE> C_accum;
|
||||
|
||||
int col = blockIdx.x;
|
||||
int row = blockIdx.y;
|
||||
|
||||
warp::zero(C_accum);
|
||||
int num_tiles = (g_N + BLOCK_SIZE - 1) / BLOCK_SIZE;
|
||||
for (int tile = 0; tile < num_tiles; ++tile) {
|
||||
warp::load(As, g_A, {0, 0, row, tile});
|
||||
warp::load(Bs, g_B, {0, 0, tile, col});
|
||||
__syncthreads();
|
||||
warp::load(A_reg, As);
|
||||
warp::load(B_reg, Bs);
|
||||
warp::swap_layout(B_reg_col, B_reg);
|
||||
__syncthreads();
|
||||
warp::mma_AB(C_accum, A_reg, B_reg_col, C_accum);
|
||||
__syncthreads();
|
||||
}
|
||||
warp::store(g_C, C_accum, {0, 0, row, col});
|
||||
}
|
||||
37
extra/thunder/cuda/matmul.py
Normal file
37
extra/thunder/cuda/matmul.py
Normal file
@@ -0,0 +1,37 @@
|
||||
import pathlib
|
||||
from tinygrad import Device, Tensor
|
||||
from tinygrad.helpers import Context
|
||||
from tinygrad.runtime.support.compiler_cuda import pretty_ptx, NVCCCompiler
|
||||
|
||||
if __name__ == "__main__":
|
||||
code = (pathlib.Path(__file__).parent / "matmul.cu").read_text()
|
||||
device = Device["CUDA"]
|
||||
kitten_args = [f"-I{(pathlib.Path(__file__).parent / 'include').as_posix()}", "-std=c++20", "--expt-relaxed-constexpr", "-DKITTENS_HOPPER"]
|
||||
lib = NVCCCompiler(device.compiler.arch, kitten_args).compile(code)
|
||||
kernel_name = lib.decode().split(".globl\t")[1].split("\n")[0]
|
||||
print("kernel name", kernel_name)
|
||||
print(pretty_ptx(lib.decode()))
|
||||
|
||||
prg = device.runtime(kernel_name, lib)
|
||||
prg.smem = 10000
|
||||
|
||||
N = 8192
|
||||
a = Tensor.randn(N, N, device='CUDA', dtype="bfloat16")
|
||||
b = Tensor.randn(N, N, device='CUDA', dtype="bfloat16")
|
||||
c = Tensor.empty(N, N, device='CUDA', dtype="bfloat16")
|
||||
Tensor.realize(a, b, c)
|
||||
|
||||
BLOCK_SIZE = 32
|
||||
|
||||
gsz = (N // BLOCK_SIZE, N // BLOCK_SIZE, 1)
|
||||
for _ in range(5):
|
||||
et = prg(c.uop.buffer.ensure_allocated()._buf, a.uop.buffer._buf, b.uop.buffer._buf,
|
||||
global_size=gsz, local_size=(32,1,1), wait=True)
|
||||
print(f"{N*N*N*2/(et*1e9):2f} GFLOPS")
|
||||
|
||||
for _ in range(5):
|
||||
with Context(DEBUG=2):
|
||||
ref = (a@b).realize()
|
||||
|
||||
ref, c = ref.float(), c.float()
|
||||
print((ref-c).mean().item(), (ref-c).max().item())
|
||||
Reference in New Issue
Block a user