extra/gemm/max_matmul: start of custom kernels for GEMM (#6926)

* extra/gemm/max_matmul: start of custom kernels for GEMM

* add an unoptimized FP16/FP16 MMA example

* add slow 3-stage fp16 acc example

* add correct 3-stage pipeline with unswizzled/flat smem input (slow)

* add acc fp16 example with 3 stages and swizzle (no bank conflicts)

* add max version of NV fp16_fp16_fp16

* fix up comments and removed unused code in max variations

* add start of no_xor example

* fix to account for UOps to Ops
This commit is contained in:
Francis Lam
2025-03-19 00:04:57 -07:00
committed by GitHub
parent 865f23dd7b
commit 1e5d9ad8f7
11 changed files with 4418 additions and 0 deletions

View File

@@ -0,0 +1,508 @@
#define INFINITY (__int_as_float(0x7f800000))
#define NAN (__int_as_float(0x7fffffff))
#include <cuda_fp16.h>
#include <cuda_pipeline.h>
#define N_PAD 132
struct __align__(8) half4 { half x, y, z, w; };
__device__ half4 make_half4(half x, half y, half z, half w) { half4 r={x, y, z, w}; return r; }
struct __align__(16) half8 { half x, y, z, w, a, b, c, d; };
__device__ half8 make_half8(half x, half y, half z, half w, half a, half b, half c, half d) { half8 r={x, y, z, w, a, b, c, d}; return r; }
__device__ void __ldmatrix_a_elems(half8 *regs, half *smem) {
uint32_t reg0, reg1, reg2, reg3;
asm volatile(
"ldmatrix.sync.aligned.m8n8.x4.shared.b16 {%0, %1, %2, %3}, [%4];"
: "=r"(reg0), "=r"(reg1), "=r"(reg2), "=r"(reg3)
: "l"(__cvta_generic_to_shared(smem))
);
uint32_t *addr = reinterpret_cast<uint32_t*>(regs);
addr[0] = reg0;
addr[1] = reg1;
addr[2] = reg2;
addr[3] = reg3;
}
__device__ void __ldmatrix_b_elems(half4 *regs_lo, half4 *regs_hi, half *smem) {
uint32_t reg0, reg1, reg2, reg3;
asm volatile(
"ldmatrix.sync.aligned.m8n8.x4.trans.shared.b16 {%0, %1, %2, %3}, [%4];"
: "=r"(reg0), "=r"(reg1), "=r"(reg2), "=r"(reg3)
: "l"(__cvta_generic_to_shared(smem))
);
uint32_t *addr_lo = reinterpret_cast<uint32_t*>(regs_lo);
uint32_t *addr_hi = reinterpret_cast<uint32_t*>(regs_hi);
addr_lo[0] = reg0;
addr_lo[1] = reg1;
addr_hi[0] = reg2;
addr_hi[1] = reg3;
}
__device__ half4 __WMMA_8_16_16_half_half(half8 a, half4 b, half4 c) {
int *a_pk = (int *) (&a), *b_pk = (int *) (&b), *c_pk = (int *) (&c);
asm( "mma.sync.aligned.m16n8k16.row.col.f16.f16.f16.f16 { %0, %1 }, { %2, %3, %4, %5 }, { %6, %7 }, { %0, %1 };"
: "+r"(c_pk[0]), "+r"(c_pk[1]): "r"(a_pk[0]), "r"(a_pk[1]), "r"(a_pk[2]), "r"(a_pk[3]), "r"(b_pk[0]), "r"(b_pk[1]) );
return c;
}
extern "C" __global__ void __launch_bounds__(128) wmma_example(half* data0, const half* data1, const half* data2, int N, int K) {
int grid_m = blockIdx.x; /* M//64 */
int grid_n = blockIdx.y; /* N//128 */
int threads = threadIdx.x; /* 128 */
int wg_m = (threads/64); // 0 or 1 for 1st and 3rd blocks of b_m=16xb_k=16 vs 2nd and 4th blocks
int wg_n = (threads/32)%2; // 0 or 1 for 1st, 3rd, 5th, 7th blocks of b_n=16xb_k=16 vs 2nd, 4th, 6th, 8th blocks - differs from triton
int wg_threads = threads%32;
int num_k_blocks = K / 64;
// load indexes
size_t global_a_off = ((grid_m * 64) * K) + ((threads % 8) * 8) + ((threads / 8) * K);
size_t global_b_off = (grid_n * 128) + ((threads % 16) * 8) + ((threads / 16) * N);
// swizzled smem store offsets - columns of smem are swizzled
// here's a link to a description of the triton: https://github.com/triton-lang/triton/discussions/2026#discussioncomment-6746579
// see also the thunderkittens impl: https://github.com/HazyResearch/ThunderKittens/blob/main/include/types/shared/st.cuh
size_t store_smem_a_off = ((threads / 8) * 64) + (((threads * 8) ^ threads) & 56); // r15
size_t store_smem_b_off = ((threads / 16) * 128) + (((threads / 16) * 8) ^ ((threads % 16) * 8)); // r19\
// ldmatrix indices
// threads 0-7 are row starts for A, 8-15 for B, 16-23 for C, 24-31 for D
// [ A | C ]
// [ - + - ]
// [ B | D ]
// swizzled ldmatrix
size_t load_smem_a_row = ((wg_m * 16) + (threads % 16)) * 64; // r293
size_t load_smem_a_phase = (threads / 16) % 2; // r4
size_t load_smem_b_row = (threads % 16) * 128; // r299
size_t load_smem_b_phase = (wg_n * 2) + (((threads / 16) % 2)); // r297 -- this differs from the generated triton kernel (swapped order)
size_t load_smem_a_0_k_0 = load_smem_a_row + (((load_smem_a_phase + 0) ^ (threads % 8)) * 8); // r38
size_t load_smem_a_1_k_0 = load_smem_a_0_k_0 + (32 * 64);
size_t load_smem_b_0_k_0 = load_smem_b_row + (((load_smem_b_phase + 0) ^ (threads % 8)) * 8);
size_t load_smem_b_1_k_0 = load_smem_b_row + (((load_smem_b_phase + 4) ^ (threads % 8)) * 8);
size_t load_smem_b_2_k_0 = load_smem_b_row + (((load_smem_b_phase + 8) ^ (threads % 8)) * 8);
size_t load_smem_b_3_k_0 = load_smem_b_row + (((load_smem_b_phase + 12) ^ (threads % 8)) * 8);
size_t load_smem_a_0_k_1 = load_smem_a_row + (((load_smem_a_phase + 2) ^ (threads % 8)) * 8); // r58 = r293 + r316;
size_t load_smem_a_1_k_1 = load_smem_a_0_k_1 + (32 * 64);
size_t load_smem_b_0_k_1 = load_smem_b_0_k_0 + (16 * 128);
size_t load_smem_b_1_k_1 = load_smem_b_1_k_0 + (16 * 128);
size_t load_smem_b_2_k_1 = load_smem_b_2_k_0 + (16 * 128);
size_t load_smem_b_3_k_1 = load_smem_b_3_k_0 + (16 * 128);
size_t load_smem_a_0_k_2 = load_smem_a_row + (((load_smem_a_phase + 4) ^ (threads % 8)) * 8); // r59 = r293 + r319;
size_t load_smem_a_1_k_2 = load_smem_a_0_k_2 + (32 * 64);
size_t load_smem_b_0_k_2 = load_smem_b_0_k_0 + (32 * 128);
size_t load_smem_b_1_k_2 = load_smem_b_1_k_0 + (32 * 128);
size_t load_smem_b_2_k_2 = load_smem_b_2_k_0 + (32 * 128);
size_t load_smem_b_3_k_2 = load_smem_b_3_k_0 + (32 * 128);
size_t load_smem_a_0_k_3 = load_smem_a_row + (((load_smem_a_phase + 6) ^ (threads % 8)) * 8); // r60 = r293 + r322;
size_t load_smem_a_1_k_3 = load_smem_a_0_k_3 + (32 * 64);
size_t load_smem_b_0_k_3 = load_smem_b_0_k_0 + (48 * 128);
size_t load_smem_b_1_k_3 = load_smem_b_1_k_0 + (48 * 128);
size_t load_smem_b_2_k_3 = load_smem_b_2_k_0 + (48 * 128);
size_t load_smem_b_3_k_3 = load_smem_b_3_k_0 + (48 * 128);
// create shared mem (A_1 8192 bytes, A_2 8192 bytes, B_1 16384 bytes, B2_16384 bytes)
__shared__ alignas(16) char smem[49152];
// create accs (16 WMMAs and 4 output elements each) and zero
half4 acc_frag_0_0 = make_half4(0.0f,0.0f,0.0f,0.0f);
half4 acc_frag_0_1 = make_half4(0.0f,0.0f,0.0f,0.0f);
half4 acc_frag_0_2 = make_half4(0.0f,0.0f,0.0f,0.0f);
half4 acc_frag_0_3 = make_half4(0.0f,0.0f,0.0f,0.0f);
half4 acc_frag_0_4 = make_half4(0.0f,0.0f,0.0f,0.0f);
half4 acc_frag_0_5 = make_half4(0.0f,0.0f,0.0f,0.0f);
half4 acc_frag_0_6 = make_half4(0.0f,0.0f,0.0f,0.0f);
half4 acc_frag_0_7 = make_half4(0.0f,0.0f,0.0f,0.0f);
half4 acc_frag_1_0 = make_half4(0.0f,0.0f,0.0f,0.0f);
half4 acc_frag_1_1 = make_half4(0.0f,0.0f,0.0f,0.0f);
half4 acc_frag_1_2 = make_half4(0.0f,0.0f,0.0f,0.0f);
half4 acc_frag_1_3 = make_half4(0.0f,0.0f,0.0f,0.0f);
half4 acc_frag_1_4 = make_half4(0.0f,0.0f,0.0f,0.0f);
half4 acc_frag_1_5 = make_half4(0.0f,0.0f,0.0f,0.0f);
half4 acc_frag_1_6 = make_half4(0.0f,0.0f,0.0f,0.0f);
half4 acc_frag_1_7 = make_half4(0.0f,0.0f,0.0f,0.0f);
// create registers for block A elements (2)
half8 a_frag_0;
half8 a_frag_1;
// create register for block B elements (8)
half4 b_frag_0;
half4 b_frag_1;
half4 b_frag_2;
half4 b_frag_3;
half4 b_frag_4;
half4 b_frag_5;
half4 b_frag_6;
half4 b_frag_7;
half *smem_a_even = (half *)(smem);
half *smem_a_odd = (half *)(smem + 8192);
half *smem_b_even = (half *)(smem + 16384);
half *smem_b_odd = (half *)(smem + 32768);
// https://developer.nvidia.com/blog/controlling-data-movement-to-boost-performance-on-ampere-architecture/
// https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#asynchronous-data-copies
// start first pre-fetch load A
__pipeline_memcpy_async(&smem_a_even[store_smem_a_off + ( 0)], &data1[global_a_off + ( 0)], 16);
__pipeline_memcpy_async(&smem_a_even[store_smem_a_off + (16*64)], &data1[global_a_off + (16*K)], 16);
__pipeline_memcpy_async(&smem_a_even[store_smem_a_off + (32*64)], &data1[global_a_off + (32*K)], 16);
__pipeline_memcpy_async(&smem_a_even[store_smem_a_off + (48*64)], &data1[global_a_off + (48*K)], 16);
// start first pre-fetch load B
__pipeline_memcpy_async(&smem_b_even[store_smem_b_off + ( 0)], &data2[global_b_off + ( 0)], 16);
__pipeline_memcpy_async(&smem_b_even[store_smem_b_off + ( 8*128)], &data2[global_b_off + ( 8*N)], 16);
__pipeline_memcpy_async(&smem_b_even[store_smem_b_off + (16*128)], &data2[global_b_off + (16*N)], 16);
__pipeline_memcpy_async(&smem_b_even[store_smem_b_off + (24*128)], &data2[global_b_off + (24*N)], 16);
__pipeline_memcpy_async(&smem_b_even[store_smem_b_off + (32*128)], &data2[global_b_off + (32*N)], 16);
__pipeline_memcpy_async(&smem_b_even[store_smem_b_off + (40*128)], &data2[global_b_off + (40*N)], 16);
__pipeline_memcpy_async(&smem_b_even[store_smem_b_off + (48*128)], &data2[global_b_off + (48*N)], 16);
__pipeline_memcpy_async(&smem_b_even[store_smem_b_off + (56*128)], &data2[global_b_off + (56*N)], 16);
__pipeline_commit();
global_a_off += 64;
global_b_off += 64 * N;
__syncthreads();
// start second pre-fetch load A
__pipeline_memcpy_async(&smem_a_odd[store_smem_a_off + ( 0)], &data1[global_a_off + ( 0)], 16);
__pipeline_memcpy_async(&smem_a_odd[store_smem_a_off + (16*64)], &data1[global_a_off + (16*K)], 16);
__pipeline_memcpy_async(&smem_a_odd[store_smem_a_off + (32*64)], &data1[global_a_off + (32*K)], 16);
__pipeline_memcpy_async(&smem_a_odd[store_smem_a_off + (48*64)], &data1[global_a_off + (48*K)], 16);
// start second pre-fetch load B
__pipeline_memcpy_async(&smem_b_odd[store_smem_b_off + ( 0)], &data2[global_b_off + ( 0)], 16);
__pipeline_memcpy_async(&smem_b_odd[store_smem_b_off + ( 8*128)], &data2[global_b_off + ( 8*N)], 16);
__pipeline_memcpy_async(&smem_b_odd[store_smem_b_off + (16*128)], &data2[global_b_off + (16*N)], 16);
__pipeline_memcpy_async(&smem_b_odd[store_smem_b_off + (24*128)], &data2[global_b_off + (24*N)], 16);
__pipeline_memcpy_async(&smem_b_odd[store_smem_b_off + (32*128)], &data2[global_b_off + (32*N)], 16);
__pipeline_memcpy_async(&smem_b_odd[store_smem_b_off + (40*128)], &data2[global_b_off + (40*N)], 16);
__pipeline_memcpy_async(&smem_b_odd[store_smem_b_off + (48*128)], &data2[global_b_off + (48*N)], 16);
__pipeline_memcpy_async(&smem_b_odd[store_smem_b_off + (56*128)], &data2[global_b_off + (56*N)], 16);
__pipeline_commit();
global_a_off += 64;
global_b_off += 64 * N;
// wait on needed prefetch value
__pipeline_wait_prior(0); // TODO: this enables fast iterations, but incorrect results with 1 (it shouldn't)
__syncthreads();
for (int block_k = 0; block_k < num_k_blocks; block_k++) {
// BLOCK_K==4: unroll 4 iterations of ldmatrix/wmma
half *smem_a_curr = (block_k % 2) ? smem_a_even : smem_a_odd;
half *smem_b_curr = (block_k % 2) ? smem_b_even : smem_b_odd;
// first load 16 K elements and 16 WMMAs: BLOCK_M==2 * BLOCK_N==8
__ldmatrix_a_elems(&a_frag_0, &smem_a_curr[load_smem_a_0_k_0]);
__ldmatrix_a_elems(&a_frag_1, &smem_a_curr[load_smem_a_1_k_0]);
__ldmatrix_b_elems(&b_frag_0, &b_frag_1, &smem_b_curr[load_smem_b_0_k_0]);
__ldmatrix_b_elems(&b_frag_2, &b_frag_3, &smem_b_curr[load_smem_b_1_k_0]);
__ldmatrix_b_elems(&b_frag_4, &b_frag_5, &smem_b_curr[load_smem_b_2_k_0]);
__ldmatrix_b_elems(&b_frag_6, &b_frag_7, &smem_b_curr[load_smem_b_3_k_0]);
acc_frag_0_0 = __WMMA_8_16_16_half_half(a_frag_0, b_frag_0, acc_frag_0_0);
acc_frag_0_1 = __WMMA_8_16_16_half_half(a_frag_0, b_frag_1, acc_frag_0_1);
acc_frag_0_2 = __WMMA_8_16_16_half_half(a_frag_0, b_frag_2, acc_frag_0_2);
acc_frag_0_3 = __WMMA_8_16_16_half_half(a_frag_0, b_frag_3, acc_frag_0_3);
acc_frag_0_4 = __WMMA_8_16_16_half_half(a_frag_0, b_frag_4, acc_frag_0_4);
acc_frag_0_5 = __WMMA_8_16_16_half_half(a_frag_0, b_frag_5, acc_frag_0_5);
acc_frag_0_6 = __WMMA_8_16_16_half_half(a_frag_0, b_frag_6, acc_frag_0_6);
acc_frag_0_7 = __WMMA_8_16_16_half_half(a_frag_0, b_frag_7, acc_frag_0_7);
acc_frag_1_0 = __WMMA_8_16_16_half_half(a_frag_1, b_frag_0, acc_frag_1_0);
acc_frag_1_1 = __WMMA_8_16_16_half_half(a_frag_1, b_frag_1, acc_frag_1_1);
acc_frag_1_2 = __WMMA_8_16_16_half_half(a_frag_1, b_frag_2, acc_frag_1_2);
acc_frag_1_3 = __WMMA_8_16_16_half_half(a_frag_1, b_frag_3, acc_frag_1_3);
acc_frag_1_4 = __WMMA_8_16_16_half_half(a_frag_1, b_frag_4, acc_frag_1_4);
acc_frag_1_5 = __WMMA_8_16_16_half_half(a_frag_1, b_frag_5, acc_frag_1_5);
acc_frag_1_6 = __WMMA_8_16_16_half_half(a_frag_1, b_frag_6, acc_frag_1_6);
acc_frag_1_7 = __WMMA_8_16_16_half_half(a_frag_1, b_frag_7, acc_frag_1_7);
// next 16 K elements
__ldmatrix_a_elems(&a_frag_0, &smem_a_curr[load_smem_a_0_k_1]);
__ldmatrix_a_elems(&a_frag_1, &smem_a_curr[load_smem_a_1_k_1]);
__ldmatrix_b_elems(&b_frag_0, &b_frag_1, &smem_b_curr[load_smem_b_0_k_1]);
__ldmatrix_b_elems(&b_frag_2, &b_frag_3, &smem_b_curr[load_smem_b_1_k_1]);
__ldmatrix_b_elems(&b_frag_4, &b_frag_5, &smem_b_curr[load_smem_b_2_k_1]);
__ldmatrix_b_elems(&b_frag_6, &b_frag_7, &smem_b_curr[load_smem_b_3_k_1]);
acc_frag_0_0 = __WMMA_8_16_16_half_half(a_frag_0, b_frag_0, acc_frag_0_0);
acc_frag_0_1 = __WMMA_8_16_16_half_half(a_frag_0, b_frag_1, acc_frag_0_1);
acc_frag_0_2 = __WMMA_8_16_16_half_half(a_frag_0, b_frag_2, acc_frag_0_2);
acc_frag_0_3 = __WMMA_8_16_16_half_half(a_frag_0, b_frag_3, acc_frag_0_3);
acc_frag_0_4 = __WMMA_8_16_16_half_half(a_frag_0, b_frag_4, acc_frag_0_4);
acc_frag_0_5 = __WMMA_8_16_16_half_half(a_frag_0, b_frag_5, acc_frag_0_5);
acc_frag_0_6 = __WMMA_8_16_16_half_half(a_frag_0, b_frag_6, acc_frag_0_6);
acc_frag_0_7 = __WMMA_8_16_16_half_half(a_frag_0, b_frag_7, acc_frag_0_7);
acc_frag_1_0 = __WMMA_8_16_16_half_half(a_frag_1, b_frag_0, acc_frag_1_0);
acc_frag_1_1 = __WMMA_8_16_16_half_half(a_frag_1, b_frag_1, acc_frag_1_1);
acc_frag_1_2 = __WMMA_8_16_16_half_half(a_frag_1, b_frag_2, acc_frag_1_2);
acc_frag_1_3 = __WMMA_8_16_16_half_half(a_frag_1, b_frag_3, acc_frag_1_3);
acc_frag_1_4 = __WMMA_8_16_16_half_half(a_frag_1, b_frag_4, acc_frag_1_4);
acc_frag_1_5 = __WMMA_8_16_16_half_half(a_frag_1, b_frag_5, acc_frag_1_5);
acc_frag_1_6 = __WMMA_8_16_16_half_half(a_frag_1, b_frag_6, acc_frag_1_6);
acc_frag_1_7 = __WMMA_8_16_16_half_half(a_frag_1, b_frag_7, acc_frag_1_7);
// next 16 K elements
__ldmatrix_a_elems(&a_frag_0, &smem_a_curr[load_smem_a_0_k_2]);
__ldmatrix_a_elems(&a_frag_1, &smem_a_curr[load_smem_a_1_k_2]);
__ldmatrix_b_elems(&b_frag_0, &b_frag_1, &smem_b_curr[load_smem_b_0_k_2]);
__ldmatrix_b_elems(&b_frag_2, &b_frag_3, &smem_b_curr[load_smem_b_1_k_2]);
__ldmatrix_b_elems(&b_frag_4, &b_frag_5, &smem_b_curr[load_smem_b_2_k_2]);
__ldmatrix_b_elems(&b_frag_6, &b_frag_7, &smem_b_curr[load_smem_b_3_k_2]);
acc_frag_0_0 = __WMMA_8_16_16_half_half(a_frag_0, b_frag_0, acc_frag_0_0);
acc_frag_0_1 = __WMMA_8_16_16_half_half(a_frag_0, b_frag_1, acc_frag_0_1);
acc_frag_0_2 = __WMMA_8_16_16_half_half(a_frag_0, b_frag_2, acc_frag_0_2);
acc_frag_0_3 = __WMMA_8_16_16_half_half(a_frag_0, b_frag_3, acc_frag_0_3);
acc_frag_0_4 = __WMMA_8_16_16_half_half(a_frag_0, b_frag_4, acc_frag_0_4);
acc_frag_0_5 = __WMMA_8_16_16_half_half(a_frag_0, b_frag_5, acc_frag_0_5);
acc_frag_0_6 = __WMMA_8_16_16_half_half(a_frag_0, b_frag_6, acc_frag_0_6);
acc_frag_0_7 = __WMMA_8_16_16_half_half(a_frag_0, b_frag_7, acc_frag_0_7);
acc_frag_1_0 = __WMMA_8_16_16_half_half(a_frag_1, b_frag_0, acc_frag_1_0);
acc_frag_1_1 = __WMMA_8_16_16_half_half(a_frag_1, b_frag_1, acc_frag_1_1);
acc_frag_1_2 = __WMMA_8_16_16_half_half(a_frag_1, b_frag_2, acc_frag_1_2);
acc_frag_1_3 = __WMMA_8_16_16_half_half(a_frag_1, b_frag_3, acc_frag_1_3);
acc_frag_1_4 = __WMMA_8_16_16_half_half(a_frag_1, b_frag_4, acc_frag_1_4);
acc_frag_1_5 = __WMMA_8_16_16_half_half(a_frag_1, b_frag_5, acc_frag_1_5);
acc_frag_1_6 = __WMMA_8_16_16_half_half(a_frag_1, b_frag_6, acc_frag_1_6);
acc_frag_1_7 = __WMMA_8_16_16_half_half(a_frag_1, b_frag_7, acc_frag_1_7);
// last 16 K elements
__ldmatrix_a_elems(&a_frag_0, &smem_a_curr[load_smem_a_0_k_3]);
__ldmatrix_a_elems(&a_frag_1, &smem_a_curr[load_smem_a_1_k_3]);
__ldmatrix_b_elems(&b_frag_0, &b_frag_1, &smem_b_curr[load_smem_b_0_k_3]);
__ldmatrix_b_elems(&b_frag_2, &b_frag_3, &smem_b_curr[load_smem_b_1_k_3]);
__ldmatrix_b_elems(&b_frag_4, &b_frag_5, &smem_b_curr[load_smem_b_2_k_3]);
__ldmatrix_b_elems(&b_frag_6, &b_frag_7, &smem_b_curr[load_smem_b_3_k_3]);
acc_frag_0_0 = __WMMA_8_16_16_half_half(a_frag_0, b_frag_0, acc_frag_0_0);
acc_frag_0_1 = __WMMA_8_16_16_half_half(a_frag_0, b_frag_1, acc_frag_0_1);
acc_frag_0_2 = __WMMA_8_16_16_half_half(a_frag_0, b_frag_2, acc_frag_0_2);
acc_frag_0_3 = __WMMA_8_16_16_half_half(a_frag_0, b_frag_3, acc_frag_0_3);
acc_frag_0_4 = __WMMA_8_16_16_half_half(a_frag_0, b_frag_4, acc_frag_0_4);
acc_frag_0_5 = __WMMA_8_16_16_half_half(a_frag_0, b_frag_5, acc_frag_0_5);
acc_frag_0_6 = __WMMA_8_16_16_half_half(a_frag_0, b_frag_6, acc_frag_0_6);
acc_frag_0_7 = __WMMA_8_16_16_half_half(a_frag_0, b_frag_7, acc_frag_0_7);
acc_frag_1_0 = __WMMA_8_16_16_half_half(a_frag_1, b_frag_0, acc_frag_1_0);
acc_frag_1_1 = __WMMA_8_16_16_half_half(a_frag_1, b_frag_1, acc_frag_1_1);
acc_frag_1_2 = __WMMA_8_16_16_half_half(a_frag_1, b_frag_2, acc_frag_1_2);
acc_frag_1_3 = __WMMA_8_16_16_half_half(a_frag_1, b_frag_3, acc_frag_1_3);
acc_frag_1_4 = __WMMA_8_16_16_half_half(a_frag_1, b_frag_4, acc_frag_1_4);
acc_frag_1_5 = __WMMA_8_16_16_half_half(a_frag_1, b_frag_5, acc_frag_1_5);
acc_frag_1_6 = __WMMA_8_16_16_half_half(a_frag_1, b_frag_6, acc_frag_1_6);
acc_frag_1_7 = __WMMA_8_16_16_half_half(a_frag_1, b_frag_7, acc_frag_1_7);
// prefetch next iteration if needed
__syncthreads();
if (block_k < (num_k_blocks-2)) {
__pipeline_memcpy_async(&smem_a_curr[store_smem_a_off + ( 0)], &data1[global_a_off + ( 0)], 16);
__pipeline_memcpy_async(&smem_a_curr[store_smem_a_off + (16*64)], &data1[global_a_off + (16*K)], 16);
__pipeline_memcpy_async(&smem_a_curr[store_smem_a_off + (32*64)], &data1[global_a_off + (32*K)], 16);
__pipeline_memcpy_async(&smem_a_curr[store_smem_a_off + (48*64)], &data1[global_a_off + (48*K)], 16);
__pipeline_memcpy_async(&smem_b_curr[store_smem_b_off + ( 0)], &data2[global_b_off + ( 0)], 16);
__pipeline_memcpy_async(&smem_b_curr[store_smem_b_off + ( 8*128)], &data2[global_b_off + ( 8*N)], 16);
__pipeline_memcpy_async(&smem_b_curr[store_smem_b_off + (16*128)], &data2[global_b_off + (16*N)], 16);
__pipeline_memcpy_async(&smem_b_curr[store_smem_b_off + (24*128)], &data2[global_b_off + (24*N)], 16);
__pipeline_memcpy_async(&smem_b_curr[store_smem_b_off + (32*128)], &data2[global_b_off + (32*N)], 16);
__pipeline_memcpy_async(&smem_b_curr[store_smem_b_off + (40*128)], &data2[global_b_off + (40*N)], 16);
__pipeline_memcpy_async(&smem_b_curr[store_smem_b_off + (48*128)], &data2[global_b_off + (48*N)], 16);
__pipeline_memcpy_async(&smem_b_curr[store_smem_b_off + (56*128)], &data2[global_b_off + (56*N)], 16);
global_a_off += 64;
global_b_off += 64 * N;
}
__pipeline_commit();
if (block_k < num_k_blocks-1) {
__pipeline_wait_prior(1);
__syncthreads();
}
}
// write accumulators to output
__pipeline_wait_prior(0);
__syncthreads();
// // store registers to smem first, then read back to do float4 writes to global
// float *smem_d = (float *)(smem);
// size_t smem_d_off = (wg_m * 16 * N_PAD) + (wg_n * 16) + ((wg_threads % 4) * 2) + (((wg_threads / 4) % 8) * N_PAD);
// smem_d[smem_d_off + 0 + ( 0*8) ] = acc_frag_0_0.x;
// smem_d[smem_d_off + 1 + ( 0*8) ] = acc_frag_0_0.y;
// smem_d[smem_d_off + 0 + ( 0*8) + (8*N_PAD)] = acc_frag_0_0.z;
// smem_d[smem_d_off + 1 + ( 0*8) + (8*N_PAD)] = acc_frag_0_0.w;
// smem_d[smem_d_off + 0 + ( 1*8) ] = acc_frag_0_1.x;
// smem_d[smem_d_off + 1 + ( 1*8) ] = acc_frag_0_1.y;
// smem_d[smem_d_off + 0 + ( 1*8) + (8*N_PAD)] = acc_frag_0_1.z;
// smem_d[smem_d_off + 1 + ( 1*8) + (8*N_PAD)] = acc_frag_0_1.w;
// smem_d[smem_d_off + 0 + ( 4*8) ] = acc_frag_0_2.x;
// smem_d[smem_d_off + 1 + ( 4*8) ] = acc_frag_0_2.y;
// smem_d[smem_d_off + 0 + ( 4*8) + (8*N_PAD)] = acc_frag_0_2.z;
// smem_d[smem_d_off + 1 + ( 4*8) + (8*N_PAD)] = acc_frag_0_2.w;
// smem_d[smem_d_off + 0 + ( 5*8) ] = acc_frag_0_3.x;
// smem_d[smem_d_off + 1 + ( 5*8) ] = acc_frag_0_3.y;
// smem_d[smem_d_off + 0 + ( 5*8) + (8*N_PAD)] = acc_frag_0_3.z;
// smem_d[smem_d_off + 1 + ( 5*8) + (8*N_PAD)] = acc_frag_0_3.w;
// smem_d[smem_d_off + 0 + ( 8*8) ] = acc_frag_0_4.x;
// smem_d[smem_d_off + 1 + ( 8*8) ] = acc_frag_0_4.y;
// smem_d[smem_d_off + 0 + ( 8*8) + (8*N_PAD)] = acc_frag_0_4.z;
// smem_d[smem_d_off + 1 + ( 8*8) + (8*N_PAD)] = acc_frag_0_4.w;
// smem_d[smem_d_off + 0 + ( 9*8) ] = acc_frag_0_5.x;
// smem_d[smem_d_off + 1 + ( 9*8) ] = acc_frag_0_5.y;
// smem_d[smem_d_off + 0 + ( 9*8) + (8*N_PAD)] = acc_frag_0_5.z;
// smem_d[smem_d_off + 1 + ( 9*8) + (8*N_PAD)] = acc_frag_0_5.w;
// smem_d[smem_d_off + 0 + (12*8) ] = acc_frag_0_6.x;
// smem_d[smem_d_off + 1 + (12*8) ] = acc_frag_0_6.y;
// smem_d[smem_d_off + 0 + (12*8) + (8*N_PAD)] = acc_frag_0_6.z;
// smem_d[smem_d_off + 1 + (12*8) + (8*N_PAD)] = acc_frag_0_6.w;
// smem_d[smem_d_off + 0 + (13*8) ] = acc_frag_0_7.x;
// smem_d[smem_d_off + 1 + (13*8) ] = acc_frag_0_7.y;
// smem_d[smem_d_off + 0 + (13*8) + (8*N_PAD)] = acc_frag_0_7.z;
// smem_d[smem_d_off + 1 + (13*8) + (8*N_PAD)] = acc_frag_0_7.w;
// __syncthreads();
// size_t load_smem_d_off = ((threads % 32) * 4) + ((threads / 32) * N_PAD);
// float4 d_0_0 = *((float4 *)(smem_d + load_smem_d_off + ( 0 * N_PAD)));
// float4 d_0_1 = *((float4 *)(smem_d + load_smem_d_off + ( 4 * N_PAD)));
// float4 d_0_2 = *((float4 *)(smem_d + load_smem_d_off + ( 8 * N_PAD)));
// float4 d_0_3 = *((float4 *)(smem_d + load_smem_d_off + (12 * N_PAD)));
// float4 d_0_4 = *((float4 *)(smem_d + load_smem_d_off + (16 * N_PAD)));
// float4 d_0_5 = *((float4 *)(smem_d + load_smem_d_off + (20 * N_PAD)));
// float4 d_0_6 = *((float4 *)(smem_d + load_smem_d_off + (24 * N_PAD)));
// float4 d_0_7 = *((float4 *)(smem_d + load_smem_d_off + (28 * N_PAD)));
// __syncthreads();
// smem_d[smem_d_off + 0 + ( 0*8) ] = acc_frag_1_0.x;
// smem_d[smem_d_off + 1 + ( 0*8) ] = acc_frag_1_0.y;
// smem_d[smem_d_off + 0 + ( 0*8) + (8*N_PAD)] = acc_frag_1_0.z;
// smem_d[smem_d_off + 1 + ( 0*8) + (8*N_PAD)] = acc_frag_1_0.w;
// smem_d[smem_d_off + 0 + ( 1*8) ] = acc_frag_1_1.x;
// smem_d[smem_d_off + 1 + ( 1*8) ] = acc_frag_1_1.y;
// smem_d[smem_d_off + 0 + ( 1*8) + (8*N_PAD)] = acc_frag_1_1.z;
// smem_d[smem_d_off + 1 + ( 1*8) + (8*N_PAD)] = acc_frag_1_1.w;
// smem_d[smem_d_off + 0 + ( 4*8) ] = acc_frag_1_2.x;
// smem_d[smem_d_off + 1 + ( 4*8) ] = acc_frag_1_2.y;
// smem_d[smem_d_off + 0 + ( 4*8) + (8*N_PAD)] = acc_frag_1_2.z;
// smem_d[smem_d_off + 1 + ( 4*8) + (8*N_PAD)] = acc_frag_1_2.w;
// smem_d[smem_d_off + 0 + ( 5*8) ] = acc_frag_1_3.x;
// smem_d[smem_d_off + 1 + ( 5*8) ] = acc_frag_1_3.y;
// smem_d[smem_d_off + 0 + ( 5*8) + (8*N_PAD)] = acc_frag_1_3.z;
// smem_d[smem_d_off + 1 + ( 5*8) + (8*N_PAD)] = acc_frag_1_3.w;
// smem_d[smem_d_off + 0 + ( 8*8) ] = acc_frag_1_4.x;
// smem_d[smem_d_off + 1 + ( 8*8) ] = acc_frag_1_4.y;
// smem_d[smem_d_off + 0 + ( 8*8) + (8*N_PAD)] = acc_frag_1_4.z;
// smem_d[smem_d_off + 1 + ( 8*8) + (8*N_PAD)] = acc_frag_1_4.w;
// smem_d[smem_d_off + 0 + ( 9*8) ] = acc_frag_1_5.x;
// smem_d[smem_d_off + 1 + ( 9*8) ] = acc_frag_1_5.y;
// smem_d[smem_d_off + 0 + ( 9*8) + (8*N_PAD)] = acc_frag_1_5.z;
// smem_d[smem_d_off + 1 + ( 9*8) + (8*N_PAD)] = acc_frag_1_5.w;
// smem_d[smem_d_off + 0 + (12*8) ] = acc_frag_1_6.x;
// smem_d[smem_d_off + 1 + (12*8) ] = acc_frag_1_6.y;
// smem_d[smem_d_off + 0 + (12*8) + (8*N_PAD)] = acc_frag_1_6.z;
// smem_d[smem_d_off + 1 + (12*8) + (8*N_PAD)] = acc_frag_1_6.w;
// smem_d[smem_d_off + 0 + (13*8) ] = acc_frag_1_7.x;
// smem_d[smem_d_off + 1 + (13*8) ] = acc_frag_1_7.y;
// smem_d[smem_d_off + 0 + (13*8) + (8*N_PAD)] = acc_frag_1_7.z;
// smem_d[smem_d_off + 1 + (13*8) + (8*N_PAD)] = acc_frag_1_7.w;
// __syncthreads();
// float4 d_1_0 = *((float4 *)(smem_d + load_smem_d_off + ( 0 * N_PAD)));
// float4 d_1_1 = *((float4 *)(smem_d + load_smem_d_off + ( 4 * N_PAD)));
// float4 d_1_2 = *((float4 *)(smem_d + load_smem_d_off + ( 8 * N_PAD)));
// float4 d_1_3 = *((float4 *)(smem_d + load_smem_d_off + (12 * N_PAD)));
// float4 d_1_4 = *((float4 *)(smem_d + load_smem_d_off + (16 * N_PAD)));
// float4 d_1_5 = *((float4 *)(smem_d + load_smem_d_off + (20 * N_PAD)));
// float4 d_1_6 = *((float4 *)(smem_d + load_smem_d_off + (24 * N_PAD)));
// float4 d_1_7 = *((float4 *)(smem_d + load_smem_d_off + (28 * N_PAD)));
// __syncthreads();
// float *global_d = &data0[((grid_m * 64) * N) + (grid_n * 128) + ((threads % 32) * 4) + ((threads / 32) * N)];
// *((float4 *)(global_d + 0*N)) = d_0_0;
// *((float4 *)(global_d + 4*N)) = d_0_1;
// *((float4 *)(global_d + 8*N)) = d_0_2;
// *((float4 *)(global_d + 12*N)) = d_0_3;
// *((float4 *)(global_d + 16*N)) = d_0_4;
// *((float4 *)(global_d + 20*N)) = d_0_5;
// *((float4 *)(global_d + 24*N)) = d_0_6;
// *((float4 *)(global_d + 28*N)) = d_0_7;
// *((float4 *)(global_d + 32*N)) = d_1_0;
// *((float4 *)(global_d + 36*N)) = d_1_1;
// *((float4 *)(global_d + 40*N)) = d_1_2;
// *((float4 *)(global_d + 44*N)) = d_1_3;
// *((float4 *)(global_d + 48*N)) = d_1_4;
// *((float4 *)(global_d + 52*N)) = d_1_5;
// *((float4 *)(global_d + 56*N)) = d_1_6;
// *((float4 *)(global_d + 60*N)) = d_1_7;
// slower way: write floats one by one to data0
size_t wg_c_off = ((grid_m * 64) * N) + (grid_n * 128) + (wg_m * 16 * N) + (wg_n * 16);
size_t thread_c_off = ((wg_threads % 4) * 2) + (((wg_threads / 4) % 8) * N);
data0[wg_c_off + thread_c_off + 0 + ( 0*8)] = acc_frag_0_0.x;
data0[wg_c_off + thread_c_off + 1 + ( 0*8)] = acc_frag_0_0.y;
data0[wg_c_off + thread_c_off + (8 * N) + 0 + ( 0*8)] = acc_frag_0_0.z;
data0[wg_c_off + thread_c_off + (8 * N) + 1 + ( 0*8)] = acc_frag_0_0.w;
data0[wg_c_off + thread_c_off + 0 + ( 1*8)] = acc_frag_0_1.x;
data0[wg_c_off + thread_c_off + 1 + ( 1*8)] = acc_frag_0_1.y;
data0[wg_c_off + thread_c_off + (8 * N) + 0 + ( 1*8)] = acc_frag_0_1.z;
data0[wg_c_off + thread_c_off + (8 * N) + 1 + ( 1*8)] = acc_frag_0_1.w;
data0[wg_c_off + thread_c_off + 0 + ( 4*8)] = acc_frag_0_2.x;
data0[wg_c_off + thread_c_off + 1 + ( 4*8)] = acc_frag_0_2.y;
data0[wg_c_off + thread_c_off + (8 * N) + 0 + ( 4*8)] = acc_frag_0_2.z;
data0[wg_c_off + thread_c_off + (8 * N) + 1 + ( 4*8)] = acc_frag_0_2.w;
data0[wg_c_off + thread_c_off + 0 + ( 5*8)] = acc_frag_0_3.x;
data0[wg_c_off + thread_c_off + 1 + ( 5*8)] = acc_frag_0_3.y;
data0[wg_c_off + thread_c_off + (8 * N) + 0 + ( 5*8)] = acc_frag_0_3.z;
data0[wg_c_off + thread_c_off + (8 * N) + 1 + ( 5*8)] = acc_frag_0_3.w;
data0[wg_c_off + thread_c_off + 0 + ( 8*8)] = acc_frag_0_4.x;
data0[wg_c_off + thread_c_off + 1 + ( 8*8)] = acc_frag_0_4.y;
data0[wg_c_off + thread_c_off + (8 * N) + 0 + ( 8*8)] = acc_frag_0_4.z;
data0[wg_c_off + thread_c_off + (8 * N) + 1 + ( 8*8)] = acc_frag_0_4.w;
data0[wg_c_off + thread_c_off + 0 + ( 9*8)] = acc_frag_0_5.x;
data0[wg_c_off + thread_c_off + 1 + ( 9*8)] = acc_frag_0_5.y;
data0[wg_c_off + thread_c_off + (8 * N) + 0 + ( 9*8)] = acc_frag_0_5.z;
data0[wg_c_off + thread_c_off + (8 * N) + 1 + ( 9*8)] = acc_frag_0_5.w;
data0[wg_c_off + thread_c_off + 0 + (12*8)] = acc_frag_0_6.x;
data0[wg_c_off + thread_c_off + 1 + (12*8)] = acc_frag_0_6.y;
data0[wg_c_off + thread_c_off + (8 * N) + 0 + (12*8)] = acc_frag_0_6.z;
data0[wg_c_off + thread_c_off + (8 * N) + 1 + (12*8)] = acc_frag_0_6.w;
data0[wg_c_off + thread_c_off + 0 + (13*8)] = acc_frag_0_7.x;
data0[wg_c_off + thread_c_off + 1 + (13*8)] = acc_frag_0_7.y;
data0[wg_c_off + thread_c_off + (8 * N) + 0 + (13*8)] = acc_frag_0_7.z;
data0[wg_c_off + thread_c_off + (8 * N) + 1 + (13*8)] = acc_frag_0_7.w;
wg_c_off += 32*N;
data0[wg_c_off + thread_c_off + 0 + ( 0*8)] = acc_frag_1_0.x;
data0[wg_c_off + thread_c_off + 1 + ( 0*8)] = acc_frag_1_0.y;
data0[wg_c_off + thread_c_off + (8 * N) + 0 + ( 0*8)] = acc_frag_1_0.z;
data0[wg_c_off + thread_c_off + (8 * N) + 1 + ( 0*8)] = acc_frag_1_0.w;
data0[wg_c_off + thread_c_off + 0 + ( 1*8)] = acc_frag_1_1.x;
data0[wg_c_off + thread_c_off + 1 + ( 1*8)] = acc_frag_1_1.y;
data0[wg_c_off + thread_c_off + (8 * N) + 0 + ( 1*8)] = acc_frag_1_1.z;
data0[wg_c_off + thread_c_off + (8 * N) + 1 + ( 1*8)] = acc_frag_1_1.w;
data0[wg_c_off + thread_c_off + 0 + ( 4*8)] = acc_frag_1_2.x;
data0[wg_c_off + thread_c_off + 1 + ( 4*8)] = acc_frag_1_2.y;
data0[wg_c_off + thread_c_off + (8 * N) + 0 + ( 4*8)] = acc_frag_1_2.z;
data0[wg_c_off + thread_c_off + (8 * N) + 1 + ( 4*8)] = acc_frag_1_2.w;
data0[wg_c_off + thread_c_off + 0 + ( 5*8)] = acc_frag_1_3.x;
data0[wg_c_off + thread_c_off + 1 + ( 5*8)] = acc_frag_1_3.y;
data0[wg_c_off + thread_c_off + (8 * N) + 0 + ( 5*8)] = acc_frag_1_3.z;
data0[wg_c_off + thread_c_off + (8 * N) + 1 + ( 5*8)] = acc_frag_1_3.w;
data0[wg_c_off + thread_c_off + 0 + ( 8*8)] = acc_frag_1_4.x;
data0[wg_c_off + thread_c_off + 1 + ( 8*8)] = acc_frag_1_4.y;
data0[wg_c_off + thread_c_off + (8 * N) + 0 + ( 8*8)] = acc_frag_1_4.z;
data0[wg_c_off + thread_c_off + (8 * N) + 1 + ( 8*8)] = acc_frag_1_4.w;
data0[wg_c_off + thread_c_off + 0 + ( 9*8)] = acc_frag_1_5.x;
data0[wg_c_off + thread_c_off + 1 + ( 9*8)] = acc_frag_1_5.y;
data0[wg_c_off + thread_c_off + (8 * N) + 0 + ( 9*8)] = acc_frag_1_5.z;
data0[wg_c_off + thread_c_off + (8 * N) + 1 + ( 9*8)] = acc_frag_1_5.w;
data0[wg_c_off + thread_c_off + 0 + (12*8)] = acc_frag_1_6.x;
data0[wg_c_off + thread_c_off + 1 + (12*8)] = acc_frag_1_6.y;
data0[wg_c_off + thread_c_off + (8 * N) + 0 + (12*8)] = acc_frag_1_6.z;
data0[wg_c_off + thread_c_off + (8 * N) + 1 + (12*8)] = acc_frag_1_6.w;
data0[wg_c_off + thread_c_off + 0 + (13*8)] = acc_frag_1_7.x;
data0[wg_c_off + thread_c_off + 1 + (13*8)] = acc_frag_1_7.y;
data0[wg_c_off + thread_c_off + (8 * N) + 0 + (13*8)] = acc_frag_1_7.z;
data0[wg_c_off + thread_c_off + (8 * N) + 1 + (13*8)] = acc_frag_1_7.w;
}

View File

@@ -0,0 +1,465 @@
#define INFINITY (__int_as_float(0x7f800000))
#define NAN (__int_as_float(0x7fffffff))
#include <cuda_fp16.h>
#include <cuda_pipeline.h>
#define N_PAD 132
struct __align__(8) half4 { half x, y, z, w; };
__device__ half4 make_half4(half x, half y, half z, half w) { half4 r={x, y, z, w}; return r; }
struct __align__(16) half8 { half x, y, z, w, a, b, c, d; };
__device__ half8 make_half8(half x, half y, half z, half w, half a, half b, half c, half d) { half8 r={x, y, z, w, a, b, c, d}; return r; }
__device__ void __ldmatrix_a_elems(half8 *regs, half *smem) {
uint32_t reg0, reg1, reg2, reg3;
asm volatile(
"ldmatrix.sync.aligned.m8n8.x4.shared.b16 {%0, %1, %2, %3}, [%4];"
: "=r"(reg0), "=r"(reg1), "=r"(reg2), "=r"(reg3)
: "l"(__cvta_generic_to_shared(smem))
);
uint32_t *addr = reinterpret_cast<uint32_t*>(regs);
addr[0] = reg0;
addr[1] = reg1;
addr[2] = reg2;
addr[3] = reg3;
}
__device__ void __ldmatrix_b_elems(half4 *regs_lo, half4 *regs_hi, half *smem) {
uint32_t reg0, reg1, reg2, reg3;
asm volatile(
"ldmatrix.sync.aligned.m8n8.x4.trans.shared.b16 {%0, %1, %2, %3}, [%4];"
: "=r"(reg0), "=r"(reg1), "=r"(reg2), "=r"(reg3)
: "l"(__cvta_generic_to_shared(smem))
);
uint32_t *addr_lo = reinterpret_cast<uint32_t*>(regs_lo);
uint32_t *addr_hi = reinterpret_cast<uint32_t*>(regs_hi);
addr_lo[0] = reg0;
addr_lo[1] = reg1;
addr_hi[0] = reg2;
addr_hi[1] = reg3;
}
__device__ half4 __WMMA_8_16_16_half_half(half8 a, half4 b, half4 c) {
int *a_pk = (int *) (&a), *b_pk = (int *) (&b), *c_pk = (int *) (&c);
asm( "mma.sync.aligned.m16n8k16.row.col.f16.f16.f16.f16 { %0, %1 }, { %2, %3, %4, %5 }, { %6, %7 }, { %0, %1 };"
: "+r"(c_pk[0]), "+r"(c_pk[1]): "r"(a_pk[0]), "r"(a_pk[1]), "r"(a_pk[2]), "r"(a_pk[3]), "r"(b_pk[0]), "r"(b_pk[1]) );
return c;
}
extern "C" __global__ void __launch_bounds__(256) wmma_example(half* data0, const half* data1, const half* data2, int N, int K) {
extern __shared__ char smem[];
half *smem_a_0 = (half *)(smem);
half *smem_a_1 = (half *)(smem + 16384);
half *smem_a_2 = (half *)(smem + 32768);
half *smem_b_0 = (half *)(smem + 49152);
half *smem_b_1 = (half *)(smem + 57344);
half *smem_b_2 = (half *)(smem + 65536);
int grid_m = blockIdx.x; /* M//256 */
int grid_n = blockIdx.y; /* N//128 */
int wg_threads = threadIdx.x; // 32
int wg_m = threadIdx.y; // 4
int wg_n = threadIdx.z; // 2
int threads = threadIdx.x + (threadIdx.y * 32) + (threadIdx.z * 128); /* 256 */
int num_k_blocks = K / 32;
// load indexes
size_t global_a_off = ((grid_m * 256) * K) + ((threads % 4) * 8) + ((threads / 4) * K);
size_t global_b_off = (grid_n * 128) + ((threads % 16) * 8) + ((threads / 16) * N);
// unswizzed smem store
size_t store_smem_a_off = ((threads % 4) * 8) + ((threads / 4) * 32); // 64 rows / 32 cols per copy
size_t store_smem_b_off = ((threads % 16) * 8) + ((threads / 16) * 128); // 16 rows / 128 cols per copy
// ldmatrix indices
// threads 0-7 are row starts for A, 8-15 for B, 16-23 for C, 24-31 for D
// [ A | C ]
// [ - + - ]
// [ B | D ]
// unswizzed ldmatrix
size_t load_smem_a_0_k_0 = (wg_m * 16 * 32) + ((wg_threads % 16) * 32) + ((wg_threads / 16) * 8);
size_t load_smem_a_1_k_0 = load_smem_a_0_k_0 + ( 64 * 32);
size_t load_smem_a_2_k_0 = load_smem_a_0_k_0 + (128 * 32);
size_t load_smem_a_3_k_0 = load_smem_a_0_k_0 + (192 * 32);
size_t load_smem_a_0_k_1 = load_smem_a_0_k_0 + 16;
size_t load_smem_a_1_k_1 = load_smem_a_0_k_1 + ( 64 * 32);
size_t load_smem_a_2_k_1 = load_smem_a_0_k_1 + (128 * 32);
size_t load_smem_a_3_k_1 = load_smem_a_0_k_1 + (192 * 32);
size_t load_smem_b_0_k_0 = (wg_n * 16) + ((wg_threads % 16) * 128) + ((wg_threads / 16) * 8);
size_t load_smem_b_1_k_0 = load_smem_b_0_k_0 + 32;
size_t load_smem_b_2_k_0 = load_smem_b_0_k_0 + 64;
size_t load_smem_b_3_k_0 = load_smem_b_0_k_0 + 96;
size_t load_smem_b_0_k_1 = load_smem_b_0_k_0 + (16 * 128);
size_t load_smem_b_1_k_1 = load_smem_b_0_k_1 + 32;
size_t load_smem_b_2_k_1 = load_smem_b_0_k_1 + 64;
size_t load_smem_b_3_k_1 = load_smem_b_0_k_1 + 96;
// create accs (M=4, N=8)
half4 acc_frag_0_0 = make_half4(0.0f,0.0f,0.0f,0.0f);
half4 acc_frag_0_1 = make_half4(0.0f,0.0f,0.0f,0.0f);
half4 acc_frag_0_2 = make_half4(0.0f,0.0f,0.0f,0.0f);
half4 acc_frag_0_3 = make_half4(0.0f,0.0f,0.0f,0.0f);
half4 acc_frag_0_4 = make_half4(0.0f,0.0f,0.0f,0.0f);
half4 acc_frag_0_5 = make_half4(0.0f,0.0f,0.0f,0.0f);
half4 acc_frag_0_6 = make_half4(0.0f,0.0f,0.0f,0.0f);
half4 acc_frag_0_7 = make_half4(0.0f,0.0f,0.0f,0.0f);
half4 acc_frag_1_0 = make_half4(0.0f,0.0f,0.0f,0.0f);
half4 acc_frag_1_1 = make_half4(0.0f,0.0f,0.0f,0.0f);
half4 acc_frag_1_2 = make_half4(0.0f,0.0f,0.0f,0.0f);
half4 acc_frag_1_3 = make_half4(0.0f,0.0f,0.0f,0.0f);
half4 acc_frag_1_4 = make_half4(0.0f,0.0f,0.0f,0.0f);
half4 acc_frag_1_5 = make_half4(0.0f,0.0f,0.0f,0.0f);
half4 acc_frag_1_6 = make_half4(0.0f,0.0f,0.0f,0.0f);
half4 acc_frag_1_7 = make_half4(0.0f,0.0f,0.0f,0.0f);
half4 acc_frag_2_0 = make_half4(0.0f,0.0f,0.0f,0.0f);
half4 acc_frag_2_1 = make_half4(0.0f,0.0f,0.0f,0.0f);
half4 acc_frag_2_2 = make_half4(0.0f,0.0f,0.0f,0.0f);
half4 acc_frag_2_3 = make_half4(0.0f,0.0f,0.0f,0.0f);
half4 acc_frag_2_4 = make_half4(0.0f,0.0f,0.0f,0.0f);
half4 acc_frag_2_5 = make_half4(0.0f,0.0f,0.0f,0.0f);
half4 acc_frag_2_6 = make_half4(0.0f,0.0f,0.0f,0.0f);
half4 acc_frag_2_7 = make_half4(0.0f,0.0f,0.0f,0.0f);
half4 acc_frag_3_0 = make_half4(0.0f,0.0f,0.0f,0.0f);
half4 acc_frag_3_1 = make_half4(0.0f,0.0f,0.0f,0.0f);
half4 acc_frag_3_2 = make_half4(0.0f,0.0f,0.0f,0.0f);
half4 acc_frag_3_3 = make_half4(0.0f,0.0f,0.0f,0.0f);
half4 acc_frag_3_4 = make_half4(0.0f,0.0f,0.0f,0.0f);
half4 acc_frag_3_5 = make_half4(0.0f,0.0f,0.0f,0.0f);
half4 acc_frag_3_6 = make_half4(0.0f,0.0f,0.0f,0.0f);
half4 acc_frag_3_7 = make_half4(0.0f,0.0f,0.0f,0.0f);
// create registers for block A elements
half8 a_frag_0_k_0;
half8 a_frag_1_k_0;
half8 a_frag_2_k_0;
half8 a_frag_3_k_0;
half8 a_frag_0_k_1;
half8 a_frag_1_k_1;
half8 a_frag_2_k_1;
half8 a_frag_3_k_1;
// create register for block B elements
half4 b_frag_0_k_0;
half4 b_frag_1_k_0;
half4 b_frag_2_k_0;
half4 b_frag_3_k_0;
half4 b_frag_4_k_0;
half4 b_frag_5_k_0;
half4 b_frag_6_k_0;
half4 b_frag_7_k_0;
half4 b_frag_0_k_1;
half4 b_frag_1_k_1;
half4 b_frag_2_k_1;
half4 b_frag_3_k_1;
half4 b_frag_4_k_1;
half4 b_frag_5_k_1;
half4 b_frag_6_k_1;
half4 b_frag_7_k_1;
__syncthreads();
// load first tile
__pipeline_memcpy_async(&smem_a_0[store_smem_a_off + ( 0)], &data1[global_a_off + ( 0)], 16);
__pipeline_memcpy_async(&smem_a_0[store_smem_a_off + ( 64*32)], &data1[global_a_off + ( 64*K)], 16);
__pipeline_memcpy_async(&smem_a_0[store_smem_a_off + (128*32)], &data1[global_a_off + (128*K)], 16);
__pipeline_memcpy_async(&smem_a_0[store_smem_a_off + (192*32)], &data1[global_a_off + (192*K)], 16);
__pipeline_memcpy_async(&smem_b_0[store_smem_b_off + ( 0)], &data2[global_b_off + ( 0)], 16);
__pipeline_memcpy_async(&smem_b_0[store_smem_b_off + (16*128)], &data2[global_b_off + ( 16*N)], 16);
__pipeline_commit();
global_a_off += 32;
global_b_off += 32 * N;
// load second tile
__pipeline_memcpy_async(&smem_a_1[store_smem_a_off + ( 0)], &data1[global_a_off + ( 0)], 16);
__pipeline_memcpy_async(&smem_a_1[store_smem_a_off + ( 64*32)], &data1[global_a_off + ( 64*K)], 16);
__pipeline_memcpy_async(&smem_a_1[store_smem_a_off + (128*32)], &data1[global_a_off + (128*K)], 16);
__pipeline_memcpy_async(&smem_a_1[store_smem_a_off + (192*32)], &data1[global_a_off + (192*K)], 16);
__pipeline_memcpy_async(&smem_b_1[store_smem_b_off + ( 0)], &data2[global_b_off + ( 0)], 16);
__pipeline_memcpy_async(&smem_b_1[store_smem_b_off + (16*128)], &data2[global_b_off + ( 16*N)], 16);
__pipeline_commit();
global_a_off += 32;
global_b_off += 32 * N;
// wait on first pre-fetch load
__pipeline_wait_prior(1);
__syncthreads();
// load K=0 for the first tile
__ldmatrix_a_elems(&a_frag_0_k_0, &smem_a_0[load_smem_a_0_k_0]);
__ldmatrix_a_elems(&a_frag_1_k_0, &smem_a_0[load_smem_a_1_k_0]);
__ldmatrix_a_elems(&a_frag_2_k_0, &smem_a_0[load_smem_a_2_k_0]);
__ldmatrix_a_elems(&a_frag_3_k_0, &smem_a_0[load_smem_a_3_k_0]);
__ldmatrix_b_elems(&b_frag_0_k_0, &b_frag_1_k_0, &smem_b_0[load_smem_b_0_k_0]);
__ldmatrix_b_elems(&b_frag_2_k_0, &b_frag_3_k_0, &smem_b_0[load_smem_b_1_k_0]);
__ldmatrix_b_elems(&b_frag_4_k_0, &b_frag_5_k_0, &smem_b_0[load_smem_b_2_k_0]);
__ldmatrix_b_elems(&b_frag_6_k_0, &b_frag_7_k_0, &smem_b_0[load_smem_b_3_k_0]);
for (int block_k = 0; block_k < num_k_blocks; block_k++) {
int phase_k = block_k % 3;
half *smem_a_curr = (phase_k == 0) ? smem_a_0 : ((phase_k == 1) ? smem_a_1 : smem_a_2);
half *smem_b_curr = (phase_k == 0) ? smem_b_0 : ((phase_k == 1) ? smem_b_1 : smem_b_2);
int next_phase_k = (block_k+1) % 3;
half *smem_a_next = (next_phase_k == 0) ? smem_a_0 : ((next_phase_k == 1) ? smem_a_1 : smem_a_2);
half *smem_b_next = (next_phase_k == 0) ? smem_b_0 : ((next_phase_k == 1) ? smem_b_1 : smem_b_2);
int store_phase_k = (block_k+2) % 3;
half *smem_a_store = (store_phase_k == 0) ? smem_a_0 : ((store_phase_k == 1) ? smem_a_1 : smem_a_2);
half *smem_b_store = (store_phase_k == 0) ? smem_b_0 : ((store_phase_k == 1) ? smem_b_1 : smem_b_2);
// load K=1 elements for the current tile
__ldmatrix_a_elems(&a_frag_0_k_1, &smem_a_curr[load_smem_a_0_k_1]);
__ldmatrix_a_elems(&a_frag_1_k_1, &smem_a_curr[load_smem_a_1_k_1]);
__ldmatrix_a_elems(&a_frag_2_k_1, &smem_a_curr[load_smem_a_2_k_1]);
__ldmatrix_a_elems(&a_frag_3_k_1, &smem_a_curr[load_smem_a_3_k_1]);
__ldmatrix_b_elems(&b_frag_0_k_1, &b_frag_1_k_1, &smem_b_curr[load_smem_b_0_k_1]);
__ldmatrix_b_elems(&b_frag_2_k_1, &b_frag_3_k_1, &smem_b_curr[load_smem_b_1_k_1]);
__ldmatrix_b_elems(&b_frag_4_k_1, &b_frag_5_k_1, &smem_b_curr[load_smem_b_2_k_1]);
__ldmatrix_b_elems(&b_frag_6_k_1, &b_frag_7_k_1, &smem_b_curr[load_smem_b_3_k_1]);
// MMA K=0, (M=4 x N=8)
acc_frag_0_0 = __WMMA_8_16_16_half_half(a_frag_0_k_0, b_frag_0_k_0, acc_frag_0_0);
acc_frag_0_1 = __WMMA_8_16_16_half_half(a_frag_0_k_0, b_frag_1_k_0, acc_frag_0_1);
acc_frag_0_2 = __WMMA_8_16_16_half_half(a_frag_0_k_0, b_frag_2_k_0, acc_frag_0_2);
acc_frag_0_3 = __WMMA_8_16_16_half_half(a_frag_0_k_0, b_frag_3_k_0, acc_frag_0_3);
acc_frag_0_4 = __WMMA_8_16_16_half_half(a_frag_0_k_0, b_frag_4_k_0, acc_frag_0_4);
acc_frag_0_5 = __WMMA_8_16_16_half_half(a_frag_0_k_0, b_frag_5_k_0, acc_frag_0_5);
acc_frag_0_6 = __WMMA_8_16_16_half_half(a_frag_0_k_0, b_frag_6_k_0, acc_frag_0_6);
acc_frag_0_7 = __WMMA_8_16_16_half_half(a_frag_0_k_0, b_frag_7_k_0, acc_frag_0_7);
acc_frag_1_0 = __WMMA_8_16_16_half_half(a_frag_1_k_0, b_frag_0_k_0, acc_frag_1_0);
acc_frag_1_1 = __WMMA_8_16_16_half_half(a_frag_1_k_0, b_frag_1_k_0, acc_frag_1_1);
acc_frag_1_2 = __WMMA_8_16_16_half_half(a_frag_1_k_0, b_frag_2_k_0, acc_frag_1_2);
acc_frag_1_3 = __WMMA_8_16_16_half_half(a_frag_1_k_0, b_frag_3_k_0, acc_frag_1_3);
acc_frag_1_4 = __WMMA_8_16_16_half_half(a_frag_1_k_0, b_frag_4_k_0, acc_frag_1_4);
acc_frag_1_5 = __WMMA_8_16_16_half_half(a_frag_1_k_0, b_frag_5_k_0, acc_frag_1_5);
acc_frag_1_6 = __WMMA_8_16_16_half_half(a_frag_1_k_0, b_frag_6_k_0, acc_frag_1_6);
acc_frag_1_7 = __WMMA_8_16_16_half_half(a_frag_1_k_0, b_frag_7_k_0, acc_frag_1_7);
acc_frag_2_0 = __WMMA_8_16_16_half_half(a_frag_2_k_0, b_frag_0_k_0, acc_frag_2_0);
acc_frag_2_1 = __WMMA_8_16_16_half_half(a_frag_2_k_0, b_frag_1_k_0, acc_frag_2_1);
acc_frag_2_2 = __WMMA_8_16_16_half_half(a_frag_2_k_0, b_frag_2_k_0, acc_frag_2_2);
acc_frag_2_3 = __WMMA_8_16_16_half_half(a_frag_2_k_0, b_frag_3_k_0, acc_frag_2_3);
acc_frag_2_4 = __WMMA_8_16_16_half_half(a_frag_2_k_0, b_frag_4_k_0, acc_frag_2_4);
acc_frag_2_5 = __WMMA_8_16_16_half_half(a_frag_2_k_0, b_frag_5_k_0, acc_frag_2_5);
acc_frag_2_6 = __WMMA_8_16_16_half_half(a_frag_2_k_0, b_frag_6_k_0, acc_frag_2_6);
acc_frag_2_7 = __WMMA_8_16_16_half_half(a_frag_2_k_0, b_frag_7_k_0, acc_frag_2_7);
acc_frag_3_0 = __WMMA_8_16_16_half_half(a_frag_3_k_0, b_frag_0_k_0, acc_frag_3_0);
acc_frag_3_1 = __WMMA_8_16_16_half_half(a_frag_3_k_0, b_frag_1_k_0, acc_frag_3_1);
acc_frag_3_2 = __WMMA_8_16_16_half_half(a_frag_3_k_0, b_frag_2_k_0, acc_frag_3_2);
acc_frag_3_3 = __WMMA_8_16_16_half_half(a_frag_3_k_0, b_frag_3_k_0, acc_frag_3_3);
acc_frag_3_4 = __WMMA_8_16_16_half_half(a_frag_3_k_0, b_frag_4_k_0, acc_frag_3_4);
acc_frag_3_5 = __WMMA_8_16_16_half_half(a_frag_3_k_0, b_frag_5_k_0, acc_frag_3_5);
acc_frag_3_6 = __WMMA_8_16_16_half_half(a_frag_3_k_0, b_frag_6_k_0, acc_frag_3_6);
acc_frag_3_7 = __WMMA_8_16_16_half_half(a_frag_3_k_0, b_frag_7_k_0, acc_frag_3_7);
// load next tile
if (block_k < (num_k_blocks-2)) {
__pipeline_memcpy_async(&smem_a_store[store_smem_a_off + ( 0)], &data1[global_a_off + ( 0)], 16);
__pipeline_memcpy_async(&smem_a_store[store_smem_a_off + ( 64*32)], &data1[global_a_off + ( 64*K)], 16);
__pipeline_memcpy_async(&smem_a_store[store_smem_a_off + (128*32)], &data1[global_a_off + (128*K)], 16);
__pipeline_memcpy_async(&smem_a_store[store_smem_a_off + (192*32)], &data1[global_a_off + (192*K)], 16);
__pipeline_memcpy_async(&smem_b_store[store_smem_b_off + ( 0)], &data2[global_b_off + ( 0)], 16);
__pipeline_memcpy_async(&smem_b_store[store_smem_b_off + (16*128)], &data2[global_b_off + ( 16*N)], 16);
global_a_off += 32;
global_b_off += 32 * N;
}
__pipeline_commit();
// wait next tile
__pipeline_wait_prior(1);
__syncthreads();
// load K=0 for the next tile
__ldmatrix_a_elems(&a_frag_0_k_0, &smem_a_next[load_smem_a_0_k_0]);
__ldmatrix_a_elems(&a_frag_1_k_0, &smem_a_next[load_smem_a_1_k_0]);
__ldmatrix_a_elems(&a_frag_2_k_0, &smem_a_next[load_smem_a_2_k_0]);
__ldmatrix_a_elems(&a_frag_3_k_0, &smem_a_next[load_smem_a_3_k_0]);
__ldmatrix_b_elems(&b_frag_0_k_0, &b_frag_1_k_0, &smem_b_next[load_smem_b_0_k_0]);
__ldmatrix_b_elems(&b_frag_2_k_0, &b_frag_3_k_0, &smem_b_next[load_smem_b_1_k_0]);
__ldmatrix_b_elems(&b_frag_4_k_0, &b_frag_5_k_0, &smem_b_next[load_smem_b_2_k_0]);
__ldmatrix_b_elems(&b_frag_6_k_0, &b_frag_7_k_0, &smem_b_next[load_smem_b_3_k_0]);
// MMA K=1, (M=4 x N=8)
acc_frag_0_0 = __WMMA_8_16_16_half_half(a_frag_0_k_1, b_frag_0_k_1, acc_frag_0_0);
acc_frag_0_1 = __WMMA_8_16_16_half_half(a_frag_0_k_1, b_frag_1_k_1, acc_frag_0_1);
acc_frag_0_2 = __WMMA_8_16_16_half_half(a_frag_0_k_1, b_frag_2_k_1, acc_frag_0_2);
acc_frag_0_3 = __WMMA_8_16_16_half_half(a_frag_0_k_1, b_frag_3_k_1, acc_frag_0_3);
acc_frag_0_4 = __WMMA_8_16_16_half_half(a_frag_0_k_1, b_frag_4_k_1, acc_frag_0_4);
acc_frag_0_5 = __WMMA_8_16_16_half_half(a_frag_0_k_1, b_frag_5_k_1, acc_frag_0_5);
acc_frag_0_6 = __WMMA_8_16_16_half_half(a_frag_0_k_1, b_frag_6_k_1, acc_frag_0_6);
acc_frag_0_7 = __WMMA_8_16_16_half_half(a_frag_0_k_1, b_frag_7_k_1, acc_frag_0_7);
acc_frag_1_0 = __WMMA_8_16_16_half_half(a_frag_1_k_1, b_frag_0_k_1, acc_frag_1_0);
acc_frag_1_1 = __WMMA_8_16_16_half_half(a_frag_1_k_1, b_frag_1_k_1, acc_frag_1_1);
acc_frag_1_2 = __WMMA_8_16_16_half_half(a_frag_1_k_1, b_frag_2_k_1, acc_frag_1_2);
acc_frag_1_3 = __WMMA_8_16_16_half_half(a_frag_1_k_1, b_frag_3_k_1, acc_frag_1_3);
acc_frag_1_4 = __WMMA_8_16_16_half_half(a_frag_1_k_1, b_frag_4_k_1, acc_frag_1_4);
acc_frag_1_5 = __WMMA_8_16_16_half_half(a_frag_1_k_1, b_frag_5_k_1, acc_frag_1_5);
acc_frag_1_6 = __WMMA_8_16_16_half_half(a_frag_1_k_1, b_frag_6_k_1, acc_frag_1_6);
acc_frag_1_7 = __WMMA_8_16_16_half_half(a_frag_1_k_1, b_frag_7_k_1, acc_frag_1_7);
acc_frag_2_0 = __WMMA_8_16_16_half_half(a_frag_2_k_1, b_frag_0_k_1, acc_frag_2_0);
acc_frag_2_1 = __WMMA_8_16_16_half_half(a_frag_2_k_1, b_frag_1_k_1, acc_frag_2_1);
acc_frag_2_2 = __WMMA_8_16_16_half_half(a_frag_2_k_1, b_frag_2_k_1, acc_frag_2_2);
acc_frag_2_3 = __WMMA_8_16_16_half_half(a_frag_2_k_1, b_frag_3_k_1, acc_frag_2_3);
acc_frag_2_4 = __WMMA_8_16_16_half_half(a_frag_2_k_1, b_frag_4_k_1, acc_frag_2_4);
acc_frag_2_5 = __WMMA_8_16_16_half_half(a_frag_2_k_1, b_frag_5_k_1, acc_frag_2_5);
acc_frag_2_6 = __WMMA_8_16_16_half_half(a_frag_2_k_1, b_frag_6_k_1, acc_frag_2_6);
acc_frag_2_7 = __WMMA_8_16_16_half_half(a_frag_2_k_1, b_frag_7_k_1, acc_frag_2_7);
acc_frag_3_0 = __WMMA_8_16_16_half_half(a_frag_3_k_1, b_frag_0_k_1, acc_frag_3_0);
acc_frag_3_1 = __WMMA_8_16_16_half_half(a_frag_3_k_1, b_frag_1_k_1, acc_frag_3_1);
acc_frag_3_2 = __WMMA_8_16_16_half_half(a_frag_3_k_1, b_frag_2_k_1, acc_frag_3_2);
acc_frag_3_3 = __WMMA_8_16_16_half_half(a_frag_3_k_1, b_frag_3_k_1, acc_frag_3_3);
acc_frag_3_4 = __WMMA_8_16_16_half_half(a_frag_3_k_1, b_frag_4_k_1, acc_frag_3_4);
acc_frag_3_5 = __WMMA_8_16_16_half_half(a_frag_3_k_1, b_frag_5_k_1, acc_frag_3_5);
acc_frag_3_6 = __WMMA_8_16_16_half_half(a_frag_3_k_1, b_frag_6_k_1, acc_frag_3_6);
acc_frag_3_7 = __WMMA_8_16_16_half_half(a_frag_3_k_1, b_frag_7_k_1, acc_frag_3_7);
}
// write accumulators to output
__pipeline_wait_prior(0);
__syncthreads();
// slower way: write accs one by one to data0
size_t wg_c_off = ((grid_m * 256) * N) + (grid_n * 128) + (wg_m * 16 * N) + (wg_n * 16);
size_t thread_c_off = ((wg_threads % 4) * 2) + (((wg_threads / 4) % 8) * N);
data0[wg_c_off + thread_c_off + 0 + ( 0*8)] = acc_frag_0_0.x;
data0[wg_c_off + thread_c_off + 1 + ( 0*8)] = acc_frag_0_0.y;
data0[wg_c_off + thread_c_off + (8 * N) + 0 + ( 0*8)] = acc_frag_0_0.z;
data0[wg_c_off + thread_c_off + (8 * N) + 1 + ( 0*8)] = acc_frag_0_0.w;
data0[wg_c_off + thread_c_off + 0 + ( 1*8)] = acc_frag_0_1.x;
data0[wg_c_off + thread_c_off + 1 + ( 1*8)] = acc_frag_0_1.y;
data0[wg_c_off + thread_c_off + (8 * N) + 0 + ( 1*8)] = acc_frag_0_1.z;
data0[wg_c_off + thread_c_off + (8 * N) + 1 + ( 1*8)] = acc_frag_0_1.w;
data0[wg_c_off + thread_c_off + 0 + ( 4*8)] = acc_frag_0_2.x;
data0[wg_c_off + thread_c_off + 1 + ( 4*8)] = acc_frag_0_2.y;
data0[wg_c_off + thread_c_off + (8 * N) + 0 + ( 4*8)] = acc_frag_0_2.z;
data0[wg_c_off + thread_c_off + (8 * N) + 1 + ( 4*8)] = acc_frag_0_2.w;
data0[wg_c_off + thread_c_off + 0 + ( 5*8)] = acc_frag_0_3.x;
data0[wg_c_off + thread_c_off + 1 + ( 5*8)] = acc_frag_0_3.y;
data0[wg_c_off + thread_c_off + (8 * N) + 0 + ( 5*8)] = acc_frag_0_3.z;
data0[wg_c_off + thread_c_off + (8 * N) + 1 + ( 5*8)] = acc_frag_0_3.w;
data0[wg_c_off + thread_c_off + 0 + ( 8*8)] = acc_frag_0_4.x;
data0[wg_c_off + thread_c_off + 1 + ( 8*8)] = acc_frag_0_4.y;
data0[wg_c_off + thread_c_off + (8 * N) + 0 + ( 8*8)] = acc_frag_0_4.z;
data0[wg_c_off + thread_c_off + (8 * N) + 1 + ( 8*8)] = acc_frag_0_4.w;
data0[wg_c_off + thread_c_off + 0 + ( 9*8)] = acc_frag_0_5.x;
data0[wg_c_off + thread_c_off + 1 + ( 9*8)] = acc_frag_0_5.y;
data0[wg_c_off + thread_c_off + (8 * N) + 0 + ( 9*8)] = acc_frag_0_5.z;
data0[wg_c_off + thread_c_off + (8 * N) + 1 + ( 9*8)] = acc_frag_0_5.w;
data0[wg_c_off + thread_c_off + 0 + (12*8)] = acc_frag_0_6.x;
data0[wg_c_off + thread_c_off + 1 + (12*8)] = acc_frag_0_6.y;
data0[wg_c_off + thread_c_off + (8 * N) + 0 + (12*8)] = acc_frag_0_6.z;
data0[wg_c_off + thread_c_off + (8 * N) + 1 + (12*8)] = acc_frag_0_6.w;
data0[wg_c_off + thread_c_off + 0 + (13*8)] = acc_frag_0_7.x;
data0[wg_c_off + thread_c_off + 1 + (13*8)] = acc_frag_0_7.y;
data0[wg_c_off + thread_c_off + (8 * N) + 0 + (13*8)] = acc_frag_0_7.z;
data0[wg_c_off + thread_c_off + (8 * N) + 1 + (13*8)] = acc_frag_0_7.w;
wg_c_off += 64*N;
data0[wg_c_off + thread_c_off + 0 + ( 0*8)] = acc_frag_1_0.x;
data0[wg_c_off + thread_c_off + 1 + ( 0*8)] = acc_frag_1_0.y;
data0[wg_c_off + thread_c_off + (8 * N) + 0 + ( 0*8)] = acc_frag_1_0.z;
data0[wg_c_off + thread_c_off + (8 * N) + 1 + ( 0*8)] = acc_frag_1_0.w;
data0[wg_c_off + thread_c_off + 0 + ( 1*8)] = acc_frag_1_1.x;
data0[wg_c_off + thread_c_off + 1 + ( 1*8)] = acc_frag_1_1.y;
data0[wg_c_off + thread_c_off + (8 * N) + 0 + ( 1*8)] = acc_frag_1_1.z;
data0[wg_c_off + thread_c_off + (8 * N) + 1 + ( 1*8)] = acc_frag_1_1.w;
data0[wg_c_off + thread_c_off + 0 + ( 4*8)] = acc_frag_1_2.x;
data0[wg_c_off + thread_c_off + 1 + ( 4*8)] = acc_frag_1_2.y;
data0[wg_c_off + thread_c_off + (8 * N) + 0 + ( 4*8)] = acc_frag_1_2.z;
data0[wg_c_off + thread_c_off + (8 * N) + 1 + ( 4*8)] = acc_frag_1_2.w;
data0[wg_c_off + thread_c_off + 0 + ( 5*8)] = acc_frag_1_3.x;
data0[wg_c_off + thread_c_off + 1 + ( 5*8)] = acc_frag_1_3.y;
data0[wg_c_off + thread_c_off + (8 * N) + 0 + ( 5*8)] = acc_frag_1_3.z;
data0[wg_c_off + thread_c_off + (8 * N) + 1 + ( 5*8)] = acc_frag_1_3.w;
data0[wg_c_off + thread_c_off + 0 + ( 8*8)] = acc_frag_1_4.x;
data0[wg_c_off + thread_c_off + 1 + ( 8*8)] = acc_frag_1_4.y;
data0[wg_c_off + thread_c_off + (8 * N) + 0 + ( 8*8)] = acc_frag_1_4.z;
data0[wg_c_off + thread_c_off + (8 * N) + 1 + ( 8*8)] = acc_frag_1_4.w;
data0[wg_c_off + thread_c_off + 0 + ( 9*8)] = acc_frag_1_5.x;
data0[wg_c_off + thread_c_off + 1 + ( 9*8)] = acc_frag_1_5.y;
data0[wg_c_off + thread_c_off + (8 * N) + 0 + ( 9*8)] = acc_frag_1_5.z;
data0[wg_c_off + thread_c_off + (8 * N) + 1 + ( 9*8)] = acc_frag_1_5.w;
data0[wg_c_off + thread_c_off + 0 + (12*8)] = acc_frag_1_6.x;
data0[wg_c_off + thread_c_off + 1 + (12*8)] = acc_frag_1_6.y;
data0[wg_c_off + thread_c_off + (8 * N) + 0 + (12*8)] = acc_frag_1_6.z;
data0[wg_c_off + thread_c_off + (8 * N) + 1 + (12*8)] = acc_frag_1_6.w;
data0[wg_c_off + thread_c_off + 0 + (13*8)] = acc_frag_1_7.x;
data0[wg_c_off + thread_c_off + 1 + (13*8)] = acc_frag_1_7.y;
data0[wg_c_off + thread_c_off + (8 * N) + 0 + (13*8)] = acc_frag_1_7.z;
data0[wg_c_off + thread_c_off + (8 * N) + 1 + (13*8)] = acc_frag_1_7.w;
wg_c_off += 64*N;
data0[wg_c_off + thread_c_off + 0 + ( 0*8)] = acc_frag_2_0.x;
data0[wg_c_off + thread_c_off + 1 + ( 0*8)] = acc_frag_2_0.y;
data0[wg_c_off + thread_c_off + (8 * N) + 0 + ( 0*8)] = acc_frag_2_0.z;
data0[wg_c_off + thread_c_off + (8 * N) + 1 + ( 0*8)] = acc_frag_2_0.w;
data0[wg_c_off + thread_c_off + 0 + ( 1*8)] = acc_frag_2_1.x;
data0[wg_c_off + thread_c_off + 1 + ( 1*8)] = acc_frag_2_1.y;
data0[wg_c_off + thread_c_off + (8 * N) + 0 + ( 1*8)] = acc_frag_2_1.z;
data0[wg_c_off + thread_c_off + (8 * N) + 1 + ( 1*8)] = acc_frag_2_1.w;
data0[wg_c_off + thread_c_off + 0 + ( 4*8)] = acc_frag_2_2.x;
data0[wg_c_off + thread_c_off + 1 + ( 4*8)] = acc_frag_2_2.y;
data0[wg_c_off + thread_c_off + (8 * N) + 0 + ( 4*8)] = acc_frag_2_2.z;
data0[wg_c_off + thread_c_off + (8 * N) + 1 + ( 4*8)] = acc_frag_2_2.w;
data0[wg_c_off + thread_c_off + 0 + ( 5*8)] = acc_frag_2_3.x;
data0[wg_c_off + thread_c_off + 1 + ( 5*8)] = acc_frag_2_3.y;
data0[wg_c_off + thread_c_off + (8 * N) + 0 + ( 5*8)] = acc_frag_2_3.z;
data0[wg_c_off + thread_c_off + (8 * N) + 1 + ( 5*8)] = acc_frag_2_3.w;
data0[wg_c_off + thread_c_off + 0 + ( 8*8)] = acc_frag_2_4.x;
data0[wg_c_off + thread_c_off + 1 + ( 8*8)] = acc_frag_2_4.y;
data0[wg_c_off + thread_c_off + (8 * N) + 0 + ( 8*8)] = acc_frag_2_4.z;
data0[wg_c_off + thread_c_off + (8 * N) + 1 + ( 8*8)] = acc_frag_2_4.w;
data0[wg_c_off + thread_c_off + 0 + ( 9*8)] = acc_frag_2_5.x;
data0[wg_c_off + thread_c_off + 1 + ( 9*8)] = acc_frag_2_5.y;
data0[wg_c_off + thread_c_off + (8 * N) + 0 + ( 9*8)] = acc_frag_2_5.z;
data0[wg_c_off + thread_c_off + (8 * N) + 1 + ( 9*8)] = acc_frag_2_5.w;
data0[wg_c_off + thread_c_off + 0 + (12*8)] = acc_frag_2_6.x;
data0[wg_c_off + thread_c_off + 1 + (12*8)] = acc_frag_2_6.y;
data0[wg_c_off + thread_c_off + (8 * N) + 0 + (12*8)] = acc_frag_2_6.z;
data0[wg_c_off + thread_c_off + (8 * N) + 1 + (12*8)] = acc_frag_2_6.w;
data0[wg_c_off + thread_c_off + 0 + (13*8)] = acc_frag_2_7.x;
data0[wg_c_off + thread_c_off + 1 + (13*8)] = acc_frag_2_7.y;
data0[wg_c_off + thread_c_off + (8 * N) + 0 + (13*8)] = acc_frag_2_7.z;
data0[wg_c_off + thread_c_off + (8 * N) + 1 + (13*8)] = acc_frag_2_7.w;
wg_c_off += 64*N;
data0[wg_c_off + thread_c_off + 0 + ( 0*8)] = acc_frag_3_0.x;
data0[wg_c_off + thread_c_off + 1 + ( 0*8)] = acc_frag_3_0.y;
data0[wg_c_off + thread_c_off + (8 * N) + 0 + ( 0*8)] = acc_frag_3_0.z;
data0[wg_c_off + thread_c_off + (8 * N) + 1 + ( 0*8)] = acc_frag_3_0.w;
data0[wg_c_off + thread_c_off + 0 + ( 1*8)] = acc_frag_3_1.x;
data0[wg_c_off + thread_c_off + 1 + ( 1*8)] = acc_frag_3_1.y;
data0[wg_c_off + thread_c_off + (8 * N) + 0 + ( 1*8)] = acc_frag_3_1.z;
data0[wg_c_off + thread_c_off + (8 * N) + 1 + ( 1*8)] = acc_frag_3_1.w;
data0[wg_c_off + thread_c_off + 0 + ( 4*8)] = acc_frag_3_2.x;
data0[wg_c_off + thread_c_off + 1 + ( 4*8)] = acc_frag_3_2.y;
data0[wg_c_off + thread_c_off + (8 * N) + 0 + ( 4*8)] = acc_frag_3_2.z;
data0[wg_c_off + thread_c_off + (8 * N) + 1 + ( 4*8)] = acc_frag_3_2.w;
data0[wg_c_off + thread_c_off + 0 + ( 5*8)] = acc_frag_3_3.x;
data0[wg_c_off + thread_c_off + 1 + ( 5*8)] = acc_frag_3_3.y;
data0[wg_c_off + thread_c_off + (8 * N) + 0 + ( 5*8)] = acc_frag_3_3.z;
data0[wg_c_off + thread_c_off + (8 * N) + 1 + ( 5*8)] = acc_frag_3_3.w;
data0[wg_c_off + thread_c_off + 0 + ( 8*8)] = acc_frag_3_4.x;
data0[wg_c_off + thread_c_off + 1 + ( 8*8)] = acc_frag_3_4.y;
data0[wg_c_off + thread_c_off + (8 * N) + 0 + ( 8*8)] = acc_frag_3_4.z;
data0[wg_c_off + thread_c_off + (8 * N) + 1 + ( 8*8)] = acc_frag_3_4.w;
data0[wg_c_off + thread_c_off + 0 + ( 9*8)] = acc_frag_3_5.x;
data0[wg_c_off + thread_c_off + 1 + ( 9*8)] = acc_frag_3_5.y;
data0[wg_c_off + thread_c_off + (8 * N) + 0 + ( 9*8)] = acc_frag_3_5.z;
data0[wg_c_off + thread_c_off + (8 * N) + 1 + ( 9*8)] = acc_frag_3_5.w;
data0[wg_c_off + thread_c_off + 0 + (12*8)] = acc_frag_3_6.x;
data0[wg_c_off + thread_c_off + 1 + (12*8)] = acc_frag_3_6.y;
data0[wg_c_off + thread_c_off + (8 * N) + 0 + (12*8)] = acc_frag_3_6.z;
data0[wg_c_off + thread_c_off + (8 * N) + 1 + (12*8)] = acc_frag_3_6.w;
data0[wg_c_off + thread_c_off + 0 + (13*8)] = acc_frag_3_7.x;
data0[wg_c_off + thread_c_off + 1 + (13*8)] = acc_frag_3_7.y;
data0[wg_c_off + thread_c_off + (8 * N) + 0 + (13*8)] = acc_frag_3_7.z;
data0[wg_c_off + thread_c_off + (8 * N) + 1 + (13*8)] = acc_frag_3_7.w;
}

View File

@@ -0,0 +1,517 @@
#define INFINITY (__int_as_float(0x7f800000))
#define NAN (__int_as_float(0x7fffffff))
#include <cuda_fp16.h>
#include <cuda_pipeline.h>
#define N_PAD 132
struct __align__(8) half4 { half x, y, z, w; };
__device__ half4 make_half4(half x, half y, half z, half w) { half4 r={x, y, z, w}; return r; }
struct __align__(16) half8 { half x, y, z, w, a, b, c, d; };
__device__ half8 make_half8(half x, half y, half z, half w, half a, half b, half c, half d) { half8 r={x, y, z, w, a, b, c, d}; return r; }
__device__ void __ldmatrix_a_elems(half8 *regs, half *smem) {
uint32_t reg0, reg1, reg2, reg3;
asm volatile(
"ldmatrix.sync.aligned.m8n8.x4.shared.b16 {%0, %1, %2, %3}, [%4];"
: "=r"(reg0), "=r"(reg1), "=r"(reg2), "=r"(reg3)
: "l"(__cvta_generic_to_shared(smem))
);
uint32_t *addr = reinterpret_cast<uint32_t*>(regs);
addr[0] = reg0;
addr[1] = reg1;
addr[2] = reg2;
addr[3] = reg3;
}
__device__ void __ldmatrix_b_elems(half4 *regs_lo, half4 *regs_hi, half *smem) {
uint32_t reg0, reg1, reg2, reg3;
asm volatile(
"ldmatrix.sync.aligned.m8n8.x4.trans.shared.b16 {%0, %1, %2, %3}, [%4];"
: "=r"(reg0), "=r"(reg1), "=r"(reg2), "=r"(reg3)
: "l"(__cvta_generic_to_shared(smem))
);
uint32_t *addr_lo = reinterpret_cast<uint32_t*>(regs_lo);
uint32_t *addr_hi = reinterpret_cast<uint32_t*>(regs_hi);
addr_lo[0] = reg0;
addr_lo[1] = reg1;
addr_hi[0] = reg2;
addr_hi[1] = reg3;
}
__device__ half4 __WMMA_8_16_16_half_half(half8 a, half4 b, half4 c) {
int *a_pk = (int *) (&a), *b_pk = (int *) (&b), *c_pk = (int *) (&c);
asm( "mma.sync.aligned.m16n8k16.row.col.f16.f16.f16.f16 { %0, %1 }, { %2, %3, %4, %5 }, { %6, %7 }, { %0, %1 };"
: "+r"(c_pk[0]), "+r"(c_pk[1]): "r"(a_pk[0]), "r"(a_pk[1]), "r"(a_pk[2]), "r"(a_pk[3]), "r"(b_pk[0]), "r"(b_pk[1]) );
return c;
}
extern "C" __global__ void __launch_bounds__(256) wmma_example(half* data0, const half* data1, const half* data2, int N, int K) {
extern __shared__ char smem[];
half *smem_a_0 = (half *)(smem);
half *smem_a_1 = (half *)(smem + 16384);
half *smem_a_2 = (half *)(smem + 32768);
half *smem_b_0 = (half *)(smem + 49152);
half *smem_b_1 = (half *)(smem + 57344);
half *smem_b_2 = (half *)(smem + 65536);
int grid_m = blockIdx.x; /* M//256 */
int grid_n = blockIdx.y; /* N//128 */
int wg_threads = threadIdx.x; // 32
int wg_m = threadIdx.y; // 4
int wg_n = threadIdx.z; // 2
int threads = threadIdx.x + (threadIdx.y * 32) + (threadIdx.z * 128); /* 256 */
int num_k_blocks = K / 32;
// ldmatrix indices
// threads 0-7 are row starts for A, 8-15 for B, 16-23 for C, 24-31 for D
// [ A | C ]
// [ - + - ]
// [ B | D ]
// unswizzled A - SMEM_A is 256 rows x 32 cols
// size_t global_a_off = ((grid_m * 256) * K) + ((threads % 4) * 8) + ((threads / 4) * K);
// size_t store_smem_a_off = ((threads % 4) * 8) + ((threads / 4) * 32); // 64 rows / 32 cols per copy
// size_t load_smem_a_0_k_0 = (wg_m * 16 * 32) + ((wg_threads % 16) * 32) + ((wg_threads / 16) * 8);
// size_t load_smem_a_1_k_0 = load_smem_a_0_k_0 + ( 64 * 32);
// size_t load_smem_a_2_k_0 = load_smem_a_0_k_0 + (128 * 32);
// size_t load_smem_a_3_k_0 = load_smem_a_0_k_0 + (192 * 32);
// size_t load_smem_a_0_k_1 = load_smem_a_0_k_0 + 16;
// size_t load_smem_a_1_k_1 = load_smem_a_0_k_1 + ( 64 * 32);
// size_t load_smem_a_2_k_1 = load_smem_a_0_k_1 + (128 * 32);
// size_t load_smem_a_3_k_1 = load_smem_a_0_k_1 + (192 * 32);
// unswizzled reshaped A - SMEM_A is 128 rows x 64 cols, [ (M=0, K=0), (M=0, K=1), (M=8, K=0), (M=8, K=1) ], etc.
// size_t global_a_off = ((grid_m * 256) * K) + ((threads % 4) * 8) + (((threads / 4) % 2) * 8 * 16 * K) + ((threads / 8) * K);
// size_t store_smem_a_off = ((threads % 8) * 8) + ((threads / 8) * 64); // 32 rows / 64 cols per copy
// size_t load_smem_a_0_k_0 = (wg_m * 16 * 64) + ((wg_threads % 16) * 64) + ((wg_threads / 16) * 8);
// size_t load_smem_a_1_k_0 = load_smem_a_0_k_0 + (64 * 64);
// size_t load_smem_a_2_k_0 = load_smem_a_0_k_0 + + 32;
// size_t load_smem_a_3_k_0 = load_smem_a_0_k_0 + (64 * 64) + 32;
// size_t load_smem_a_0_k_1 = load_smem_a_0_k_0 + 16;
// size_t load_smem_a_1_k_1 = load_smem_a_1_k_0 + 16;
// size_t load_smem_a_2_k_1 = load_smem_a_2_k_0 + 16;
// size_t load_smem_a_3_k_1 = load_smem_a_3_k_0 + 16;
// swizzled A
size_t global_a_off = ((grid_m * 256) * K) + ((threads % 4) * 8) + (((threads / 4) % 2) * 8 * 16 * K) + ((threads / 8) * K);
size_t store_smem_a_off = ((threads / 8) * 64) + (((threads * 8) ^ threads) & 56); // 32 rows / 64 cols per copy
size_t load_smem_a_row = ((wg_m * 16) + (threads % 16)) * 64;
size_t load_smem_a_phase = (threads / 16) % 2;
size_t load_smem_a_0_k_0 = load_smem_a_row + ( 0 * 64) + (((load_smem_a_phase + 0) ^ (threads % 8)) * 8);
size_t load_smem_a_1_k_0 = load_smem_a_row + (64 * 64) + (((load_smem_a_phase + 0) ^ (threads % 8)) * 8);
size_t load_smem_a_2_k_0 = load_smem_a_row + ( 0 * 64) + (((load_smem_a_phase + 4) ^ (threads % 8)) * 8);
size_t load_smem_a_3_k_0 = load_smem_a_row + (64 * 64) + (((load_smem_a_phase + 4) ^ (threads % 8)) * 8);
size_t load_smem_a_0_k_1 = load_smem_a_row + ( 0 * 64) + (((load_smem_a_phase + 2) ^ (threads % 8)) * 8);
size_t load_smem_a_1_k_1 = load_smem_a_row + (64 * 64) + (((load_smem_a_phase + 2) ^ (threads % 8)) * 8);
size_t load_smem_a_2_k_1 = load_smem_a_row + ( 0 * 64) + (((load_smem_a_phase + 6) ^ (threads % 8)) * 8);
size_t load_smem_a_3_k_1 = load_smem_a_row + (64 * 64) + (((load_smem_a_phase + 6) ^ (threads % 8)) * 8);
// unswizzed B
// size_t global_b_off = (grid_n * 128) + ((threads % 16) * 8) + ((threads / 16) * N);
// size_t store_smem_b_off = ((threads % 16) * 8) + ((threads / 16) * 128); // 16 rows / 128 cols per copy
// size_t load_smem_b_0_k_0 = (wg_n * 16) + ((wg_threads % 16) * 128) + ((wg_threads / 16) * 8);
// size_t load_smem_b_1_k_0 = load_smem_b_0_k_0 + 32;
// size_t load_smem_b_2_k_0 = load_smem_b_0_k_0 + 64;
// size_t load_smem_b_3_k_0 = load_smem_b_0_k_0 + 96;
// size_t load_smem_b_0_k_1 = load_smem_b_0_k_0 + (16 * 128);
// size_t load_smem_b_1_k_1 = load_smem_b_0_k_1 + 32;
// size_t load_smem_b_2_k_1 = load_smem_b_0_k_1 + 64;
// size_t load_smem_b_3_k_1 = load_smem_b_0_k_1 + 96;
// swizzled B
size_t global_b_off = (grid_n * 128) + ((threads % 16) * 8) + ((threads / 16) * N);
size_t store_smem_b_off = ((threads / 16) * 128) + ((((threads / 16) % 8) * 8) ^ ((threads % 16) * 8)); // 16 rows / 128 cols per copy
size_t load_smem_b_row = (threads % 16) * 128;
size_t load_smem_b_phase = (wg_n * 2) + (wg_threads / 16);
size_t load_smem_b_0_k_0 = load_smem_b_row + (((load_smem_b_phase + 0) ^ (threads % 8)) * 8);
size_t load_smem_b_1_k_0 = load_smem_b_row + (((load_smem_b_phase + 4) ^ (threads % 8)) * 8);
size_t load_smem_b_2_k_0 = load_smem_b_row + (((load_smem_b_phase + 8) ^ (threads % 8)) * 8);
size_t load_smem_b_3_k_0 = load_smem_b_row + (((load_smem_b_phase + 12) ^ (threads % 8)) * 8);
size_t load_smem_b_0_k_1 = load_smem_b_0_k_0 + (16 * 128);
size_t load_smem_b_1_k_1 = load_smem_b_1_k_0 + (16 * 128);
size_t load_smem_b_2_k_1 = load_smem_b_2_k_0 + (16 * 128);
size_t load_smem_b_3_k_1 = load_smem_b_3_k_0 + (16 * 128);
// create accs (M=4, N=8)
half4 acc_frag_0_0 = make_half4(0.0f,0.0f,0.0f,0.0f);
half4 acc_frag_0_1 = make_half4(0.0f,0.0f,0.0f,0.0f);
half4 acc_frag_0_2 = make_half4(0.0f,0.0f,0.0f,0.0f);
half4 acc_frag_0_3 = make_half4(0.0f,0.0f,0.0f,0.0f);
half4 acc_frag_0_4 = make_half4(0.0f,0.0f,0.0f,0.0f);
half4 acc_frag_0_5 = make_half4(0.0f,0.0f,0.0f,0.0f);
half4 acc_frag_0_6 = make_half4(0.0f,0.0f,0.0f,0.0f);
half4 acc_frag_0_7 = make_half4(0.0f,0.0f,0.0f,0.0f);
half4 acc_frag_1_0 = make_half4(0.0f,0.0f,0.0f,0.0f);
half4 acc_frag_1_1 = make_half4(0.0f,0.0f,0.0f,0.0f);
half4 acc_frag_1_2 = make_half4(0.0f,0.0f,0.0f,0.0f);
half4 acc_frag_1_3 = make_half4(0.0f,0.0f,0.0f,0.0f);
half4 acc_frag_1_4 = make_half4(0.0f,0.0f,0.0f,0.0f);
half4 acc_frag_1_5 = make_half4(0.0f,0.0f,0.0f,0.0f);
half4 acc_frag_1_6 = make_half4(0.0f,0.0f,0.0f,0.0f);
half4 acc_frag_1_7 = make_half4(0.0f,0.0f,0.0f,0.0f);
half4 acc_frag_2_0 = make_half4(0.0f,0.0f,0.0f,0.0f);
half4 acc_frag_2_1 = make_half4(0.0f,0.0f,0.0f,0.0f);
half4 acc_frag_2_2 = make_half4(0.0f,0.0f,0.0f,0.0f);
half4 acc_frag_2_3 = make_half4(0.0f,0.0f,0.0f,0.0f);
half4 acc_frag_2_4 = make_half4(0.0f,0.0f,0.0f,0.0f);
half4 acc_frag_2_5 = make_half4(0.0f,0.0f,0.0f,0.0f);
half4 acc_frag_2_6 = make_half4(0.0f,0.0f,0.0f,0.0f);
half4 acc_frag_2_7 = make_half4(0.0f,0.0f,0.0f,0.0f);
half4 acc_frag_3_0 = make_half4(0.0f,0.0f,0.0f,0.0f);
half4 acc_frag_3_1 = make_half4(0.0f,0.0f,0.0f,0.0f);
half4 acc_frag_3_2 = make_half4(0.0f,0.0f,0.0f,0.0f);
half4 acc_frag_3_3 = make_half4(0.0f,0.0f,0.0f,0.0f);
half4 acc_frag_3_4 = make_half4(0.0f,0.0f,0.0f,0.0f);
half4 acc_frag_3_5 = make_half4(0.0f,0.0f,0.0f,0.0f);
half4 acc_frag_3_6 = make_half4(0.0f,0.0f,0.0f,0.0f);
half4 acc_frag_3_7 = make_half4(0.0f,0.0f,0.0f,0.0f);
// create registers for block A elements
half8 a_frag_0_k_0;
half8 a_frag_1_k_0;
half8 a_frag_2_k_0;
half8 a_frag_3_k_0;
half8 a_frag_0_k_1;
half8 a_frag_1_k_1;
half8 a_frag_2_k_1;
half8 a_frag_3_k_1;
// create register for block B elements
half4 b_frag_0_k_0;
half4 b_frag_1_k_0;
half4 b_frag_2_k_0;
half4 b_frag_3_k_0;
half4 b_frag_4_k_0;
half4 b_frag_5_k_0;
half4 b_frag_6_k_0;
half4 b_frag_7_k_0;
half4 b_frag_0_k_1;
half4 b_frag_1_k_1;
half4 b_frag_2_k_1;
half4 b_frag_3_k_1;
half4 b_frag_4_k_1;
half4 b_frag_5_k_1;
half4 b_frag_6_k_1;
half4 b_frag_7_k_1;
__syncthreads();
// load first tile
// unswizzled 256 x 32
// __pipeline_memcpy_async(&smem_a_0[store_smem_a_off + ( 0)], &data1[global_a_off + ( 0)], 16);
// __pipeline_memcpy_async(&smem_a_0[store_smem_a_off + ( 64*32)], &data1[global_a_off + ( 64*K)], 16);
// __pipeline_memcpy_async(&smem_a_0[store_smem_a_off + (128*32)], &data1[global_a_off + (128*K)], 16);
// __pipeline_memcpy_async(&smem_a_0[store_smem_a_off + (192*32)], &data1[global_a_off + (192*K)], 16);
// unswizzled 128 x 64
__pipeline_memcpy_async(&smem_a_0[store_smem_a_off + ( 0)], &data1[global_a_off + ( 0)], 16);
__pipeline_memcpy_async(&smem_a_0[store_smem_a_off + ( 32*64)], &data1[global_a_off + ( 32*K)], 16);
__pipeline_memcpy_async(&smem_a_0[store_smem_a_off + ( 64*64)], &data1[global_a_off + ( 64*K)], 16);
__pipeline_memcpy_async(&smem_a_0[store_smem_a_off + ( 96*64)], &data1[global_a_off + ( 96*K)], 16);
__pipeline_memcpy_async(&smem_b_0[store_smem_b_off + ( 0)], &data2[global_b_off + ( 0)], 16);
__pipeline_memcpy_async(&smem_b_0[store_smem_b_off + (16*128)], &data2[global_b_off + ( 16*N)], 16);
__pipeline_commit();
global_a_off += 32;
global_b_off += 32 * N;
// load second tile
// unswizzled 256 x 32
// __pipeline_memcpy_async(&smem_a_1[store_smem_a_off + ( 0)], &data1[global_a_off + ( 0)], 16);
// __pipeline_memcpy_async(&smem_a_1[store_smem_a_off + ( 64*32)], &data1[global_a_off + ( 64*K)], 16);
// __pipeline_memcpy_async(&smem_a_1[store_smem_a_off + (128*32)], &data1[global_a_off + (128*K)], 16);
// __pipeline_memcpy_async(&smem_a_1[store_smem_a_off + (192*32)], &data1[global_a_off + (192*K)], 16);
// unswizzled 128 x 64
__pipeline_memcpy_async(&smem_a_1[store_smem_a_off + ( 0)], &data1[global_a_off + ( 0)], 16);
__pipeline_memcpy_async(&smem_a_1[store_smem_a_off + ( 32*64)], &data1[global_a_off + ( 32*K)], 16);
__pipeline_memcpy_async(&smem_a_1[store_smem_a_off + ( 64*64)], &data1[global_a_off + ( 64*K)], 16);
__pipeline_memcpy_async(&smem_a_1[store_smem_a_off + ( 96*64)], &data1[global_a_off + ( 96*K)], 16);
__pipeline_memcpy_async(&smem_b_1[store_smem_b_off + ( 0)], &data2[global_b_off + ( 0)], 16);
__pipeline_memcpy_async(&smem_b_1[store_smem_b_off + (16*128)], &data2[global_b_off + ( 16*N)], 16);
__pipeline_commit();
global_a_off += 32;
global_b_off += 32 * N;
// wait on first pre-fetch load
__pipeline_wait_prior(1);
__syncthreads();
// load K=0 for the first tile
__ldmatrix_a_elems(&a_frag_0_k_0, &smem_a_0[load_smem_a_0_k_0]);
__ldmatrix_a_elems(&a_frag_1_k_0, &smem_a_0[load_smem_a_1_k_0]);
__ldmatrix_a_elems(&a_frag_2_k_0, &smem_a_0[load_smem_a_2_k_0]);
__ldmatrix_a_elems(&a_frag_3_k_0, &smem_a_0[load_smem_a_3_k_0]);
__ldmatrix_b_elems(&b_frag_0_k_0, &b_frag_1_k_0, &smem_b_0[load_smem_b_0_k_0]);
__ldmatrix_b_elems(&b_frag_2_k_0, &b_frag_3_k_0, &smem_b_0[load_smem_b_1_k_0]);
__ldmatrix_b_elems(&b_frag_4_k_0, &b_frag_5_k_0, &smem_b_0[load_smem_b_2_k_0]);
__ldmatrix_b_elems(&b_frag_6_k_0, &b_frag_7_k_0, &smem_b_0[load_smem_b_3_k_0]);
for (int block_k = 0; block_k < num_k_blocks; block_k++) {
int phase_k = block_k % 3;
half *smem_a_curr = (phase_k == 0) ? smem_a_0 : ((phase_k == 1) ? smem_a_1 : smem_a_2);
half *smem_b_curr = (phase_k == 0) ? smem_b_0 : ((phase_k == 1) ? smem_b_1 : smem_b_2);
int next_phase_k = (block_k+1) % 3;
half *smem_a_next = (next_phase_k == 0) ? smem_a_0 : ((next_phase_k == 1) ? smem_a_1 : smem_a_2);
half *smem_b_next = (next_phase_k == 0) ? smem_b_0 : ((next_phase_k == 1) ? smem_b_1 : smem_b_2);
int store_phase_k = (block_k+2) % 3;
half *smem_a_store = (store_phase_k == 0) ? smem_a_0 : ((store_phase_k == 1) ? smem_a_1 : smem_a_2);
half *smem_b_store = (store_phase_k == 0) ? smem_b_0 : ((store_phase_k == 1) ? smem_b_1 : smem_b_2);
// load K=1 elements for the current tile
__ldmatrix_a_elems(&a_frag_0_k_1, &smem_a_curr[load_smem_a_0_k_1]);
__ldmatrix_a_elems(&a_frag_1_k_1, &smem_a_curr[load_smem_a_1_k_1]);
__ldmatrix_a_elems(&a_frag_2_k_1, &smem_a_curr[load_smem_a_2_k_1]);
__ldmatrix_a_elems(&a_frag_3_k_1, &smem_a_curr[load_smem_a_3_k_1]);
__ldmatrix_b_elems(&b_frag_0_k_1, &b_frag_1_k_1, &smem_b_curr[load_smem_b_0_k_1]);
__ldmatrix_b_elems(&b_frag_2_k_1, &b_frag_3_k_1, &smem_b_curr[load_smem_b_1_k_1]);
__ldmatrix_b_elems(&b_frag_4_k_1, &b_frag_5_k_1, &smem_b_curr[load_smem_b_2_k_1]);
__ldmatrix_b_elems(&b_frag_6_k_1, &b_frag_7_k_1, &smem_b_curr[load_smem_b_3_k_1]);
// MMA K=0, (M=4 x N=8)
acc_frag_0_0 = __WMMA_8_16_16_half_half(a_frag_0_k_0, b_frag_0_k_0, acc_frag_0_0);
acc_frag_0_1 = __WMMA_8_16_16_half_half(a_frag_0_k_0, b_frag_1_k_0, acc_frag_0_1);
acc_frag_0_2 = __WMMA_8_16_16_half_half(a_frag_0_k_0, b_frag_2_k_0, acc_frag_0_2);
acc_frag_0_3 = __WMMA_8_16_16_half_half(a_frag_0_k_0, b_frag_3_k_0, acc_frag_0_3);
acc_frag_0_4 = __WMMA_8_16_16_half_half(a_frag_0_k_0, b_frag_4_k_0, acc_frag_0_4);
acc_frag_0_5 = __WMMA_8_16_16_half_half(a_frag_0_k_0, b_frag_5_k_0, acc_frag_0_5);
acc_frag_0_6 = __WMMA_8_16_16_half_half(a_frag_0_k_0, b_frag_6_k_0, acc_frag_0_6);
acc_frag_0_7 = __WMMA_8_16_16_half_half(a_frag_0_k_0, b_frag_7_k_0, acc_frag_0_7);
acc_frag_1_0 = __WMMA_8_16_16_half_half(a_frag_1_k_0, b_frag_0_k_0, acc_frag_1_0);
acc_frag_1_1 = __WMMA_8_16_16_half_half(a_frag_1_k_0, b_frag_1_k_0, acc_frag_1_1);
acc_frag_1_2 = __WMMA_8_16_16_half_half(a_frag_1_k_0, b_frag_2_k_0, acc_frag_1_2);
acc_frag_1_3 = __WMMA_8_16_16_half_half(a_frag_1_k_0, b_frag_3_k_0, acc_frag_1_3);
acc_frag_1_4 = __WMMA_8_16_16_half_half(a_frag_1_k_0, b_frag_4_k_0, acc_frag_1_4);
acc_frag_1_5 = __WMMA_8_16_16_half_half(a_frag_1_k_0, b_frag_5_k_0, acc_frag_1_5);
acc_frag_1_6 = __WMMA_8_16_16_half_half(a_frag_1_k_0, b_frag_6_k_0, acc_frag_1_6);
acc_frag_1_7 = __WMMA_8_16_16_half_half(a_frag_1_k_0, b_frag_7_k_0, acc_frag_1_7);
acc_frag_2_0 = __WMMA_8_16_16_half_half(a_frag_2_k_0, b_frag_0_k_0, acc_frag_2_0);
acc_frag_2_1 = __WMMA_8_16_16_half_half(a_frag_2_k_0, b_frag_1_k_0, acc_frag_2_1);
acc_frag_2_2 = __WMMA_8_16_16_half_half(a_frag_2_k_0, b_frag_2_k_0, acc_frag_2_2);
acc_frag_2_3 = __WMMA_8_16_16_half_half(a_frag_2_k_0, b_frag_3_k_0, acc_frag_2_3);
acc_frag_2_4 = __WMMA_8_16_16_half_half(a_frag_2_k_0, b_frag_4_k_0, acc_frag_2_4);
acc_frag_2_5 = __WMMA_8_16_16_half_half(a_frag_2_k_0, b_frag_5_k_0, acc_frag_2_5);
acc_frag_2_6 = __WMMA_8_16_16_half_half(a_frag_2_k_0, b_frag_6_k_0, acc_frag_2_6);
acc_frag_2_7 = __WMMA_8_16_16_half_half(a_frag_2_k_0, b_frag_7_k_0, acc_frag_2_7);
acc_frag_3_0 = __WMMA_8_16_16_half_half(a_frag_3_k_0, b_frag_0_k_0, acc_frag_3_0);
acc_frag_3_1 = __WMMA_8_16_16_half_half(a_frag_3_k_0, b_frag_1_k_0, acc_frag_3_1);
acc_frag_3_2 = __WMMA_8_16_16_half_half(a_frag_3_k_0, b_frag_2_k_0, acc_frag_3_2);
acc_frag_3_3 = __WMMA_8_16_16_half_half(a_frag_3_k_0, b_frag_3_k_0, acc_frag_3_3);
acc_frag_3_4 = __WMMA_8_16_16_half_half(a_frag_3_k_0, b_frag_4_k_0, acc_frag_3_4);
acc_frag_3_5 = __WMMA_8_16_16_half_half(a_frag_3_k_0, b_frag_5_k_0, acc_frag_3_5);
acc_frag_3_6 = __WMMA_8_16_16_half_half(a_frag_3_k_0, b_frag_6_k_0, acc_frag_3_6);
acc_frag_3_7 = __WMMA_8_16_16_half_half(a_frag_3_k_0, b_frag_7_k_0, acc_frag_3_7);
// load next tile
if (block_k < (num_k_blocks-2)) {
// unswizzled 256 x 32
// __pipeline_memcpy_async(&smem_a_store[store_smem_a_off + ( 0)], &data1[global_a_off + ( 0)], 16);
// __pipeline_memcpy_async(&smem_a_store[store_smem_a_off + ( 64*32)], &data1[global_a_off + ( 64*K)], 16);
// __pipeline_memcpy_async(&smem_a_store[store_smem_a_off + (128*32)], &data1[global_a_off + (128*K)], 16);
// __pipeline_memcpy_async(&smem_a_store[store_smem_a_off + (192*32)], &data1[global_a_off + (192*K)], 16);
// unswizzled 128 x 64
__pipeline_memcpy_async(&smem_a_store[store_smem_a_off + ( 0)], &data1[global_a_off + ( 0)], 16);
__pipeline_memcpy_async(&smem_a_store[store_smem_a_off + ( 32*64)], &data1[global_a_off + ( 32*K)], 16);
__pipeline_memcpy_async(&smem_a_store[store_smem_a_off + ( 64*64)], &data1[global_a_off + ( 64*K)], 16);
__pipeline_memcpy_async(&smem_a_store[store_smem_a_off + ( 96*64)], &data1[global_a_off + ( 96*K)], 16);
__pipeline_memcpy_async(&smem_b_store[store_smem_b_off + ( 0)], &data2[global_b_off + ( 0)], 16);
__pipeline_memcpy_async(&smem_b_store[store_smem_b_off + (16*128)], &data2[global_b_off + ( 16*N)], 16);
global_a_off += 32;
global_b_off += 32 * N;
}
__pipeline_commit();
// wait next tile
__pipeline_wait_prior(1);
__syncthreads();
// load K=0 for the next tile
__ldmatrix_a_elems(&a_frag_0_k_0, &smem_a_next[load_smem_a_0_k_0]);
__ldmatrix_a_elems(&a_frag_1_k_0, &smem_a_next[load_smem_a_1_k_0]);
__ldmatrix_a_elems(&a_frag_2_k_0, &smem_a_next[load_smem_a_2_k_0]);
__ldmatrix_a_elems(&a_frag_3_k_0, &smem_a_next[load_smem_a_3_k_0]);
__ldmatrix_b_elems(&b_frag_0_k_0, &b_frag_1_k_0, &smem_b_next[load_smem_b_0_k_0]);
__ldmatrix_b_elems(&b_frag_2_k_0, &b_frag_3_k_0, &smem_b_next[load_smem_b_1_k_0]);
__ldmatrix_b_elems(&b_frag_4_k_0, &b_frag_5_k_0, &smem_b_next[load_smem_b_2_k_0]);
__ldmatrix_b_elems(&b_frag_6_k_0, &b_frag_7_k_0, &smem_b_next[load_smem_b_3_k_0]);
// MMA K=1, (M=4 x N=8)
acc_frag_0_0 = __WMMA_8_16_16_half_half(a_frag_0_k_1, b_frag_0_k_1, acc_frag_0_0);
acc_frag_0_1 = __WMMA_8_16_16_half_half(a_frag_0_k_1, b_frag_1_k_1, acc_frag_0_1);
acc_frag_0_2 = __WMMA_8_16_16_half_half(a_frag_0_k_1, b_frag_2_k_1, acc_frag_0_2);
acc_frag_0_3 = __WMMA_8_16_16_half_half(a_frag_0_k_1, b_frag_3_k_1, acc_frag_0_3);
acc_frag_0_4 = __WMMA_8_16_16_half_half(a_frag_0_k_1, b_frag_4_k_1, acc_frag_0_4);
acc_frag_0_5 = __WMMA_8_16_16_half_half(a_frag_0_k_1, b_frag_5_k_1, acc_frag_0_5);
acc_frag_0_6 = __WMMA_8_16_16_half_half(a_frag_0_k_1, b_frag_6_k_1, acc_frag_0_6);
acc_frag_0_7 = __WMMA_8_16_16_half_half(a_frag_0_k_1, b_frag_7_k_1, acc_frag_0_7);
acc_frag_1_0 = __WMMA_8_16_16_half_half(a_frag_1_k_1, b_frag_0_k_1, acc_frag_1_0);
acc_frag_1_1 = __WMMA_8_16_16_half_half(a_frag_1_k_1, b_frag_1_k_1, acc_frag_1_1);
acc_frag_1_2 = __WMMA_8_16_16_half_half(a_frag_1_k_1, b_frag_2_k_1, acc_frag_1_2);
acc_frag_1_3 = __WMMA_8_16_16_half_half(a_frag_1_k_1, b_frag_3_k_1, acc_frag_1_3);
acc_frag_1_4 = __WMMA_8_16_16_half_half(a_frag_1_k_1, b_frag_4_k_1, acc_frag_1_4);
acc_frag_1_5 = __WMMA_8_16_16_half_half(a_frag_1_k_1, b_frag_5_k_1, acc_frag_1_5);
acc_frag_1_6 = __WMMA_8_16_16_half_half(a_frag_1_k_1, b_frag_6_k_1, acc_frag_1_6);
acc_frag_1_7 = __WMMA_8_16_16_half_half(a_frag_1_k_1, b_frag_7_k_1, acc_frag_1_7);
acc_frag_2_0 = __WMMA_8_16_16_half_half(a_frag_2_k_1, b_frag_0_k_1, acc_frag_2_0);
acc_frag_2_1 = __WMMA_8_16_16_half_half(a_frag_2_k_1, b_frag_1_k_1, acc_frag_2_1);
acc_frag_2_2 = __WMMA_8_16_16_half_half(a_frag_2_k_1, b_frag_2_k_1, acc_frag_2_2);
acc_frag_2_3 = __WMMA_8_16_16_half_half(a_frag_2_k_1, b_frag_3_k_1, acc_frag_2_3);
acc_frag_2_4 = __WMMA_8_16_16_half_half(a_frag_2_k_1, b_frag_4_k_1, acc_frag_2_4);
acc_frag_2_5 = __WMMA_8_16_16_half_half(a_frag_2_k_1, b_frag_5_k_1, acc_frag_2_5);
acc_frag_2_6 = __WMMA_8_16_16_half_half(a_frag_2_k_1, b_frag_6_k_1, acc_frag_2_6);
acc_frag_2_7 = __WMMA_8_16_16_half_half(a_frag_2_k_1, b_frag_7_k_1, acc_frag_2_7);
acc_frag_3_0 = __WMMA_8_16_16_half_half(a_frag_3_k_1, b_frag_0_k_1, acc_frag_3_0);
acc_frag_3_1 = __WMMA_8_16_16_half_half(a_frag_3_k_1, b_frag_1_k_1, acc_frag_3_1);
acc_frag_3_2 = __WMMA_8_16_16_half_half(a_frag_3_k_1, b_frag_2_k_1, acc_frag_3_2);
acc_frag_3_3 = __WMMA_8_16_16_half_half(a_frag_3_k_1, b_frag_3_k_1, acc_frag_3_3);
acc_frag_3_4 = __WMMA_8_16_16_half_half(a_frag_3_k_1, b_frag_4_k_1, acc_frag_3_4);
acc_frag_3_5 = __WMMA_8_16_16_half_half(a_frag_3_k_1, b_frag_5_k_1, acc_frag_3_5);
acc_frag_3_6 = __WMMA_8_16_16_half_half(a_frag_3_k_1, b_frag_6_k_1, acc_frag_3_6);
acc_frag_3_7 = __WMMA_8_16_16_half_half(a_frag_3_k_1, b_frag_7_k_1, acc_frag_3_7);
}
// write accumulators to output
__pipeline_wait_prior(0);
__syncthreads();
// slower way: write accs one by one to data0
size_t wg_c_off = ((grid_m * 256) * N) + (grid_n * 128) + (wg_m * 16 * N) + (wg_n * 16);
size_t thread_c_off = ((wg_threads % 4) * 2) + (((wg_threads / 4) % 8) * N);
data0[wg_c_off + thread_c_off + 0 + ( 0*8)] = acc_frag_0_0.x;
data0[wg_c_off + thread_c_off + 1 + ( 0*8)] = acc_frag_0_0.y;
data0[wg_c_off + thread_c_off + (8 * N) + 0 + ( 0*8)] = acc_frag_0_0.z;
data0[wg_c_off + thread_c_off + (8 * N) + 1 + ( 0*8)] = acc_frag_0_0.w;
data0[wg_c_off + thread_c_off + 0 + ( 1*8)] = acc_frag_0_1.x;
data0[wg_c_off + thread_c_off + 1 + ( 1*8)] = acc_frag_0_1.y;
data0[wg_c_off + thread_c_off + (8 * N) + 0 + ( 1*8)] = acc_frag_0_1.z;
data0[wg_c_off + thread_c_off + (8 * N) + 1 + ( 1*8)] = acc_frag_0_1.w;
data0[wg_c_off + thread_c_off + 0 + ( 4*8)] = acc_frag_0_2.x;
data0[wg_c_off + thread_c_off + 1 + ( 4*8)] = acc_frag_0_2.y;
data0[wg_c_off + thread_c_off + (8 * N) + 0 + ( 4*8)] = acc_frag_0_2.z;
data0[wg_c_off + thread_c_off + (8 * N) + 1 + ( 4*8)] = acc_frag_0_2.w;
data0[wg_c_off + thread_c_off + 0 + ( 5*8)] = acc_frag_0_3.x;
data0[wg_c_off + thread_c_off + 1 + ( 5*8)] = acc_frag_0_3.y;
data0[wg_c_off + thread_c_off + (8 * N) + 0 + ( 5*8)] = acc_frag_0_3.z;
data0[wg_c_off + thread_c_off + (8 * N) + 1 + ( 5*8)] = acc_frag_0_3.w;
data0[wg_c_off + thread_c_off + 0 + ( 8*8)] = acc_frag_0_4.x;
data0[wg_c_off + thread_c_off + 1 + ( 8*8)] = acc_frag_0_4.y;
data0[wg_c_off + thread_c_off + (8 * N) + 0 + ( 8*8)] = acc_frag_0_4.z;
data0[wg_c_off + thread_c_off + (8 * N) + 1 + ( 8*8)] = acc_frag_0_4.w;
data0[wg_c_off + thread_c_off + 0 + ( 9*8)] = acc_frag_0_5.x;
data0[wg_c_off + thread_c_off + 1 + ( 9*8)] = acc_frag_0_5.y;
data0[wg_c_off + thread_c_off + (8 * N) + 0 + ( 9*8)] = acc_frag_0_5.z;
data0[wg_c_off + thread_c_off + (8 * N) + 1 + ( 9*8)] = acc_frag_0_5.w;
data0[wg_c_off + thread_c_off + 0 + (12*8)] = acc_frag_0_6.x;
data0[wg_c_off + thread_c_off + 1 + (12*8)] = acc_frag_0_6.y;
data0[wg_c_off + thread_c_off + (8 * N) + 0 + (12*8)] = acc_frag_0_6.z;
data0[wg_c_off + thread_c_off + (8 * N) + 1 + (12*8)] = acc_frag_0_6.w;
data0[wg_c_off + thread_c_off + 0 + (13*8)] = acc_frag_0_7.x;
data0[wg_c_off + thread_c_off + 1 + (13*8)] = acc_frag_0_7.y;
data0[wg_c_off + thread_c_off + (8 * N) + 0 + (13*8)] = acc_frag_0_7.z;
data0[wg_c_off + thread_c_off + (8 * N) + 1 + (13*8)] = acc_frag_0_7.w;
wg_c_off += 64*N;
data0[wg_c_off + thread_c_off + 0 + ( 0*8)] = acc_frag_1_0.x;
data0[wg_c_off + thread_c_off + 1 + ( 0*8)] = acc_frag_1_0.y;
data0[wg_c_off + thread_c_off + (8 * N) + 0 + ( 0*8)] = acc_frag_1_0.z;
data0[wg_c_off + thread_c_off + (8 * N) + 1 + ( 0*8)] = acc_frag_1_0.w;
data0[wg_c_off + thread_c_off + 0 + ( 1*8)] = acc_frag_1_1.x;
data0[wg_c_off + thread_c_off + 1 + ( 1*8)] = acc_frag_1_1.y;
data0[wg_c_off + thread_c_off + (8 * N) + 0 + ( 1*8)] = acc_frag_1_1.z;
data0[wg_c_off + thread_c_off + (8 * N) + 1 + ( 1*8)] = acc_frag_1_1.w;
data0[wg_c_off + thread_c_off + 0 + ( 4*8)] = acc_frag_1_2.x;
data0[wg_c_off + thread_c_off + 1 + ( 4*8)] = acc_frag_1_2.y;
data0[wg_c_off + thread_c_off + (8 * N) + 0 + ( 4*8)] = acc_frag_1_2.z;
data0[wg_c_off + thread_c_off + (8 * N) + 1 + ( 4*8)] = acc_frag_1_2.w;
data0[wg_c_off + thread_c_off + 0 + ( 5*8)] = acc_frag_1_3.x;
data0[wg_c_off + thread_c_off + 1 + ( 5*8)] = acc_frag_1_3.y;
data0[wg_c_off + thread_c_off + (8 * N) + 0 + ( 5*8)] = acc_frag_1_3.z;
data0[wg_c_off + thread_c_off + (8 * N) + 1 + ( 5*8)] = acc_frag_1_3.w;
data0[wg_c_off + thread_c_off + 0 + ( 8*8)] = acc_frag_1_4.x;
data0[wg_c_off + thread_c_off + 1 + ( 8*8)] = acc_frag_1_4.y;
data0[wg_c_off + thread_c_off + (8 * N) + 0 + ( 8*8)] = acc_frag_1_4.z;
data0[wg_c_off + thread_c_off + (8 * N) + 1 + ( 8*8)] = acc_frag_1_4.w;
data0[wg_c_off + thread_c_off + 0 + ( 9*8)] = acc_frag_1_5.x;
data0[wg_c_off + thread_c_off + 1 + ( 9*8)] = acc_frag_1_5.y;
data0[wg_c_off + thread_c_off + (8 * N) + 0 + ( 9*8)] = acc_frag_1_5.z;
data0[wg_c_off + thread_c_off + (8 * N) + 1 + ( 9*8)] = acc_frag_1_5.w;
data0[wg_c_off + thread_c_off + 0 + (12*8)] = acc_frag_1_6.x;
data0[wg_c_off + thread_c_off + 1 + (12*8)] = acc_frag_1_6.y;
data0[wg_c_off + thread_c_off + (8 * N) + 0 + (12*8)] = acc_frag_1_6.z;
data0[wg_c_off + thread_c_off + (8 * N) + 1 + (12*8)] = acc_frag_1_6.w;
data0[wg_c_off + thread_c_off + 0 + (13*8)] = acc_frag_1_7.x;
data0[wg_c_off + thread_c_off + 1 + (13*8)] = acc_frag_1_7.y;
data0[wg_c_off + thread_c_off + (8 * N) + 0 + (13*8)] = acc_frag_1_7.z;
data0[wg_c_off + thread_c_off + (8 * N) + 1 + (13*8)] = acc_frag_1_7.w;
wg_c_off += 64*N;
data0[wg_c_off + thread_c_off + 0 + ( 0*8)] = acc_frag_2_0.x;
data0[wg_c_off + thread_c_off + 1 + ( 0*8)] = acc_frag_2_0.y;
data0[wg_c_off + thread_c_off + (8 * N) + 0 + ( 0*8)] = acc_frag_2_0.z;
data0[wg_c_off + thread_c_off + (8 * N) + 1 + ( 0*8)] = acc_frag_2_0.w;
data0[wg_c_off + thread_c_off + 0 + ( 1*8)] = acc_frag_2_1.x;
data0[wg_c_off + thread_c_off + 1 + ( 1*8)] = acc_frag_2_1.y;
data0[wg_c_off + thread_c_off + (8 * N) + 0 + ( 1*8)] = acc_frag_2_1.z;
data0[wg_c_off + thread_c_off + (8 * N) + 1 + ( 1*8)] = acc_frag_2_1.w;
data0[wg_c_off + thread_c_off + 0 + ( 4*8)] = acc_frag_2_2.x;
data0[wg_c_off + thread_c_off + 1 + ( 4*8)] = acc_frag_2_2.y;
data0[wg_c_off + thread_c_off + (8 * N) + 0 + ( 4*8)] = acc_frag_2_2.z;
data0[wg_c_off + thread_c_off + (8 * N) + 1 + ( 4*8)] = acc_frag_2_2.w;
data0[wg_c_off + thread_c_off + 0 + ( 5*8)] = acc_frag_2_3.x;
data0[wg_c_off + thread_c_off + 1 + ( 5*8)] = acc_frag_2_3.y;
data0[wg_c_off + thread_c_off + (8 * N) + 0 + ( 5*8)] = acc_frag_2_3.z;
data0[wg_c_off + thread_c_off + (8 * N) + 1 + ( 5*8)] = acc_frag_2_3.w;
data0[wg_c_off + thread_c_off + 0 + ( 8*8)] = acc_frag_2_4.x;
data0[wg_c_off + thread_c_off + 1 + ( 8*8)] = acc_frag_2_4.y;
data0[wg_c_off + thread_c_off + (8 * N) + 0 + ( 8*8)] = acc_frag_2_4.z;
data0[wg_c_off + thread_c_off + (8 * N) + 1 + ( 8*8)] = acc_frag_2_4.w;
data0[wg_c_off + thread_c_off + 0 + ( 9*8)] = acc_frag_2_5.x;
data0[wg_c_off + thread_c_off + 1 + ( 9*8)] = acc_frag_2_5.y;
data0[wg_c_off + thread_c_off + (8 * N) + 0 + ( 9*8)] = acc_frag_2_5.z;
data0[wg_c_off + thread_c_off + (8 * N) + 1 + ( 9*8)] = acc_frag_2_5.w;
data0[wg_c_off + thread_c_off + 0 + (12*8)] = acc_frag_2_6.x;
data0[wg_c_off + thread_c_off + 1 + (12*8)] = acc_frag_2_6.y;
data0[wg_c_off + thread_c_off + (8 * N) + 0 + (12*8)] = acc_frag_2_6.z;
data0[wg_c_off + thread_c_off + (8 * N) + 1 + (12*8)] = acc_frag_2_6.w;
data0[wg_c_off + thread_c_off + 0 + (13*8)] = acc_frag_2_7.x;
data0[wg_c_off + thread_c_off + 1 + (13*8)] = acc_frag_2_7.y;
data0[wg_c_off + thread_c_off + (8 * N) + 0 + (13*8)] = acc_frag_2_7.z;
data0[wg_c_off + thread_c_off + (8 * N) + 1 + (13*8)] = acc_frag_2_7.w;
wg_c_off += 64*N;
data0[wg_c_off + thread_c_off + 0 + ( 0*8)] = acc_frag_3_0.x;
data0[wg_c_off + thread_c_off + 1 + ( 0*8)] = acc_frag_3_0.y;
data0[wg_c_off + thread_c_off + (8 * N) + 0 + ( 0*8)] = acc_frag_3_0.z;
data0[wg_c_off + thread_c_off + (8 * N) + 1 + ( 0*8)] = acc_frag_3_0.w;
data0[wg_c_off + thread_c_off + 0 + ( 1*8)] = acc_frag_3_1.x;
data0[wg_c_off + thread_c_off + 1 + ( 1*8)] = acc_frag_3_1.y;
data0[wg_c_off + thread_c_off + (8 * N) + 0 + ( 1*8)] = acc_frag_3_1.z;
data0[wg_c_off + thread_c_off + (8 * N) + 1 + ( 1*8)] = acc_frag_3_1.w;
data0[wg_c_off + thread_c_off + 0 + ( 4*8)] = acc_frag_3_2.x;
data0[wg_c_off + thread_c_off + 1 + ( 4*8)] = acc_frag_3_2.y;
data0[wg_c_off + thread_c_off + (8 * N) + 0 + ( 4*8)] = acc_frag_3_2.z;
data0[wg_c_off + thread_c_off + (8 * N) + 1 + ( 4*8)] = acc_frag_3_2.w;
data0[wg_c_off + thread_c_off + 0 + ( 5*8)] = acc_frag_3_3.x;
data0[wg_c_off + thread_c_off + 1 + ( 5*8)] = acc_frag_3_3.y;
data0[wg_c_off + thread_c_off + (8 * N) + 0 + ( 5*8)] = acc_frag_3_3.z;
data0[wg_c_off + thread_c_off + (8 * N) + 1 + ( 5*8)] = acc_frag_3_3.w;
data0[wg_c_off + thread_c_off + 0 + ( 8*8)] = acc_frag_3_4.x;
data0[wg_c_off + thread_c_off + 1 + ( 8*8)] = acc_frag_3_4.y;
data0[wg_c_off + thread_c_off + (8 * N) + 0 + ( 8*8)] = acc_frag_3_4.z;
data0[wg_c_off + thread_c_off + (8 * N) + 1 + ( 8*8)] = acc_frag_3_4.w;
data0[wg_c_off + thread_c_off + 0 + ( 9*8)] = acc_frag_3_5.x;
data0[wg_c_off + thread_c_off + 1 + ( 9*8)] = acc_frag_3_5.y;
data0[wg_c_off + thread_c_off + (8 * N) + 0 + ( 9*8)] = acc_frag_3_5.z;
data0[wg_c_off + thread_c_off + (8 * N) + 1 + ( 9*8)] = acc_frag_3_5.w;
data0[wg_c_off + thread_c_off + 0 + (12*8)] = acc_frag_3_6.x;
data0[wg_c_off + thread_c_off + 1 + (12*8)] = acc_frag_3_6.y;
data0[wg_c_off + thread_c_off + (8 * N) + 0 + (12*8)] = acc_frag_3_6.z;
data0[wg_c_off + thread_c_off + (8 * N) + 1 + (12*8)] = acc_frag_3_6.w;
data0[wg_c_off + thread_c_off + 0 + (13*8)] = acc_frag_3_7.x;
data0[wg_c_off + thread_c_off + 1 + (13*8)] = acc_frag_3_7.y;
data0[wg_c_off + thread_c_off + (8 * N) + 0 + (13*8)] = acc_frag_3_7.z;
data0[wg_c_off + thread_c_off + (8 * N) + 1 + (13*8)] = acc_frag_3_7.w;
}

View File

@@ -0,0 +1,482 @@
#define INFINITY (__int_as_float(0x7f800000))
#define NAN (__int_as_float(0x7fffffff))
#include <cuda_fp16.h>
#include <cuda_pipeline.h>
#define SMEM_N_WIDTH 136
struct __align__(8) half4 { half x, y, z, w; };
__device__ half4 make_half4(half x, half y, half z, half w) { half4 r={x, y, z, w}; return r; }
struct __align__(16) half8 { half x, y, z, w, a, b, c, d; };
__device__ half8 make_half8(half x, half y, half z, half w, half a, half b, half c, half d) { half8 r={x, y, z, w, a, b, c, d}; return r; }
__device__ void __ldmatrix_a_elems(half8 *regs, half *smem) {
uint32_t reg0, reg1, reg2, reg3;
asm volatile(
"ldmatrix.sync.aligned.m8n8.x4.shared.b16 {%0, %1, %2, %3}, [%4];"
: "=r"(reg0), "=r"(reg1), "=r"(reg2), "=r"(reg3)
: "l"(__cvta_generic_to_shared(smem))
);
uint32_t *addr = reinterpret_cast<uint32_t*>(regs);
addr[0] = reg0;
addr[1] = reg1;
addr[2] = reg2;
addr[3] = reg3;
}
__device__ void __ldmatrix_b_elems(half4 *regs_lo, half4 *regs_hi, half *smem) {
uint32_t reg0, reg1, reg2, reg3;
asm volatile(
"ldmatrix.sync.aligned.m8n8.x4.trans.shared.b16 {%0, %1, %2, %3}, [%4];"
: "=r"(reg0), "=r"(reg1), "=r"(reg2), "=r"(reg3)
: "l"(__cvta_generic_to_shared(smem))
);
uint32_t *addr_lo = reinterpret_cast<uint32_t*>(regs_lo);
uint32_t *addr_hi = reinterpret_cast<uint32_t*>(regs_hi);
addr_lo[0] = reg0;
addr_lo[1] = reg1;
addr_hi[0] = reg2;
addr_hi[1] = reg3;
}
__device__ half4 __WMMA_8_16_16_half_half(half8 a, half4 b, half4 c) {
int *a_pk = (int *) (&a), *b_pk = (int *) (&b), *c_pk = (int *) (&c);
asm( "mma.sync.aligned.m16n8k16.row.col.f16.f16.f16.f16 { %0, %1 }, { %2, %3, %4, %5 }, { %6, %7 }, { %0, %1 };"
: "+r"(c_pk[0]), "+r"(c_pk[1]): "r"(a_pk[0]), "r"(a_pk[1]), "r"(a_pk[2]), "r"(a_pk[3]), "r"(b_pk[0]), "r"(b_pk[1]) );
return c;
}
extern "C" __global__ void __launch_bounds__(256) wmma_example(half* data0, const half* data1, const half* data2, int N, int K) {
extern __shared__ char smem[];
half *smem_a_0 = (half *)(smem);
half *smem_a_1 = (half *)(smem + 16384);
half *smem_a_2 = (half *)(smem + 32768);
half *smem_b_0 = (half *)(smem + 49152);
half *smem_b_1 = (half *)(smem + 57344);
half *smem_b_2 = (half *)(smem + 65536);
int grid_m = blockIdx.x; /* M//256 */
int grid_n = blockIdx.y; /* N//128 */
int wg_threads = threadIdx.x; // 32
int wg_m = threadIdx.y; // 4
int wg_n = threadIdx.z; // 2
int threads = threadIdx.x + (threadIdx.y * 32) + (threadIdx.z * 128); /* 256 */
int num_k_blocks = K / 32;
// ldmatrix indices - 4x loads of 8x8 matrices by 32 threads
// threads 0-7 are row starts for A, 8-15 for B, 16-23 for C, 24-31 for D
// [ A | C ]
// [ - + - ]
// [ B | D ]
// swizzled A - SMEM_A is 128 rows x 64 cols
size_t global_a_off = ((grid_m * 256) * K) + ((threads % 4) * 8) + (((threads / 4) % 2) * 8 * 16 * K) + ((threads / 8) * K);
size_t store_smem_a_off = ((threads / 8) * 64) + (((threads * 8) ^ threads) & 56); // 32 rows / 64 cols per copy
size_t load_smem_a_row = ((wg_m * 16) + (threads % 16)) * 64;
size_t load_smem_a_phase = (threads / 16) % 2;
size_t load_smem_a_0_k_0 = load_smem_a_row + ( 0 * 64) + (((load_smem_a_phase + 0) ^ (threads % 8)) * 8);
size_t load_smem_a_1_k_0 = load_smem_a_row + (64 * 64) + (((load_smem_a_phase + 0) ^ (threads % 8)) * 8);
size_t load_smem_a_2_k_0 = load_smem_a_row + ( 0 * 64) + (((load_smem_a_phase + 4) ^ (threads % 8)) * 8);
size_t load_smem_a_3_k_0 = load_smem_a_row + (64 * 64) + (((load_smem_a_phase + 4) ^ (threads % 8)) * 8);
size_t load_smem_a_0_k_1 = load_smem_a_row + ( 0 * 64) + (((load_smem_a_phase + 2) ^ (threads % 8)) * 8);
size_t load_smem_a_1_k_1 = load_smem_a_row + (64 * 64) + (((load_smem_a_phase + 2) ^ (threads % 8)) * 8);
size_t load_smem_a_2_k_1 = load_smem_a_row + ( 0 * 64) + (((load_smem_a_phase + 6) ^ (threads % 8)) * 8);
size_t load_smem_a_3_k_1 = load_smem_a_row + (64 * 64) + (((load_smem_a_phase + 6) ^ (threads % 8)) * 8);
// swizzled B - SMEM_B is 32 rows x 128 cols
size_t global_b_off = (grid_n * 128) + ((threads % 16) * 8) + ((threads / 16) * N);
size_t store_smem_b_off = ((threads / 16) * 128) + ((((threads / 16) % 8) * 8) ^ ((threads % 16) * 8)); // 16 rows / 128 cols per copy
size_t load_smem_b_row = (threads % 16) * 128;
size_t load_smem_b_phase = (wg_n * 2) + (wg_threads / 16);
size_t load_smem_b_0_k_0 = load_smem_b_row + (((load_smem_b_phase + 0) ^ (threads % 8)) * 8);
size_t load_smem_b_1_k_0 = load_smem_b_row + (((load_smem_b_phase + 4) ^ (threads % 8)) * 8);
size_t load_smem_b_2_k_0 = load_smem_b_row + (((load_smem_b_phase + 8) ^ (threads % 8)) * 8);
size_t load_smem_b_3_k_0 = load_smem_b_row + (((load_smem_b_phase + 12) ^ (threads % 8)) * 8);
size_t load_smem_b_0_k_1 = load_smem_b_0_k_0 + (16 * 128);
size_t load_smem_b_1_k_1 = load_smem_b_1_k_0 + (16 * 128);
size_t load_smem_b_2_k_1 = load_smem_b_2_k_0 + (16 * 128);
size_t load_smem_b_3_k_1 = load_smem_b_3_k_0 + (16 * 128);
// create accs (M=4, N=8)
half4 acc_frag_0_0 = make_half4(0.0f,0.0f,0.0f,0.0f);
half4 acc_frag_0_1 = make_half4(0.0f,0.0f,0.0f,0.0f);
half4 acc_frag_0_2 = make_half4(0.0f,0.0f,0.0f,0.0f);
half4 acc_frag_0_3 = make_half4(0.0f,0.0f,0.0f,0.0f);
half4 acc_frag_0_4 = make_half4(0.0f,0.0f,0.0f,0.0f);
half4 acc_frag_0_5 = make_half4(0.0f,0.0f,0.0f,0.0f);
half4 acc_frag_0_6 = make_half4(0.0f,0.0f,0.0f,0.0f);
half4 acc_frag_0_7 = make_half4(0.0f,0.0f,0.0f,0.0f);
half4 acc_frag_1_0 = make_half4(0.0f,0.0f,0.0f,0.0f);
half4 acc_frag_1_1 = make_half4(0.0f,0.0f,0.0f,0.0f);
half4 acc_frag_1_2 = make_half4(0.0f,0.0f,0.0f,0.0f);
half4 acc_frag_1_3 = make_half4(0.0f,0.0f,0.0f,0.0f);
half4 acc_frag_1_4 = make_half4(0.0f,0.0f,0.0f,0.0f);
half4 acc_frag_1_5 = make_half4(0.0f,0.0f,0.0f,0.0f);
half4 acc_frag_1_6 = make_half4(0.0f,0.0f,0.0f,0.0f);
half4 acc_frag_1_7 = make_half4(0.0f,0.0f,0.0f,0.0f);
half4 acc_frag_2_0 = make_half4(0.0f,0.0f,0.0f,0.0f);
half4 acc_frag_2_1 = make_half4(0.0f,0.0f,0.0f,0.0f);
half4 acc_frag_2_2 = make_half4(0.0f,0.0f,0.0f,0.0f);
half4 acc_frag_2_3 = make_half4(0.0f,0.0f,0.0f,0.0f);
half4 acc_frag_2_4 = make_half4(0.0f,0.0f,0.0f,0.0f);
half4 acc_frag_2_5 = make_half4(0.0f,0.0f,0.0f,0.0f);
half4 acc_frag_2_6 = make_half4(0.0f,0.0f,0.0f,0.0f);
half4 acc_frag_2_7 = make_half4(0.0f,0.0f,0.0f,0.0f);
half4 acc_frag_3_0 = make_half4(0.0f,0.0f,0.0f,0.0f);
half4 acc_frag_3_1 = make_half4(0.0f,0.0f,0.0f,0.0f);
half4 acc_frag_3_2 = make_half4(0.0f,0.0f,0.0f,0.0f);
half4 acc_frag_3_3 = make_half4(0.0f,0.0f,0.0f,0.0f);
half4 acc_frag_3_4 = make_half4(0.0f,0.0f,0.0f,0.0f);
half4 acc_frag_3_5 = make_half4(0.0f,0.0f,0.0f,0.0f);
half4 acc_frag_3_6 = make_half4(0.0f,0.0f,0.0f,0.0f);
half4 acc_frag_3_7 = make_half4(0.0f,0.0f,0.0f,0.0f);
// create registers for block A elements
half8 a_frag_0_k_0;
half8 a_frag_1_k_0;
half8 a_frag_2_k_0;
half8 a_frag_3_k_0;
half8 a_frag_0_k_1;
half8 a_frag_1_k_1;
half8 a_frag_2_k_1;
half8 a_frag_3_k_1;
// create register for block B elements
half4 b_frag_0_k_0;
half4 b_frag_1_k_0;
half4 b_frag_2_k_0;
half4 b_frag_3_k_0;
half4 b_frag_4_k_0;
half4 b_frag_5_k_0;
half4 b_frag_6_k_0;
half4 b_frag_7_k_0;
half4 b_frag_0_k_1;
half4 b_frag_1_k_1;
half4 b_frag_2_k_1;
half4 b_frag_3_k_1;
half4 b_frag_4_k_1;
half4 b_frag_5_k_1;
half4 b_frag_6_k_1;
half4 b_frag_7_k_1;
__syncthreads();
// load first tile
__pipeline_memcpy_async(&smem_a_0[store_smem_a_off + ( 0)], &data1[global_a_off + ( 0)], 16);
__pipeline_memcpy_async(&smem_a_0[store_smem_a_off + ( 32*64)], &data1[global_a_off + ( 32*K)], 16);
__pipeline_memcpy_async(&smem_a_0[store_smem_a_off + ( 64*64)], &data1[global_a_off + ( 64*K)], 16);
__pipeline_memcpy_async(&smem_a_0[store_smem_a_off + ( 96*64)], &data1[global_a_off + ( 96*K)], 16);
__pipeline_memcpy_async(&smem_b_0[store_smem_b_off + ( 0)], &data2[global_b_off + ( 0)], 16);
__pipeline_memcpy_async(&smem_b_0[store_smem_b_off + (16*128)], &data2[global_b_off + ( 16*N)], 16);
__pipeline_commit();
global_a_off += 32;
global_b_off += 32 * N;
// load second tile
__pipeline_memcpy_async(&smem_a_1[store_smem_a_off + ( 0)], &data1[global_a_off + ( 0)], 16);
__pipeline_memcpy_async(&smem_a_1[store_smem_a_off + ( 32*64)], &data1[global_a_off + ( 32*K)], 16);
__pipeline_memcpy_async(&smem_a_1[store_smem_a_off + ( 64*64)], &data1[global_a_off + ( 64*K)], 16);
__pipeline_memcpy_async(&smem_a_1[store_smem_a_off + ( 96*64)], &data1[global_a_off + ( 96*K)], 16);
__pipeline_memcpy_async(&smem_b_1[store_smem_b_off + ( 0)], &data2[global_b_off + ( 0)], 16);
__pipeline_memcpy_async(&smem_b_1[store_smem_b_off + (16*128)], &data2[global_b_off + ( 16*N)], 16);
__pipeline_commit();
global_a_off += 32;
global_b_off += 32 * N;
// wait on first pre-fetch load
__pipeline_wait_prior(1);
__syncthreads();
// load K=0 elements for the first tile
__ldmatrix_a_elems(&a_frag_0_k_0, &smem_a_0[load_smem_a_0_k_0]);
__ldmatrix_a_elems(&a_frag_1_k_0, &smem_a_0[load_smem_a_1_k_0]);
__ldmatrix_a_elems(&a_frag_2_k_0, &smem_a_0[load_smem_a_2_k_0]);
__ldmatrix_a_elems(&a_frag_3_k_0, &smem_a_0[load_smem_a_3_k_0]);
__ldmatrix_b_elems(&b_frag_0_k_0, &b_frag_1_k_0, &smem_b_0[load_smem_b_0_k_0]);
__ldmatrix_b_elems(&b_frag_2_k_0, &b_frag_3_k_0, &smem_b_0[load_smem_b_1_k_0]);
__ldmatrix_b_elems(&b_frag_4_k_0, &b_frag_5_k_0, &smem_b_0[load_smem_b_2_k_0]);
__ldmatrix_b_elems(&b_frag_6_k_0, &b_frag_7_k_0, &smem_b_0[load_smem_b_3_k_0]);
for (int block_k = 0; block_k < num_k_blocks; block_k++) {
int phase_k = block_k % 3;
half *smem_a_curr = (phase_k == 0) ? smem_a_0 : ((phase_k == 1) ? smem_a_1 : smem_a_2);
half *smem_b_curr = (phase_k == 0) ? smem_b_0 : ((phase_k == 1) ? smem_b_1 : smem_b_2);
int next_phase_k = (block_k+1) % 3;
half *smem_a_next = (next_phase_k == 0) ? smem_a_0 : ((next_phase_k == 1) ? smem_a_1 : smem_a_2);
half *smem_b_next = (next_phase_k == 0) ? smem_b_0 : ((next_phase_k == 1) ? smem_b_1 : smem_b_2);
int store_phase_k = (block_k+2) % 3;
half *smem_a_store = (store_phase_k == 0) ? smem_a_0 : ((store_phase_k == 1) ? smem_a_1 : smem_a_2);
half *smem_b_store = (store_phase_k == 0) ? smem_b_0 : ((store_phase_k == 1) ? smem_b_1 : smem_b_2);
// load K=1 elements for the current tile
__ldmatrix_a_elems(&a_frag_0_k_1, &smem_a_curr[load_smem_a_0_k_1]);
__ldmatrix_a_elems(&a_frag_1_k_1, &smem_a_curr[load_smem_a_1_k_1]);
__ldmatrix_a_elems(&a_frag_2_k_1, &smem_a_curr[load_smem_a_2_k_1]);
__ldmatrix_a_elems(&a_frag_3_k_1, &smem_a_curr[load_smem_a_3_k_1]);
__ldmatrix_b_elems(&b_frag_0_k_1, &b_frag_1_k_1, &smem_b_curr[load_smem_b_0_k_1]);
__ldmatrix_b_elems(&b_frag_2_k_1, &b_frag_3_k_1, &smem_b_curr[load_smem_b_1_k_1]);
__ldmatrix_b_elems(&b_frag_4_k_1, &b_frag_5_k_1, &smem_b_curr[load_smem_b_2_k_1]);
__ldmatrix_b_elems(&b_frag_6_k_1, &b_frag_7_k_1, &smem_b_curr[load_smem_b_3_k_1]);
// MMA K=0, (M=4 x N=8)
acc_frag_0_0 = __WMMA_8_16_16_half_half(a_frag_0_k_0, b_frag_0_k_0, acc_frag_0_0);
acc_frag_0_1 = __WMMA_8_16_16_half_half(a_frag_0_k_0, b_frag_1_k_0, acc_frag_0_1);
acc_frag_0_2 = __WMMA_8_16_16_half_half(a_frag_0_k_0, b_frag_2_k_0, acc_frag_0_2);
acc_frag_0_3 = __WMMA_8_16_16_half_half(a_frag_0_k_0, b_frag_3_k_0, acc_frag_0_3);
acc_frag_0_4 = __WMMA_8_16_16_half_half(a_frag_0_k_0, b_frag_4_k_0, acc_frag_0_4);
acc_frag_0_5 = __WMMA_8_16_16_half_half(a_frag_0_k_0, b_frag_5_k_0, acc_frag_0_5);
acc_frag_0_6 = __WMMA_8_16_16_half_half(a_frag_0_k_0, b_frag_6_k_0, acc_frag_0_6);
acc_frag_0_7 = __WMMA_8_16_16_half_half(a_frag_0_k_0, b_frag_7_k_0, acc_frag_0_7);
acc_frag_1_0 = __WMMA_8_16_16_half_half(a_frag_1_k_0, b_frag_0_k_0, acc_frag_1_0);
acc_frag_1_1 = __WMMA_8_16_16_half_half(a_frag_1_k_0, b_frag_1_k_0, acc_frag_1_1);
acc_frag_1_2 = __WMMA_8_16_16_half_half(a_frag_1_k_0, b_frag_2_k_0, acc_frag_1_2);
acc_frag_1_3 = __WMMA_8_16_16_half_half(a_frag_1_k_0, b_frag_3_k_0, acc_frag_1_3);
acc_frag_1_4 = __WMMA_8_16_16_half_half(a_frag_1_k_0, b_frag_4_k_0, acc_frag_1_4);
acc_frag_1_5 = __WMMA_8_16_16_half_half(a_frag_1_k_0, b_frag_5_k_0, acc_frag_1_5);
acc_frag_1_6 = __WMMA_8_16_16_half_half(a_frag_1_k_0, b_frag_6_k_0, acc_frag_1_6);
acc_frag_1_7 = __WMMA_8_16_16_half_half(a_frag_1_k_0, b_frag_7_k_0, acc_frag_1_7);
acc_frag_2_0 = __WMMA_8_16_16_half_half(a_frag_2_k_0, b_frag_0_k_0, acc_frag_2_0);
acc_frag_2_1 = __WMMA_8_16_16_half_half(a_frag_2_k_0, b_frag_1_k_0, acc_frag_2_1);
acc_frag_2_2 = __WMMA_8_16_16_half_half(a_frag_2_k_0, b_frag_2_k_0, acc_frag_2_2);
acc_frag_2_3 = __WMMA_8_16_16_half_half(a_frag_2_k_0, b_frag_3_k_0, acc_frag_2_3);
acc_frag_2_4 = __WMMA_8_16_16_half_half(a_frag_2_k_0, b_frag_4_k_0, acc_frag_2_4);
acc_frag_2_5 = __WMMA_8_16_16_half_half(a_frag_2_k_0, b_frag_5_k_0, acc_frag_2_5);
acc_frag_2_6 = __WMMA_8_16_16_half_half(a_frag_2_k_0, b_frag_6_k_0, acc_frag_2_6);
acc_frag_2_7 = __WMMA_8_16_16_half_half(a_frag_2_k_0, b_frag_7_k_0, acc_frag_2_7);
acc_frag_3_0 = __WMMA_8_16_16_half_half(a_frag_3_k_0, b_frag_0_k_0, acc_frag_3_0);
acc_frag_3_1 = __WMMA_8_16_16_half_half(a_frag_3_k_0, b_frag_1_k_0, acc_frag_3_1);
acc_frag_3_2 = __WMMA_8_16_16_half_half(a_frag_3_k_0, b_frag_2_k_0, acc_frag_3_2);
acc_frag_3_3 = __WMMA_8_16_16_half_half(a_frag_3_k_0, b_frag_3_k_0, acc_frag_3_3);
acc_frag_3_4 = __WMMA_8_16_16_half_half(a_frag_3_k_0, b_frag_4_k_0, acc_frag_3_4);
acc_frag_3_5 = __WMMA_8_16_16_half_half(a_frag_3_k_0, b_frag_5_k_0, acc_frag_3_5);
acc_frag_3_6 = __WMMA_8_16_16_half_half(a_frag_3_k_0, b_frag_6_k_0, acc_frag_3_6);
acc_frag_3_7 = __WMMA_8_16_16_half_half(a_frag_3_k_0, b_frag_7_k_0, acc_frag_3_7);
// load next tile if needed
if (block_k < (num_k_blocks-2)) {
__pipeline_memcpy_async(&smem_a_store[store_smem_a_off + ( 0)], &data1[global_a_off + ( 0)], 16);
__pipeline_memcpy_async(&smem_a_store[store_smem_a_off + ( 32*64)], &data1[global_a_off + ( 32*K)], 16);
__pipeline_memcpy_async(&smem_a_store[store_smem_a_off + ( 64*64)], &data1[global_a_off + ( 64*K)], 16);
__pipeline_memcpy_async(&smem_a_store[store_smem_a_off + ( 96*64)], &data1[global_a_off + ( 96*K)], 16);
__pipeline_memcpy_async(&smem_b_store[store_smem_b_off + ( 0)], &data2[global_b_off + ( 0)], 16);
__pipeline_memcpy_async(&smem_b_store[store_smem_b_off + (16*128)], &data2[global_b_off + ( 16*N)], 16);
global_a_off += 32;
global_b_off += 32 * N;
}
__pipeline_commit();
// wait next tile
__pipeline_wait_prior(1);
__syncthreads();
// load K=0 elements for the next tile
__ldmatrix_a_elems(&a_frag_0_k_0, &smem_a_next[load_smem_a_0_k_0]);
__ldmatrix_a_elems(&a_frag_1_k_0, &smem_a_next[load_smem_a_1_k_0]);
__ldmatrix_a_elems(&a_frag_2_k_0, &smem_a_next[load_smem_a_2_k_0]);
__ldmatrix_a_elems(&a_frag_3_k_0, &smem_a_next[load_smem_a_3_k_0]);
__ldmatrix_b_elems(&b_frag_0_k_0, &b_frag_1_k_0, &smem_b_next[load_smem_b_0_k_0]);
__ldmatrix_b_elems(&b_frag_2_k_0, &b_frag_3_k_0, &smem_b_next[load_smem_b_1_k_0]);
__ldmatrix_b_elems(&b_frag_4_k_0, &b_frag_5_k_0, &smem_b_next[load_smem_b_2_k_0]);
__ldmatrix_b_elems(&b_frag_6_k_0, &b_frag_7_k_0, &smem_b_next[load_smem_b_3_k_0]);
// MMA K=1, (M=4 x N=8)
acc_frag_0_0 = __WMMA_8_16_16_half_half(a_frag_0_k_1, b_frag_0_k_1, acc_frag_0_0);
acc_frag_0_1 = __WMMA_8_16_16_half_half(a_frag_0_k_1, b_frag_1_k_1, acc_frag_0_1);
acc_frag_0_2 = __WMMA_8_16_16_half_half(a_frag_0_k_1, b_frag_2_k_1, acc_frag_0_2);
acc_frag_0_3 = __WMMA_8_16_16_half_half(a_frag_0_k_1, b_frag_3_k_1, acc_frag_0_3);
acc_frag_0_4 = __WMMA_8_16_16_half_half(a_frag_0_k_1, b_frag_4_k_1, acc_frag_0_4);
acc_frag_0_5 = __WMMA_8_16_16_half_half(a_frag_0_k_1, b_frag_5_k_1, acc_frag_0_5);
acc_frag_0_6 = __WMMA_8_16_16_half_half(a_frag_0_k_1, b_frag_6_k_1, acc_frag_0_6);
acc_frag_0_7 = __WMMA_8_16_16_half_half(a_frag_0_k_1, b_frag_7_k_1, acc_frag_0_7);
acc_frag_1_0 = __WMMA_8_16_16_half_half(a_frag_1_k_1, b_frag_0_k_1, acc_frag_1_0);
acc_frag_1_1 = __WMMA_8_16_16_half_half(a_frag_1_k_1, b_frag_1_k_1, acc_frag_1_1);
acc_frag_1_2 = __WMMA_8_16_16_half_half(a_frag_1_k_1, b_frag_2_k_1, acc_frag_1_2);
acc_frag_1_3 = __WMMA_8_16_16_half_half(a_frag_1_k_1, b_frag_3_k_1, acc_frag_1_3);
acc_frag_1_4 = __WMMA_8_16_16_half_half(a_frag_1_k_1, b_frag_4_k_1, acc_frag_1_4);
acc_frag_1_5 = __WMMA_8_16_16_half_half(a_frag_1_k_1, b_frag_5_k_1, acc_frag_1_5);
acc_frag_1_6 = __WMMA_8_16_16_half_half(a_frag_1_k_1, b_frag_6_k_1, acc_frag_1_6);
acc_frag_1_7 = __WMMA_8_16_16_half_half(a_frag_1_k_1, b_frag_7_k_1, acc_frag_1_7);
acc_frag_2_0 = __WMMA_8_16_16_half_half(a_frag_2_k_1, b_frag_0_k_1, acc_frag_2_0);
acc_frag_2_1 = __WMMA_8_16_16_half_half(a_frag_2_k_1, b_frag_1_k_1, acc_frag_2_1);
acc_frag_2_2 = __WMMA_8_16_16_half_half(a_frag_2_k_1, b_frag_2_k_1, acc_frag_2_2);
acc_frag_2_3 = __WMMA_8_16_16_half_half(a_frag_2_k_1, b_frag_3_k_1, acc_frag_2_3);
acc_frag_2_4 = __WMMA_8_16_16_half_half(a_frag_2_k_1, b_frag_4_k_1, acc_frag_2_4);
acc_frag_2_5 = __WMMA_8_16_16_half_half(a_frag_2_k_1, b_frag_5_k_1, acc_frag_2_5);
acc_frag_2_6 = __WMMA_8_16_16_half_half(a_frag_2_k_1, b_frag_6_k_1, acc_frag_2_6);
acc_frag_2_7 = __WMMA_8_16_16_half_half(a_frag_2_k_1, b_frag_7_k_1, acc_frag_2_7);
acc_frag_3_0 = __WMMA_8_16_16_half_half(a_frag_3_k_1, b_frag_0_k_1, acc_frag_3_0);
acc_frag_3_1 = __WMMA_8_16_16_half_half(a_frag_3_k_1, b_frag_1_k_1, acc_frag_3_1);
acc_frag_3_2 = __WMMA_8_16_16_half_half(a_frag_3_k_1, b_frag_2_k_1, acc_frag_3_2);
acc_frag_3_3 = __WMMA_8_16_16_half_half(a_frag_3_k_1, b_frag_3_k_1, acc_frag_3_3);
acc_frag_3_4 = __WMMA_8_16_16_half_half(a_frag_3_k_1, b_frag_4_k_1, acc_frag_3_4);
acc_frag_3_5 = __WMMA_8_16_16_half_half(a_frag_3_k_1, b_frag_5_k_1, acc_frag_3_5);
acc_frag_3_6 = __WMMA_8_16_16_half_half(a_frag_3_k_1, b_frag_6_k_1, acc_frag_3_6);
acc_frag_3_7 = __WMMA_8_16_16_half_half(a_frag_3_k_1, b_frag_7_k_1, acc_frag_3_7);
}
// write accumulators to output
__pipeline_wait_prior(0);
__syncthreads();
// faster epilogue: write each 8x8 TC accs to SMEM first
// - SMEM_N_WIDTH 8 larger than 128 required to deconflict bank access
// - around 14 micros
// - check bank conflict with in sudo with: "PYTHONPATH=. CUDA=1 GEMM_VARIATION="max" DTYPE_IN=half DTYPE_OUT=half DTYPE_ACC=half CNT=8 INPUT=ONES /usr/local/cuda/bin/ncu --section MemoryWorkloadAnalysis --metrics l1tex__data_bank_conflicts_pipe_lsu_mem_shared_op_ld.sum,l1tex__data_bank_conflicts_pipe_lsu_mem_shared_op_st.sum python3 ./extra/gemm/max_matmul.py"
// epilogue chunk with 256 threads / WG_M=4 / WG_N=2: split into 8 chunks (hi/lo for each in TC M)
// 1) write 32 rows of 128 cols (rows 0-7, 16-23, 32-39, 48-53 in acc_frag_0.lo, then acc_frag_0.hi, etc.)
// 2) read/write 16 rows of 128 elements in 8 elem (16B) chunks
half2 *smem32_d = (half2 *)(smem);
half8 *smem128_d = (half8 *)(smem);
half8 *out128_d = (half8 *)(data0);
size_t smem32_d_write_off = (wg_m * 8 * (SMEM_N_WIDTH / 2)) + (wg_n * (16 / 2));
size_t smem32_d_thread_off = ((wg_threads / 4) * (SMEM_N_WIDTH / 2)) + (wg_threads % 4);
size_t smem128_d_read_off = ((threads / 16) * (SMEM_N_WIDTH / 8)) + (threads % 16);
size_t out128_d_off = ((grid_m * 256) * (N / 8)) + (grid_n * (128 / 8)) +
((threads / 128) * 16 * (N / 8)) + (((threads / 16) % 8) * (N / 8)) + (threads % 16);
// write acc_frag_0_*
// write 32 rows of 128 N elements to SMEM
__syncthreads();
smem32_d[smem32_d_write_off + smem32_d_thread_off + ( 0*4)] = half2(acc_frag_0_0.x, acc_frag_0_0.y);
smem32_d[smem32_d_write_off + smem32_d_thread_off + ( 1*4)] = half2(acc_frag_0_1.x, acc_frag_0_1.y);
smem32_d[smem32_d_write_off + smem32_d_thread_off + ( 4*4)] = half2(acc_frag_0_2.x, acc_frag_0_2.y);
smem32_d[smem32_d_write_off + smem32_d_thread_off + ( 5*4)] = half2(acc_frag_0_3.x, acc_frag_0_3.y);
smem32_d[smem32_d_write_off + smem32_d_thread_off + ( 8*4)] = half2(acc_frag_0_4.x, acc_frag_0_4.y);
smem32_d[smem32_d_write_off + smem32_d_thread_off + ( 9*4)] = half2(acc_frag_0_5.x, acc_frag_0_5.y);
smem32_d[smem32_d_write_off + smem32_d_thread_off + (12*4)] = half2(acc_frag_0_6.x, acc_frag_0_6.y);
smem32_d[smem32_d_write_off + smem32_d_thread_off + (13*4)] = half2(acc_frag_0_7.x, acc_frag_0_7.y);
// each thread reads and writes two 8 element chunks
__syncthreads();
out128_d[out128_d_off + ( 0 * (N / 8))] = smem128_d[smem128_d_read_off];
out128_d[out128_d_off + (32 * (N / 8))] = smem128_d[smem128_d_read_off + (16 * (SMEM_N_WIDTH / 8))];
// write 32 rows of 128 N elements to SMEM
__syncthreads();
smem32_d[smem32_d_write_off + smem32_d_thread_off + ( 0*4)] = half2(acc_frag_0_0.z, acc_frag_0_0.w);
smem32_d[smem32_d_write_off + smem32_d_thread_off + ( 1*4)] = half2(acc_frag_0_1.z, acc_frag_0_1.w);
smem32_d[smem32_d_write_off + smem32_d_thread_off + ( 4*4)] = half2(acc_frag_0_2.z, acc_frag_0_2.w);
smem32_d[smem32_d_write_off + smem32_d_thread_off + ( 5*4)] = half2(acc_frag_0_3.z, acc_frag_0_3.w);
smem32_d[smem32_d_write_off + smem32_d_thread_off + ( 8*4)] = half2(acc_frag_0_4.z, acc_frag_0_4.w);
smem32_d[smem32_d_write_off + smem32_d_thread_off + ( 9*4)] = half2(acc_frag_0_5.z, acc_frag_0_5.w);
smem32_d[smem32_d_write_off + smem32_d_thread_off + (12*4)] = half2(acc_frag_0_6.z, acc_frag_0_6.w);
smem32_d[smem32_d_write_off + smem32_d_thread_off + (13*4)] = half2(acc_frag_0_7.z, acc_frag_0_7.w);
// each thread reads and writes two 8 element chunks
__syncthreads();
out128_d[out128_d_off + ( 8 * (N / 8))] = smem128_d[smem128_d_read_off];
out128_d[out128_d_off + (40 * (N / 8))] = smem128_d[smem128_d_read_off + (16 * (SMEM_N_WIDTH / 8))];
// write acc_frag_1_*
out128_d_off += (64 * (N / 8));
// write 32 rows of 128 N elements to SMEM
__syncthreads();
smem32_d[smem32_d_write_off + smem32_d_thread_off + ( 0*4)] = half2(acc_frag_1_0.x, acc_frag_1_0.y);
smem32_d[smem32_d_write_off + smem32_d_thread_off + ( 1*4)] = half2(acc_frag_1_1.x, acc_frag_1_1.y);
smem32_d[smem32_d_write_off + smem32_d_thread_off + ( 4*4)] = half2(acc_frag_1_2.x, acc_frag_1_2.y);
smem32_d[smem32_d_write_off + smem32_d_thread_off + ( 5*4)] = half2(acc_frag_1_3.x, acc_frag_1_3.y);
smem32_d[smem32_d_write_off + smem32_d_thread_off + ( 8*4)] = half2(acc_frag_1_4.x, acc_frag_1_4.y);
smem32_d[smem32_d_write_off + smem32_d_thread_off + ( 9*4)] = half2(acc_frag_1_5.x, acc_frag_1_5.y);
smem32_d[smem32_d_write_off + smem32_d_thread_off + (12*4)] = half2(acc_frag_1_6.x, acc_frag_1_6.y);
smem32_d[smem32_d_write_off + smem32_d_thread_off + (13*4)] = half2(acc_frag_1_7.x, acc_frag_1_7.y);
// each thread reads and writes two 8 element chunks
__syncthreads();
out128_d[out128_d_off + ( 0 * (N / 8))] = smem128_d[smem128_d_read_off];
out128_d[out128_d_off + (32 * (N / 8))] = smem128_d[smem128_d_read_off + (16 * (SMEM_N_WIDTH / 8))];
// write 32 rows of 128 N elements to SMEM
__syncthreads();
smem32_d[smem32_d_write_off + smem32_d_thread_off + ( 0*4)] = half2(acc_frag_1_0.z, acc_frag_1_0.w);
smem32_d[smem32_d_write_off + smem32_d_thread_off + ( 1*4)] = half2(acc_frag_1_1.z, acc_frag_1_1.w);
smem32_d[smem32_d_write_off + smem32_d_thread_off + ( 4*4)] = half2(acc_frag_1_2.z, acc_frag_1_2.w);
smem32_d[smem32_d_write_off + smem32_d_thread_off + ( 5*4)] = half2(acc_frag_1_3.z, acc_frag_1_3.w);
smem32_d[smem32_d_write_off + smem32_d_thread_off + ( 8*4)] = half2(acc_frag_1_4.z, acc_frag_1_4.w);
smem32_d[smem32_d_write_off + smem32_d_thread_off + ( 9*4)] = half2(acc_frag_1_5.z, acc_frag_1_5.w);
smem32_d[smem32_d_write_off + smem32_d_thread_off + (12*4)] = half2(acc_frag_1_6.z, acc_frag_1_6.w);
smem32_d[smem32_d_write_off + smem32_d_thread_off + (13*4)] = half2(acc_frag_1_7.z, acc_frag_1_7.w);
// each thread reads and writes two 8 element chunks
__syncthreads();
out128_d[out128_d_off + ( 8 * (N / 8))] = smem128_d[smem128_d_read_off];
out128_d[out128_d_off + (40 * (N / 8))] = smem128_d[smem128_d_read_off + (16 * (SMEM_N_WIDTH / 8))];
// write acc_frag_2_*
out128_d_off += (64 * (N / 8));
// write 32 rows of 128 N elements to SMEM
__syncthreads();
smem32_d[smem32_d_write_off + smem32_d_thread_off + ( 0*4)] = half2(acc_frag_2_0.x, acc_frag_2_0.y);
smem32_d[smem32_d_write_off + smem32_d_thread_off + ( 1*4)] = half2(acc_frag_2_1.x, acc_frag_2_1.y);
smem32_d[smem32_d_write_off + smem32_d_thread_off + ( 4*4)] = half2(acc_frag_2_2.x, acc_frag_2_2.y);
smem32_d[smem32_d_write_off + smem32_d_thread_off + ( 5*4)] = half2(acc_frag_2_3.x, acc_frag_2_3.y);
smem32_d[smem32_d_write_off + smem32_d_thread_off + ( 8*4)] = half2(acc_frag_2_4.x, acc_frag_2_4.y);
smem32_d[smem32_d_write_off + smem32_d_thread_off + ( 9*4)] = half2(acc_frag_2_5.x, acc_frag_2_5.y);
smem32_d[smem32_d_write_off + smem32_d_thread_off + (12*4)] = half2(acc_frag_2_6.x, acc_frag_2_6.y);
smem32_d[smem32_d_write_off + smem32_d_thread_off + (13*4)] = half2(acc_frag_2_7.x, acc_frag_2_7.y);
// each thread reads and writes two 8 element chunks
__syncthreads();
out128_d[out128_d_off + ( 0 * (N / 8))] = smem128_d[smem128_d_read_off];
out128_d[out128_d_off + (32 * (N / 8))] = smem128_d[smem128_d_read_off + (16 * (SMEM_N_WIDTH / 8))];
// write 32 rows of 128 N elements to SMEM
__syncthreads();
smem32_d[smem32_d_write_off + smem32_d_thread_off + ( 0*4)] = half2(acc_frag_2_0.z, acc_frag_2_0.w);
smem32_d[smem32_d_write_off + smem32_d_thread_off + ( 1*4)] = half2(acc_frag_2_1.z, acc_frag_2_1.w);
smem32_d[smem32_d_write_off + smem32_d_thread_off + ( 4*4)] = half2(acc_frag_2_2.z, acc_frag_2_2.w);
smem32_d[smem32_d_write_off + smem32_d_thread_off + ( 5*4)] = half2(acc_frag_2_3.z, acc_frag_2_3.w);
smem32_d[smem32_d_write_off + smem32_d_thread_off + ( 8*4)] = half2(acc_frag_2_4.z, acc_frag_2_4.w);
smem32_d[smem32_d_write_off + smem32_d_thread_off + ( 9*4)] = half2(acc_frag_2_5.z, acc_frag_2_5.w);
smem32_d[smem32_d_write_off + smem32_d_thread_off + (12*4)] = half2(acc_frag_2_6.z, acc_frag_2_6.w);
smem32_d[smem32_d_write_off + smem32_d_thread_off + (13*4)] = half2(acc_frag_2_7.z, acc_frag_2_7.w);
// each thread reads and writes two 8 element chunks
__syncthreads();
out128_d[out128_d_off + ( 8 * (N / 8))] = smem128_d[smem128_d_read_off];
out128_d[out128_d_off + (40 * (N / 8))] = smem128_d[smem128_d_read_off + (16 * (SMEM_N_WIDTH / 8))];
// write acc_frag_3_*
out128_d_off += (64 * (N / 8));
// write 32 rows of 128 N elements to SMEM
__syncthreads();
smem32_d[smem32_d_write_off + smem32_d_thread_off + ( 0*4)] = half2(acc_frag_3_0.x, acc_frag_3_0.y);
smem32_d[smem32_d_write_off + smem32_d_thread_off + ( 1*4)] = half2(acc_frag_3_1.x, acc_frag_3_1.y);
smem32_d[smem32_d_write_off + smem32_d_thread_off + ( 4*4)] = half2(acc_frag_3_2.x, acc_frag_3_2.y);
smem32_d[smem32_d_write_off + smem32_d_thread_off + ( 5*4)] = half2(acc_frag_3_3.x, acc_frag_3_3.y);
smem32_d[smem32_d_write_off + smem32_d_thread_off + ( 8*4)] = half2(acc_frag_3_4.x, acc_frag_3_4.y);
smem32_d[smem32_d_write_off + smem32_d_thread_off + ( 9*4)] = half2(acc_frag_3_5.x, acc_frag_3_5.y);
smem32_d[smem32_d_write_off + smem32_d_thread_off + (12*4)] = half2(acc_frag_3_6.x, acc_frag_3_6.y);
smem32_d[smem32_d_write_off + smem32_d_thread_off + (13*4)] = half2(acc_frag_3_7.x, acc_frag_3_7.y);
// each thread reads and writes two 8 element chunks
__syncthreads();
out128_d[out128_d_off + ( 0 * (N / 8))] = smem128_d[smem128_d_read_off];
out128_d[out128_d_off + (32 * (N / 8))] = smem128_d[smem128_d_read_off + (16 * (SMEM_N_WIDTH / 8))];
// write 32 rows of 128 N elements to SMEM
__syncthreads();
smem32_d[smem32_d_write_off + smem32_d_thread_off + ( 0*4)] = half2(acc_frag_3_0.z, acc_frag_3_0.w);
smem32_d[smem32_d_write_off + smem32_d_thread_off + ( 1*4)] = half2(acc_frag_3_1.z, acc_frag_3_1.w);
smem32_d[smem32_d_write_off + smem32_d_thread_off + ( 4*4)] = half2(acc_frag_3_2.z, acc_frag_3_2.w);
smem32_d[smem32_d_write_off + smem32_d_thread_off + ( 5*4)] = half2(acc_frag_3_3.z, acc_frag_3_3.w);
smem32_d[smem32_d_write_off + smem32_d_thread_off + ( 8*4)] = half2(acc_frag_3_4.z, acc_frag_3_4.w);
smem32_d[smem32_d_write_off + smem32_d_thread_off + ( 9*4)] = half2(acc_frag_3_5.z, acc_frag_3_5.w);
smem32_d[smem32_d_write_off + smem32_d_thread_off + (12*4)] = half2(acc_frag_3_6.z, acc_frag_3_6.w);
smem32_d[smem32_d_write_off + smem32_d_thread_off + (13*4)] = half2(acc_frag_3_7.z, acc_frag_3_7.w);
// each thread reads and writes two 8 element chunks
__syncthreads();
out128_d[out128_d_off + ( 8 * (N / 8))] = smem128_d[smem128_d_read_off];
out128_d[out128_d_off + (40 * (N / 8))] = smem128_d[smem128_d_read_off + (16 * (SMEM_N_WIDTH / 8))];
__syncthreads();
}

View File

@@ -0,0 +1,486 @@
#define INFINITY (__int_as_float(0x7f800000))
#define NAN (__int_as_float(0x7fffffff))
#include <cuda_fp16.h>
#include <cuda_pipeline.h>
#define SMEM_N_WIDTH 136
struct __align__(8) half4 { half x, y, z, w; };
__device__ half4 make_half4(half x, half y, half z, half w) { half4 r={x, y, z, w}; return r; }
struct __align__(16) half8 { half x, y, z, w, a, b, c, d; };
__device__ half8 make_half8(half x, half y, half z, half w, half a, half b, half c, half d) { half8 r={x, y, z, w, a, b, c, d}; return r; }
__device__ void __ldmatrix_a_elems(half8 *regs, half *smem) {
uint32_t reg0, reg1, reg2, reg3;
asm volatile(
"ldmatrix.sync.aligned.m8n8.x4.shared.b16 {%0, %1, %2, %3}, [%4];"
: "=r"(reg0), "=r"(reg1), "=r"(reg2), "=r"(reg3)
: "l"(__cvta_generic_to_shared(smem))
);
uint32_t *addr = reinterpret_cast<uint32_t*>(regs);
addr[0] = reg0;
addr[1] = reg1;
addr[2] = reg2;
addr[3] = reg3;
}
__device__ void __ldmatrix_b_elems(half4 *regs_lo, half4 *regs_hi, half *smem) {
uint32_t reg0, reg1, reg2, reg3;
asm volatile(
"ldmatrix.sync.aligned.m8n8.x4.trans.shared.b16 {%0, %1, %2, %3}, [%4];"
: "=r"(reg0), "=r"(reg1), "=r"(reg2), "=r"(reg3)
: "l"(__cvta_generic_to_shared(smem))
);
uint32_t *addr_lo = reinterpret_cast<uint32_t*>(regs_lo);
uint32_t *addr_hi = reinterpret_cast<uint32_t*>(regs_hi);
addr_lo[0] = reg0;
addr_lo[1] = reg1;
addr_hi[0] = reg2;
addr_hi[1] = reg3;
}
__device__ half4 __WMMA_8_16_16_half_half(half8 a, half4 b, half4 c) {
int *a_pk = (int *) (&a), *b_pk = (int *) (&b), *c_pk = (int *) (&c);
asm( "mma.sync.aligned.m16n8k16.row.col.f16.f16.f16.f16 { %0, %1 }, { %2, %3, %4, %5 }, { %6, %7 }, { %0, %1 };"
: "+r"(c_pk[0]), "+r"(c_pk[1]): "r"(a_pk[0]), "r"(a_pk[1]), "r"(a_pk[2]), "r"(a_pk[3]), "r"(b_pk[0]), "r"(b_pk[1]) );
return c;
}
extern "C" __global__ void __launch_bounds__(256) wmma_example(half* data0, const half* data1, const half* data2, int N, int K) {
extern __shared__ char smem[];
half *smem_a_0 = (half *)(smem);
half *smem_a_1 = (half *)(smem + 16384);
half *smem_a_2 = (half *)(smem + 32768);
half *smem_b_0 = (half *)(smem + 49152);
half *smem_b_1 = (half *)(smem + 57344);
half *smem_b_2 = (half *)(smem + 65536);
int grid_m = blockIdx.x; /* M//256 */
int grid_n = blockIdx.y; /* N//128 */
int wg_threads = threadIdx.x; // 32
int wg_m = threadIdx.y; // 4
int wg_n = threadIdx.z; // 2
int threads = threadIdx.x + (threadIdx.y * 32) + (threadIdx.z * 128); /* 256 */
int num_k_blocks = K / 32;
// ldmatrix indices - 4x loads of 8x8 matrices by 32 threads
// threads 0-7 are row starts for A, 8-15 for B, 16-23 for C, 24-31 for D
// [ A | C ]
// [ - + - ]
// [ B | D ]
// swizzled A - SMEM_A is 128 rows x 64 cols
size_t global_a_off = ((grid_m * 256) * K) + ((threads % 4) * 8) + (((threads / 4) % 2) * 8 * 16 * K) + ((threads / 8) * K);
size_t store_smem_a_off = ((threads / 8) * 64) + (((threads * 8) ^ threads) & 56); // 32 rows / 64 cols per copy
size_t load_smem_a_row = ((wg_m * 16) + (threads % 16)) * 64;
size_t load_smem_a_phase = (threads / 16) % 2;
size_t load_smem_a_0_k_0 = load_smem_a_row + ( 0 * 64) + (((load_smem_a_phase + 0) ^ (threads % 8)) * 8);
size_t load_smem_a_1_k_0 = load_smem_a_row + (64 * 64) + (((load_smem_a_phase + 0) ^ (threads % 8)) * 8);
size_t load_smem_a_2_k_0 = load_smem_a_row + ( 0 * 64) + (((load_smem_a_phase + 4) ^ (threads % 8)) * 8);
size_t load_smem_a_3_k_0 = load_smem_a_row + (64 * 64) + (((load_smem_a_phase + 4) ^ (threads % 8)) * 8);
size_t load_smem_a_0_k_1 = load_smem_a_row + ( 0 * 64) + (((load_smem_a_phase + 2) ^ (threads % 8)) * 8);
size_t load_smem_a_1_k_1 = load_smem_a_row + (64 * 64) + (((load_smem_a_phase + 2) ^ (threads % 8)) * 8);
size_t load_smem_a_2_k_1 = load_smem_a_row + ( 0 * 64) + (((load_smem_a_phase + 6) ^ (threads % 8)) * 8);
size_t load_smem_a_3_k_1 = load_smem_a_row + (64 * 64) + (((load_smem_a_phase + 6) ^ (threads % 8)) * 8);
// swizzled B - SMEM_B is 64 rows x 64 cols
size_t global_b_off = (grid_n * 128) + ((threads % 16) * 8) + ((threads / 16) * N);
size_t store_smem_b_off = // 32 rows of 64 cols per copy
((threads / 128) * (64)) + // [A,C] vs [B,D] in ldmatrix
((threads % 2) * (2 * 64)) + // [A vs C] or [B vs. D]
(((threads / 2) % 2) * (4 * 64)) + // WG_N in [0, 1]
(((threads / 4) % 4) * (8 * 64)) + // B in [0, 1, 2, 3]
(((threads / 16) % 8) * (8)); // cols in SMEM_B i.e. rows of 8x8
size_t load_smem_b_0_k_0 = (wg_n * 4 * 64) + ((wg_threads / 8) * 64) + ((wg_threads % 8) * 8);
size_t load_smem_b_1_k_0 = load_smem_b_0_k_0 + ( 8 * 64);
size_t load_smem_b_2_k_0 = load_smem_b_0_k_0 + (16 * 64);
size_t load_smem_b_3_k_0 = load_smem_b_0_k_0 + (24 * 64);
size_t load_smem_b_0_k_1 = load_smem_b_0_k_0 + (32 * 64);
size_t load_smem_b_1_k_1 = load_smem_b_1_k_0 + (32 * 64);
size_t load_smem_b_2_k_1 = load_smem_b_2_k_0 + (32 * 64);
size_t load_smem_b_3_k_1 = load_smem_b_3_k_0 + (32 * 64);
// create accs (M=4, N=8)
half4 acc_frag_0_0 = make_half4(0.0f,0.0f,0.0f,0.0f);
half4 acc_frag_0_1 = make_half4(0.0f,0.0f,0.0f,0.0f);
half4 acc_frag_0_2 = make_half4(0.0f,0.0f,0.0f,0.0f);
half4 acc_frag_0_3 = make_half4(0.0f,0.0f,0.0f,0.0f);
half4 acc_frag_0_4 = make_half4(0.0f,0.0f,0.0f,0.0f);
half4 acc_frag_0_5 = make_half4(0.0f,0.0f,0.0f,0.0f);
half4 acc_frag_0_6 = make_half4(0.0f,0.0f,0.0f,0.0f);
half4 acc_frag_0_7 = make_half4(0.0f,0.0f,0.0f,0.0f);
half4 acc_frag_1_0 = make_half4(0.0f,0.0f,0.0f,0.0f);
half4 acc_frag_1_1 = make_half4(0.0f,0.0f,0.0f,0.0f);
half4 acc_frag_1_2 = make_half4(0.0f,0.0f,0.0f,0.0f);
half4 acc_frag_1_3 = make_half4(0.0f,0.0f,0.0f,0.0f);
half4 acc_frag_1_4 = make_half4(0.0f,0.0f,0.0f,0.0f);
half4 acc_frag_1_5 = make_half4(0.0f,0.0f,0.0f,0.0f);
half4 acc_frag_1_6 = make_half4(0.0f,0.0f,0.0f,0.0f);
half4 acc_frag_1_7 = make_half4(0.0f,0.0f,0.0f,0.0f);
half4 acc_frag_2_0 = make_half4(0.0f,0.0f,0.0f,0.0f);
half4 acc_frag_2_1 = make_half4(0.0f,0.0f,0.0f,0.0f);
half4 acc_frag_2_2 = make_half4(0.0f,0.0f,0.0f,0.0f);
half4 acc_frag_2_3 = make_half4(0.0f,0.0f,0.0f,0.0f);
half4 acc_frag_2_4 = make_half4(0.0f,0.0f,0.0f,0.0f);
half4 acc_frag_2_5 = make_half4(0.0f,0.0f,0.0f,0.0f);
half4 acc_frag_2_6 = make_half4(0.0f,0.0f,0.0f,0.0f);
half4 acc_frag_2_7 = make_half4(0.0f,0.0f,0.0f,0.0f);
half4 acc_frag_3_0 = make_half4(0.0f,0.0f,0.0f,0.0f);
half4 acc_frag_3_1 = make_half4(0.0f,0.0f,0.0f,0.0f);
half4 acc_frag_3_2 = make_half4(0.0f,0.0f,0.0f,0.0f);
half4 acc_frag_3_3 = make_half4(0.0f,0.0f,0.0f,0.0f);
half4 acc_frag_3_4 = make_half4(0.0f,0.0f,0.0f,0.0f);
half4 acc_frag_3_5 = make_half4(0.0f,0.0f,0.0f,0.0f);
half4 acc_frag_3_6 = make_half4(0.0f,0.0f,0.0f,0.0f);
half4 acc_frag_3_7 = make_half4(0.0f,0.0f,0.0f,0.0f);
// create registers for block A elements
half8 a_frag_0_k_0;
half8 a_frag_1_k_0;
half8 a_frag_2_k_0;
half8 a_frag_3_k_0;
half8 a_frag_0_k_1;
half8 a_frag_1_k_1;
half8 a_frag_2_k_1;
half8 a_frag_3_k_1;
// create register for block B elements
half4 b_frag_0_k_0;
half4 b_frag_1_k_0;
half4 b_frag_2_k_0;
half4 b_frag_3_k_0;
half4 b_frag_4_k_0;
half4 b_frag_5_k_0;
half4 b_frag_6_k_0;
half4 b_frag_7_k_0;
half4 b_frag_0_k_1;
half4 b_frag_1_k_1;
half4 b_frag_2_k_1;
half4 b_frag_3_k_1;
half4 b_frag_4_k_1;
half4 b_frag_5_k_1;
half4 b_frag_6_k_1;
half4 b_frag_7_k_1;
__syncthreads();
// load first tile
__pipeline_memcpy_async(&smem_a_0[store_smem_a_off + ( 0)], &data1[global_a_off + ( 0)], 16);
__pipeline_memcpy_async(&smem_a_0[store_smem_a_off + ( 32*64)], &data1[global_a_off + ( 32*K)], 16);
__pipeline_memcpy_async(&smem_a_0[store_smem_a_off + ( 64*64)], &data1[global_a_off + ( 64*K)], 16);
__pipeline_memcpy_async(&smem_a_0[store_smem_a_off + ( 96*64)], &data1[global_a_off + ( 96*K)], 16);
__pipeline_memcpy_async(&smem_b_0[store_smem_b_off + ( 0)], &data2[global_b_off + ( 0)], 16);
__pipeline_memcpy_async(&smem_b_0[store_smem_b_off + ( 32*64)], &data2[global_b_off + ( 16*N)], 16);
__pipeline_commit();
global_a_off += 32;
global_b_off += 32 * N;
// load second tile
__pipeline_memcpy_async(&smem_a_1[store_smem_a_off + ( 0)], &data1[global_a_off + ( 0)], 16);
__pipeline_memcpy_async(&smem_a_1[store_smem_a_off + ( 32*64)], &data1[global_a_off + ( 32*K)], 16);
__pipeline_memcpy_async(&smem_a_1[store_smem_a_off + ( 64*64)], &data1[global_a_off + ( 64*K)], 16);
__pipeline_memcpy_async(&smem_a_1[store_smem_a_off + ( 96*64)], &data1[global_a_off + ( 96*K)], 16);
__pipeline_memcpy_async(&smem_b_1[store_smem_b_off + ( 0)], &data2[global_b_off + ( 0)], 16);
__pipeline_memcpy_async(&smem_b_1[store_smem_b_off + ( 32*64)], &data2[global_b_off + ( 16*N)], 16);
__pipeline_commit();
global_a_off += 32;
global_b_off += 32 * N;
// wait on first pre-fetch load
__pipeline_wait_prior(1);
__syncthreads();
// load K=0 elements for the first tile
__ldmatrix_a_elems(&a_frag_0_k_0, &smem_a_0[load_smem_a_0_k_0]);
__ldmatrix_a_elems(&a_frag_1_k_0, &smem_a_0[load_smem_a_1_k_0]);
__ldmatrix_a_elems(&a_frag_2_k_0, &smem_a_0[load_smem_a_2_k_0]);
__ldmatrix_a_elems(&a_frag_3_k_0, &smem_a_0[load_smem_a_3_k_0]);
__ldmatrix_b_elems(&b_frag_0_k_0, &b_frag_1_k_0, &smem_b_0[load_smem_b_0_k_0]);
__ldmatrix_b_elems(&b_frag_2_k_0, &b_frag_3_k_0, &smem_b_0[load_smem_b_1_k_0]);
__ldmatrix_b_elems(&b_frag_4_k_0, &b_frag_5_k_0, &smem_b_0[load_smem_b_2_k_0]);
__ldmatrix_b_elems(&b_frag_6_k_0, &b_frag_7_k_0, &smem_b_0[load_smem_b_3_k_0]);
for (int block_k = 0; block_k < num_k_blocks; block_k++) {
int phase_k = block_k % 3;
half *smem_a_curr = (phase_k == 0) ? smem_a_0 : ((phase_k == 1) ? smem_a_1 : smem_a_2);
half *smem_b_curr = (phase_k == 0) ? smem_b_0 : ((phase_k == 1) ? smem_b_1 : smem_b_2);
int next_phase_k = (block_k+1) % 3;
half *smem_a_next = (next_phase_k == 0) ? smem_a_0 : ((next_phase_k == 1) ? smem_a_1 : smem_a_2);
half *smem_b_next = (next_phase_k == 0) ? smem_b_0 : ((next_phase_k == 1) ? smem_b_1 : smem_b_2);
int store_phase_k = (block_k+2) % 3;
half *smem_a_store = (store_phase_k == 0) ? smem_a_0 : ((store_phase_k == 1) ? smem_a_1 : smem_a_2);
half *smem_b_store = (store_phase_k == 0) ? smem_b_0 : ((store_phase_k == 1) ? smem_b_1 : smem_b_2);
// load K=1 elements for the current tile
__ldmatrix_a_elems(&a_frag_0_k_1, &smem_a_curr[load_smem_a_0_k_1]);
__ldmatrix_a_elems(&a_frag_1_k_1, &smem_a_curr[load_smem_a_1_k_1]);
__ldmatrix_a_elems(&a_frag_2_k_1, &smem_a_curr[load_smem_a_2_k_1]);
__ldmatrix_a_elems(&a_frag_3_k_1, &smem_a_curr[load_smem_a_3_k_1]);
__ldmatrix_b_elems(&b_frag_0_k_1, &b_frag_1_k_1, &smem_b_curr[load_smem_b_0_k_1]);
__ldmatrix_b_elems(&b_frag_2_k_1, &b_frag_3_k_1, &smem_b_curr[load_smem_b_1_k_1]);
__ldmatrix_b_elems(&b_frag_4_k_1, &b_frag_5_k_1, &smem_b_curr[load_smem_b_2_k_1]);
__ldmatrix_b_elems(&b_frag_6_k_1, &b_frag_7_k_1, &smem_b_curr[load_smem_b_3_k_1]);
// MMA K=0, (M=4 x N=8)
acc_frag_0_0 = __WMMA_8_16_16_half_half(a_frag_0_k_0, b_frag_0_k_0, acc_frag_0_0);
acc_frag_0_1 = __WMMA_8_16_16_half_half(a_frag_0_k_0, b_frag_1_k_0, acc_frag_0_1);
acc_frag_0_2 = __WMMA_8_16_16_half_half(a_frag_0_k_0, b_frag_2_k_0, acc_frag_0_2);
acc_frag_0_3 = __WMMA_8_16_16_half_half(a_frag_0_k_0, b_frag_3_k_0, acc_frag_0_3);
acc_frag_0_4 = __WMMA_8_16_16_half_half(a_frag_0_k_0, b_frag_4_k_0, acc_frag_0_4);
acc_frag_0_5 = __WMMA_8_16_16_half_half(a_frag_0_k_0, b_frag_5_k_0, acc_frag_0_5);
acc_frag_0_6 = __WMMA_8_16_16_half_half(a_frag_0_k_0, b_frag_6_k_0, acc_frag_0_6);
acc_frag_0_7 = __WMMA_8_16_16_half_half(a_frag_0_k_0, b_frag_7_k_0, acc_frag_0_7);
acc_frag_1_0 = __WMMA_8_16_16_half_half(a_frag_1_k_0, b_frag_0_k_0, acc_frag_1_0);
acc_frag_1_1 = __WMMA_8_16_16_half_half(a_frag_1_k_0, b_frag_1_k_0, acc_frag_1_1);
acc_frag_1_2 = __WMMA_8_16_16_half_half(a_frag_1_k_0, b_frag_2_k_0, acc_frag_1_2);
acc_frag_1_3 = __WMMA_8_16_16_half_half(a_frag_1_k_0, b_frag_3_k_0, acc_frag_1_3);
acc_frag_1_4 = __WMMA_8_16_16_half_half(a_frag_1_k_0, b_frag_4_k_0, acc_frag_1_4);
acc_frag_1_5 = __WMMA_8_16_16_half_half(a_frag_1_k_0, b_frag_5_k_0, acc_frag_1_5);
acc_frag_1_6 = __WMMA_8_16_16_half_half(a_frag_1_k_0, b_frag_6_k_0, acc_frag_1_6);
acc_frag_1_7 = __WMMA_8_16_16_half_half(a_frag_1_k_0, b_frag_7_k_0, acc_frag_1_7);
acc_frag_2_0 = __WMMA_8_16_16_half_half(a_frag_2_k_0, b_frag_0_k_0, acc_frag_2_0);
acc_frag_2_1 = __WMMA_8_16_16_half_half(a_frag_2_k_0, b_frag_1_k_0, acc_frag_2_1);
acc_frag_2_2 = __WMMA_8_16_16_half_half(a_frag_2_k_0, b_frag_2_k_0, acc_frag_2_2);
acc_frag_2_3 = __WMMA_8_16_16_half_half(a_frag_2_k_0, b_frag_3_k_0, acc_frag_2_3);
acc_frag_2_4 = __WMMA_8_16_16_half_half(a_frag_2_k_0, b_frag_4_k_0, acc_frag_2_4);
acc_frag_2_5 = __WMMA_8_16_16_half_half(a_frag_2_k_0, b_frag_5_k_0, acc_frag_2_5);
acc_frag_2_6 = __WMMA_8_16_16_half_half(a_frag_2_k_0, b_frag_6_k_0, acc_frag_2_6);
acc_frag_2_7 = __WMMA_8_16_16_half_half(a_frag_2_k_0, b_frag_7_k_0, acc_frag_2_7);
acc_frag_3_0 = __WMMA_8_16_16_half_half(a_frag_3_k_0, b_frag_0_k_0, acc_frag_3_0);
acc_frag_3_1 = __WMMA_8_16_16_half_half(a_frag_3_k_0, b_frag_1_k_0, acc_frag_3_1);
acc_frag_3_2 = __WMMA_8_16_16_half_half(a_frag_3_k_0, b_frag_2_k_0, acc_frag_3_2);
acc_frag_3_3 = __WMMA_8_16_16_half_half(a_frag_3_k_0, b_frag_3_k_0, acc_frag_3_3);
acc_frag_3_4 = __WMMA_8_16_16_half_half(a_frag_3_k_0, b_frag_4_k_0, acc_frag_3_4);
acc_frag_3_5 = __WMMA_8_16_16_half_half(a_frag_3_k_0, b_frag_5_k_0, acc_frag_3_5);
acc_frag_3_6 = __WMMA_8_16_16_half_half(a_frag_3_k_0, b_frag_6_k_0, acc_frag_3_6);
acc_frag_3_7 = __WMMA_8_16_16_half_half(a_frag_3_k_0, b_frag_7_k_0, acc_frag_3_7);
// load next tile if needed
if (block_k < (num_k_blocks-2)) {
__pipeline_memcpy_async(&smem_a_store[store_smem_a_off + ( 0)], &data1[global_a_off + ( 0)], 16);
__pipeline_memcpy_async(&smem_a_store[store_smem_a_off + ( 32*64)], &data1[global_a_off + ( 32*K)], 16);
__pipeline_memcpy_async(&smem_a_store[store_smem_a_off + ( 64*64)], &data1[global_a_off + ( 64*K)], 16);
__pipeline_memcpy_async(&smem_a_store[store_smem_a_off + ( 96*64)], &data1[global_a_off + ( 96*K)], 16);
__pipeline_memcpy_async(&smem_b_store[store_smem_b_off + ( 0)], &data2[global_b_off + ( 0)], 16);
__pipeline_memcpy_async(&smem_b_store[store_smem_b_off + ( 32*64)], &data2[global_b_off + ( 16*N)], 16);
global_a_off += 32;
global_b_off += 32 * N;
}
__pipeline_commit();
// wait next tile
__pipeline_wait_prior(1);
__syncthreads();
// load K=0 elements for the next tile
__ldmatrix_a_elems(&a_frag_0_k_0, &smem_a_next[load_smem_a_0_k_0]);
__ldmatrix_a_elems(&a_frag_1_k_0, &smem_a_next[load_smem_a_1_k_0]);
__ldmatrix_a_elems(&a_frag_2_k_0, &smem_a_next[load_smem_a_2_k_0]);
__ldmatrix_a_elems(&a_frag_3_k_0, &smem_a_next[load_smem_a_3_k_0]);
__ldmatrix_b_elems(&b_frag_0_k_0, &b_frag_1_k_0, &smem_b_next[load_smem_b_0_k_0]);
__ldmatrix_b_elems(&b_frag_2_k_0, &b_frag_3_k_0, &smem_b_next[load_smem_b_1_k_0]);
__ldmatrix_b_elems(&b_frag_4_k_0, &b_frag_5_k_0, &smem_b_next[load_smem_b_2_k_0]);
__ldmatrix_b_elems(&b_frag_6_k_0, &b_frag_7_k_0, &smem_b_next[load_smem_b_3_k_0]);
// MMA K=1, (M=4 x N=8)
acc_frag_0_0 = __WMMA_8_16_16_half_half(a_frag_0_k_1, b_frag_0_k_1, acc_frag_0_0);
acc_frag_0_1 = __WMMA_8_16_16_half_half(a_frag_0_k_1, b_frag_1_k_1, acc_frag_0_1);
acc_frag_0_2 = __WMMA_8_16_16_half_half(a_frag_0_k_1, b_frag_2_k_1, acc_frag_0_2);
acc_frag_0_3 = __WMMA_8_16_16_half_half(a_frag_0_k_1, b_frag_3_k_1, acc_frag_0_3);
acc_frag_0_4 = __WMMA_8_16_16_half_half(a_frag_0_k_1, b_frag_4_k_1, acc_frag_0_4);
acc_frag_0_5 = __WMMA_8_16_16_half_half(a_frag_0_k_1, b_frag_5_k_1, acc_frag_0_5);
acc_frag_0_6 = __WMMA_8_16_16_half_half(a_frag_0_k_1, b_frag_6_k_1, acc_frag_0_6);
acc_frag_0_7 = __WMMA_8_16_16_half_half(a_frag_0_k_1, b_frag_7_k_1, acc_frag_0_7);
acc_frag_1_0 = __WMMA_8_16_16_half_half(a_frag_1_k_1, b_frag_0_k_1, acc_frag_1_0);
acc_frag_1_1 = __WMMA_8_16_16_half_half(a_frag_1_k_1, b_frag_1_k_1, acc_frag_1_1);
acc_frag_1_2 = __WMMA_8_16_16_half_half(a_frag_1_k_1, b_frag_2_k_1, acc_frag_1_2);
acc_frag_1_3 = __WMMA_8_16_16_half_half(a_frag_1_k_1, b_frag_3_k_1, acc_frag_1_3);
acc_frag_1_4 = __WMMA_8_16_16_half_half(a_frag_1_k_1, b_frag_4_k_1, acc_frag_1_4);
acc_frag_1_5 = __WMMA_8_16_16_half_half(a_frag_1_k_1, b_frag_5_k_1, acc_frag_1_5);
acc_frag_1_6 = __WMMA_8_16_16_half_half(a_frag_1_k_1, b_frag_6_k_1, acc_frag_1_6);
acc_frag_1_7 = __WMMA_8_16_16_half_half(a_frag_1_k_1, b_frag_7_k_1, acc_frag_1_7);
acc_frag_2_0 = __WMMA_8_16_16_half_half(a_frag_2_k_1, b_frag_0_k_1, acc_frag_2_0);
acc_frag_2_1 = __WMMA_8_16_16_half_half(a_frag_2_k_1, b_frag_1_k_1, acc_frag_2_1);
acc_frag_2_2 = __WMMA_8_16_16_half_half(a_frag_2_k_1, b_frag_2_k_1, acc_frag_2_2);
acc_frag_2_3 = __WMMA_8_16_16_half_half(a_frag_2_k_1, b_frag_3_k_1, acc_frag_2_3);
acc_frag_2_4 = __WMMA_8_16_16_half_half(a_frag_2_k_1, b_frag_4_k_1, acc_frag_2_4);
acc_frag_2_5 = __WMMA_8_16_16_half_half(a_frag_2_k_1, b_frag_5_k_1, acc_frag_2_5);
acc_frag_2_6 = __WMMA_8_16_16_half_half(a_frag_2_k_1, b_frag_6_k_1, acc_frag_2_6);
acc_frag_2_7 = __WMMA_8_16_16_half_half(a_frag_2_k_1, b_frag_7_k_1, acc_frag_2_7);
acc_frag_3_0 = __WMMA_8_16_16_half_half(a_frag_3_k_1, b_frag_0_k_1, acc_frag_3_0);
acc_frag_3_1 = __WMMA_8_16_16_half_half(a_frag_3_k_1, b_frag_1_k_1, acc_frag_3_1);
acc_frag_3_2 = __WMMA_8_16_16_half_half(a_frag_3_k_1, b_frag_2_k_1, acc_frag_3_2);
acc_frag_3_3 = __WMMA_8_16_16_half_half(a_frag_3_k_1, b_frag_3_k_1, acc_frag_3_3);
acc_frag_3_4 = __WMMA_8_16_16_half_half(a_frag_3_k_1, b_frag_4_k_1, acc_frag_3_4);
acc_frag_3_5 = __WMMA_8_16_16_half_half(a_frag_3_k_1, b_frag_5_k_1, acc_frag_3_5);
acc_frag_3_6 = __WMMA_8_16_16_half_half(a_frag_3_k_1, b_frag_6_k_1, acc_frag_3_6);
acc_frag_3_7 = __WMMA_8_16_16_half_half(a_frag_3_k_1, b_frag_7_k_1, acc_frag_3_7);
}
// write accumulators to output
__pipeline_wait_prior(0);
__syncthreads();
// faster epilogue: write each 8x8 TC accs to SMEM first
// - SMEM_N_WIDTH 8 larger than 128 required to deconflict bank access
// - around 14 micros
// - check bank conflict with in sudo with: "PYTHONPATH=. CUDA=1 GEMM_VARIATION="max" DTYPE_IN=half DTYPE_OUT=half DTYPE_ACC=half CNT=8 INPUT=ONES /usr/local/cuda/bin/ncu --section MemoryWorkloadAnalysis --metrics l1tex__data_bank_conflicts_pipe_lsu_mem_shared_op_ld.sum,l1tex__data_bank_conflicts_pipe_lsu_mem_shared_op_st.sum python3 ./extra/gemm/max_matmul.py"
// epilogue chunk with 256 threads / WG_M=4 / WG_N=2: split into 8 chunks (hi/lo for each in TC M)
// 1) write 32 rows of 128 cols (rows 0-7, 16-23, 32-39, 48-53 in acc_frag_0.lo, then acc_frag_0.hi, etc.)
// 2) read/write 16 rows of 128 elements in 8 elem (16B) chunks
half2 *smem32_d = (half2 *)(smem);
half8 *smem128_d = (half8 *)(smem);
half8 *out128_d = (half8 *)(data0);
size_t smem32_d_write_off = (wg_m * 8 * (SMEM_N_WIDTH / 2)) + (wg_n * (16 / 2));
size_t smem32_d_thread_off = ((wg_threads / 4) * (SMEM_N_WIDTH / 2)) + (wg_threads % 4);
size_t smem128_d_read_off = ((threads / 16) * (SMEM_N_WIDTH / 8)) + (threads % 16);
size_t out128_d_off = ((grid_m * 256) * (N / 8)) + (grid_n * (128 / 8)) +
((threads / 128) * 16 * (N / 8)) + (((threads / 16) % 8) * (N / 8)) + (threads % 16);
// write acc_frag_0_*
// write 32 rows of 128 N elements to SMEM
__syncthreads();
smem32_d[smem32_d_write_off + smem32_d_thread_off + ( 0*4)] = half2(acc_frag_0_0.x, acc_frag_0_0.y);
smem32_d[smem32_d_write_off + smem32_d_thread_off + ( 1*4)] = half2(acc_frag_0_1.x, acc_frag_0_1.y);
smem32_d[smem32_d_write_off + smem32_d_thread_off + ( 4*4)] = half2(acc_frag_0_2.x, acc_frag_0_2.y);
smem32_d[smem32_d_write_off + smem32_d_thread_off + ( 5*4)] = half2(acc_frag_0_3.x, acc_frag_0_3.y);
smem32_d[smem32_d_write_off + smem32_d_thread_off + ( 8*4)] = half2(acc_frag_0_4.x, acc_frag_0_4.y);
smem32_d[smem32_d_write_off + smem32_d_thread_off + ( 9*4)] = half2(acc_frag_0_5.x, acc_frag_0_5.y);
smem32_d[smem32_d_write_off + smem32_d_thread_off + (12*4)] = half2(acc_frag_0_6.x, acc_frag_0_6.y);
smem32_d[smem32_d_write_off + smem32_d_thread_off + (13*4)] = half2(acc_frag_0_7.x, acc_frag_0_7.y);
// each thread reads and writes two 8 element chunks
__syncthreads();
out128_d[out128_d_off + ( 0 * (N / 8))] = smem128_d[smem128_d_read_off];
out128_d[out128_d_off + (32 * (N / 8))] = smem128_d[smem128_d_read_off + (16 * (SMEM_N_WIDTH / 8))];
// write 32 rows of 128 N elements to SMEM
__syncthreads();
smem32_d[smem32_d_write_off + smem32_d_thread_off + ( 0*4)] = half2(acc_frag_0_0.z, acc_frag_0_0.w);
smem32_d[smem32_d_write_off + smem32_d_thread_off + ( 1*4)] = half2(acc_frag_0_1.z, acc_frag_0_1.w);
smem32_d[smem32_d_write_off + smem32_d_thread_off + ( 4*4)] = half2(acc_frag_0_2.z, acc_frag_0_2.w);
smem32_d[smem32_d_write_off + smem32_d_thread_off + ( 5*4)] = half2(acc_frag_0_3.z, acc_frag_0_3.w);
smem32_d[smem32_d_write_off + smem32_d_thread_off + ( 8*4)] = half2(acc_frag_0_4.z, acc_frag_0_4.w);
smem32_d[smem32_d_write_off + smem32_d_thread_off + ( 9*4)] = half2(acc_frag_0_5.z, acc_frag_0_5.w);
smem32_d[smem32_d_write_off + smem32_d_thread_off + (12*4)] = half2(acc_frag_0_6.z, acc_frag_0_6.w);
smem32_d[smem32_d_write_off + smem32_d_thread_off + (13*4)] = half2(acc_frag_0_7.z, acc_frag_0_7.w);
// each thread reads and writes two 8 element chunks
__syncthreads();
out128_d[out128_d_off + ( 8 * (N / 8))] = smem128_d[smem128_d_read_off];
out128_d[out128_d_off + (40 * (N / 8))] = smem128_d[smem128_d_read_off + (16 * (SMEM_N_WIDTH / 8))];
// write acc_frag_1_*
out128_d_off += (64 * (N / 8));
// write 32 rows of 128 N elements to SMEM
__syncthreads();
smem32_d[smem32_d_write_off + smem32_d_thread_off + ( 0*4)] = half2(acc_frag_1_0.x, acc_frag_1_0.y);
smem32_d[smem32_d_write_off + smem32_d_thread_off + ( 1*4)] = half2(acc_frag_1_1.x, acc_frag_1_1.y);
smem32_d[smem32_d_write_off + smem32_d_thread_off + ( 4*4)] = half2(acc_frag_1_2.x, acc_frag_1_2.y);
smem32_d[smem32_d_write_off + smem32_d_thread_off + ( 5*4)] = half2(acc_frag_1_3.x, acc_frag_1_3.y);
smem32_d[smem32_d_write_off + smem32_d_thread_off + ( 8*4)] = half2(acc_frag_1_4.x, acc_frag_1_4.y);
smem32_d[smem32_d_write_off + smem32_d_thread_off + ( 9*4)] = half2(acc_frag_1_5.x, acc_frag_1_5.y);
smem32_d[smem32_d_write_off + smem32_d_thread_off + (12*4)] = half2(acc_frag_1_6.x, acc_frag_1_6.y);
smem32_d[smem32_d_write_off + smem32_d_thread_off + (13*4)] = half2(acc_frag_1_7.x, acc_frag_1_7.y);
// each thread reads and writes two 8 element chunks
__syncthreads();
out128_d[out128_d_off + ( 0 * (N / 8))] = smem128_d[smem128_d_read_off];
out128_d[out128_d_off + (32 * (N / 8))] = smem128_d[smem128_d_read_off + (16 * (SMEM_N_WIDTH / 8))];
// write 32 rows of 128 N elements to SMEM
__syncthreads();
smem32_d[smem32_d_write_off + smem32_d_thread_off + ( 0*4)] = half2(acc_frag_1_0.z, acc_frag_1_0.w);
smem32_d[smem32_d_write_off + smem32_d_thread_off + ( 1*4)] = half2(acc_frag_1_1.z, acc_frag_1_1.w);
smem32_d[smem32_d_write_off + smem32_d_thread_off + ( 4*4)] = half2(acc_frag_1_2.z, acc_frag_1_2.w);
smem32_d[smem32_d_write_off + smem32_d_thread_off + ( 5*4)] = half2(acc_frag_1_3.z, acc_frag_1_3.w);
smem32_d[smem32_d_write_off + smem32_d_thread_off + ( 8*4)] = half2(acc_frag_1_4.z, acc_frag_1_4.w);
smem32_d[smem32_d_write_off + smem32_d_thread_off + ( 9*4)] = half2(acc_frag_1_5.z, acc_frag_1_5.w);
smem32_d[smem32_d_write_off + smem32_d_thread_off + (12*4)] = half2(acc_frag_1_6.z, acc_frag_1_6.w);
smem32_d[smem32_d_write_off + smem32_d_thread_off + (13*4)] = half2(acc_frag_1_7.z, acc_frag_1_7.w);
// each thread reads and writes two 8 element chunks
__syncthreads();
out128_d[out128_d_off + ( 8 * (N / 8))] = smem128_d[smem128_d_read_off];
out128_d[out128_d_off + (40 * (N / 8))] = smem128_d[smem128_d_read_off + (16 * (SMEM_N_WIDTH / 8))];
// write acc_frag_2_*
out128_d_off += (64 * (N / 8));
// write 32 rows of 128 N elements to SMEM
__syncthreads();
smem32_d[smem32_d_write_off + smem32_d_thread_off + ( 0*4)] = half2(acc_frag_2_0.x, acc_frag_2_0.y);
smem32_d[smem32_d_write_off + smem32_d_thread_off + ( 1*4)] = half2(acc_frag_2_1.x, acc_frag_2_1.y);
smem32_d[smem32_d_write_off + smem32_d_thread_off + ( 4*4)] = half2(acc_frag_2_2.x, acc_frag_2_2.y);
smem32_d[smem32_d_write_off + smem32_d_thread_off + ( 5*4)] = half2(acc_frag_2_3.x, acc_frag_2_3.y);
smem32_d[smem32_d_write_off + smem32_d_thread_off + ( 8*4)] = half2(acc_frag_2_4.x, acc_frag_2_4.y);
smem32_d[smem32_d_write_off + smem32_d_thread_off + ( 9*4)] = half2(acc_frag_2_5.x, acc_frag_2_5.y);
smem32_d[smem32_d_write_off + smem32_d_thread_off + (12*4)] = half2(acc_frag_2_6.x, acc_frag_2_6.y);
smem32_d[smem32_d_write_off + smem32_d_thread_off + (13*4)] = half2(acc_frag_2_7.x, acc_frag_2_7.y);
// each thread reads and writes two 8 element chunks
__syncthreads();
out128_d[out128_d_off + ( 0 * (N / 8))] = smem128_d[smem128_d_read_off];
out128_d[out128_d_off + (32 * (N / 8))] = smem128_d[smem128_d_read_off + (16 * (SMEM_N_WIDTH / 8))];
// write 32 rows of 128 N elements to SMEM
__syncthreads();
smem32_d[smem32_d_write_off + smem32_d_thread_off + ( 0*4)] = half2(acc_frag_2_0.z, acc_frag_2_0.w);
smem32_d[smem32_d_write_off + smem32_d_thread_off + ( 1*4)] = half2(acc_frag_2_1.z, acc_frag_2_1.w);
smem32_d[smem32_d_write_off + smem32_d_thread_off + ( 4*4)] = half2(acc_frag_2_2.z, acc_frag_2_2.w);
smem32_d[smem32_d_write_off + smem32_d_thread_off + ( 5*4)] = half2(acc_frag_2_3.z, acc_frag_2_3.w);
smem32_d[smem32_d_write_off + smem32_d_thread_off + ( 8*4)] = half2(acc_frag_2_4.z, acc_frag_2_4.w);
smem32_d[smem32_d_write_off + smem32_d_thread_off + ( 9*4)] = half2(acc_frag_2_5.z, acc_frag_2_5.w);
smem32_d[smem32_d_write_off + smem32_d_thread_off + (12*4)] = half2(acc_frag_2_6.z, acc_frag_2_6.w);
smem32_d[smem32_d_write_off + smem32_d_thread_off + (13*4)] = half2(acc_frag_2_7.z, acc_frag_2_7.w);
// each thread reads and writes two 8 element chunks
__syncthreads();
out128_d[out128_d_off + ( 8 * (N / 8))] = smem128_d[smem128_d_read_off];
out128_d[out128_d_off + (40 * (N / 8))] = smem128_d[smem128_d_read_off + (16 * (SMEM_N_WIDTH / 8))];
// write acc_frag_3_*
out128_d_off += (64 * (N / 8));
// write 32 rows of 128 N elements to SMEM
__syncthreads();
smem32_d[smem32_d_write_off + smem32_d_thread_off + ( 0*4)] = half2(acc_frag_3_0.x, acc_frag_3_0.y);
smem32_d[smem32_d_write_off + smem32_d_thread_off + ( 1*4)] = half2(acc_frag_3_1.x, acc_frag_3_1.y);
smem32_d[smem32_d_write_off + smem32_d_thread_off + ( 4*4)] = half2(acc_frag_3_2.x, acc_frag_3_2.y);
smem32_d[smem32_d_write_off + smem32_d_thread_off + ( 5*4)] = half2(acc_frag_3_3.x, acc_frag_3_3.y);
smem32_d[smem32_d_write_off + smem32_d_thread_off + ( 8*4)] = half2(acc_frag_3_4.x, acc_frag_3_4.y);
smem32_d[smem32_d_write_off + smem32_d_thread_off + ( 9*4)] = half2(acc_frag_3_5.x, acc_frag_3_5.y);
smem32_d[smem32_d_write_off + smem32_d_thread_off + (12*4)] = half2(acc_frag_3_6.x, acc_frag_3_6.y);
smem32_d[smem32_d_write_off + smem32_d_thread_off + (13*4)] = half2(acc_frag_3_7.x, acc_frag_3_7.y);
// each thread reads and writes two 8 element chunks
__syncthreads();
out128_d[out128_d_off + ( 0 * (N / 8))] = smem128_d[smem128_d_read_off];
out128_d[out128_d_off + (32 * (N / 8))] = smem128_d[smem128_d_read_off + (16 * (SMEM_N_WIDTH / 8))];
// write 32 rows of 128 N elements to SMEM
__syncthreads();
smem32_d[smem32_d_write_off + smem32_d_thread_off + ( 0*4)] = half2(acc_frag_3_0.z, acc_frag_3_0.w);
smem32_d[smem32_d_write_off + smem32_d_thread_off + ( 1*4)] = half2(acc_frag_3_1.z, acc_frag_3_1.w);
smem32_d[smem32_d_write_off + smem32_d_thread_off + ( 4*4)] = half2(acc_frag_3_2.z, acc_frag_3_2.w);
smem32_d[smem32_d_write_off + smem32_d_thread_off + ( 5*4)] = half2(acc_frag_3_3.z, acc_frag_3_3.w);
smem32_d[smem32_d_write_off + smem32_d_thread_off + ( 8*4)] = half2(acc_frag_3_4.z, acc_frag_3_4.w);
smem32_d[smem32_d_write_off + smem32_d_thread_off + ( 9*4)] = half2(acc_frag_3_5.z, acc_frag_3_5.w);
smem32_d[smem32_d_write_off + smem32_d_thread_off + (12*4)] = half2(acc_frag_3_6.z, acc_frag_3_6.w);
smem32_d[smem32_d_write_off + smem32_d_thread_off + (13*4)] = half2(acc_frag_3_7.z, acc_frag_3_7.w);
// each thread reads and writes two 8 element chunks
__syncthreads();
out128_d[out128_d_off + ( 8 * (N / 8))] = smem128_d[smem128_d_read_off];
out128_d[out128_d_off + (40 * (N / 8))] = smem128_d[smem128_d_read_off + (16 * (SMEM_N_WIDTH / 8))];
__syncthreads();
}

View File

@@ -0,0 +1,157 @@
#define INFINITY (__int_as_float(0x7f800000))
#define NAN (__int_as_float(0x7fffffff))
#include <cuda_fp16.h>
struct __align__(8) half4 { half x, y, z, w; }; __device__ half4 make_half4(half x, half y, half z, half w) { half4 r={x, y, z, w}; return r; }
struct __align__(16) half8 { half x, y, z, w, a, b, c, d; }; __device__ half8 make_half8(half x, half y, half z, half w, half a, half b, half c, half d) { half8 r={x, y, z, w, a, b, c, d}; return r; }
__device__ float4 __WMMA_8_16_16_half_float(half8 a, half4 b, float4 c) { int *a_pk = (int *) (&a), *b_pk = (int *) (&b);
asm( "mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 { %0, %1, %2, %3 }, { %4, %5, %6, %7 }, { %8, %9 }, { %0, %1, %2, %3 };"
: "+f"(c.x), "+f"(c.y), "+f"(c.z), "+f"(c.w) : "r"(a_pk[0]), "r"(a_pk[1]), "r"(a_pk[2]), "r"(a_pk[3]), "r"(b_pk[0]), "r"(b_pk[1]) );
return c;}
extern "C" __global__ void __launch_bounds__(128) wmma_example(half* data0, const half* data1, const half* data2) {
int gidx0 = blockIdx.x; /* 32 */
int gidx1 = blockIdx.y; /* 64 */
int lidx0 = threadIdx.x; /* 16 */
int lidx1 = threadIdx.y; /* 2 */
int lidx2 = threadIdx.z; /* 4 */
float4 cast0 = make_float4(0.0f,0.0f,0.0f,0.0f);
int alu0 = (gidx0*128);
int alu1 = (gidx1*262144);
int alu2 = (lidx1*32768);
int alu3 = (lidx2*32);
int alu4 = (lidx0/8);
int alu5 = (alu4*16384);
int alu6 = (lidx0%2);
int alu7 = (alu6*2);
int alu8 = ((lidx0/2)%2);
int alu9 = (alu8*4);
int alu10 = ((lidx0/4)%2);
int alu11 = (alu10*8192);
int alu12 = (alu1+alu0+alu7+alu9+alu11+alu5+alu2+alu3);
int alu13 = (alu1+alu7+alu9+alu11+alu5+alu2);
float4 acc0 = cast0;
float4 acc1 = cast0;
float4 acc2 = cast0;
float4 acc3 = cast0;
float4 acc4 = cast0;
float4 acc5 = cast0;
float4 acc6 = cast0;
float4 acc7 = cast0;
float4 acc8 = cast0;
float4 acc9 = cast0;
float4 acc10 = cast0;
float4 acc11 = cast0;
float4 acc12 = cast0;
float4 acc13 = cast0;
float4 acc14 = cast0;
float4 acc15 = cast0;
for (int ridx0 = 0; ridx0 < 256; ridx0++) {
int alu14 = (ridx0*16);
int alu15 = (alu13+alu14);
int alu16 = (alu14+alu13);
int alu17 = (alu0+(alu6*8192)+(alu8*16384)+alu10+(alu4*2)+(lidx1*4)+alu3+(ridx0*65536));
half val0 = data2[alu17+8];
half val1 = data2[alu17+16];
half val2 = data2[alu17+24];
half val3 = data2[alu17+4096];
half val4 = data2[alu17+4104];
half val5 = data2[alu17+4112];
half val6 = data2[alu17+4120];
half val7 = data2[alu17+32768];
half val8 = data2[alu17+32776];
half val9 = data2[alu17+32784];
half val10 = data2[alu17+32792];
half val11 = data2[alu17+36864];
half val12 = data2[alu17+36872];
half4 cast1 = make_half4(val0,val4,val8,val12);
half val13 = data2[alu17+36880];
half4 cast2 = make_half4(val1,val5,val9,val13);
half val14 = data2[alu17+36888];
half4 cast3 = make_half4(val2,val6,val10,val14);
half val15 = data2[alu17];
half4 cast4 = make_half4(val15,val3,val7,val11);
half2 val16 = *((half2*)(data1+alu15+4096));
half2 val17 = *((half2*)(data1+alu15+65536));
half2 val18 = *((half2*)(data1+alu15+69632));
half2 val19 = *((half2*)(data1+alu15+131072));
half2 val20 = *((half2*)(data1+alu15+135168));
half2 val21 = *((half2*)(data1+alu15+196608));
half2 val22 = *((half2*)(data1+alu15+200704));
half2 val23 = *((half2*)(data1+alu15));
half2 val24 = *((half2*)(data1+alu16+8));
half2 val25 = *((half2*)(data1+alu16+4104));
half8 cast5 = make_half8(val23.x,val23.y,val16.x,val16.y,val24.x,val24.y,val25.x,val25.y);
float4 wmma0 = __WMMA_8_16_16_half_float(cast5, cast1, acc1);
float4 wmma1 = __WMMA_8_16_16_half_float(cast5, cast2, acc2);
float4 wmma2 = __WMMA_8_16_16_half_float(cast5, cast3, acc3);
float4 wmma3 = __WMMA_8_16_16_half_float(cast5, cast4, acc0);
half2 val26 = *((half2*)(data1+alu16+65544));
half2 val27 = *((half2*)(data1+alu16+69640));
half8 cast6 = make_half8(val17.x,val17.y,val18.x,val18.y,val26.x,val26.y,val27.x,val27.y);
float4 wmma4 = __WMMA_8_16_16_half_float(cast6, cast1, acc5);
float4 wmma5 = __WMMA_8_16_16_half_float(cast6, cast2, acc6);
float4 wmma6 = __WMMA_8_16_16_half_float(cast6, cast3, acc7);
float4 wmma7 = __WMMA_8_16_16_half_float(cast6, cast4, acc4);
half2 val28 = *((half2*)(data1+alu16+131080));
half2 val29 = *((half2*)(data1+alu16+135176));
half8 cast7 = make_half8(val19.x,val19.y,val20.x,val20.y,val28.x,val28.y,val29.x,val29.y);
float4 wmma8 = __WMMA_8_16_16_half_float(cast7, cast1, acc9);
float4 wmma9 = __WMMA_8_16_16_half_float(cast7, cast2, acc10);
float4 wmma10 = __WMMA_8_16_16_half_float(cast7, cast3, acc11);
float4 wmma11 = __WMMA_8_16_16_half_float(cast7, cast4, acc8);
half2 val30 = *((half2*)(data1+alu16+196616));
half2 val31 = *((half2*)(data1+alu16+200712));
half8 cast8 = make_half8(val21.x,val21.y,val22.x,val22.y,val30.x,val30.y,val31.x,val31.y);
float4 wmma12 = __WMMA_8_16_16_half_float(cast8, cast1, acc13);
float4 wmma13 = __WMMA_8_16_16_half_float(cast8, cast2, acc14);
float4 wmma14 = __WMMA_8_16_16_half_float(cast8, cast3, acc15);
float4 wmma15 = __WMMA_8_16_16_half_float(cast8, cast4, acc12);
acc0 = wmma3;
acc1 = wmma0;
acc2 = wmma1;
acc3 = wmma2;
acc4 = wmma7;
acc5 = wmma4;
acc6 = wmma5;
acc7 = wmma6;
acc8 = wmma11;
acc9 = wmma8;
acc10 = wmma9;
acc11 = wmma10;
acc12 = wmma15;
acc13 = wmma12;
acc14 = wmma13;
acc15 = wmma14;
}
*((half2*)(data0+alu12+8)) = make_half2((half)(acc1.x),(half)(acc1.y));
*((half2*)(data0+alu12+16)) = make_half2((half)(acc2.x),(half)(acc2.y));
*((half2*)(data0+alu12+24)) = make_half2((half)(acc3.x),(half)(acc3.y));
*((half2*)(data0+alu12+4096)) = make_half2((half)(acc0.z),(half)(acc0.w));
*((half2*)(data0+alu12+4104)) = make_half2((half)(acc1.z),(half)(acc1.w));
*((half2*)(data0+alu12+4112)) = make_half2((half)(acc2.z),(half)(acc2.w));
*((half2*)(data0+alu12+4120)) = make_half2((half)(acc3.z),(half)(acc3.w));
*((half2*)(data0+alu12+65536)) = make_half2((half)(acc4.x),(half)(acc4.y));
*((half2*)(data0+alu12+65544)) = make_half2((half)(acc5.x),(half)(acc5.y));
*((half2*)(data0+alu12+65552)) = make_half2((half)(acc6.x),(half)(acc6.y));
*((half2*)(data0+alu12+65560)) = make_half2((half)(acc7.x),(half)(acc7.y));
*((half2*)(data0+alu12+69632)) = make_half2((half)(acc4.z),(half)(acc4.w));
*((half2*)(data0+alu12+69640)) = make_half2((half)(acc5.z),(half)(acc5.w));
*((half2*)(data0+alu12+69648)) = make_half2((half)(acc6.z),(half)(acc6.w));
*((half2*)(data0+alu12+69656)) = make_half2((half)(acc7.z),(half)(acc7.w));
*((half2*)(data0+alu12+131072)) = make_half2((half)(acc8.x),(half)(acc8.y));
*((half2*)(data0+alu12+131080)) = make_half2((half)(acc9.x),(half)(acc9.y));
*((half2*)(data0+alu12+131088)) = make_half2((half)(acc10.x),(half)(acc10.y));
*((half2*)(data0+alu12+131096)) = make_half2((half)(acc11.x),(half)(acc11.y));
*((half2*)(data0+alu12+135168)) = make_half2((half)(acc8.z),(half)(acc8.w));
*((half2*)(data0+alu12+135176)) = make_half2((half)(acc9.z),(half)(acc9.w));
*((half2*)(data0+alu12+135184)) = make_half2((half)(acc10.z),(half)(acc10.w));
*((half2*)(data0+alu12+135192)) = make_half2((half)(acc11.z),(half)(acc11.w));
*((half2*)(data0+alu12+196608)) = make_half2((half)(acc12.x),(half)(acc12.y));
*((half2*)(data0+alu12+196616)) = make_half2((half)(acc13.x),(half)(acc13.y));
*((half2*)(data0+alu12+196624)) = make_half2((half)(acc14.x),(half)(acc14.y));
*((half2*)(data0+alu12+196632)) = make_half2((half)(acc15.x),(half)(acc15.y));
*((half2*)(data0+alu12+200704)) = make_half2((half)(acc12.z),(half)(acc12.w));
*((half2*)(data0+alu12+200712)) = make_half2((half)(acc13.z),(half)(acc13.w));
*((half2*)(data0+alu12+200720)) = make_half2((half)(acc14.z),(half)(acc14.w));
*((half2*)(data0+alu12+200728)) = make_half2((half)(acc15.z),(half)(acc15.w));
*((half2*)(data0+alu12)) = make_half2((half)(acc0.x),(half)(acc0.y));
}

View File

@@ -0,0 +1,398 @@
#define INFINITY (__int_as_float(0x7f800000))
#define NAN (__int_as_float(0x7fffffff))
#include <cuda_fp16.h>
#include <cuda_pipeline.h>
#define N_PAD 132
struct __align__(8) half4 { half x, y, z, w; };
__device__ half4 make_half4(half x, half y, half z, half w) { half4 r={x, y, z, w}; return r; }
struct __align__(16) half8 { half x, y, z, w, a, b, c, d; };
__device__ half8 make_half8(half x, half y, half z, half w, half a, half b, half c, half d) { half8 r={x, y, z, w, a, b, c, d}; return r; }
__device__ void __ldmatrix_a_elems(half8 *regs, half *smem) {
uint32_t reg0, reg1, reg2, reg3;
asm volatile(
"ldmatrix.sync.aligned.m8n8.x4.shared.b16 {%0, %1, %2, %3}, [%4];"
: "=r"(reg0), "=r"(reg1), "=r"(reg2), "=r"(reg3)
: "l"(__cvta_generic_to_shared(smem))
);
uint32_t *addr = reinterpret_cast<uint32_t*>(regs);
addr[0] = reg0;
addr[1] = reg1;
addr[2] = reg2;
addr[3] = reg3;
}
__device__ void __ldmatrix_b_elems(half4 *regs_lo, half4 *regs_hi, half *smem) {
uint32_t reg0, reg1, reg2, reg3;
asm volatile(
"ldmatrix.sync.aligned.m8n8.x4.trans.shared.b16 {%0, %1, %2, %3}, [%4];"
: "=r"(reg0), "=r"(reg1), "=r"(reg2), "=r"(reg3)
: "l"(__cvta_generic_to_shared(smem))
);
uint32_t *addr_lo = reinterpret_cast<uint32_t*>(regs_lo);
uint32_t *addr_hi = reinterpret_cast<uint32_t*>(regs_hi);
addr_lo[0] = reg0;
addr_lo[1] = reg1;
addr_hi[0] = reg2;
addr_hi[1] = reg3;
}
__device__ float4 __WMMA_8_16_16_half_float(half8 a, half4 b, float4 c) {
int *a_pk = (int *) (&a), *b_pk = (int *) (&b);
asm( "mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 { %0, %1, %2, %3 }, { %4, %5, %6, %7 }, { %8, %9 }, { %0, %1, %2, %3 };"
: "+f"(c.x), "+f"(c.y), "+f"(c.z), "+f"(c.w) : "r"(a_pk[0]), "r"(a_pk[1]), "r"(a_pk[2]), "r"(a_pk[3]), "r"(b_pk[0]), "r"(b_pk[1]) );
return c;
}
extern "C" __global__ void __launch_bounds__(128) wmma_example(float* data0, const half* data1, const half* data2, int N, int K) {
int grid_m = blockIdx.x; /* M//64 */
int grid_n = blockIdx.y; /* N//128 */
int threads = threadIdx.x; /* 128 */
int wg_m = (threads/64); // 0 or 1 for 1st and 3rd blocks of b_m=16xb_k=16 vs 2nd and 4th blocks
int wg_n = (threads/32)%2; // 0 or 1 for 1st, 3rd, 5th, 7th blocks of b_n=16xb_k=16 vs 2nd, 4th, 6th, 8th blocks - differs from triton
int wg_threads = threads%32;
int num_k_blocks = K / 64;
// load indexes
size_t global_a_off = ((grid_m * 64) * K) + ((threads % 8) * 8) + ((threads / 8) * K);
size_t global_b_off = (grid_n * 128) + ((threads % 16) * 8) + ((threads / 16) * N);
// swizzled smem store offsets - columns of smem are swizzled
// here's a link to a description of the triton: https://github.com/triton-lang/triton/discussions/2026#discussioncomment-6746579
// see also the thunderkittens impl: https://github.com/HazyResearch/ThunderKittens/blob/main/include/types/shared/st.cuh
size_t store_smem_a_off = ((threads / 8) * 64) + (((threads * 8) ^ threads) & 56); // r15
size_t store_smem_b_off = ((threads / 16) * 128) + (((threads / 16) * 8) ^ ((threads % 16) * 8)); // r19
// ldmatrix indices
// threads 0-7 are row starts for A, 8-15 for B, 16-23 for C, 24-31 for D
// [ A | C ]
// [ - + - ]
// [ B | D ]
// swizzled ldmatrix
size_t load_smem_a_row = ((wg_m * 16) + (threads % 16)) * 64; // r293
size_t load_smem_a_phase = (threads / 16) % 2; // r4
size_t load_smem_b_row = (threads % 16) * 128; // r299
size_t load_smem_b_phase = (wg_n * 2) + (((threads / 16) % 2)); // r297 -- this differs from the generated triton kernel (swapped order)
size_t load_smem_a_0_k_0 = load_smem_a_row + (((load_smem_a_phase + 0) ^ (threads % 8)) * 8); // r38
size_t load_smem_a_1_k_0 = load_smem_a_0_k_0 + (32 * 64);
size_t load_smem_b_0_k_0 = load_smem_b_row + (((load_smem_b_phase + 0) ^ (threads % 8)) * 8);
size_t load_smem_b_1_k_0 = load_smem_b_row + (((load_smem_b_phase + 4) ^ (threads % 8)) * 8);
size_t load_smem_b_2_k_0 = load_smem_b_row + (((load_smem_b_phase + 8) ^ (threads % 8)) * 8);
size_t load_smem_b_3_k_0 = load_smem_b_row + (((load_smem_b_phase + 12) ^ (threads % 8)) * 8);
size_t load_smem_a_0_k_1 = load_smem_a_row + (((load_smem_a_phase + 2) ^ (threads % 8)) * 8); // r58 = r293 + r316;
size_t load_smem_a_1_k_1 = load_smem_a_0_k_1 + (32 * 64);
size_t load_smem_b_0_k_1 = load_smem_b_0_k_0 + (16 * 128);
size_t load_smem_b_1_k_1 = load_smem_b_1_k_0 + (16 * 128);
size_t load_smem_b_2_k_1 = load_smem_b_2_k_0 + (16 * 128);
size_t load_smem_b_3_k_1 = load_smem_b_3_k_0 + (16 * 128);
size_t load_smem_a_0_k_2 = load_smem_a_row + (((load_smem_a_phase + 4) ^ (threads % 8)) * 8); // r59 = r293 + r319;
size_t load_smem_a_1_k_2 = load_smem_a_0_k_2 + (32 * 64);
size_t load_smem_b_0_k_2 = load_smem_b_0_k_0 + (32 * 128);
size_t load_smem_b_1_k_2 = load_smem_b_1_k_0 + (32 * 128);
size_t load_smem_b_2_k_2 = load_smem_b_2_k_0 + (32 * 128);
size_t load_smem_b_3_k_2 = load_smem_b_3_k_0 + (32 * 128);
size_t load_smem_a_0_k_3 = load_smem_a_row + (((load_smem_a_phase + 6) ^ (threads % 8)) * 8); // r60 = r293 + r322;
size_t load_smem_a_1_k_3 = load_smem_a_0_k_3 + (32 * 64);
size_t load_smem_b_0_k_3 = load_smem_b_0_k_0 + (48 * 128);
size_t load_smem_b_1_k_3 = load_smem_b_1_k_0 + (48 * 128);
size_t load_smem_b_2_k_3 = load_smem_b_2_k_0 + (48 * 128);
size_t load_smem_b_3_k_3 = load_smem_b_3_k_0 + (48 * 128);
// create shared mem (A_1 8192 bytes, A_2 8192 bytes, B_1 16384 bytes, B2_16384 bytes)
__shared__ alignas(16) char smem[49152];
// create accs (16 WMMAs and 4 output elements each) and zero
float4 acc_frag_0_0 = make_float4(0.0f,0.0f,0.0f,0.0f);
float4 acc_frag_0_1 = make_float4(0.0f,0.0f,0.0f,0.0f);
float4 acc_frag_0_2 = make_float4(0.0f,0.0f,0.0f,0.0f);
float4 acc_frag_0_3 = make_float4(0.0f,0.0f,0.0f,0.0f);
float4 acc_frag_0_4 = make_float4(0.0f,0.0f,0.0f,0.0f);
float4 acc_frag_0_5 = make_float4(0.0f,0.0f,0.0f,0.0f);
float4 acc_frag_0_6 = make_float4(0.0f,0.0f,0.0f,0.0f);
float4 acc_frag_0_7 = make_float4(0.0f,0.0f,0.0f,0.0f);
float4 acc_frag_1_0 = make_float4(0.0f,0.0f,0.0f,0.0f);
float4 acc_frag_1_1 = make_float4(0.0f,0.0f,0.0f,0.0f);
float4 acc_frag_1_2 = make_float4(0.0f,0.0f,0.0f,0.0f);
float4 acc_frag_1_3 = make_float4(0.0f,0.0f,0.0f,0.0f);
float4 acc_frag_1_4 = make_float4(0.0f,0.0f,0.0f,0.0f);
float4 acc_frag_1_5 = make_float4(0.0f,0.0f,0.0f,0.0f);
float4 acc_frag_1_6 = make_float4(0.0f,0.0f,0.0f,0.0f);
float4 acc_frag_1_7 = make_float4(0.0f,0.0f,0.0f,0.0f);
// create registers for block A elements (2)
half8 a_frag_0;
half8 a_frag_1;
// create register for block B elements (8)
half4 b_frag_0;
half4 b_frag_1;
half4 b_frag_2;
half4 b_frag_3;
half4 b_frag_4;
half4 b_frag_5;
half4 b_frag_6;
half4 b_frag_7;
half *smem_a_even = (half *)(smem);
half *smem_a_odd = (half *)(smem + 8192);
half *smem_b_even = (half *)(smem + 16384);
half *smem_b_odd = (half *)(smem + 32768);
// https://developer.nvidia.com/blog/controlling-data-movement-to-boost-performance-on-ampere-architecture/
// https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#asynchronous-data-copies
// start first pre-fetch load A
__pipeline_memcpy_async(&smem_a_even[store_smem_a_off + ( 0)], &data1[global_a_off + ( 0)], 16);
__pipeline_memcpy_async(&smem_a_even[store_smem_a_off + (16*64)], &data1[global_a_off + (16*K)], 16);
__pipeline_memcpy_async(&smem_a_even[store_smem_a_off + (32*64)], &data1[global_a_off + (32*K)], 16);
__pipeline_memcpy_async(&smem_a_even[store_smem_a_off + (48*64)], &data1[global_a_off + (48*K)], 16);
// start first pre-fetch load B
__pipeline_memcpy_async(&smem_b_even[store_smem_b_off + ( 0)], &data2[global_b_off + ( 0)], 16);
__pipeline_memcpy_async(&smem_b_even[store_smem_b_off + ( 8*128)], &data2[global_b_off + ( 8*N)], 16);
__pipeline_memcpy_async(&smem_b_even[store_smem_b_off + (16*128)], &data2[global_b_off + (16*N)], 16);
__pipeline_memcpy_async(&smem_b_even[store_smem_b_off + (24*128)], &data2[global_b_off + (24*N)], 16);
__pipeline_memcpy_async(&smem_b_even[store_smem_b_off + (32*128)], &data2[global_b_off + (32*N)], 16);
__pipeline_memcpy_async(&smem_b_even[store_smem_b_off + (40*128)], &data2[global_b_off + (40*N)], 16);
__pipeline_memcpy_async(&smem_b_even[store_smem_b_off + (48*128)], &data2[global_b_off + (48*N)], 16);
__pipeline_memcpy_async(&smem_b_even[store_smem_b_off + (56*128)], &data2[global_b_off + (56*N)], 16);
__pipeline_commit();
global_a_off += 64;
global_b_off += 64 * N;
__syncthreads();
// start second pre-fetch load A
__pipeline_memcpy_async(&smem_a_odd[store_smem_a_off + ( 0)], &data1[global_a_off + ( 0)], 16);
__pipeline_memcpy_async(&smem_a_odd[store_smem_a_off + (16*64)], &data1[global_a_off + (16*K)], 16);
__pipeline_memcpy_async(&smem_a_odd[store_smem_a_off + (32*64)], &data1[global_a_off + (32*K)], 16);
__pipeline_memcpy_async(&smem_a_odd[store_smem_a_off + (48*64)], &data1[global_a_off + (48*K)], 16);
// start second pre-fetch load B
__pipeline_memcpy_async(&smem_b_odd[store_smem_b_off + ( 0)], &data2[global_b_off + ( 0)], 16);
__pipeline_memcpy_async(&smem_b_odd[store_smem_b_off + ( 8*128)], &data2[global_b_off + ( 8*N)], 16);
__pipeline_memcpy_async(&smem_b_odd[store_smem_b_off + (16*128)], &data2[global_b_off + (16*N)], 16);
__pipeline_memcpy_async(&smem_b_odd[store_smem_b_off + (24*128)], &data2[global_b_off + (24*N)], 16);
__pipeline_memcpy_async(&smem_b_odd[store_smem_b_off + (32*128)], &data2[global_b_off + (32*N)], 16);
__pipeline_memcpy_async(&smem_b_odd[store_smem_b_off + (40*128)], &data2[global_b_off + (40*N)], 16);
__pipeline_memcpy_async(&smem_b_odd[store_smem_b_off + (48*128)], &data2[global_b_off + (48*N)], 16);
__pipeline_memcpy_async(&smem_b_odd[store_smem_b_off + (56*128)], &data2[global_b_off + (56*N)], 16);
__pipeline_commit();
global_a_off += 64;
global_b_off += 64 * N;
// wait on needed prefetch value
__pipeline_wait_prior(0); // TODO: this enables fast iterations, but incorrect results with 1 (it shouldn't)
__syncthreads();
for (int block_k = 0; block_k < num_k_blocks; block_k++) {
// BLOCK_K==4: unroll 4 iterations of ldmatrix/wmma
half *smem_a_curr = (block_k % 2) ? smem_a_even : smem_a_odd;
half *smem_b_curr = (block_k % 2) ? smem_b_even : smem_b_odd;
// first load 16 K elements and 16 WMMAs: BLOCK_M==2 * BLOCK_N==8
__ldmatrix_a_elems(&a_frag_0, &smem_a_curr[load_smem_a_0_k_0]);
__ldmatrix_a_elems(&a_frag_1, &smem_a_curr[load_smem_a_1_k_0]);
__ldmatrix_b_elems(&b_frag_0, &b_frag_1, &smem_b_curr[load_smem_b_0_k_0]);
__ldmatrix_b_elems(&b_frag_2, &b_frag_3, &smem_b_curr[load_smem_b_1_k_0]);
__ldmatrix_b_elems(&b_frag_4, &b_frag_5, &smem_b_curr[load_smem_b_2_k_0]);
__ldmatrix_b_elems(&b_frag_6, &b_frag_7, &smem_b_curr[load_smem_b_3_k_0]);
acc_frag_0_0 = __WMMA_8_16_16_half_float(a_frag_0, b_frag_0, acc_frag_0_0);
acc_frag_0_1 = __WMMA_8_16_16_half_float(a_frag_0, b_frag_1, acc_frag_0_1);
acc_frag_0_2 = __WMMA_8_16_16_half_float(a_frag_0, b_frag_2, acc_frag_0_2);
acc_frag_0_3 = __WMMA_8_16_16_half_float(a_frag_0, b_frag_3, acc_frag_0_3);
acc_frag_0_4 = __WMMA_8_16_16_half_float(a_frag_0, b_frag_4, acc_frag_0_4);
acc_frag_0_5 = __WMMA_8_16_16_half_float(a_frag_0, b_frag_5, acc_frag_0_5);
acc_frag_0_6 = __WMMA_8_16_16_half_float(a_frag_0, b_frag_6, acc_frag_0_6);
acc_frag_0_7 = __WMMA_8_16_16_half_float(a_frag_0, b_frag_7, acc_frag_0_7);
acc_frag_1_0 = __WMMA_8_16_16_half_float(a_frag_1, b_frag_0, acc_frag_1_0);
acc_frag_1_1 = __WMMA_8_16_16_half_float(a_frag_1, b_frag_1, acc_frag_1_1);
acc_frag_1_2 = __WMMA_8_16_16_half_float(a_frag_1, b_frag_2, acc_frag_1_2);
acc_frag_1_3 = __WMMA_8_16_16_half_float(a_frag_1, b_frag_3, acc_frag_1_3);
acc_frag_1_4 = __WMMA_8_16_16_half_float(a_frag_1, b_frag_4, acc_frag_1_4);
acc_frag_1_5 = __WMMA_8_16_16_half_float(a_frag_1, b_frag_5, acc_frag_1_5);
acc_frag_1_6 = __WMMA_8_16_16_half_float(a_frag_1, b_frag_6, acc_frag_1_6);
acc_frag_1_7 = __WMMA_8_16_16_half_float(a_frag_1, b_frag_7, acc_frag_1_7);
// next 16 K elements
__ldmatrix_a_elems(&a_frag_0, &smem_a_curr[load_smem_a_0_k_1]);
__ldmatrix_a_elems(&a_frag_1, &smem_a_curr[load_smem_a_1_k_1]);
__ldmatrix_b_elems(&b_frag_0, &b_frag_1, &smem_b_curr[load_smem_b_0_k_1]);
__ldmatrix_b_elems(&b_frag_2, &b_frag_3, &smem_b_curr[load_smem_b_1_k_1]);
__ldmatrix_b_elems(&b_frag_4, &b_frag_5, &smem_b_curr[load_smem_b_2_k_1]);
__ldmatrix_b_elems(&b_frag_6, &b_frag_7, &smem_b_curr[load_smem_b_3_k_1]);
acc_frag_0_0 = __WMMA_8_16_16_half_float(a_frag_0, b_frag_0, acc_frag_0_0);
acc_frag_0_1 = __WMMA_8_16_16_half_float(a_frag_0, b_frag_1, acc_frag_0_1);
acc_frag_0_2 = __WMMA_8_16_16_half_float(a_frag_0, b_frag_2, acc_frag_0_2);
acc_frag_0_3 = __WMMA_8_16_16_half_float(a_frag_0, b_frag_3, acc_frag_0_3);
acc_frag_0_4 = __WMMA_8_16_16_half_float(a_frag_0, b_frag_4, acc_frag_0_4);
acc_frag_0_5 = __WMMA_8_16_16_half_float(a_frag_0, b_frag_5, acc_frag_0_5);
acc_frag_0_6 = __WMMA_8_16_16_half_float(a_frag_0, b_frag_6, acc_frag_0_6);
acc_frag_0_7 = __WMMA_8_16_16_half_float(a_frag_0, b_frag_7, acc_frag_0_7);
acc_frag_1_0 = __WMMA_8_16_16_half_float(a_frag_1, b_frag_0, acc_frag_1_0);
acc_frag_1_1 = __WMMA_8_16_16_half_float(a_frag_1, b_frag_1, acc_frag_1_1);
acc_frag_1_2 = __WMMA_8_16_16_half_float(a_frag_1, b_frag_2, acc_frag_1_2);
acc_frag_1_3 = __WMMA_8_16_16_half_float(a_frag_1, b_frag_3, acc_frag_1_3);
acc_frag_1_4 = __WMMA_8_16_16_half_float(a_frag_1, b_frag_4, acc_frag_1_4);
acc_frag_1_5 = __WMMA_8_16_16_half_float(a_frag_1, b_frag_5, acc_frag_1_5);
acc_frag_1_6 = __WMMA_8_16_16_half_float(a_frag_1, b_frag_6, acc_frag_1_6);
acc_frag_1_7 = __WMMA_8_16_16_half_float(a_frag_1, b_frag_7, acc_frag_1_7);
// next 16 K elements
__ldmatrix_a_elems(&a_frag_0, &smem_a_curr[load_smem_a_0_k_2]);
__ldmatrix_a_elems(&a_frag_1, &smem_a_curr[load_smem_a_1_k_2]);
__ldmatrix_b_elems(&b_frag_0, &b_frag_1, &smem_b_curr[load_smem_b_0_k_2]);
__ldmatrix_b_elems(&b_frag_2, &b_frag_3, &smem_b_curr[load_smem_b_1_k_2]);
__ldmatrix_b_elems(&b_frag_4, &b_frag_5, &smem_b_curr[load_smem_b_2_k_2]);
__ldmatrix_b_elems(&b_frag_6, &b_frag_7, &smem_b_curr[load_smem_b_3_k_2]);
acc_frag_0_0 = __WMMA_8_16_16_half_float(a_frag_0, b_frag_0, acc_frag_0_0);
acc_frag_0_1 = __WMMA_8_16_16_half_float(a_frag_0, b_frag_1, acc_frag_0_1);
acc_frag_0_2 = __WMMA_8_16_16_half_float(a_frag_0, b_frag_2, acc_frag_0_2);
acc_frag_0_3 = __WMMA_8_16_16_half_float(a_frag_0, b_frag_3, acc_frag_0_3);
acc_frag_0_4 = __WMMA_8_16_16_half_float(a_frag_0, b_frag_4, acc_frag_0_4);
acc_frag_0_5 = __WMMA_8_16_16_half_float(a_frag_0, b_frag_5, acc_frag_0_5);
acc_frag_0_6 = __WMMA_8_16_16_half_float(a_frag_0, b_frag_6, acc_frag_0_6);
acc_frag_0_7 = __WMMA_8_16_16_half_float(a_frag_0, b_frag_7, acc_frag_0_7);
acc_frag_1_0 = __WMMA_8_16_16_half_float(a_frag_1, b_frag_0, acc_frag_1_0);
acc_frag_1_1 = __WMMA_8_16_16_half_float(a_frag_1, b_frag_1, acc_frag_1_1);
acc_frag_1_2 = __WMMA_8_16_16_half_float(a_frag_1, b_frag_2, acc_frag_1_2);
acc_frag_1_3 = __WMMA_8_16_16_half_float(a_frag_1, b_frag_3, acc_frag_1_3);
acc_frag_1_4 = __WMMA_8_16_16_half_float(a_frag_1, b_frag_4, acc_frag_1_4);
acc_frag_1_5 = __WMMA_8_16_16_half_float(a_frag_1, b_frag_5, acc_frag_1_5);
acc_frag_1_6 = __WMMA_8_16_16_half_float(a_frag_1, b_frag_6, acc_frag_1_6);
acc_frag_1_7 = __WMMA_8_16_16_half_float(a_frag_1, b_frag_7, acc_frag_1_7);
// last 16 K elements
__ldmatrix_a_elems(&a_frag_0, &smem_a_curr[load_smem_a_0_k_3]);
__ldmatrix_a_elems(&a_frag_1, &smem_a_curr[load_smem_a_1_k_3]);
__ldmatrix_b_elems(&b_frag_0, &b_frag_1, &smem_b_curr[load_smem_b_0_k_3]);
__ldmatrix_b_elems(&b_frag_2, &b_frag_3, &smem_b_curr[load_smem_b_1_k_3]);
__ldmatrix_b_elems(&b_frag_4, &b_frag_5, &smem_b_curr[load_smem_b_2_k_3]);
__ldmatrix_b_elems(&b_frag_6, &b_frag_7, &smem_b_curr[load_smem_b_3_k_3]);
acc_frag_0_0 = __WMMA_8_16_16_half_float(a_frag_0, b_frag_0, acc_frag_0_0);
acc_frag_0_1 = __WMMA_8_16_16_half_float(a_frag_0, b_frag_1, acc_frag_0_1);
acc_frag_0_2 = __WMMA_8_16_16_half_float(a_frag_0, b_frag_2, acc_frag_0_2);
acc_frag_0_3 = __WMMA_8_16_16_half_float(a_frag_0, b_frag_3, acc_frag_0_3);
acc_frag_0_4 = __WMMA_8_16_16_half_float(a_frag_0, b_frag_4, acc_frag_0_4);
acc_frag_0_5 = __WMMA_8_16_16_half_float(a_frag_0, b_frag_5, acc_frag_0_5);
acc_frag_0_6 = __WMMA_8_16_16_half_float(a_frag_0, b_frag_6, acc_frag_0_6);
acc_frag_0_7 = __WMMA_8_16_16_half_float(a_frag_0, b_frag_7, acc_frag_0_7);
acc_frag_1_0 = __WMMA_8_16_16_half_float(a_frag_1, b_frag_0, acc_frag_1_0);
acc_frag_1_1 = __WMMA_8_16_16_half_float(a_frag_1, b_frag_1, acc_frag_1_1);
acc_frag_1_2 = __WMMA_8_16_16_half_float(a_frag_1, b_frag_2, acc_frag_1_2);
acc_frag_1_3 = __WMMA_8_16_16_half_float(a_frag_1, b_frag_3, acc_frag_1_3);
acc_frag_1_4 = __WMMA_8_16_16_half_float(a_frag_1, b_frag_4, acc_frag_1_4);
acc_frag_1_5 = __WMMA_8_16_16_half_float(a_frag_1, b_frag_5, acc_frag_1_5);
acc_frag_1_6 = __WMMA_8_16_16_half_float(a_frag_1, b_frag_6, acc_frag_1_6);
acc_frag_1_7 = __WMMA_8_16_16_half_float(a_frag_1, b_frag_7, acc_frag_1_7);
// prefetch next iteration if needed
__syncthreads();
if (block_k < (num_k_blocks-2)) {
__pipeline_memcpy_async(&smem_a_curr[store_smem_a_off + ( 0)], &data1[global_a_off + ( 0)], 16);
__pipeline_memcpy_async(&smem_a_curr[store_smem_a_off + (16*64)], &data1[global_a_off + (16*K)], 16);
__pipeline_memcpy_async(&smem_a_curr[store_smem_a_off + (32*64)], &data1[global_a_off + (32*K)], 16);
__pipeline_memcpy_async(&smem_a_curr[store_smem_a_off + (48*64)], &data1[global_a_off + (48*K)], 16);
__pipeline_memcpy_async(&smem_b_curr[store_smem_b_off + ( 0)], &data2[global_b_off + ( 0)], 16);
__pipeline_memcpy_async(&smem_b_curr[store_smem_b_off + ( 8*128)], &data2[global_b_off + ( 8*N)], 16);
__pipeline_memcpy_async(&smem_b_curr[store_smem_b_off + (16*128)], &data2[global_b_off + (16*N)], 16);
__pipeline_memcpy_async(&smem_b_curr[store_smem_b_off + (24*128)], &data2[global_b_off + (24*N)], 16);
__pipeline_memcpy_async(&smem_b_curr[store_smem_b_off + (32*128)], &data2[global_b_off + (32*N)], 16);
__pipeline_memcpy_async(&smem_b_curr[store_smem_b_off + (40*128)], &data2[global_b_off + (40*N)], 16);
__pipeline_memcpy_async(&smem_b_curr[store_smem_b_off + (48*128)], &data2[global_b_off + (48*N)], 16);
__pipeline_memcpy_async(&smem_b_curr[store_smem_b_off + (56*128)], &data2[global_b_off + (56*N)], 16);
global_a_off += 64;
global_b_off += 64 * N;
}
__pipeline_commit();
if (block_k < num_k_blocks-1) {
__pipeline_wait_prior(1);
__syncthreads();
}
}
// write accumulators to output
__pipeline_wait_prior(0);
__syncthreads();
// slower way: write floats one by one to data0
size_t wg_c_off = ((grid_m * 64) * N) + (grid_n * 128) + (wg_m * 16 * N) + (wg_n * 16);
size_t thread_c_off = ((wg_threads % 4) * 2) + (((wg_threads / 4) % 8) * N);
data0[wg_c_off + thread_c_off + 0 + ( 0*8)] = acc_frag_0_0.x;
data0[wg_c_off + thread_c_off + 1 + ( 0*8)] = acc_frag_0_0.y;
data0[wg_c_off + thread_c_off + (8 * N) + 0 + ( 0*8)] = acc_frag_0_0.z;
data0[wg_c_off + thread_c_off + (8 * N) + 1 + ( 0*8)] = acc_frag_0_0.w;
data0[wg_c_off + thread_c_off + 0 + ( 1*8)] = acc_frag_0_1.x;
data0[wg_c_off + thread_c_off + 1 + ( 1*8)] = acc_frag_0_1.y;
data0[wg_c_off + thread_c_off + (8 * N) + 0 + ( 1*8)] = acc_frag_0_1.z;
data0[wg_c_off + thread_c_off + (8 * N) + 1 + ( 1*8)] = acc_frag_0_1.w;
data0[wg_c_off + thread_c_off + 0 + ( 4*8)] = acc_frag_0_2.x;
data0[wg_c_off + thread_c_off + 1 + ( 4*8)] = acc_frag_0_2.y;
data0[wg_c_off + thread_c_off + (8 * N) + 0 + ( 4*8)] = acc_frag_0_2.z;
data0[wg_c_off + thread_c_off + (8 * N) + 1 + ( 4*8)] = acc_frag_0_2.w;
data0[wg_c_off + thread_c_off + 0 + ( 5*8)] = acc_frag_0_3.x;
data0[wg_c_off + thread_c_off + 1 + ( 5*8)] = acc_frag_0_3.y;
data0[wg_c_off + thread_c_off + (8 * N) + 0 + ( 5*8)] = acc_frag_0_3.z;
data0[wg_c_off + thread_c_off + (8 * N) + 1 + ( 5*8)] = acc_frag_0_3.w;
data0[wg_c_off + thread_c_off + 0 + ( 8*8)] = acc_frag_0_4.x;
data0[wg_c_off + thread_c_off + 1 + ( 8*8)] = acc_frag_0_4.y;
data0[wg_c_off + thread_c_off + (8 * N) + 0 + ( 8*8)] = acc_frag_0_4.z;
data0[wg_c_off + thread_c_off + (8 * N) + 1 + ( 8*8)] = acc_frag_0_4.w;
data0[wg_c_off + thread_c_off + 0 + ( 9*8)] = acc_frag_0_5.x;
data0[wg_c_off + thread_c_off + 1 + ( 9*8)] = acc_frag_0_5.y;
data0[wg_c_off + thread_c_off + (8 * N) + 0 + ( 9*8)] = acc_frag_0_5.z;
data0[wg_c_off + thread_c_off + (8 * N) + 1 + ( 9*8)] = acc_frag_0_5.w;
data0[wg_c_off + thread_c_off + 0 + (12*8)] = acc_frag_0_6.x;
data0[wg_c_off + thread_c_off + 1 + (12*8)] = acc_frag_0_6.y;
data0[wg_c_off + thread_c_off + (8 * N) + 0 + (12*8)] = acc_frag_0_6.z;
data0[wg_c_off + thread_c_off + (8 * N) + 1 + (12*8)] = acc_frag_0_6.w;
data0[wg_c_off + thread_c_off + 0 + (13*8)] = acc_frag_0_7.x;
data0[wg_c_off + thread_c_off + 1 + (13*8)] = acc_frag_0_7.y;
data0[wg_c_off + thread_c_off + (8 * N) + 0 + (13*8)] = acc_frag_0_7.z;
data0[wg_c_off + thread_c_off + (8 * N) + 1 + (13*8)] = acc_frag_0_7.w;
wg_c_off += 32*N;
data0[wg_c_off + thread_c_off + 0 + ( 0*8)] = acc_frag_1_0.x;
data0[wg_c_off + thread_c_off + 1 + ( 0*8)] = acc_frag_1_0.y;
data0[wg_c_off + thread_c_off + (8 * N) + 0 + ( 0*8)] = acc_frag_1_0.z;
data0[wg_c_off + thread_c_off + (8 * N) + 1 + ( 0*8)] = acc_frag_1_0.w;
data0[wg_c_off + thread_c_off + 0 + ( 1*8)] = acc_frag_1_1.x;
data0[wg_c_off + thread_c_off + 1 + ( 1*8)] = acc_frag_1_1.y;
data0[wg_c_off + thread_c_off + (8 * N) + 0 + ( 1*8)] = acc_frag_1_1.z;
data0[wg_c_off + thread_c_off + (8 * N) + 1 + ( 1*8)] = acc_frag_1_1.w;
data0[wg_c_off + thread_c_off + 0 + ( 4*8)] = acc_frag_1_2.x;
data0[wg_c_off + thread_c_off + 1 + ( 4*8)] = acc_frag_1_2.y;
data0[wg_c_off + thread_c_off + (8 * N) + 0 + ( 4*8)] = acc_frag_1_2.z;
data0[wg_c_off + thread_c_off + (8 * N) + 1 + ( 4*8)] = acc_frag_1_2.w;
data0[wg_c_off + thread_c_off + 0 + ( 5*8)] = acc_frag_1_3.x;
data0[wg_c_off + thread_c_off + 1 + ( 5*8)] = acc_frag_1_3.y;
data0[wg_c_off + thread_c_off + (8 * N) + 0 + ( 5*8)] = acc_frag_1_3.z;
data0[wg_c_off + thread_c_off + (8 * N) + 1 + ( 5*8)] = acc_frag_1_3.w;
data0[wg_c_off + thread_c_off + 0 + ( 8*8)] = acc_frag_1_4.x;
data0[wg_c_off + thread_c_off + 1 + ( 8*8)] = acc_frag_1_4.y;
data0[wg_c_off + thread_c_off + (8 * N) + 0 + ( 8*8)] = acc_frag_1_4.z;
data0[wg_c_off + thread_c_off + (8 * N) + 1 + ( 8*8)] = acc_frag_1_4.w;
data0[wg_c_off + thread_c_off + 0 + ( 9*8)] = acc_frag_1_5.x;
data0[wg_c_off + thread_c_off + 1 + ( 9*8)] = acc_frag_1_5.y;
data0[wg_c_off + thread_c_off + (8 * N) + 0 + ( 9*8)] = acc_frag_1_5.z;
data0[wg_c_off + thread_c_off + (8 * N) + 1 + ( 9*8)] = acc_frag_1_5.w;
data0[wg_c_off + thread_c_off + 0 + (12*8)] = acc_frag_1_6.x;
data0[wg_c_off + thread_c_off + 1 + (12*8)] = acc_frag_1_6.y;
data0[wg_c_off + thread_c_off + (8 * N) + 0 + (12*8)] = acc_frag_1_6.z;
data0[wg_c_off + thread_c_off + (8 * N) + 1 + (12*8)] = acc_frag_1_6.w;
data0[wg_c_off + thread_c_off + 0 + (13*8)] = acc_frag_1_7.x;
data0[wg_c_off + thread_c_off + 1 + (13*8)] = acc_frag_1_7.y;
data0[wg_c_off + thread_c_off + (8 * N) + 0 + (13*8)] = acc_frag_1_7.z;
data0[wg_c_off + thread_c_off + (8 * N) + 1 + (13*8)] = acc_frag_1_7.w;
}

View File

@@ -0,0 +1,363 @@
#define INFINITY (__int_as_float(0x7f800000))
#define NAN (__int_as_float(0x7fffffff))
#include <cuda_fp16.h>
#include <cuda_pipeline.h>
#define N_PAD 132
struct __align__(8) half4 { half x, y, z, w; };
__device__ half4 make_half4(half x, half y, half z, half w) { half4 r={x, y, z, w}; return r; }
struct __align__(16) half8 { half x, y, z, w, a, b, c, d; };
__device__ half8 make_half8(half x, half y, half z, half w, half a, half b, half c, half d) { half8 r={x, y, z, w, a, b, c, d}; return r; }
__device__ void __ldmatrix_a_elems(half8 *regs, half *smem) {
uint32_t reg0, reg1, reg2, reg3;
asm volatile(
"ldmatrix.sync.aligned.m8n8.x4.shared.b16 {%0, %1, %2, %3}, [%4];"
: "=r"(reg0), "=r"(reg1), "=r"(reg2), "=r"(reg3)
: "l"(__cvta_generic_to_shared(smem))
);
uint32_t *addr = reinterpret_cast<uint32_t*>(regs);
addr[0] = reg0;
addr[1] = reg1;
addr[2] = reg2;
addr[3] = reg3;
}
__device__ void __ldmatrix_b_elems(half4 *regs_lo, half4 *regs_hi, half *smem) {
uint32_t reg0, reg1, reg2, reg3;
asm volatile(
"ldmatrix.sync.aligned.m8n8.x4.trans.shared.b16 {%0, %1, %2, %3}, [%4];"
: "=r"(reg0), "=r"(reg1), "=r"(reg2), "=r"(reg3)
: "l"(__cvta_generic_to_shared(smem))
);
uint32_t *addr_lo = reinterpret_cast<uint32_t*>(regs_lo);
uint32_t *addr_hi = reinterpret_cast<uint32_t*>(regs_hi);
addr_lo[0] = reg0;
addr_lo[1] = reg1;
addr_hi[0] = reg2;
addr_hi[1] = reg3;
}
__device__ float4 __WMMA_8_16_16_half_float(half8 a, half4 b, float4 c) {
int *a_pk = (int *) (&a), *b_pk = (int *) (&b);
asm( "mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 { %0, %1, %2, %3 }, { %4, %5, %6, %7 }, { %8, %9 }, { %0, %1, %2, %3 };"
: "+f"(c.x), "+f"(c.y), "+f"(c.z), "+f"(c.w) : "r"(a_pk[0]), "r"(a_pk[1]), "r"(a_pk[2]), "r"(a_pk[3]), "r"(b_pk[0]), "r"(b_pk[1]) );
return c;
}
extern "C" __global__ void __launch_bounds__(128) wmma_example(float* data0, const half* data1, const half* data2, int N, int K) {
int grid_m = blockIdx.x; /* M//64 */
int grid_n = blockIdx.y; /* N//128 */
int threads = threadIdx.x; /* 128 */
int wg_m = (threads/64); // 0 or 1 for 1st and 3rd blocks of b_m=16xb_k=16 vs 2nd and 4th blocks
int wg_n = (threads/32)%2; // 0 or 1 for 1st, 3rd, 5th, 7th blocks of b_n=16xb_k=16 vs 2nd, 4th, 6th, 8th blocks - differs from triton
int wg_threads = threads%32;
int num_k_blocks = K / 64;
// load indexes
size_t global_a_off = ((grid_m * 64) * K) + ((threads % 8) * 8) + ((threads / 8) * K);
size_t global_b_off = (grid_n * 128) + ((threads % 16) * 8) + ((threads / 16) * N);
// non-swizzled - should work slowly with bank conflicts
size_t store_smem_a_off = ((threads % 8) * 8) + ((threads / 8) * 64);
size_t store_smem_b_off = ((threads % 16) * 8) + ((threads / 16) * 128);
// ldmatrix indices
// threads 0-7 are row starts for A, 8-15 for B, 16-23 for C, 24-31 for D
// [ A | C ]
// [ - + - ]
// [ B | D ]
// unswizzled ldmatrix
size_t load_smem_a_0_k_0 = (wg_m * 16 * 64) + ((wg_threads % 8) * 64) + (((wg_threads / 8) % 2) * 64 * 8) + ((wg_threads / 16) * 8);
size_t load_smem_a_1_k_0 = load_smem_a_0_k_0 + (32*64);
size_t load_smem_b_0_k_0 = (wg_n * 16) + ((wg_threads % 8) * 128) + (((wg_threads / 8) % 2) * 128 * 8) + ((wg_threads / 16) * 8);
size_t load_smem_b_1_k_0 = load_smem_b_0_k_0 + 32;
size_t load_smem_b_2_k_0 = load_smem_b_0_k_0 + 64;
size_t load_smem_b_3_k_0 = load_smem_b_0_k_0 + 96;
size_t load_smem_a_0_k_1 = load_smem_a_0_k_0 + 16;
size_t load_smem_a_1_k_1 = load_smem_a_1_k_0 + 16;
size_t load_smem_b_0_k_1 = load_smem_b_0_k_0 + (16 * 128);
size_t load_smem_b_1_k_1 = load_smem_b_1_k_0 + (16 * 128);
size_t load_smem_b_2_k_1 = load_smem_b_2_k_0 + (16 * 128);
size_t load_smem_b_3_k_1 = load_smem_b_3_k_0 + (16 * 128);
size_t load_smem_a_0_k_2 = load_smem_a_0_k_0 + 32;
size_t load_smem_a_1_k_2 = load_smem_a_1_k_0 + 32;
size_t load_smem_b_0_k_2 = load_smem_b_0_k_0 + (32 * 128);
size_t load_smem_b_1_k_2 = load_smem_b_1_k_0 + (32 * 128);
size_t load_smem_b_2_k_2 = load_smem_b_2_k_0 + (32 * 128);
size_t load_smem_b_3_k_2 = load_smem_b_3_k_0 + (32 * 128);
size_t load_smem_a_0_k_3 = load_smem_a_0_k_0 + 48;
size_t load_smem_a_1_k_3 = load_smem_a_1_k_0 + 48;
size_t load_smem_b_0_k_3 = load_smem_b_0_k_0 + (48 * 128);
size_t load_smem_b_1_k_3 = load_smem_b_1_k_0 + (48 * 128);
size_t load_smem_b_2_k_3 = load_smem_b_2_k_0 + (48 * 128);
size_t load_smem_b_3_k_3 = load_smem_b_3_k_0 + (48 * 128);
// create shared mem (A 8192 bytes, B 16384 bytes)
__shared__ alignas(16) char smem[24576];
// create accs (16 WMMAs and 4 output elements each) and zero
float4 acc_frag_0_0 = make_float4(0.0f,0.0f,0.0f,0.0f);
float4 acc_frag_0_1 = make_float4(0.0f,0.0f,0.0f,0.0f);
float4 acc_frag_0_2 = make_float4(0.0f,0.0f,0.0f,0.0f);
float4 acc_frag_0_3 = make_float4(0.0f,0.0f,0.0f,0.0f);
float4 acc_frag_0_4 = make_float4(0.0f,0.0f,0.0f,0.0f);
float4 acc_frag_0_5 = make_float4(0.0f,0.0f,0.0f,0.0f);
float4 acc_frag_0_6 = make_float4(0.0f,0.0f,0.0f,0.0f);
float4 acc_frag_0_7 = make_float4(0.0f,0.0f,0.0f,0.0f);
float4 acc_frag_1_0 = make_float4(0.0f,0.0f,0.0f,0.0f);
float4 acc_frag_1_1 = make_float4(0.0f,0.0f,0.0f,0.0f);
float4 acc_frag_1_2 = make_float4(0.0f,0.0f,0.0f,0.0f);
float4 acc_frag_1_3 = make_float4(0.0f,0.0f,0.0f,0.0f);
float4 acc_frag_1_4 = make_float4(0.0f,0.0f,0.0f,0.0f);
float4 acc_frag_1_5 = make_float4(0.0f,0.0f,0.0f,0.0f);
float4 acc_frag_1_6 = make_float4(0.0f,0.0f,0.0f,0.0f);
float4 acc_frag_1_7 = make_float4(0.0f,0.0f,0.0f,0.0f);
// create registers for block A elements (2)
half8 a_frag_0;
half8 a_frag_1;
// create register for block B elements (8)
half4 b_frag_0;
half4 b_frag_1;
half4 b_frag_2;
half4 b_frag_3;
half4 b_frag_4;
half4 b_frag_5;
half4 b_frag_6;
half4 b_frag_7;
half *smem_a = (half *)(smem);
half *smem_b = (half *)(smem + 8192);
// https://developer.nvidia.com/blog/controlling-data-movement-to-boost-performance-on-ampere-architecture/
// https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#asynchronous-data-copies
// start first pre-fetch load A
__pipeline_memcpy_async(&smem_a[store_smem_a_off + ( 0)], &data1[global_a_off + ( 0)], 16);
__pipeline_memcpy_async(&smem_a[store_smem_a_off + (16*64)], &data1[global_a_off + (16*K)], 16);
__pipeline_memcpy_async(&smem_a[store_smem_a_off + (32*64)], &data1[global_a_off + (32*K)], 16);
__pipeline_memcpy_async(&smem_a[store_smem_a_off + (48*64)], &data1[global_a_off + (48*K)], 16);
// start first pre-fetch load B
__pipeline_memcpy_async(&smem_b[store_smem_b_off + ( 0)], &data2[global_b_off + ( 0)], 16);
__pipeline_memcpy_async(&smem_b[store_smem_b_off + ( 8*128)], &data2[global_b_off + ( 8*N)], 16);
__pipeline_memcpy_async(&smem_b[store_smem_b_off + (16*128)], &data2[global_b_off + (16*N)], 16);
__pipeline_memcpy_async(&smem_b[store_smem_b_off + (24*128)], &data2[global_b_off + (24*N)], 16);
__pipeline_memcpy_async(&smem_b[store_smem_b_off + (32*128)], &data2[global_b_off + (32*N)], 16);
__pipeline_memcpy_async(&smem_b[store_smem_b_off + (40*128)], &data2[global_b_off + (40*N)], 16);
__pipeline_memcpy_async(&smem_b[store_smem_b_off + (48*128)], &data2[global_b_off + (48*N)], 16);
__pipeline_memcpy_async(&smem_b[store_smem_b_off + (56*128)], &data2[global_b_off + (56*N)], 16);
__pipeline_commit();
global_a_off += 64;
global_b_off += 64 * N;
__syncthreads();
for (int block_k = 0; block_k < num_k_blocks; block_k++) {
// wait on needed prefetch value
__pipeline_wait_prior(0);
__syncthreads();
// BLOCK_K==4: unroll 4 iterations of ldmatrix/wmma
half *smem_a_curr = smem_a;
half *smem_b_curr = smem_b;
// first load 16 K elements and 16 WMMAs: BLOCK_M==2 * BLOCK_N==8
__ldmatrix_a_elems(&a_frag_0, &smem_a_curr[load_smem_a_0_k_0]);
__ldmatrix_a_elems(&a_frag_1, &smem_a_curr[load_smem_a_1_k_0]);
__ldmatrix_b_elems(&b_frag_0, &b_frag_1, &smem_b_curr[load_smem_b_0_k_0]);
__ldmatrix_b_elems(&b_frag_2, &b_frag_3, &smem_b_curr[load_smem_b_1_k_0]);
__ldmatrix_b_elems(&b_frag_4, &b_frag_5, &smem_b_curr[load_smem_b_2_k_0]);
__ldmatrix_b_elems(&b_frag_6, &b_frag_7, &smem_b_curr[load_smem_b_3_k_0]);
acc_frag_0_0 = __WMMA_8_16_16_half_float(a_frag_0, b_frag_0, acc_frag_0_0);
acc_frag_0_1 = __WMMA_8_16_16_half_float(a_frag_0, b_frag_1, acc_frag_0_1);
acc_frag_0_2 = __WMMA_8_16_16_half_float(a_frag_0, b_frag_2, acc_frag_0_2);
acc_frag_0_3 = __WMMA_8_16_16_half_float(a_frag_0, b_frag_3, acc_frag_0_3);
acc_frag_0_4 = __WMMA_8_16_16_half_float(a_frag_0, b_frag_4, acc_frag_0_4);
acc_frag_0_5 = __WMMA_8_16_16_half_float(a_frag_0, b_frag_5, acc_frag_0_5);
acc_frag_0_6 = __WMMA_8_16_16_half_float(a_frag_0, b_frag_6, acc_frag_0_6);
acc_frag_0_7 = __WMMA_8_16_16_half_float(a_frag_0, b_frag_7, acc_frag_0_7);
acc_frag_1_0 = __WMMA_8_16_16_half_float(a_frag_1, b_frag_0, acc_frag_1_0);
acc_frag_1_1 = __WMMA_8_16_16_half_float(a_frag_1, b_frag_1, acc_frag_1_1);
acc_frag_1_2 = __WMMA_8_16_16_half_float(a_frag_1, b_frag_2, acc_frag_1_2);
acc_frag_1_3 = __WMMA_8_16_16_half_float(a_frag_1, b_frag_3, acc_frag_1_3);
acc_frag_1_4 = __WMMA_8_16_16_half_float(a_frag_1, b_frag_4, acc_frag_1_4);
acc_frag_1_5 = __WMMA_8_16_16_half_float(a_frag_1, b_frag_5, acc_frag_1_5);
acc_frag_1_6 = __WMMA_8_16_16_half_float(a_frag_1, b_frag_6, acc_frag_1_6);
acc_frag_1_7 = __WMMA_8_16_16_half_float(a_frag_1, b_frag_7, acc_frag_1_7);
// next 16 K elements
__ldmatrix_a_elems(&a_frag_0, &smem_a_curr[load_smem_a_0_k_1]);
__ldmatrix_a_elems(&a_frag_1, &smem_a_curr[load_smem_a_1_k_1]);
__ldmatrix_b_elems(&b_frag_0, &b_frag_1, &smem_b_curr[load_smem_b_0_k_1]);
__ldmatrix_b_elems(&b_frag_2, &b_frag_3, &smem_b_curr[load_smem_b_1_k_1]);
__ldmatrix_b_elems(&b_frag_4, &b_frag_5, &smem_b_curr[load_smem_b_2_k_1]);
__ldmatrix_b_elems(&b_frag_6, &b_frag_7, &smem_b_curr[load_smem_b_3_k_1]);
acc_frag_0_0 = __WMMA_8_16_16_half_float(a_frag_0, b_frag_0, acc_frag_0_0);
acc_frag_0_1 = __WMMA_8_16_16_half_float(a_frag_0, b_frag_1, acc_frag_0_1);
acc_frag_0_2 = __WMMA_8_16_16_half_float(a_frag_0, b_frag_2, acc_frag_0_2);
acc_frag_0_3 = __WMMA_8_16_16_half_float(a_frag_0, b_frag_3, acc_frag_0_3);
acc_frag_0_4 = __WMMA_8_16_16_half_float(a_frag_0, b_frag_4, acc_frag_0_4);
acc_frag_0_5 = __WMMA_8_16_16_half_float(a_frag_0, b_frag_5, acc_frag_0_5);
acc_frag_0_6 = __WMMA_8_16_16_half_float(a_frag_0, b_frag_6, acc_frag_0_6);
acc_frag_0_7 = __WMMA_8_16_16_half_float(a_frag_0, b_frag_7, acc_frag_0_7);
acc_frag_1_0 = __WMMA_8_16_16_half_float(a_frag_1, b_frag_0, acc_frag_1_0);
acc_frag_1_1 = __WMMA_8_16_16_half_float(a_frag_1, b_frag_1, acc_frag_1_1);
acc_frag_1_2 = __WMMA_8_16_16_half_float(a_frag_1, b_frag_2, acc_frag_1_2);
acc_frag_1_3 = __WMMA_8_16_16_half_float(a_frag_1, b_frag_3, acc_frag_1_3);
acc_frag_1_4 = __WMMA_8_16_16_half_float(a_frag_1, b_frag_4, acc_frag_1_4);
acc_frag_1_5 = __WMMA_8_16_16_half_float(a_frag_1, b_frag_5, acc_frag_1_5);
acc_frag_1_6 = __WMMA_8_16_16_half_float(a_frag_1, b_frag_6, acc_frag_1_6);
acc_frag_1_7 = __WMMA_8_16_16_half_float(a_frag_1, b_frag_7, acc_frag_1_7);
// next 16 K elements
__ldmatrix_a_elems(&a_frag_0, &smem_a_curr[load_smem_a_0_k_2]);
__ldmatrix_a_elems(&a_frag_1, &smem_a_curr[load_smem_a_1_k_2]);
__ldmatrix_b_elems(&b_frag_0, &b_frag_1, &smem_b_curr[load_smem_b_0_k_2]);
__ldmatrix_b_elems(&b_frag_2, &b_frag_3, &smem_b_curr[load_smem_b_1_k_2]);
__ldmatrix_b_elems(&b_frag_4, &b_frag_5, &smem_b_curr[load_smem_b_2_k_2]);
__ldmatrix_b_elems(&b_frag_6, &b_frag_7, &smem_b_curr[load_smem_b_3_k_2]);
acc_frag_0_0 = __WMMA_8_16_16_half_float(a_frag_0, b_frag_0, acc_frag_0_0);
acc_frag_0_1 = __WMMA_8_16_16_half_float(a_frag_0, b_frag_1, acc_frag_0_1);
acc_frag_0_2 = __WMMA_8_16_16_half_float(a_frag_0, b_frag_2, acc_frag_0_2);
acc_frag_0_3 = __WMMA_8_16_16_half_float(a_frag_0, b_frag_3, acc_frag_0_3);
acc_frag_0_4 = __WMMA_8_16_16_half_float(a_frag_0, b_frag_4, acc_frag_0_4);
acc_frag_0_5 = __WMMA_8_16_16_half_float(a_frag_0, b_frag_5, acc_frag_0_5);
acc_frag_0_6 = __WMMA_8_16_16_half_float(a_frag_0, b_frag_6, acc_frag_0_6);
acc_frag_0_7 = __WMMA_8_16_16_half_float(a_frag_0, b_frag_7, acc_frag_0_7);
acc_frag_1_0 = __WMMA_8_16_16_half_float(a_frag_1, b_frag_0, acc_frag_1_0);
acc_frag_1_1 = __WMMA_8_16_16_half_float(a_frag_1, b_frag_1, acc_frag_1_1);
acc_frag_1_2 = __WMMA_8_16_16_half_float(a_frag_1, b_frag_2, acc_frag_1_2);
acc_frag_1_3 = __WMMA_8_16_16_half_float(a_frag_1, b_frag_3, acc_frag_1_3);
acc_frag_1_4 = __WMMA_8_16_16_half_float(a_frag_1, b_frag_4, acc_frag_1_4);
acc_frag_1_5 = __WMMA_8_16_16_half_float(a_frag_1, b_frag_5, acc_frag_1_5);
acc_frag_1_6 = __WMMA_8_16_16_half_float(a_frag_1, b_frag_6, acc_frag_1_6);
acc_frag_1_7 = __WMMA_8_16_16_half_float(a_frag_1, b_frag_7, acc_frag_1_7);
// last 16 K elements
__ldmatrix_a_elems(&a_frag_0, &smem_a_curr[load_smem_a_0_k_3]);
__ldmatrix_a_elems(&a_frag_1, &smem_a_curr[load_smem_a_1_k_3]);
__ldmatrix_b_elems(&b_frag_0, &b_frag_1, &smem_b_curr[load_smem_b_0_k_3]);
__ldmatrix_b_elems(&b_frag_2, &b_frag_3, &smem_b_curr[load_smem_b_1_k_3]);
__ldmatrix_b_elems(&b_frag_4, &b_frag_5, &smem_b_curr[load_smem_b_2_k_3]);
__ldmatrix_b_elems(&b_frag_6, &b_frag_7, &smem_b_curr[load_smem_b_3_k_3]);
acc_frag_0_0 = __WMMA_8_16_16_half_float(a_frag_0, b_frag_0, acc_frag_0_0);
acc_frag_0_1 = __WMMA_8_16_16_half_float(a_frag_0, b_frag_1, acc_frag_0_1);
acc_frag_0_2 = __WMMA_8_16_16_half_float(a_frag_0, b_frag_2, acc_frag_0_2);
acc_frag_0_3 = __WMMA_8_16_16_half_float(a_frag_0, b_frag_3, acc_frag_0_3);
acc_frag_0_4 = __WMMA_8_16_16_half_float(a_frag_0, b_frag_4, acc_frag_0_4);
acc_frag_0_5 = __WMMA_8_16_16_half_float(a_frag_0, b_frag_5, acc_frag_0_5);
acc_frag_0_6 = __WMMA_8_16_16_half_float(a_frag_0, b_frag_6, acc_frag_0_6);
acc_frag_0_7 = __WMMA_8_16_16_half_float(a_frag_0, b_frag_7, acc_frag_0_7);
acc_frag_1_0 = __WMMA_8_16_16_half_float(a_frag_1, b_frag_0, acc_frag_1_0);
acc_frag_1_1 = __WMMA_8_16_16_half_float(a_frag_1, b_frag_1, acc_frag_1_1);
acc_frag_1_2 = __WMMA_8_16_16_half_float(a_frag_1, b_frag_2, acc_frag_1_2);
acc_frag_1_3 = __WMMA_8_16_16_half_float(a_frag_1, b_frag_3, acc_frag_1_3);
acc_frag_1_4 = __WMMA_8_16_16_half_float(a_frag_1, b_frag_4, acc_frag_1_4);
acc_frag_1_5 = __WMMA_8_16_16_half_float(a_frag_1, b_frag_5, acc_frag_1_5);
acc_frag_1_6 = __WMMA_8_16_16_half_float(a_frag_1, b_frag_6, acc_frag_1_6);
acc_frag_1_7 = __WMMA_8_16_16_half_float(a_frag_1, b_frag_7, acc_frag_1_7);
// prefetch next iteration if needed
__syncthreads();
if (block_k < (num_k_blocks-1)) {
__pipeline_memcpy_async(&smem_a_curr[store_smem_a_off + ( 0)], &data1[global_a_off + ( 0)], 16);
__pipeline_memcpy_async(&smem_a_curr[store_smem_a_off + (16*64)], &data1[global_a_off + (16*K)], 16);
__pipeline_memcpy_async(&smem_a_curr[store_smem_a_off + (32*64)], &data1[global_a_off + (32*K)], 16);
__pipeline_memcpy_async(&smem_a_curr[store_smem_a_off + (48*64)], &data1[global_a_off + (48*K)], 16);
__pipeline_memcpy_async(&smem_b_curr[store_smem_b_off + ( 0)], &data2[global_b_off + ( 0)], 16);
__pipeline_memcpy_async(&smem_b_curr[store_smem_b_off + ( 8*128)], &data2[global_b_off + ( 8*N)], 16);
__pipeline_memcpy_async(&smem_b_curr[store_smem_b_off + (16*128)], &data2[global_b_off + (16*N)], 16);
__pipeline_memcpy_async(&smem_b_curr[store_smem_b_off + (24*128)], &data2[global_b_off + (24*N)], 16);
__pipeline_memcpy_async(&smem_b_curr[store_smem_b_off + (32*128)], &data2[global_b_off + (32*N)], 16);
__pipeline_memcpy_async(&smem_b_curr[store_smem_b_off + (40*128)], &data2[global_b_off + (40*N)], 16);
__pipeline_memcpy_async(&smem_b_curr[store_smem_b_off + (48*128)], &data2[global_b_off + (48*N)], 16);
__pipeline_memcpy_async(&smem_b_curr[store_smem_b_off + (56*128)], &data2[global_b_off + (56*N)], 16);
global_a_off += 64;
global_b_off += 64 * N;
}
__pipeline_commit();
}
// write accumulators to output
__pipeline_wait_prior(0);
__syncthreads();
// slower way: write floats one by one to data0
size_t wg_c_off = ((grid_m * 64) * N) + (grid_n * 128) + (wg_m * 16 * N) + (wg_n * 16);
size_t thread_c_off = ((wg_threads % 4) * 2) + (((wg_threads / 4) % 8) * N);
data0[wg_c_off + thread_c_off + 0 + ( 0*8)] = acc_frag_0_0.x;
data0[wg_c_off + thread_c_off + 1 + ( 0*8)] = acc_frag_0_0.y;
data0[wg_c_off + thread_c_off + (8 * N) + 0 + ( 0*8)] = acc_frag_0_0.z;
data0[wg_c_off + thread_c_off + (8 * N) + 1 + ( 0*8)] = acc_frag_0_0.w;
data0[wg_c_off + thread_c_off + 0 + ( 1*8)] = acc_frag_0_1.x;
data0[wg_c_off + thread_c_off + 1 + ( 1*8)] = acc_frag_0_1.y;
data0[wg_c_off + thread_c_off + (8 * N) + 0 + ( 1*8)] = acc_frag_0_1.z;
data0[wg_c_off + thread_c_off + (8 * N) + 1 + ( 1*8)] = acc_frag_0_1.w;
data0[wg_c_off + thread_c_off + 0 + ( 4*8)] = acc_frag_0_2.x;
data0[wg_c_off + thread_c_off + 1 + ( 4*8)] = acc_frag_0_2.y;
data0[wg_c_off + thread_c_off + (8 * N) + 0 + ( 4*8)] = acc_frag_0_2.z;
data0[wg_c_off + thread_c_off + (8 * N) + 1 + ( 4*8)] = acc_frag_0_2.w;
data0[wg_c_off + thread_c_off + 0 + ( 5*8)] = acc_frag_0_3.x;
data0[wg_c_off + thread_c_off + 1 + ( 5*8)] = acc_frag_0_3.y;
data0[wg_c_off + thread_c_off + (8 * N) + 0 + ( 5*8)] = acc_frag_0_3.z;
data0[wg_c_off + thread_c_off + (8 * N) + 1 + ( 5*8)] = acc_frag_0_3.w;
data0[wg_c_off + thread_c_off + 0 + ( 8*8)] = acc_frag_0_4.x;
data0[wg_c_off + thread_c_off + 1 + ( 8*8)] = acc_frag_0_4.y;
data0[wg_c_off + thread_c_off + (8 * N) + 0 + ( 8*8)] = acc_frag_0_4.z;
data0[wg_c_off + thread_c_off + (8 * N) + 1 + ( 8*8)] = acc_frag_0_4.w;
data0[wg_c_off + thread_c_off + 0 + ( 9*8)] = acc_frag_0_5.x;
data0[wg_c_off + thread_c_off + 1 + ( 9*8)] = acc_frag_0_5.y;
data0[wg_c_off + thread_c_off + (8 * N) + 0 + ( 9*8)] = acc_frag_0_5.z;
data0[wg_c_off + thread_c_off + (8 * N) + 1 + ( 9*8)] = acc_frag_0_5.w;
data0[wg_c_off + thread_c_off + 0 + (12*8)] = acc_frag_0_6.x;
data0[wg_c_off + thread_c_off + 1 + (12*8)] = acc_frag_0_6.y;
data0[wg_c_off + thread_c_off + (8 * N) + 0 + (12*8)] = acc_frag_0_6.z;
data0[wg_c_off + thread_c_off + (8 * N) + 1 + (12*8)] = acc_frag_0_6.w;
data0[wg_c_off + thread_c_off + 0 + (13*8)] = acc_frag_0_7.x;
data0[wg_c_off + thread_c_off + 1 + (13*8)] = acc_frag_0_7.y;
data0[wg_c_off + thread_c_off + (8 * N) + 0 + (13*8)] = acc_frag_0_7.z;
data0[wg_c_off + thread_c_off + (8 * N) + 1 + (13*8)] = acc_frag_0_7.w;
wg_c_off += 32*N;
data0[wg_c_off + thread_c_off + 0 + ( 0*8)] = acc_frag_1_0.x;
data0[wg_c_off + thread_c_off + 1 + ( 0*8)] = acc_frag_1_0.y;
data0[wg_c_off + thread_c_off + (8 * N) + 0 + ( 0*8)] = acc_frag_1_0.z;
data0[wg_c_off + thread_c_off + (8 * N) + 1 + ( 0*8)] = acc_frag_1_0.w;
data0[wg_c_off + thread_c_off + 0 + ( 1*8)] = acc_frag_1_1.x;
data0[wg_c_off + thread_c_off + 1 + ( 1*8)] = acc_frag_1_1.y;
data0[wg_c_off + thread_c_off + (8 * N) + 0 + ( 1*8)] = acc_frag_1_1.z;
data0[wg_c_off + thread_c_off + (8 * N) + 1 + ( 1*8)] = acc_frag_1_1.w;
data0[wg_c_off + thread_c_off + 0 + ( 4*8)] = acc_frag_1_2.x;
data0[wg_c_off + thread_c_off + 1 + ( 4*8)] = acc_frag_1_2.y;
data0[wg_c_off + thread_c_off + (8 * N) + 0 + ( 4*8)] = acc_frag_1_2.z;
data0[wg_c_off + thread_c_off + (8 * N) + 1 + ( 4*8)] = acc_frag_1_2.w;
data0[wg_c_off + thread_c_off + 0 + ( 5*8)] = acc_frag_1_3.x;
data0[wg_c_off + thread_c_off + 1 + ( 5*8)] = acc_frag_1_3.y;
data0[wg_c_off + thread_c_off + (8 * N) + 0 + ( 5*8)] = acc_frag_1_3.z;
data0[wg_c_off + thread_c_off + (8 * N) + 1 + ( 5*8)] = acc_frag_1_3.w;
data0[wg_c_off + thread_c_off + 0 + ( 8*8)] = acc_frag_1_4.x;
data0[wg_c_off + thread_c_off + 1 + ( 8*8)] = acc_frag_1_4.y;
data0[wg_c_off + thread_c_off + (8 * N) + 0 + ( 8*8)] = acc_frag_1_4.z;
data0[wg_c_off + thread_c_off + (8 * N) + 1 + ( 8*8)] = acc_frag_1_4.w;
data0[wg_c_off + thread_c_off + 0 + ( 9*8)] = acc_frag_1_5.x;
data0[wg_c_off + thread_c_off + 1 + ( 9*8)] = acc_frag_1_5.y;
data0[wg_c_off + thread_c_off + (8 * N) + 0 + ( 9*8)] = acc_frag_1_5.z;
data0[wg_c_off + thread_c_off + (8 * N) + 1 + ( 9*8)] = acc_frag_1_5.w;
data0[wg_c_off + thread_c_off + 0 + (12*8)] = acc_frag_1_6.x;
data0[wg_c_off + thread_c_off + 1 + (12*8)] = acc_frag_1_6.y;
data0[wg_c_off + thread_c_off + (8 * N) + 0 + (12*8)] = acc_frag_1_6.z;
data0[wg_c_off + thread_c_off + (8 * N) + 1 + (12*8)] = acc_frag_1_6.w;
data0[wg_c_off + thread_c_off + 0 + (13*8)] = acc_frag_1_7.x;
data0[wg_c_off + thread_c_off + 1 + (13*8)] = acc_frag_1_7.y;
data0[wg_c_off + thread_c_off + (8 * N) + 0 + (13*8)] = acc_frag_1_7.z;
data0[wg_c_off + thread_c_off + (8 * N) + 1 + (13*8)] = acc_frag_1_7.w;
}

View File

@@ -0,0 +1,439 @@
#define INFINITY (__int_as_float(0x7f800000))
#define NAN (__int_as_float(0x7fffffff))
#include <cuda_fp16.h>
#include <cuda_pipeline.h>
#define N_PAD 132
struct __align__(8) half4 { half x, y, z, w; };
__device__ half4 make_half4(half x, half y, half z, half w) { half4 r={x, y, z, w}; return r; }
struct __align__(16) half8 { half x, y, z, w, a, b, c, d; };
__device__ half8 make_half8(half x, half y, half z, half w, half a, half b, half c, half d) { half8 r={x, y, z, w, a, b, c, d}; return r; }
__device__ void __ldmatrix_a_elems(half8 *regs, half *smem) {
uint32_t reg0, reg1, reg2, reg3;
asm volatile(
"ldmatrix.sync.aligned.m8n8.x4.shared.b16 {%0, %1, %2, %3}, [%4];"
: "=r"(reg0), "=r"(reg1), "=r"(reg2), "=r"(reg3)
: "l"(__cvta_generic_to_shared(smem))
);
uint32_t *addr = reinterpret_cast<uint32_t*>(regs);
addr[0] = reg0;
addr[1] = reg1;
addr[2] = reg2;
addr[3] = reg3;
}
__device__ void __ldmatrix_b_elems(half4 *regs_lo, half4 *regs_hi, half *smem) {
uint32_t reg0, reg1, reg2, reg3;
asm volatile(
"ldmatrix.sync.aligned.m8n8.x4.trans.shared.b16 {%0, %1, %2, %3}, [%4];"
: "=r"(reg0), "=r"(reg1), "=r"(reg2), "=r"(reg3)
: "l"(__cvta_generic_to_shared(smem))
);
uint32_t *addr_lo = reinterpret_cast<uint32_t*>(regs_lo);
uint32_t *addr_hi = reinterpret_cast<uint32_t*>(regs_hi);
addr_lo[0] = reg0;
addr_lo[1] = reg1;
addr_hi[0] = reg2;
addr_hi[1] = reg3;
}
__device__ float4 __WMMA_8_16_16_half_float(half8 a, half4 b, float4 c) {
int *a_pk = (int *) (&a), *b_pk = (int *) (&b);
asm( "mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 { %0, %1, %2, %3 }, { %4, %5, %6, %7 }, { %8, %9 }, { %0, %1, %2, %3 };"
: "+f"(c.x), "+f"(c.y), "+f"(c.z), "+f"(c.w) : "r"(a_pk[0]), "r"(a_pk[1]), "r"(a_pk[2]), "r"(a_pk[3]), "r"(b_pk[0]), "r"(b_pk[1]) );
return c;
}
extern "C" __global__ void __launch_bounds__(128) wmma_example(float* data0, const half* data1, const half* data2, int N, int K) {
int grid_m = blockIdx.x; /* M//64 */
int grid_n = blockIdx.y; /* N//128 */
int threads = threadIdx.x; /* 128 */
int wg_m = (threads/64); // 0 or 1 for 1st and 3rd blocks of b_m=16xb_k=16 vs 2nd and 4th blocks
int wg_n = (threads/32)%2; // 0 or 1 for 1st, 3rd, 5th, 7th blocks of b_n=16xb_k=16 vs 2nd, 4th, 6th, 8th blocks - differs from triton
int wg_threads = threads%32;
int num_k_blocks = K / 64;
// load indexes
size_t global_a_off = ((grid_m * 64) * K) + ((threads % 8) * 8) + ((threads / 8) * K);
size_t global_b_off = (grid_n * 128) + ((threads % 16) * 8) + ((threads / 16) * N);
// swizzled smem store offsets - columns of smem are swizzled
// here's a link to a description of the triton: https://github.com/triton-lang/triton/discussions/2026#discussioncomment-6746579
// see also the thunderkittens impl: https://github.com/HazyResearch/ThunderKittens/blob/main/include/types/shared/st.cuh
size_t store_smem_a_off = ((threads / 8) * 64) + (((threads * 8) ^ threads) & 56); // r15
size_t store_smem_b_off = ((threads / 16) * 128) + (((threads / 16) * 8) ^ ((threads % 16) * 8)); // r19
// ldmatrix indices - 4x loads of 8x8 matrices by 32 threads
// threads 0-7 are row starts for A, 8-15 for B, 16-23 for C, 24-31 for D
// [ A | C ]
// [ - + - ]
// [ B | D ]
// swizzled ldmatrix
size_t load_smem_a_row = ((wg_m * 16) + (threads % 16)) * 64; // r293
size_t load_smem_a_phase = (threads / 16) % 2; // r4
size_t load_smem_b_row = (threads % 16) * 128; // r299
size_t load_smem_b_phase = (wg_n * 2) + (((threads / 16) % 2)); // r297 -- this differs from the generated triton kernel (swapped order)
size_t load_smem_a_0_k_0 = load_smem_a_row + (((load_smem_a_phase + 0) ^ (threads % 8)) * 8); // r38
size_t load_smem_a_1_k_0 = load_smem_a_0_k_0 + (32 * 64);
size_t load_smem_b_0_k_0 = load_smem_b_row + (((load_smem_b_phase + 0) ^ (threads % 8)) * 8);
size_t load_smem_b_1_k_0 = load_smem_b_row + (((load_smem_b_phase + 4) ^ (threads % 8)) * 8);
size_t load_smem_b_2_k_0 = load_smem_b_row + (((load_smem_b_phase + 8) ^ (threads % 8)) * 8);
size_t load_smem_b_3_k_0 = load_smem_b_row + (((load_smem_b_phase + 12) ^ (threads % 8)) * 8);
size_t load_smem_a_0_k_1 = load_smem_a_row + (((load_smem_a_phase + 2) ^ (threads % 8)) * 8); // r58 = r293 + r316;
size_t load_smem_a_1_k_1 = load_smem_a_0_k_1 + (32 * 64);
size_t load_smem_b_0_k_1 = load_smem_b_0_k_0 + (16 * 128);
size_t load_smem_b_1_k_1 = load_smem_b_1_k_0 + (16 * 128);
size_t load_smem_b_2_k_1 = load_smem_b_2_k_0 + (16 * 128);
size_t load_smem_b_3_k_1 = load_smem_b_3_k_0 + (16 * 128);
size_t load_smem_a_0_k_2 = load_smem_a_row + (((load_smem_a_phase + 4) ^ (threads % 8)) * 8); // r59 = r293 + r319;
size_t load_smem_a_1_k_2 = load_smem_a_0_k_2 + (32 * 64);
size_t load_smem_b_0_k_2 = load_smem_b_0_k_0 + (32 * 128);
size_t load_smem_b_1_k_2 = load_smem_b_1_k_0 + (32 * 128);
size_t load_smem_b_2_k_2 = load_smem_b_2_k_0 + (32 * 128);
size_t load_smem_b_3_k_2 = load_smem_b_3_k_0 + (32 * 128);
size_t load_smem_a_0_k_3 = load_smem_a_row + (((load_smem_a_phase + 6) ^ (threads % 8)) * 8); // r60 = r293 + r322;
size_t load_smem_a_1_k_3 = load_smem_a_0_k_3 + (32 * 64);
size_t load_smem_b_0_k_3 = load_smem_b_0_k_0 + (48 * 128);
size_t load_smem_b_1_k_3 = load_smem_b_1_k_0 + (48 * 128);
size_t load_smem_b_2_k_3 = load_smem_b_2_k_0 + (48 * 128);
size_t load_smem_b_3_k_3 = load_smem_b_3_k_0 + (48 * 128);
// create shared mem (A_1 8192 bytes, A_2 8192 bytes, B_1 16384 bytes, B2_16384 bytes)
__shared__ alignas(16) char smem[49152];
// create accs (16 WMMAs and 4 output elements each) and zero
float4 acc_frag_0_0 = make_float4(0.0f,0.0f,0.0f,0.0f);
float4 acc_frag_0_1 = make_float4(0.0f,0.0f,0.0f,0.0f);
float4 acc_frag_0_2 = make_float4(0.0f,0.0f,0.0f,0.0f);
float4 acc_frag_0_3 = make_float4(0.0f,0.0f,0.0f,0.0f);
float4 acc_frag_0_4 = make_float4(0.0f,0.0f,0.0f,0.0f);
float4 acc_frag_0_5 = make_float4(0.0f,0.0f,0.0f,0.0f);
float4 acc_frag_0_6 = make_float4(0.0f,0.0f,0.0f,0.0f);
float4 acc_frag_0_7 = make_float4(0.0f,0.0f,0.0f,0.0f);
float4 acc_frag_1_0 = make_float4(0.0f,0.0f,0.0f,0.0f);
float4 acc_frag_1_1 = make_float4(0.0f,0.0f,0.0f,0.0f);
float4 acc_frag_1_2 = make_float4(0.0f,0.0f,0.0f,0.0f);
float4 acc_frag_1_3 = make_float4(0.0f,0.0f,0.0f,0.0f);
float4 acc_frag_1_4 = make_float4(0.0f,0.0f,0.0f,0.0f);
float4 acc_frag_1_5 = make_float4(0.0f,0.0f,0.0f,0.0f);
float4 acc_frag_1_6 = make_float4(0.0f,0.0f,0.0f,0.0f);
float4 acc_frag_1_7 = make_float4(0.0f,0.0f,0.0f,0.0f);
// create registers for block A elements (2)
half8 a_frag_0;
half8 a_frag_1;
// create register for block B elements (8)
half4 b_frag_0;
half4 b_frag_1;
half4 b_frag_2;
half4 b_frag_3;
half4 b_frag_4;
half4 b_frag_5;
half4 b_frag_6;
half4 b_frag_7;
half *smem_a_even = (half *)(smem);
half *smem_a_odd = (half *)(smem + 8192);
half *smem_b_even = (half *)(smem + 16384);
half *smem_b_odd = (half *)(smem + 32768);
// https://developer.nvidia.com/blog/controlling-data-movement-to-boost-performance-on-ampere-architecture/
// https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#asynchronous-data-copies
// start first pre-fetch load A
__pipeline_memcpy_async(&smem_a_even[store_smem_a_off + ( 0)], &data1[global_a_off + ( 0)], 16);
__pipeline_memcpy_async(&smem_a_even[store_smem_a_off + (16*64)], &data1[global_a_off + (16*K)], 16);
__pipeline_memcpy_async(&smem_a_even[store_smem_a_off + (32*64)], &data1[global_a_off + (32*K)], 16);
__pipeline_memcpy_async(&smem_a_even[store_smem_a_off + (48*64)], &data1[global_a_off + (48*K)], 16);
// start first pre-fetch load B
__pipeline_memcpy_async(&smem_b_even[store_smem_b_off + ( 0)], &data2[global_b_off + ( 0)], 16);
__pipeline_memcpy_async(&smem_b_even[store_smem_b_off + ( 8*128)], &data2[global_b_off + ( 8*N)], 16);
__pipeline_memcpy_async(&smem_b_even[store_smem_b_off + (16*128)], &data2[global_b_off + (16*N)], 16);
__pipeline_memcpy_async(&smem_b_even[store_smem_b_off + (24*128)], &data2[global_b_off + (24*N)], 16);
__pipeline_memcpy_async(&smem_b_even[store_smem_b_off + (32*128)], &data2[global_b_off + (32*N)], 16);
__pipeline_memcpy_async(&smem_b_even[store_smem_b_off + (40*128)], &data2[global_b_off + (40*N)], 16);
__pipeline_memcpy_async(&smem_b_even[store_smem_b_off + (48*128)], &data2[global_b_off + (48*N)], 16);
__pipeline_memcpy_async(&smem_b_even[store_smem_b_off + (56*128)], &data2[global_b_off + (56*N)], 16);
__pipeline_commit();
global_a_off += 64;
global_b_off += 64 * N;
__syncthreads();
// start second pre-fetch load A
__pipeline_memcpy_async(&smem_a_odd[store_smem_a_off + ( 0)], &data1[global_a_off + ( 0)], 16);
__pipeline_memcpy_async(&smem_a_odd[store_smem_a_off + (16*64)], &data1[global_a_off + (16*K)], 16);
__pipeline_memcpy_async(&smem_a_odd[store_smem_a_off + (32*64)], &data1[global_a_off + (32*K)], 16);
__pipeline_memcpy_async(&smem_a_odd[store_smem_a_off + (48*64)], &data1[global_a_off + (48*K)], 16);
// start second pre-fetch load B
__pipeline_memcpy_async(&smem_b_odd[store_smem_b_off + ( 0)], &data2[global_b_off + ( 0)], 16);
__pipeline_memcpy_async(&smem_b_odd[store_smem_b_off + ( 8*128)], &data2[global_b_off + ( 8*N)], 16);
__pipeline_memcpy_async(&smem_b_odd[store_smem_b_off + (16*128)], &data2[global_b_off + (16*N)], 16);
__pipeline_memcpy_async(&smem_b_odd[store_smem_b_off + (24*128)], &data2[global_b_off + (24*N)], 16);
__pipeline_memcpy_async(&smem_b_odd[store_smem_b_off + (32*128)], &data2[global_b_off + (32*N)], 16);
__pipeline_memcpy_async(&smem_b_odd[store_smem_b_off + (40*128)], &data2[global_b_off + (40*N)], 16);
__pipeline_memcpy_async(&smem_b_odd[store_smem_b_off + (48*128)], &data2[global_b_off + (48*N)], 16);
__pipeline_memcpy_async(&smem_b_odd[store_smem_b_off + (56*128)], &data2[global_b_off + (56*N)], 16);
__pipeline_commit();
global_a_off += 64;
global_b_off += 64 * N;
// wait on needed prefetch value
__pipeline_wait_prior(0); // TODO: this enables fast iterations, but incorrect results with 1 (it shouldn't)
__syncthreads();
for (int block_k = 0; block_k < num_k_blocks; block_k++) {
// BLOCK_K==4: unroll 4 iterations of ldmatrix/wmma
half *smem_a_curr = (block_k % 2) ? smem_a_even : smem_a_odd;
half *smem_b_curr = (block_k % 2) ? smem_b_even : smem_b_odd;
// first load 16 K elements and 16 WMMAs: BLOCK_M==2 * BLOCK_N==8
__ldmatrix_a_elems(&a_frag_0, &smem_a_curr[load_smem_a_0_k_0]);
__ldmatrix_a_elems(&a_frag_1, &smem_a_curr[load_smem_a_1_k_0]);
__ldmatrix_b_elems(&b_frag_0, &b_frag_1, &smem_b_curr[load_smem_b_0_k_0]);
__ldmatrix_b_elems(&b_frag_2, &b_frag_3, &smem_b_curr[load_smem_b_1_k_0]);
__ldmatrix_b_elems(&b_frag_4, &b_frag_5, &smem_b_curr[load_smem_b_2_k_0]);
__ldmatrix_b_elems(&b_frag_6, &b_frag_7, &smem_b_curr[load_smem_b_3_k_0]);
acc_frag_0_0 = __WMMA_8_16_16_half_float(a_frag_0, b_frag_0, acc_frag_0_0);
acc_frag_0_1 = __WMMA_8_16_16_half_float(a_frag_0, b_frag_1, acc_frag_0_1);
acc_frag_0_2 = __WMMA_8_16_16_half_float(a_frag_0, b_frag_2, acc_frag_0_2);
acc_frag_0_3 = __WMMA_8_16_16_half_float(a_frag_0, b_frag_3, acc_frag_0_3);
acc_frag_0_4 = __WMMA_8_16_16_half_float(a_frag_0, b_frag_4, acc_frag_0_4);
acc_frag_0_5 = __WMMA_8_16_16_half_float(a_frag_0, b_frag_5, acc_frag_0_5);
acc_frag_0_6 = __WMMA_8_16_16_half_float(a_frag_0, b_frag_6, acc_frag_0_6);
acc_frag_0_7 = __WMMA_8_16_16_half_float(a_frag_0, b_frag_7, acc_frag_0_7);
acc_frag_1_0 = __WMMA_8_16_16_half_float(a_frag_1, b_frag_0, acc_frag_1_0);
acc_frag_1_1 = __WMMA_8_16_16_half_float(a_frag_1, b_frag_1, acc_frag_1_1);
acc_frag_1_2 = __WMMA_8_16_16_half_float(a_frag_1, b_frag_2, acc_frag_1_2);
acc_frag_1_3 = __WMMA_8_16_16_half_float(a_frag_1, b_frag_3, acc_frag_1_3);
acc_frag_1_4 = __WMMA_8_16_16_half_float(a_frag_1, b_frag_4, acc_frag_1_4);
acc_frag_1_5 = __WMMA_8_16_16_half_float(a_frag_1, b_frag_5, acc_frag_1_5);
acc_frag_1_6 = __WMMA_8_16_16_half_float(a_frag_1, b_frag_6, acc_frag_1_6);
acc_frag_1_7 = __WMMA_8_16_16_half_float(a_frag_1, b_frag_7, acc_frag_1_7);
// next 16 K elements
__ldmatrix_a_elems(&a_frag_0, &smem_a_curr[load_smem_a_0_k_1]);
__ldmatrix_a_elems(&a_frag_1, &smem_a_curr[load_smem_a_1_k_1]);
__ldmatrix_b_elems(&b_frag_0, &b_frag_1, &smem_b_curr[load_smem_b_0_k_1]);
__ldmatrix_b_elems(&b_frag_2, &b_frag_3, &smem_b_curr[load_smem_b_1_k_1]);
__ldmatrix_b_elems(&b_frag_4, &b_frag_5, &smem_b_curr[load_smem_b_2_k_1]);
__ldmatrix_b_elems(&b_frag_6, &b_frag_7, &smem_b_curr[load_smem_b_3_k_1]);
acc_frag_0_0 = __WMMA_8_16_16_half_float(a_frag_0, b_frag_0, acc_frag_0_0);
acc_frag_0_1 = __WMMA_8_16_16_half_float(a_frag_0, b_frag_1, acc_frag_0_1);
acc_frag_0_2 = __WMMA_8_16_16_half_float(a_frag_0, b_frag_2, acc_frag_0_2);
acc_frag_0_3 = __WMMA_8_16_16_half_float(a_frag_0, b_frag_3, acc_frag_0_3);
acc_frag_0_4 = __WMMA_8_16_16_half_float(a_frag_0, b_frag_4, acc_frag_0_4);
acc_frag_0_5 = __WMMA_8_16_16_half_float(a_frag_0, b_frag_5, acc_frag_0_5);
acc_frag_0_6 = __WMMA_8_16_16_half_float(a_frag_0, b_frag_6, acc_frag_0_6);
acc_frag_0_7 = __WMMA_8_16_16_half_float(a_frag_0, b_frag_7, acc_frag_0_7);
acc_frag_1_0 = __WMMA_8_16_16_half_float(a_frag_1, b_frag_0, acc_frag_1_0);
acc_frag_1_1 = __WMMA_8_16_16_half_float(a_frag_1, b_frag_1, acc_frag_1_1);
acc_frag_1_2 = __WMMA_8_16_16_half_float(a_frag_1, b_frag_2, acc_frag_1_2);
acc_frag_1_3 = __WMMA_8_16_16_half_float(a_frag_1, b_frag_3, acc_frag_1_3);
acc_frag_1_4 = __WMMA_8_16_16_half_float(a_frag_1, b_frag_4, acc_frag_1_4);
acc_frag_1_5 = __WMMA_8_16_16_half_float(a_frag_1, b_frag_5, acc_frag_1_5);
acc_frag_1_6 = __WMMA_8_16_16_half_float(a_frag_1, b_frag_6, acc_frag_1_6);
acc_frag_1_7 = __WMMA_8_16_16_half_float(a_frag_1, b_frag_7, acc_frag_1_7);
// next 16 K elements
__ldmatrix_a_elems(&a_frag_0, &smem_a_curr[load_smem_a_0_k_2]);
__ldmatrix_a_elems(&a_frag_1, &smem_a_curr[load_smem_a_1_k_2]);
__ldmatrix_b_elems(&b_frag_0, &b_frag_1, &smem_b_curr[load_smem_b_0_k_2]);
__ldmatrix_b_elems(&b_frag_2, &b_frag_3, &smem_b_curr[load_smem_b_1_k_2]);
__ldmatrix_b_elems(&b_frag_4, &b_frag_5, &smem_b_curr[load_smem_b_2_k_2]);
__ldmatrix_b_elems(&b_frag_6, &b_frag_7, &smem_b_curr[load_smem_b_3_k_2]);
acc_frag_0_0 = __WMMA_8_16_16_half_float(a_frag_0, b_frag_0, acc_frag_0_0);
acc_frag_0_1 = __WMMA_8_16_16_half_float(a_frag_0, b_frag_1, acc_frag_0_1);
acc_frag_0_2 = __WMMA_8_16_16_half_float(a_frag_0, b_frag_2, acc_frag_0_2);
acc_frag_0_3 = __WMMA_8_16_16_half_float(a_frag_0, b_frag_3, acc_frag_0_3);
acc_frag_0_4 = __WMMA_8_16_16_half_float(a_frag_0, b_frag_4, acc_frag_0_4);
acc_frag_0_5 = __WMMA_8_16_16_half_float(a_frag_0, b_frag_5, acc_frag_0_5);
acc_frag_0_6 = __WMMA_8_16_16_half_float(a_frag_0, b_frag_6, acc_frag_0_6);
acc_frag_0_7 = __WMMA_8_16_16_half_float(a_frag_0, b_frag_7, acc_frag_0_7);
acc_frag_1_0 = __WMMA_8_16_16_half_float(a_frag_1, b_frag_0, acc_frag_1_0);
acc_frag_1_1 = __WMMA_8_16_16_half_float(a_frag_1, b_frag_1, acc_frag_1_1);
acc_frag_1_2 = __WMMA_8_16_16_half_float(a_frag_1, b_frag_2, acc_frag_1_2);
acc_frag_1_3 = __WMMA_8_16_16_half_float(a_frag_1, b_frag_3, acc_frag_1_3);
acc_frag_1_4 = __WMMA_8_16_16_half_float(a_frag_1, b_frag_4, acc_frag_1_4);
acc_frag_1_5 = __WMMA_8_16_16_half_float(a_frag_1, b_frag_5, acc_frag_1_5);
acc_frag_1_6 = __WMMA_8_16_16_half_float(a_frag_1, b_frag_6, acc_frag_1_6);
acc_frag_1_7 = __WMMA_8_16_16_half_float(a_frag_1, b_frag_7, acc_frag_1_7);
// last 16 K elements
__ldmatrix_a_elems(&a_frag_0, &smem_a_curr[load_smem_a_0_k_3]);
__ldmatrix_a_elems(&a_frag_1, &smem_a_curr[load_smem_a_1_k_3]);
__ldmatrix_b_elems(&b_frag_0, &b_frag_1, &smem_b_curr[load_smem_b_0_k_3]);
__ldmatrix_b_elems(&b_frag_2, &b_frag_3, &smem_b_curr[load_smem_b_1_k_3]);
__ldmatrix_b_elems(&b_frag_4, &b_frag_5, &smem_b_curr[load_smem_b_2_k_3]);
__ldmatrix_b_elems(&b_frag_6, &b_frag_7, &smem_b_curr[load_smem_b_3_k_3]);
acc_frag_0_0 = __WMMA_8_16_16_half_float(a_frag_0, b_frag_0, acc_frag_0_0);
acc_frag_0_1 = __WMMA_8_16_16_half_float(a_frag_0, b_frag_1, acc_frag_0_1);
acc_frag_0_2 = __WMMA_8_16_16_half_float(a_frag_0, b_frag_2, acc_frag_0_2);
acc_frag_0_3 = __WMMA_8_16_16_half_float(a_frag_0, b_frag_3, acc_frag_0_3);
acc_frag_0_4 = __WMMA_8_16_16_half_float(a_frag_0, b_frag_4, acc_frag_0_4);
acc_frag_0_5 = __WMMA_8_16_16_half_float(a_frag_0, b_frag_5, acc_frag_0_5);
acc_frag_0_6 = __WMMA_8_16_16_half_float(a_frag_0, b_frag_6, acc_frag_0_6);
acc_frag_0_7 = __WMMA_8_16_16_half_float(a_frag_0, b_frag_7, acc_frag_0_7);
acc_frag_1_0 = __WMMA_8_16_16_half_float(a_frag_1, b_frag_0, acc_frag_1_0);
acc_frag_1_1 = __WMMA_8_16_16_half_float(a_frag_1, b_frag_1, acc_frag_1_1);
acc_frag_1_2 = __WMMA_8_16_16_half_float(a_frag_1, b_frag_2, acc_frag_1_2);
acc_frag_1_3 = __WMMA_8_16_16_half_float(a_frag_1, b_frag_3, acc_frag_1_3);
acc_frag_1_4 = __WMMA_8_16_16_half_float(a_frag_1, b_frag_4, acc_frag_1_4);
acc_frag_1_5 = __WMMA_8_16_16_half_float(a_frag_1, b_frag_5, acc_frag_1_5);
acc_frag_1_6 = __WMMA_8_16_16_half_float(a_frag_1, b_frag_6, acc_frag_1_6);
acc_frag_1_7 = __WMMA_8_16_16_half_float(a_frag_1, b_frag_7, acc_frag_1_7);
// prefetch next iteration if needed
__syncthreads();
if (block_k < (num_k_blocks-2)) {
__pipeline_memcpy_async(&smem_a_curr[store_smem_a_off + ( 0)], &data1[global_a_off + ( 0)], 16);
__pipeline_memcpy_async(&smem_a_curr[store_smem_a_off + (16*64)], &data1[global_a_off + (16*K)], 16);
__pipeline_memcpy_async(&smem_a_curr[store_smem_a_off + (32*64)], &data1[global_a_off + (32*K)], 16);
__pipeline_memcpy_async(&smem_a_curr[store_smem_a_off + (48*64)], &data1[global_a_off + (48*K)], 16);
__pipeline_memcpy_async(&smem_b_curr[store_smem_b_off + ( 0)], &data2[global_b_off + ( 0)], 16);
__pipeline_memcpy_async(&smem_b_curr[store_smem_b_off + ( 8*128)], &data2[global_b_off + ( 8*N)], 16);
__pipeline_memcpy_async(&smem_b_curr[store_smem_b_off + (16*128)], &data2[global_b_off + (16*N)], 16);
__pipeline_memcpy_async(&smem_b_curr[store_smem_b_off + (24*128)], &data2[global_b_off + (24*N)], 16);
__pipeline_memcpy_async(&smem_b_curr[store_smem_b_off + (32*128)], &data2[global_b_off + (32*N)], 16);
__pipeline_memcpy_async(&smem_b_curr[store_smem_b_off + (40*128)], &data2[global_b_off + (40*N)], 16);
__pipeline_memcpy_async(&smem_b_curr[store_smem_b_off + (48*128)], &data2[global_b_off + (48*N)], 16);
__pipeline_memcpy_async(&smem_b_curr[store_smem_b_off + (56*128)], &data2[global_b_off + (56*N)], 16);
global_a_off += 64;
global_b_off += 64 * N;
}
__pipeline_commit();
if (block_k < num_k_blocks-1) {
__pipeline_wait_prior(1);
__syncthreads();
}
}
// write accumulators to output
__pipeline_wait_prior(0);
__syncthreads();
// store registers to smem first, then read back to do float4 writes to global
float *smem_d = (float *)(smem);
size_t smem_d_off = (wg_m * 16 * N_PAD) + (wg_n * 16) + ((wg_threads % 4) * 2) + (((wg_threads / 4) % 8) * N_PAD);
smem_d[smem_d_off + 0 + ( 0*8) ] = acc_frag_0_0.x;
smem_d[smem_d_off + 1 + ( 0*8) ] = acc_frag_0_0.y;
smem_d[smem_d_off + 0 + ( 0*8) + (8*N_PAD)] = acc_frag_0_0.z;
smem_d[smem_d_off + 1 + ( 0*8) + (8*N_PAD)] = acc_frag_0_0.w;
smem_d[smem_d_off + 0 + ( 1*8) ] = acc_frag_0_1.x;
smem_d[smem_d_off + 1 + ( 1*8) ] = acc_frag_0_1.y;
smem_d[smem_d_off + 0 + ( 1*8) + (8*N_PAD)] = acc_frag_0_1.z;
smem_d[smem_d_off + 1 + ( 1*8) + (8*N_PAD)] = acc_frag_0_1.w;
smem_d[smem_d_off + 0 + ( 4*8) ] = acc_frag_0_2.x;
smem_d[smem_d_off + 1 + ( 4*8) ] = acc_frag_0_2.y;
smem_d[smem_d_off + 0 + ( 4*8) + (8*N_PAD)] = acc_frag_0_2.z;
smem_d[smem_d_off + 1 + ( 4*8) + (8*N_PAD)] = acc_frag_0_2.w;
smem_d[smem_d_off + 0 + ( 5*8) ] = acc_frag_0_3.x;
smem_d[smem_d_off + 1 + ( 5*8) ] = acc_frag_0_3.y;
smem_d[smem_d_off + 0 + ( 5*8) + (8*N_PAD)] = acc_frag_0_3.z;
smem_d[smem_d_off + 1 + ( 5*8) + (8*N_PAD)] = acc_frag_0_3.w;
smem_d[smem_d_off + 0 + ( 8*8) ] = acc_frag_0_4.x;
smem_d[smem_d_off + 1 + ( 8*8) ] = acc_frag_0_4.y;
smem_d[smem_d_off + 0 + ( 8*8) + (8*N_PAD)] = acc_frag_0_4.z;
smem_d[smem_d_off + 1 + ( 8*8) + (8*N_PAD)] = acc_frag_0_4.w;
smem_d[smem_d_off + 0 + ( 9*8) ] = acc_frag_0_5.x;
smem_d[smem_d_off + 1 + ( 9*8) ] = acc_frag_0_5.y;
smem_d[smem_d_off + 0 + ( 9*8) + (8*N_PAD)] = acc_frag_0_5.z;
smem_d[smem_d_off + 1 + ( 9*8) + (8*N_PAD)] = acc_frag_0_5.w;
smem_d[smem_d_off + 0 + (12*8) ] = acc_frag_0_6.x;
smem_d[smem_d_off + 1 + (12*8) ] = acc_frag_0_6.y;
smem_d[smem_d_off + 0 + (12*8) + (8*N_PAD)] = acc_frag_0_6.z;
smem_d[smem_d_off + 1 + (12*8) + (8*N_PAD)] = acc_frag_0_6.w;
smem_d[smem_d_off + 0 + (13*8) ] = acc_frag_0_7.x;
smem_d[smem_d_off + 1 + (13*8) ] = acc_frag_0_7.y;
smem_d[smem_d_off + 0 + (13*8) + (8*N_PAD)] = acc_frag_0_7.z;
smem_d[smem_d_off + 1 + (13*8) + (8*N_PAD)] = acc_frag_0_7.w;
__syncthreads();
size_t load_smem_d_off = ((threads % 32) * 4) + ((threads / 32) * N_PAD);
float4 d_0_0 = *((float4 *)(smem_d + load_smem_d_off + ( 0 * N_PAD)));
float4 d_0_1 = *((float4 *)(smem_d + load_smem_d_off + ( 4 * N_PAD)));
float4 d_0_2 = *((float4 *)(smem_d + load_smem_d_off + ( 8 * N_PAD)));
float4 d_0_3 = *((float4 *)(smem_d + load_smem_d_off + (12 * N_PAD)));
float4 d_0_4 = *((float4 *)(smem_d + load_smem_d_off + (16 * N_PAD)));
float4 d_0_5 = *((float4 *)(smem_d + load_smem_d_off + (20 * N_PAD)));
float4 d_0_6 = *((float4 *)(smem_d + load_smem_d_off + (24 * N_PAD)));
float4 d_0_7 = *((float4 *)(smem_d + load_smem_d_off + (28 * N_PAD)));
__syncthreads();
smem_d[smem_d_off + 0 + ( 0*8) ] = acc_frag_1_0.x;
smem_d[smem_d_off + 1 + ( 0*8) ] = acc_frag_1_0.y;
smem_d[smem_d_off + 0 + ( 0*8) + (8*N_PAD)] = acc_frag_1_0.z;
smem_d[smem_d_off + 1 + ( 0*8) + (8*N_PAD)] = acc_frag_1_0.w;
smem_d[smem_d_off + 0 + ( 1*8) ] = acc_frag_1_1.x;
smem_d[smem_d_off + 1 + ( 1*8) ] = acc_frag_1_1.y;
smem_d[smem_d_off + 0 + ( 1*8) + (8*N_PAD)] = acc_frag_1_1.z;
smem_d[smem_d_off + 1 + ( 1*8) + (8*N_PAD)] = acc_frag_1_1.w;
smem_d[smem_d_off + 0 + ( 4*8) ] = acc_frag_1_2.x;
smem_d[smem_d_off + 1 + ( 4*8) ] = acc_frag_1_2.y;
smem_d[smem_d_off + 0 + ( 4*8) + (8*N_PAD)] = acc_frag_1_2.z;
smem_d[smem_d_off + 1 + ( 4*8) + (8*N_PAD)] = acc_frag_1_2.w;
smem_d[smem_d_off + 0 + ( 5*8) ] = acc_frag_1_3.x;
smem_d[smem_d_off + 1 + ( 5*8) ] = acc_frag_1_3.y;
smem_d[smem_d_off + 0 + ( 5*8) + (8*N_PAD)] = acc_frag_1_3.z;
smem_d[smem_d_off + 1 + ( 5*8) + (8*N_PAD)] = acc_frag_1_3.w;
smem_d[smem_d_off + 0 + ( 8*8) ] = acc_frag_1_4.x;
smem_d[smem_d_off + 1 + ( 8*8) ] = acc_frag_1_4.y;
smem_d[smem_d_off + 0 + ( 8*8) + (8*N_PAD)] = acc_frag_1_4.z;
smem_d[smem_d_off + 1 + ( 8*8) + (8*N_PAD)] = acc_frag_1_4.w;
smem_d[smem_d_off + 0 + ( 9*8) ] = acc_frag_1_5.x;
smem_d[smem_d_off + 1 + ( 9*8) ] = acc_frag_1_5.y;
smem_d[smem_d_off + 0 + ( 9*8) + (8*N_PAD)] = acc_frag_1_5.z;
smem_d[smem_d_off + 1 + ( 9*8) + (8*N_PAD)] = acc_frag_1_5.w;
smem_d[smem_d_off + 0 + (12*8) ] = acc_frag_1_6.x;
smem_d[smem_d_off + 1 + (12*8) ] = acc_frag_1_6.y;
smem_d[smem_d_off + 0 + (12*8) + (8*N_PAD)] = acc_frag_1_6.z;
smem_d[smem_d_off + 1 + (12*8) + (8*N_PAD)] = acc_frag_1_6.w;
smem_d[smem_d_off + 0 + (13*8) ] = acc_frag_1_7.x;
smem_d[smem_d_off + 1 + (13*8) ] = acc_frag_1_7.y;
smem_d[smem_d_off + 0 + (13*8) + (8*N_PAD)] = acc_frag_1_7.z;
smem_d[smem_d_off + 1 + (13*8) + (8*N_PAD)] = acc_frag_1_7.w;
__syncthreads();
float4 d_1_0 = *((float4 *)(smem_d + load_smem_d_off + ( 0 * N_PAD)));
float4 d_1_1 = *((float4 *)(smem_d + load_smem_d_off + ( 4 * N_PAD)));
float4 d_1_2 = *((float4 *)(smem_d + load_smem_d_off + ( 8 * N_PAD)));
float4 d_1_3 = *((float4 *)(smem_d + load_smem_d_off + (12 * N_PAD)));
float4 d_1_4 = *((float4 *)(smem_d + load_smem_d_off + (16 * N_PAD)));
float4 d_1_5 = *((float4 *)(smem_d + load_smem_d_off + (20 * N_PAD)));
float4 d_1_6 = *((float4 *)(smem_d + load_smem_d_off + (24 * N_PAD)));
float4 d_1_7 = *((float4 *)(smem_d + load_smem_d_off + (28 * N_PAD)));
__syncthreads();
float *global_d = &data0[((grid_m * 64) * N) + (grid_n * 128) + ((threads % 32) * 4) + ((threads / 32) * N)];
*((float4 *)(global_d + 0*N)) = d_0_0;
*((float4 *)(global_d + 4*N)) = d_0_1;
*((float4 *)(global_d + 8*N)) = d_0_2;
*((float4 *)(global_d + 12*N)) = d_0_3;
*((float4 *)(global_d + 16*N)) = d_0_4;
*((float4 *)(global_d + 20*N)) = d_0_5;
*((float4 *)(global_d + 24*N)) = d_0_6;
*((float4 *)(global_d + 28*N)) = d_0_7;
*((float4 *)(global_d + 32*N)) = d_1_0;
*((float4 *)(global_d + 36*N)) = d_1_1;
*((float4 *)(global_d + 40*N)) = d_1_2;
*((float4 *)(global_d + 44*N)) = d_1_3;
*((float4 *)(global_d + 48*N)) = d_1_4;
*((float4 *)(global_d + 52*N)) = d_1_5;
*((float4 *)(global_d + 56*N)) = d_1_6;
*((float4 *)(global_d + 60*N)) = d_1_7;
}

View File

@@ -0,0 +1,371 @@
#define INFINITY (__int_as_float(0x7f800000))
#define NAN (__int_as_float(0x7fffffff))
#include <cuda_fp16.h>
#include <cuda_pipeline.h>
#define N_PAD 132
struct __align__(8) half4 { half x, y, z, w; };
__device__ half4 make_half4(half x, half y, half z, half w) { half4 r={x, y, z, w}; return r; }
struct __align__(16) half8 { half x, y, z, w, a, b, c, d; };
__device__ half8 make_half8(half x, half y, half z, half w, half a, half b, half c, half d) { half8 r={x, y, z, w, a, b, c, d}; return r; }
__device__ void __ldmatrix_a_elems(half8 *regs, half *smem) {
uint32_t reg0, reg1, reg2, reg3;
asm volatile(
"ldmatrix.sync.aligned.m8n8.x4.shared.b16 {%0, %1, %2, %3}, [%4];"
: "=r"(reg0), "=r"(reg1), "=r"(reg2), "=r"(reg3)
: "l"(__cvta_generic_to_shared(smem))
);
uint32_t *addr = reinterpret_cast<uint32_t*>(regs);
addr[0] = reg0;
addr[1] = reg1;
addr[2] = reg2;
addr[3] = reg3;
}
__device__ void __ldmatrix_b_elems(half4 *regs_lo, half4 *regs_hi, half *smem) {
uint32_t reg0, reg1, reg2, reg3;
asm volatile(
"ldmatrix.sync.aligned.m8n8.x4.trans.shared.b16 {%0, %1, %2, %3}, [%4];"
: "=r"(reg0), "=r"(reg1), "=r"(reg2), "=r"(reg3)
: "l"(__cvta_generic_to_shared(smem))
);
uint32_t *addr_lo = reinterpret_cast<uint32_t*>(regs_lo);
uint32_t *addr_hi = reinterpret_cast<uint32_t*>(regs_hi);
addr_lo[0] = reg0;
addr_lo[1] = reg1;
addr_hi[0] = reg2;
addr_hi[1] = reg3;
}
__device__ float4 __WMMA_8_16_16_half_float(half8 a, half4 b, float4 c) {
int *a_pk = (int *) (&a), *b_pk = (int *) (&b);
asm( "mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 { %0, %1, %2, %3 }, { %4, %5, %6, %7 }, { %8, %9 }, { %0, %1, %2, %3 };"
: "+f"(c.x), "+f"(c.y), "+f"(c.z), "+f"(c.w) : "r"(a_pk[0]), "r"(a_pk[1]), "r"(a_pk[2]), "r"(a_pk[3]), "r"(b_pk[0]), "r"(b_pk[1]) );
return c;
}
extern "C" __global__ void __launch_bounds__(128) wmma_example(float* data0, const half* data1, const half* data2, int N, int K) {
int grid_m = blockIdx.x; /* M//64 */
int grid_n = blockIdx.y; /* N//128 */
int threads = threadIdx.x; /* 128 */
int wg_m = (threads/64); // 0 or 1 for 1st and 3rd blocks of b_m=16xb_k=16 vs 2nd and 4th blocks
int wg_n = (threads/32)%2; // 0 or 1 for 1st, 3rd, 5th, 7th blocks of b_n=16xb_k=16 vs 2nd, 4th, 6th, 8th blocks - differs from triton
int wg_threads = threads%32;
int num_k_blocks = K / 64;
// load indexes
size_t global_a_off = ((grid_m * 64) * K) + ((threads % 8) * 8) + ((threads / 8) * K);
size_t global_b_off = (grid_n * 128) + ((threads % 16) * 8) + ((threads / 16) * N);
// swizzled smem store offsets - columns of smem are swizzled
// here's a link to a description of the triton: https://github.com/triton-lang/triton/discussions/2026#discussioncomment-6746579
// see also the thunderkittens impl: https://github.com/HazyResearch/ThunderKittens/blob/main/include/types/shared/st.cuh
size_t store_smem_a_off = ((threads / 8) * 64) + (((threads * 8) ^ threads) & 56); // r15
size_t store_smem_b_off = ((threads / 16) * 128) + (((threads / 16) * 8) ^ ((threads % 16) * 8)); // r19
// ldmatrix indices
// threads 0-7 are row starts for A, 8-15 for B, 16-23 for C, 24-31 for D
// [ A | C ]
// [ - + - ]
// [ B | D ]
// swizzled ldmatrix
size_t load_smem_a_row = ((wg_m * 16) + (threads % 16)) * 64; // r293
size_t load_smem_a_phase = (threads / 16) % 2; // r4
size_t load_smem_b_row = (threads % 16) * 128; // r299
size_t load_smem_b_phase = (wg_n * 2) + (((threads / 16) % 2)); // r297 -- this differs from the generated triton kernel (swapped order)
size_t load_smem_a_0_k_0 = load_smem_a_row + (((load_smem_a_phase + 0) ^ (threads % 8)) * 8); // r38
size_t load_smem_a_1_k_0 = load_smem_a_0_k_0 + (32 * 64);
size_t load_smem_b_0_k_0 = load_smem_b_row + (((load_smem_b_phase + 0) ^ (threads % 8)) * 8);
size_t load_smem_b_1_k_0 = load_smem_b_row + (((load_smem_b_phase + 4) ^ (threads % 8)) * 8);
size_t load_smem_b_2_k_0 = load_smem_b_row + (((load_smem_b_phase + 8) ^ (threads % 8)) * 8);
size_t load_smem_b_3_k_0 = load_smem_b_row + (((load_smem_b_phase + 12) ^ (threads % 8)) * 8);
size_t load_smem_a_0_k_1 = load_smem_a_row + (((load_smem_a_phase + 2) ^ (threads % 8)) * 8); // r58 = r293 + r316;
size_t load_smem_a_1_k_1 = load_smem_a_0_k_1 + (32 * 64);
size_t load_smem_b_0_k_1 = load_smem_b_0_k_0 + (16 * 128);
size_t load_smem_b_1_k_1 = load_smem_b_1_k_0 + (16 * 128);
size_t load_smem_b_2_k_1 = load_smem_b_2_k_0 + (16 * 128);
size_t load_smem_b_3_k_1 = load_smem_b_3_k_0 + (16 * 128);
size_t load_smem_a_0_k_2 = load_smem_a_row + (((load_smem_a_phase + 4) ^ (threads % 8)) * 8); // r59 = r293 + r319;
size_t load_smem_a_1_k_2 = load_smem_a_0_k_2 + (32 * 64);
size_t load_smem_b_0_k_2 = load_smem_b_0_k_0 + (32 * 128);
size_t load_smem_b_1_k_2 = load_smem_b_1_k_0 + (32 * 128);
size_t load_smem_b_2_k_2 = load_smem_b_2_k_0 + (32 * 128);
size_t load_smem_b_3_k_2 = load_smem_b_3_k_0 + (32 * 128);
size_t load_smem_a_0_k_3 = load_smem_a_row + (((load_smem_a_phase + 6) ^ (threads % 8)) * 8); // r60 = r293 + r322;
size_t load_smem_a_1_k_3 = load_smem_a_0_k_3 + (32 * 64);
size_t load_smem_b_0_k_3 = load_smem_b_0_k_0 + (48 * 128);
size_t load_smem_b_1_k_3 = load_smem_b_1_k_0 + (48 * 128);
size_t load_smem_b_2_k_3 = load_smem_b_2_k_0 + (48 * 128);
size_t load_smem_b_3_k_3 = load_smem_b_3_k_0 + (48 * 128);
// create shared mem (A 8192 bytes, B 16384 bytes)
__shared__ alignas(16) char smem[24576];
// create accs (16 WMMAs and 4 output elements each) and zero
float4 acc_frag_0_0 = make_float4(0.0f,0.0f,0.0f,0.0f);
float4 acc_frag_0_1 = make_float4(0.0f,0.0f,0.0f,0.0f);
float4 acc_frag_0_2 = make_float4(0.0f,0.0f,0.0f,0.0f);
float4 acc_frag_0_3 = make_float4(0.0f,0.0f,0.0f,0.0f);
float4 acc_frag_0_4 = make_float4(0.0f,0.0f,0.0f,0.0f);
float4 acc_frag_0_5 = make_float4(0.0f,0.0f,0.0f,0.0f);
float4 acc_frag_0_6 = make_float4(0.0f,0.0f,0.0f,0.0f);
float4 acc_frag_0_7 = make_float4(0.0f,0.0f,0.0f,0.0f);
float4 acc_frag_1_0 = make_float4(0.0f,0.0f,0.0f,0.0f);
float4 acc_frag_1_1 = make_float4(0.0f,0.0f,0.0f,0.0f);
float4 acc_frag_1_2 = make_float4(0.0f,0.0f,0.0f,0.0f);
float4 acc_frag_1_3 = make_float4(0.0f,0.0f,0.0f,0.0f);
float4 acc_frag_1_4 = make_float4(0.0f,0.0f,0.0f,0.0f);
float4 acc_frag_1_5 = make_float4(0.0f,0.0f,0.0f,0.0f);
float4 acc_frag_1_6 = make_float4(0.0f,0.0f,0.0f,0.0f);
float4 acc_frag_1_7 = make_float4(0.0f,0.0f,0.0f,0.0f);
// create registers for block A elements (2)
half8 a_frag_0;
half8 a_frag_1;
// create register for block B elements (8)
half4 b_frag_0;
half4 b_frag_1;
half4 b_frag_2;
half4 b_frag_3;
half4 b_frag_4;
half4 b_frag_5;
half4 b_frag_6;
half4 b_frag_7;
half *smem_a = (half *)(smem);
half *smem_b = (half *)(smem + 8192);
// https://developer.nvidia.com/blog/controlling-data-movement-to-boost-performance-on-ampere-architecture/
// https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#asynchronous-data-copies
// start first pre-fetch load A
__pipeline_memcpy_async(&smem_a[store_smem_a_off + ( 0)], &data1[global_a_off + ( 0)], 16);
__pipeline_memcpy_async(&smem_a[store_smem_a_off + (16*64)], &data1[global_a_off + (16*K)], 16);
__pipeline_memcpy_async(&smem_a[store_smem_a_off + (32*64)], &data1[global_a_off + (32*K)], 16);
__pipeline_memcpy_async(&smem_a[store_smem_a_off + (48*64)], &data1[global_a_off + (48*K)], 16);
// start first pre-fetch load B
__pipeline_memcpy_async(&smem_b[store_smem_b_off + ( 0)], &data2[global_b_off + ( 0)], 16);
__pipeline_memcpy_async(&smem_b[store_smem_b_off + ( 8*128)], &data2[global_b_off + ( 8*N)], 16);
__pipeline_memcpy_async(&smem_b[store_smem_b_off + (16*128)], &data2[global_b_off + (16*N)], 16);
__pipeline_memcpy_async(&smem_b[store_smem_b_off + (24*128)], &data2[global_b_off + (24*N)], 16);
__pipeline_memcpy_async(&smem_b[store_smem_b_off + (32*128)], &data2[global_b_off + (32*N)], 16);
__pipeline_memcpy_async(&smem_b[store_smem_b_off + (40*128)], &data2[global_b_off + (40*N)], 16);
__pipeline_memcpy_async(&smem_b[store_smem_b_off + (48*128)], &data2[global_b_off + (48*N)], 16);
__pipeline_memcpy_async(&smem_b[store_smem_b_off + (56*128)], &data2[global_b_off + (56*N)], 16);
__pipeline_commit();
global_a_off += 64;
global_b_off += 64 * N;
__syncthreads();
for (int block_k = 0; block_k < num_k_blocks; block_k++) {
// wait on needed prefetch value
__pipeline_wait_prior(0);
__syncthreads();
// BLOCK_K==4: unroll 4 iterations of ldmatrix/wmma
half *smem_a_curr = smem_a;
half *smem_b_curr = smem_b;
// first load 16 K elements and 16 WMMAs: BLOCK_M==2 * BLOCK_N==8
__ldmatrix_a_elems(&a_frag_0, &smem_a_curr[load_smem_a_0_k_0]);
__ldmatrix_a_elems(&a_frag_1, &smem_a_curr[load_smem_a_1_k_0]);
__ldmatrix_b_elems(&b_frag_0, &b_frag_1, &smem_b_curr[load_smem_b_0_k_0]);
__ldmatrix_b_elems(&b_frag_2, &b_frag_3, &smem_b_curr[load_smem_b_1_k_0]);
__ldmatrix_b_elems(&b_frag_4, &b_frag_5, &smem_b_curr[load_smem_b_2_k_0]);
__ldmatrix_b_elems(&b_frag_6, &b_frag_7, &smem_b_curr[load_smem_b_3_k_0]);
acc_frag_0_0 = __WMMA_8_16_16_half_float(a_frag_0, b_frag_0, acc_frag_0_0);
acc_frag_0_1 = __WMMA_8_16_16_half_float(a_frag_0, b_frag_1, acc_frag_0_1);
acc_frag_0_2 = __WMMA_8_16_16_half_float(a_frag_0, b_frag_2, acc_frag_0_2);
acc_frag_0_3 = __WMMA_8_16_16_half_float(a_frag_0, b_frag_3, acc_frag_0_3);
acc_frag_0_4 = __WMMA_8_16_16_half_float(a_frag_0, b_frag_4, acc_frag_0_4);
acc_frag_0_5 = __WMMA_8_16_16_half_float(a_frag_0, b_frag_5, acc_frag_0_5);
acc_frag_0_6 = __WMMA_8_16_16_half_float(a_frag_0, b_frag_6, acc_frag_0_6);
acc_frag_0_7 = __WMMA_8_16_16_half_float(a_frag_0, b_frag_7, acc_frag_0_7);
acc_frag_1_0 = __WMMA_8_16_16_half_float(a_frag_1, b_frag_0, acc_frag_1_0);
acc_frag_1_1 = __WMMA_8_16_16_half_float(a_frag_1, b_frag_1, acc_frag_1_1);
acc_frag_1_2 = __WMMA_8_16_16_half_float(a_frag_1, b_frag_2, acc_frag_1_2);
acc_frag_1_3 = __WMMA_8_16_16_half_float(a_frag_1, b_frag_3, acc_frag_1_3);
acc_frag_1_4 = __WMMA_8_16_16_half_float(a_frag_1, b_frag_4, acc_frag_1_4);
acc_frag_1_5 = __WMMA_8_16_16_half_float(a_frag_1, b_frag_5, acc_frag_1_5);
acc_frag_1_6 = __WMMA_8_16_16_half_float(a_frag_1, b_frag_6, acc_frag_1_6);
acc_frag_1_7 = __WMMA_8_16_16_half_float(a_frag_1, b_frag_7, acc_frag_1_7);
// next 16 K elements
__ldmatrix_a_elems(&a_frag_0, &smem_a_curr[load_smem_a_0_k_1]);
__ldmatrix_a_elems(&a_frag_1, &smem_a_curr[load_smem_a_1_k_1]);
__ldmatrix_b_elems(&b_frag_0, &b_frag_1, &smem_b_curr[load_smem_b_0_k_1]);
__ldmatrix_b_elems(&b_frag_2, &b_frag_3, &smem_b_curr[load_smem_b_1_k_1]);
__ldmatrix_b_elems(&b_frag_4, &b_frag_5, &smem_b_curr[load_smem_b_2_k_1]);
__ldmatrix_b_elems(&b_frag_6, &b_frag_7, &smem_b_curr[load_smem_b_3_k_1]);
acc_frag_0_0 = __WMMA_8_16_16_half_float(a_frag_0, b_frag_0, acc_frag_0_0);
acc_frag_0_1 = __WMMA_8_16_16_half_float(a_frag_0, b_frag_1, acc_frag_0_1);
acc_frag_0_2 = __WMMA_8_16_16_half_float(a_frag_0, b_frag_2, acc_frag_0_2);
acc_frag_0_3 = __WMMA_8_16_16_half_float(a_frag_0, b_frag_3, acc_frag_0_3);
acc_frag_0_4 = __WMMA_8_16_16_half_float(a_frag_0, b_frag_4, acc_frag_0_4);
acc_frag_0_5 = __WMMA_8_16_16_half_float(a_frag_0, b_frag_5, acc_frag_0_5);
acc_frag_0_6 = __WMMA_8_16_16_half_float(a_frag_0, b_frag_6, acc_frag_0_6);
acc_frag_0_7 = __WMMA_8_16_16_half_float(a_frag_0, b_frag_7, acc_frag_0_7);
acc_frag_1_0 = __WMMA_8_16_16_half_float(a_frag_1, b_frag_0, acc_frag_1_0);
acc_frag_1_1 = __WMMA_8_16_16_half_float(a_frag_1, b_frag_1, acc_frag_1_1);
acc_frag_1_2 = __WMMA_8_16_16_half_float(a_frag_1, b_frag_2, acc_frag_1_2);
acc_frag_1_3 = __WMMA_8_16_16_half_float(a_frag_1, b_frag_3, acc_frag_1_3);
acc_frag_1_4 = __WMMA_8_16_16_half_float(a_frag_1, b_frag_4, acc_frag_1_4);
acc_frag_1_5 = __WMMA_8_16_16_half_float(a_frag_1, b_frag_5, acc_frag_1_5);
acc_frag_1_6 = __WMMA_8_16_16_half_float(a_frag_1, b_frag_6, acc_frag_1_6);
acc_frag_1_7 = __WMMA_8_16_16_half_float(a_frag_1, b_frag_7, acc_frag_1_7);
// next 16 K elements
__ldmatrix_a_elems(&a_frag_0, &smem_a_curr[load_smem_a_0_k_2]);
__ldmatrix_a_elems(&a_frag_1, &smem_a_curr[load_smem_a_1_k_2]);
__ldmatrix_b_elems(&b_frag_0, &b_frag_1, &smem_b_curr[load_smem_b_0_k_2]);
__ldmatrix_b_elems(&b_frag_2, &b_frag_3, &smem_b_curr[load_smem_b_1_k_2]);
__ldmatrix_b_elems(&b_frag_4, &b_frag_5, &smem_b_curr[load_smem_b_2_k_2]);
__ldmatrix_b_elems(&b_frag_6, &b_frag_7, &smem_b_curr[load_smem_b_3_k_2]);
acc_frag_0_0 = __WMMA_8_16_16_half_float(a_frag_0, b_frag_0, acc_frag_0_0);
acc_frag_0_1 = __WMMA_8_16_16_half_float(a_frag_0, b_frag_1, acc_frag_0_1);
acc_frag_0_2 = __WMMA_8_16_16_half_float(a_frag_0, b_frag_2, acc_frag_0_2);
acc_frag_0_3 = __WMMA_8_16_16_half_float(a_frag_0, b_frag_3, acc_frag_0_3);
acc_frag_0_4 = __WMMA_8_16_16_half_float(a_frag_0, b_frag_4, acc_frag_0_4);
acc_frag_0_5 = __WMMA_8_16_16_half_float(a_frag_0, b_frag_5, acc_frag_0_5);
acc_frag_0_6 = __WMMA_8_16_16_half_float(a_frag_0, b_frag_6, acc_frag_0_6);
acc_frag_0_7 = __WMMA_8_16_16_half_float(a_frag_0, b_frag_7, acc_frag_0_7);
acc_frag_1_0 = __WMMA_8_16_16_half_float(a_frag_1, b_frag_0, acc_frag_1_0);
acc_frag_1_1 = __WMMA_8_16_16_half_float(a_frag_1, b_frag_1, acc_frag_1_1);
acc_frag_1_2 = __WMMA_8_16_16_half_float(a_frag_1, b_frag_2, acc_frag_1_2);
acc_frag_1_3 = __WMMA_8_16_16_half_float(a_frag_1, b_frag_3, acc_frag_1_3);
acc_frag_1_4 = __WMMA_8_16_16_half_float(a_frag_1, b_frag_4, acc_frag_1_4);
acc_frag_1_5 = __WMMA_8_16_16_half_float(a_frag_1, b_frag_5, acc_frag_1_5);
acc_frag_1_6 = __WMMA_8_16_16_half_float(a_frag_1, b_frag_6, acc_frag_1_6);
acc_frag_1_7 = __WMMA_8_16_16_half_float(a_frag_1, b_frag_7, acc_frag_1_7);
// last 16 K elements
__ldmatrix_a_elems(&a_frag_0, &smem_a_curr[load_smem_a_0_k_3]);
__ldmatrix_a_elems(&a_frag_1, &smem_a_curr[load_smem_a_1_k_3]);
__ldmatrix_b_elems(&b_frag_0, &b_frag_1, &smem_b_curr[load_smem_b_0_k_3]);
__ldmatrix_b_elems(&b_frag_2, &b_frag_3, &smem_b_curr[load_smem_b_1_k_3]);
__ldmatrix_b_elems(&b_frag_4, &b_frag_5, &smem_b_curr[load_smem_b_2_k_3]);
__ldmatrix_b_elems(&b_frag_6, &b_frag_7, &smem_b_curr[load_smem_b_3_k_3]);
acc_frag_0_0 = __WMMA_8_16_16_half_float(a_frag_0, b_frag_0, acc_frag_0_0);
acc_frag_0_1 = __WMMA_8_16_16_half_float(a_frag_0, b_frag_1, acc_frag_0_1);
acc_frag_0_2 = __WMMA_8_16_16_half_float(a_frag_0, b_frag_2, acc_frag_0_2);
acc_frag_0_3 = __WMMA_8_16_16_half_float(a_frag_0, b_frag_3, acc_frag_0_3);
acc_frag_0_4 = __WMMA_8_16_16_half_float(a_frag_0, b_frag_4, acc_frag_0_4);
acc_frag_0_5 = __WMMA_8_16_16_half_float(a_frag_0, b_frag_5, acc_frag_0_5);
acc_frag_0_6 = __WMMA_8_16_16_half_float(a_frag_0, b_frag_6, acc_frag_0_6);
acc_frag_0_7 = __WMMA_8_16_16_half_float(a_frag_0, b_frag_7, acc_frag_0_7);
acc_frag_1_0 = __WMMA_8_16_16_half_float(a_frag_1, b_frag_0, acc_frag_1_0);
acc_frag_1_1 = __WMMA_8_16_16_half_float(a_frag_1, b_frag_1, acc_frag_1_1);
acc_frag_1_2 = __WMMA_8_16_16_half_float(a_frag_1, b_frag_2, acc_frag_1_2);
acc_frag_1_3 = __WMMA_8_16_16_half_float(a_frag_1, b_frag_3, acc_frag_1_3);
acc_frag_1_4 = __WMMA_8_16_16_half_float(a_frag_1, b_frag_4, acc_frag_1_4);
acc_frag_1_5 = __WMMA_8_16_16_half_float(a_frag_1, b_frag_5, acc_frag_1_5);
acc_frag_1_6 = __WMMA_8_16_16_half_float(a_frag_1, b_frag_6, acc_frag_1_6);
acc_frag_1_7 = __WMMA_8_16_16_half_float(a_frag_1, b_frag_7, acc_frag_1_7);
// prefetch next iteration if needed
__syncthreads();
if (block_k < (num_k_blocks-1)) {
__pipeline_memcpy_async(&smem_a_curr[store_smem_a_off + ( 0)], &data1[global_a_off + ( 0)], 16);
__pipeline_memcpy_async(&smem_a_curr[store_smem_a_off + (16*64)], &data1[global_a_off + (16*K)], 16);
__pipeline_memcpy_async(&smem_a_curr[store_smem_a_off + (32*64)], &data1[global_a_off + (32*K)], 16);
__pipeline_memcpy_async(&smem_a_curr[store_smem_a_off + (48*64)], &data1[global_a_off + (48*K)], 16);
__pipeline_memcpy_async(&smem_b_curr[store_smem_b_off + ( 0)], &data2[global_b_off + ( 0)], 16);
__pipeline_memcpy_async(&smem_b_curr[store_smem_b_off + ( 8*128)], &data2[global_b_off + ( 8*N)], 16);
__pipeline_memcpy_async(&smem_b_curr[store_smem_b_off + (16*128)], &data2[global_b_off + (16*N)], 16);
__pipeline_memcpy_async(&smem_b_curr[store_smem_b_off + (24*128)], &data2[global_b_off + (24*N)], 16);
__pipeline_memcpy_async(&smem_b_curr[store_smem_b_off + (32*128)], &data2[global_b_off + (32*N)], 16);
__pipeline_memcpy_async(&smem_b_curr[store_smem_b_off + (40*128)], &data2[global_b_off + (40*N)], 16);
__pipeline_memcpy_async(&smem_b_curr[store_smem_b_off + (48*128)], &data2[global_b_off + (48*N)], 16);
__pipeline_memcpy_async(&smem_b_curr[store_smem_b_off + (56*128)], &data2[global_b_off + (56*N)], 16);
global_a_off += 64;
global_b_off += 64 * N;
}
__pipeline_commit();
}
// write accumulators to output
__pipeline_wait_prior(0);
__syncthreads();
// slower way: write floats one by one to data0
size_t wg_c_off = ((grid_m * 64) * N) + (grid_n * 128) + (wg_m * 16 * N) + (wg_n * 16);
size_t thread_c_off = ((wg_threads % 4) * 2) + (((wg_threads / 4) % 8) * N);
data0[wg_c_off + thread_c_off + 0 + ( 0*8)] = acc_frag_0_0.x;
data0[wg_c_off + thread_c_off + 1 + ( 0*8)] = acc_frag_0_0.y;
data0[wg_c_off + thread_c_off + (8 * N) + 0 + ( 0*8)] = acc_frag_0_0.z;
data0[wg_c_off + thread_c_off + (8 * N) + 1 + ( 0*8)] = acc_frag_0_0.w;
data0[wg_c_off + thread_c_off + 0 + ( 1*8)] = acc_frag_0_1.x;
data0[wg_c_off + thread_c_off + 1 + ( 1*8)] = acc_frag_0_1.y;
data0[wg_c_off + thread_c_off + (8 * N) + 0 + ( 1*8)] = acc_frag_0_1.z;
data0[wg_c_off + thread_c_off + (8 * N) + 1 + ( 1*8)] = acc_frag_0_1.w;
data0[wg_c_off + thread_c_off + 0 + ( 4*8)] = acc_frag_0_2.x;
data0[wg_c_off + thread_c_off + 1 + ( 4*8)] = acc_frag_0_2.y;
data0[wg_c_off + thread_c_off + (8 * N) + 0 + ( 4*8)] = acc_frag_0_2.z;
data0[wg_c_off + thread_c_off + (8 * N) + 1 + ( 4*8)] = acc_frag_0_2.w;
data0[wg_c_off + thread_c_off + 0 + ( 5*8)] = acc_frag_0_3.x;
data0[wg_c_off + thread_c_off + 1 + ( 5*8)] = acc_frag_0_3.y;
data0[wg_c_off + thread_c_off + (8 * N) + 0 + ( 5*8)] = acc_frag_0_3.z;
data0[wg_c_off + thread_c_off + (8 * N) + 1 + ( 5*8)] = acc_frag_0_3.w;
data0[wg_c_off + thread_c_off + 0 + ( 8*8)] = acc_frag_0_4.x;
data0[wg_c_off + thread_c_off + 1 + ( 8*8)] = acc_frag_0_4.y;
data0[wg_c_off + thread_c_off + (8 * N) + 0 + ( 8*8)] = acc_frag_0_4.z;
data0[wg_c_off + thread_c_off + (8 * N) + 1 + ( 8*8)] = acc_frag_0_4.w;
data0[wg_c_off + thread_c_off + 0 + ( 9*8)] = acc_frag_0_5.x;
data0[wg_c_off + thread_c_off + 1 + ( 9*8)] = acc_frag_0_5.y;
data0[wg_c_off + thread_c_off + (8 * N) + 0 + ( 9*8)] = acc_frag_0_5.z;
data0[wg_c_off + thread_c_off + (8 * N) + 1 + ( 9*8)] = acc_frag_0_5.w;
data0[wg_c_off + thread_c_off + 0 + (12*8)] = acc_frag_0_6.x;
data0[wg_c_off + thread_c_off + 1 + (12*8)] = acc_frag_0_6.y;
data0[wg_c_off + thread_c_off + (8 * N) + 0 + (12*8)] = acc_frag_0_6.z;
data0[wg_c_off + thread_c_off + (8 * N) + 1 + (12*8)] = acc_frag_0_6.w;
data0[wg_c_off + thread_c_off + 0 + (13*8)] = acc_frag_0_7.x;
data0[wg_c_off + thread_c_off + 1 + (13*8)] = acc_frag_0_7.y;
data0[wg_c_off + thread_c_off + (8 * N) + 0 + (13*8)] = acc_frag_0_7.z;
data0[wg_c_off + thread_c_off + (8 * N) + 1 + (13*8)] = acc_frag_0_7.w;
wg_c_off += 32*N;
data0[wg_c_off + thread_c_off + 0 + ( 0*8)] = acc_frag_1_0.x;
data0[wg_c_off + thread_c_off + 1 + ( 0*8)] = acc_frag_1_0.y;
data0[wg_c_off + thread_c_off + (8 * N) + 0 + ( 0*8)] = acc_frag_1_0.z;
data0[wg_c_off + thread_c_off + (8 * N) + 1 + ( 0*8)] = acc_frag_1_0.w;
data0[wg_c_off + thread_c_off + 0 + ( 1*8)] = acc_frag_1_1.x;
data0[wg_c_off + thread_c_off + 1 + ( 1*8)] = acc_frag_1_1.y;
data0[wg_c_off + thread_c_off + (8 * N) + 0 + ( 1*8)] = acc_frag_1_1.z;
data0[wg_c_off + thread_c_off + (8 * N) + 1 + ( 1*8)] = acc_frag_1_1.w;
data0[wg_c_off + thread_c_off + 0 + ( 4*8)] = acc_frag_1_2.x;
data0[wg_c_off + thread_c_off + 1 + ( 4*8)] = acc_frag_1_2.y;
data0[wg_c_off + thread_c_off + (8 * N) + 0 + ( 4*8)] = acc_frag_1_2.z;
data0[wg_c_off + thread_c_off + (8 * N) + 1 + ( 4*8)] = acc_frag_1_2.w;
data0[wg_c_off + thread_c_off + 0 + ( 5*8)] = acc_frag_1_3.x;
data0[wg_c_off + thread_c_off + 1 + ( 5*8)] = acc_frag_1_3.y;
data0[wg_c_off + thread_c_off + (8 * N) + 0 + ( 5*8)] = acc_frag_1_3.z;
data0[wg_c_off + thread_c_off + (8 * N) + 1 + ( 5*8)] = acc_frag_1_3.w;
data0[wg_c_off + thread_c_off + 0 + ( 8*8)] = acc_frag_1_4.x;
data0[wg_c_off + thread_c_off + 1 + ( 8*8)] = acc_frag_1_4.y;
data0[wg_c_off + thread_c_off + (8 * N) + 0 + ( 8*8)] = acc_frag_1_4.z;
data0[wg_c_off + thread_c_off + (8 * N) + 1 + ( 8*8)] = acc_frag_1_4.w;
data0[wg_c_off + thread_c_off + 0 + ( 9*8)] = acc_frag_1_5.x;
data0[wg_c_off + thread_c_off + 1 + ( 9*8)] = acc_frag_1_5.y;
data0[wg_c_off + thread_c_off + (8 * N) + 0 + ( 9*8)] = acc_frag_1_5.z;
data0[wg_c_off + thread_c_off + (8 * N) + 1 + ( 9*8)] = acc_frag_1_5.w;
data0[wg_c_off + thread_c_off + 0 + (12*8)] = acc_frag_1_6.x;
data0[wg_c_off + thread_c_off + 1 + (12*8)] = acc_frag_1_6.y;
data0[wg_c_off + thread_c_off + (8 * N) + 0 + (12*8)] = acc_frag_1_6.z;
data0[wg_c_off + thread_c_off + (8 * N) + 1 + (12*8)] = acc_frag_1_6.w;
data0[wg_c_off + thread_c_off + 0 + (13*8)] = acc_frag_1_7.x;
data0[wg_c_off + thread_c_off + 1 + (13*8)] = acc_frag_1_7.y;
data0[wg_c_off + thread_c_off + (8 * N) + 0 + (13*8)] = acc_frag_1_7.z;
data0[wg_c_off + thread_c_off + (8 * N) + 1 + (13*8)] = acc_frag_1_7.w;
}

232
extra/gemm/max_matmul.py Normal file
View File

@@ -0,0 +1,232 @@
import numpy as np, os
from tinygrad.helpers import getenv, flat_mv
from tinygrad import dtypes
from typing import Optional, List, Tuple, cast, Dict, Final, DefaultDict, Self
# for copied uops
from tinygrad.codegen.kernel import Kernel, KernelOptError
from tinygrad.ops import UOp, Ops, BinaryOps, UnaryOps, TernaryOps, KernelInfo
from tinygrad.engine.search import Opt, OptOps
from tinygrad import Device, dtypes, Tensor
from tinygrad.dtype import PtrDType, DType, DTYPES_DICT
from tinygrad.shape.shapetracker import ShapeTracker
from tinygrad.shape.view import View
script_dir = os.path.dirname(os.path.abspath(__file__))
# problem variations
DTYPE_IN = DTYPES_DICT[getenv("DTYPE_IN", "half")]
DTYPE_OUT = DTYPES_DICT[getenv("DTYPE_OUT", "half")]
DTYPE_ACC = DTYPES_DICT[getenv("DTYPE_ACC", "float")]
N = getenv("N", 4096)
M = getenv("M", N)
K = getenv("K", N)
CNT = getenv("CNT", 10)
ATOL = getenv("ATOL", 5e-3 if DTYPE_IN == dtypes.float else 1e-2)
RTOL = getenv("RTOL", 1e-4 if DTYPE_IN == dtypes.float else 1e-3)
FLOPS = M * N * K * 2
BW = 2 * ((M*K) + (K*N) + (M*N))
# algorithm variations
INPUT = getenv("INPUT", "RAND")
GEMM_VARIATION = getenv("GEMM_VARIATION", "nv_hcopt")
def randoms():
if INPUT == "RAND":
na = np.random.default_rng().normal(scale=1.0, size=(M,K)).astype(dtype=np.float32)
nb = np.random.default_rng().normal(scale=1.0, size=(K,N)).astype(dtype=np.float32)
elif INPUT == "IDENTITY" and M==N==K:
na = np.identity(K, dtype=np.float32)
nb = np.identity(K, dtype=np.float32)
elif INPUT == "OUTPUTONES" and M==K:
na = np.identity(K, dtype=np.float32)
nb = np.ones((K,N), dtype=np.float32)
else:
na = np.ones((M,K), dtype=np.float32)
nb = np.ones((K,N), dtype=np.float32)
nc = np.zeros(M*N, np.float32)
if DTYPE_IN != dtypes.float:
na = na.astype(np.bfloat16 if DTYPE_IN == dtypes.bfloat16 else np.float16)
nb = nb.astype(np.bfloat16 if DTYPE_IN == dtypes.bfloat16 else np.float16)
if DTYPE_OUT != dtypes.float:
nc = nc.astype(np.bfloat16 if DTYPE_IN == dtypes.bfloat16 else np.float16)
return na, nb, nc
def ast_to_cuda_prog(compiler, ast, opts):
k = Kernel(ast)
k.required_optimizations()
for opt in opts:
k.apply_opt(opt)
p = k.to_program()
return CUDAProgram(device, p.function_name, compiler.compile(p.src))
if __name__ == "__main__":
print(f"gemm variation: {GEMM_VARIATION=} {M=} {N=} {K=} {DTYPE_IN=} {DTYPE_OUT=} {DTYPE_ACC=}")
prog, global_size, local_size = None, None, None
if getenv("CUDA") == 1:
from tinygrad.runtime.ops_cuda import CUDAAllocator, CUDADevice, CUDAProgram, CUDACompiler
device = CUDADevice("cuda:0")
compiler = CUDACompiler(device.arch)
cudaalloc = CUDAAllocator(device)
a = cudaalloc.alloc(M*K*DTYPE_IN.itemsize)
b = cudaalloc.alloc(K*N*DTYPE_IN.itemsize)
c = cudaalloc.alloc(M*N*DTYPE_OUT.itemsize)
if GEMM_VARIATION == "max" and (M%64)==0 and (N%128)==0 and (K%64)==0 and DTYPE_IN == dtypes.half and DTYPE_OUT == dtypes.float and DTYPE_ACC == dtypes.float:
print("Using CUDA and triton-generated kernel")
# See nv_triton_gemm.annotated.ptx for PTX code which was generated from `PYTHONPATH=. DEBUG=6 CUDA=1 PTX=1 python3 extra/gemm/triton_nv_matmul.py`
# this kernel with M=N=K=4096 does 162TFLOPS, vs torch at 144TFLOPS and BEAM=8 tinygrad at 138TFLOPS. theo max is 165TFLOPS.
# WMMA element size is (M, N, K) = (16, 8, 16)
# warpgroup size in WMMA tiles is (B_M, B_N, B_K) = (2, 8, 4) so 64 HMMA calls per threadgroup reduce iteration
# thread block size is (T_M, T_N, T_K) = (2, 2, 1), i.e. macro blocks in M and N, so 256 HMMA calls per kernel reduce iteration
# kernel reduce iteration size in elements = (64, 128, 64)
# single iteration SMEM_A = (64 * 64) * (2 bytes / half) = 8192 bytes, SMEM_B = (128 * 64) * (2 bytes / half) = 16384 bytes
# double-buffer smem = (8192 + 16384) * 2 = 49152 bytes
# reduce for_loop size = [1, 1, (4096 // 16 // 4)==64]
# NOTE: T_K > 0 would be group_for_reduce
prog = CUDAProgram(device, "wmma_example", compiler.compile(open(os.path.join(script_dir, 'max_kernels/nv.fp16_fp32_fp32.max.cu')).read()))
args = (c, a, b)
kwargs = {
'global_size': [M//64, N//128, 1],
'local_size': [128, 1, 1], # 4 warpgroups == (T_M:=2) * (T_N:=2)
'wait': True,
'vals': (N, K),
}
elif GEMM_VARIATION == "2_stage_swizzled_smem_input" and (M%64)==0 and (N%128)==0 and (K%64)==0 and DTYPE_IN == dtypes.half and DTYPE_OUT == dtypes.float and DTYPE_ACC == dtypes.float:
print("Using CUDA, 2-stage reduce pipeline, swizzled SMEM inputs")
prog = CUDAProgram(device, "wmma_example", compiler.compile(open(os.path.join(script_dir, 'max_kernels/nv.fp16_fp32_fp32.2_stage_swizzled_smem_input.cu')).read()))
args = (c, a, b)
kwargs = {
'global_size': [M//64, N//128, 1],
'local_size': [128, 1, 1], # 4 warpgroups == (T_M:=2) * (T_N:=2)
'wait': True,
'vals': (N, K),
}
elif GEMM_VARIATION == "swizzled_smem_input" and (M%64)==0 and (N%128)==0 and (K%64)==0 and DTYPE_IN == dtypes.half and DTYPE_OUT == dtypes.float and DTYPE_ACC == dtypes.float:
print("Using CUDA, swizzled SMEM inputs")
prog = CUDAProgram(device, "wmma_example", compiler.compile(open(os.path.join(script_dir, 'max_kernels/nv.fp16_fp32_fp32.swizzled_smem_input.cu')).read()))
args = (c, a, b)
kwargs = {
'global_size': [M//64, N//128, 1],
'local_size': [128, 1, 1], # 4 warpgroups == (T_M:=2) * (T_N:=2)
'wait': True,
'vals': (N, K),
}
elif GEMM_VARIATION == "flat_smem_input" and (M%64)==0 and (N%128)==0 and (K%64)==0 and DTYPE_IN == dtypes.half and DTYPE_OUT == dtypes.float and DTYPE_ACC == dtypes.float:
print("Using CUDA, flat SMEM inputs")
prog = CUDAProgram(device, "wmma_example", compiler.compile(open(os.path.join(script_dir, 'max_kernels/nv.fp16_fp32_fp32.flat_smem_input.cu')).read()))
args = (c, a, b)
kwargs = {
'global_size': [M//64, N//128, 1],
'local_size': [128, 1, 1], # 4 warpgroups == (T_M:=2) * (T_N:=2)
'wait': True,
'vals': (N, K),
}
elif GEMM_VARIATION == "hcopt" and M == N == K == 4096 and DTYPE_IN == dtypes.half and DTYPE_OUT == dtypes.half and DTYPE_ACC == dtypes.float:
print("Using CUDA and generated hcopt")
# [Opt(op=OptOps.TC, axis=0, amt=0), Opt(op=OptOps.UPCAST, axis=0, amt=4), Opt(op=OptOps.UPCAST, axis=1, amt=4), Opt(op=OptOps.LOCAL, axis=1, amt=4)]
prog = CUDAProgram(device, "wmma_example", compiler.compile(open(os.path.join(script_dir, 'max_kernels/nv.fp16_fp32_fp16.hcopt.cu')).read()))
args = (c, a, b)
kwargs = {
'global_size': [32, 64, 1],
'local_size': [16, 2, 4], # 16,2 are warp, 4 workgroups upcasted to axis=1
'wait': True,
}
elif GEMM_VARIATION == "2_stage" and (M%64)== 0 and (N%128)==0 and (K%64)==0 and DTYPE_IN == dtypes.half and DTYPE_OUT == dtypes.half and DTYPE_ACC == dtypes.half:
print("Using CUDA and un-optimized 2-stage, swizzled SMEM inputs and direct acc to output kernel")
prog = CUDAProgram(device, "wmma_example", compiler.compile(open(os.path.join(script_dir, 'max_kernels/nv.fp16_fp16_fp16.2_stage.cu')).read()))
args = (c, a, b)
kwargs = {
'global_size': [M//64, N//128, 1],
'local_size': [128, 1, 1], # 4 warpgroups == (T_M:=2) * (T_N:=2)
'wait': True,
'vals': (N, K),
}
elif GEMM_VARIATION == "3_stage" and (M%256)== 0 and (N%128)==0 and (K%32)==0 and DTYPE_IN == dtypes.half and DTYPE_OUT == dtypes.half and DTYPE_ACC == dtypes.half:
print("Using CUDA and 3-stage (interleave global copies and ldmatrix)")
prog = CUDAProgram(device, "wmma_example", compiler.compile(open(os.path.join(script_dir, 'max_kernels/nv.fp16_fp16_fp16.3_stage.cu')).read()), 73728)
args = (c, a, b)
kwargs = {
'global_size': [M//256, N//128, 1],
'local_size': [32, 4, 2], # 8 warpgroups, WG_M=4 and WG_N=2
'wait': True,
'vals': (N, K),
}
elif GEMM_VARIATION == "3_stage_swizzled" and (M%256)== 0 and (N%128)==0 and (K%32)==0 and DTYPE_IN == dtypes.half and DTYPE_OUT == dtypes.half and DTYPE_ACC == dtypes.half:
print("Using CUDA and 3-stage (interleave global copies and ldmatrix) and swizzled SMEM inputs")
prog = CUDAProgram(device, "wmma_example", compiler.compile(open(os.path.join(script_dir, 'max_kernels/nv.fp16_fp16_fp16.3_stage_swizzled.cu')).read()), 73728)
args = (c, a, b)
kwargs = {
'global_size': [M//256, N//128, 1],
'local_size': [32, 4, 2], # 8 warpgroups, WG_M=4 and WG_N=2
'wait': True,
'vals': (N, K),
}
elif GEMM_VARIATION == "max" and (M%256)== 0 and (N%128)==0 and (K%32)==0 and DTYPE_IN == dtypes.half and DTYPE_OUT == dtypes.half and DTYPE_ACC == dtypes.half:
print("Using CUDA and 3-stage (interleave global copies and ldmatrix), swizzled SMEM inputs and epilogue")
prog = CUDAProgram(device, "wmma_example", compiler.compile(open(os.path.join(script_dir, 'max_kernels/nv.fp16_fp16_fp16.max.cu')).read()), 73728)
args = (c, a, b)
kwargs = {
'global_size': [M//256, N//128, 1],
'local_size': [32, 4, 2], # 8 warpgroups, WG_M=4 and WG_N=2
'wait': True,
'vals': (N, K),
}
elif GEMM_VARIATION == "no_xor" and (M%256)== 0 and (N%128)==0 and (K%32)==0 and DTYPE_IN == dtypes.half and DTYPE_OUT == dtypes.half and DTYPE_ACC == dtypes.half:
print("Using CUDA and 3-stage (interleave global copies and ldmatrix), swizzled SMEM inputs and epilogue")
prog = CUDAProgram(device, "wmma_example", compiler.compile(open(os.path.join(script_dir, 'max_kernels/nv.fp16_fp16_fp16.no_xor.cu')).read()), 73728)
args = (c, a, b)
kwargs = {
'global_size': [M//256, N//128, 1],
'local_size': [32, 4, 2], # 8 warpgroups, WG_M=4 and WG_N=2
'wait': True,
'vals': (N, K),
}
else:
raise RuntimeError(f"invalid gemm variation: {GEMM_VARIATION=} {M=} {N=} {K=} {DTYPE_IN=} {DTYPE_OUT=} {DTYPE_ACC=}")
tms = []
na, nb, nc = randoms()
cudaalloc.copyin(a, bytearray(na))
cudaalloc.copyin(b, bytearray(nb))
for i in range(CNT):
tms.append(prog(*args, **kwargs))
cudaalloc.copyout(flat_mv(nc.data), c)
comp = na.astype(np.float32) @ nb.astype(np.float32)
result = nc.reshape(M, N).astype(np.float32)
print(f"{N*N:10d} {min(tms)*1e6:9.2f} us, would be {FLOPS*1e-9/min(tms):9.2f} GFLOPS matmul, {BW*1e-9/min(tms):.2f} GB/s")
try:
np.testing.assert_allclose(result, comp, atol=ATOL, rtol=RTOL)
except AssertionError as e:
if getenv("DEBUG_VALUES") > 0:
indices = np.where(~np.isclose(result, comp, rtol=RTOL, atol=ATOL))
non_matching_elements_result = result[indices]
non_matching_elements_comp = comp[indices]
print("valid :", np.where(np.isclose(result, comp, rtol=RTOL, atol=ATOL)))
print("invalid :", indices)
print("result :", non_matching_elements_result)
print("ground truth:", non_matching_elements_comp)
print("result sum :", np.sum(result))
print("ground sum :", np.sum(comp))
raise e
if getenv("DEBUG_VALUES") > 0:
print(comp)
print("ground sum :", np.sum(comp))
print(result)
print("result sum :", np.sum(result))
elif getenv("AMD") == 1:
# note: https://hipfft.readthedocs.io/en/rocm-6.1.2/how-to/fine-tuning-llms/optimizing-triton-kernel.html
# also this is different than the rocblas/tensile approach to GEMM
# see: https://github.com/ROCm/Tensile/blob/develop/Tensile/KernelWriterAssembly.py
raise RuntimeError("invalid max_matmul device")
else:
raise RuntimeError("invalid max_matmul device")