mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 06:58:11 -05:00
extra/gemm/max_matmul: start of custom kernels for GEMM (#6926)
* extra/gemm/max_matmul: start of custom kernels for GEMM * add an unoptimized FP16/FP16 MMA example * add slow 3-stage fp16 acc example * add correct 3-stage pipeline with unswizzled/flat smem input (slow) * add acc fp16 example with 3 stages and swizzle (no bank conflicts) * add max version of NV fp16_fp16_fp16 * fix up comments and removed unused code in max variations * add start of no_xor example * fix to account for UOps to Ops
This commit is contained in:
508
extra/gemm/max_kernels/nv.fp16_fp16_fp16.2_stage.cu
Normal file
508
extra/gemm/max_kernels/nv.fp16_fp16_fp16.2_stage.cu
Normal file
@@ -0,0 +1,508 @@
|
||||
#define INFINITY (__int_as_float(0x7f800000))
|
||||
#define NAN (__int_as_float(0x7fffffff))
|
||||
#include <cuda_fp16.h>
|
||||
#include <cuda_pipeline.h>
|
||||
#define N_PAD 132
|
||||
|
||||
struct __align__(8) half4 { half x, y, z, w; };
|
||||
__device__ half4 make_half4(half x, half y, half z, half w) { half4 r={x, y, z, w}; return r; }
|
||||
|
||||
struct __align__(16) half8 { half x, y, z, w, a, b, c, d; };
|
||||
__device__ half8 make_half8(half x, half y, half z, half w, half a, half b, half c, half d) { half8 r={x, y, z, w, a, b, c, d}; return r; }
|
||||
|
||||
__device__ void __ldmatrix_a_elems(half8 *regs, half *smem) {
|
||||
uint32_t reg0, reg1, reg2, reg3;
|
||||
asm volatile(
|
||||
"ldmatrix.sync.aligned.m8n8.x4.shared.b16 {%0, %1, %2, %3}, [%4];"
|
||||
: "=r"(reg0), "=r"(reg1), "=r"(reg2), "=r"(reg3)
|
||||
: "l"(__cvta_generic_to_shared(smem))
|
||||
);
|
||||
uint32_t *addr = reinterpret_cast<uint32_t*>(regs);
|
||||
addr[0] = reg0;
|
||||
addr[1] = reg1;
|
||||
addr[2] = reg2;
|
||||
addr[3] = reg3;
|
||||
}
|
||||
|
||||
__device__ void __ldmatrix_b_elems(half4 *regs_lo, half4 *regs_hi, half *smem) {
|
||||
uint32_t reg0, reg1, reg2, reg3;
|
||||
asm volatile(
|
||||
"ldmatrix.sync.aligned.m8n8.x4.trans.shared.b16 {%0, %1, %2, %3}, [%4];"
|
||||
: "=r"(reg0), "=r"(reg1), "=r"(reg2), "=r"(reg3)
|
||||
: "l"(__cvta_generic_to_shared(smem))
|
||||
);
|
||||
uint32_t *addr_lo = reinterpret_cast<uint32_t*>(regs_lo);
|
||||
uint32_t *addr_hi = reinterpret_cast<uint32_t*>(regs_hi);
|
||||
addr_lo[0] = reg0;
|
||||
addr_lo[1] = reg1;
|
||||
addr_hi[0] = reg2;
|
||||
addr_hi[1] = reg3;
|
||||
}
|
||||
|
||||
__device__ half4 __WMMA_8_16_16_half_half(half8 a, half4 b, half4 c) {
|
||||
int *a_pk = (int *) (&a), *b_pk = (int *) (&b), *c_pk = (int *) (&c);
|
||||
asm( "mma.sync.aligned.m16n8k16.row.col.f16.f16.f16.f16 { %0, %1 }, { %2, %3, %4, %5 }, { %6, %7 }, { %0, %1 };"
|
||||
: "+r"(c_pk[0]), "+r"(c_pk[1]): "r"(a_pk[0]), "r"(a_pk[1]), "r"(a_pk[2]), "r"(a_pk[3]), "r"(b_pk[0]), "r"(b_pk[1]) );
|
||||
return c;
|
||||
}
|
||||
|
||||
extern "C" __global__ void __launch_bounds__(128) wmma_example(half* data0, const half* data1, const half* data2, int N, int K) {
|
||||
int grid_m = blockIdx.x; /* M//64 */
|
||||
int grid_n = blockIdx.y; /* N//128 */
|
||||
int threads = threadIdx.x; /* 128 */
|
||||
int wg_m = (threads/64); // 0 or 1 for 1st and 3rd blocks of b_m=16xb_k=16 vs 2nd and 4th blocks
|
||||
int wg_n = (threads/32)%2; // 0 or 1 for 1st, 3rd, 5th, 7th blocks of b_n=16xb_k=16 vs 2nd, 4th, 6th, 8th blocks - differs from triton
|
||||
int wg_threads = threads%32;
|
||||
int num_k_blocks = K / 64;
|
||||
|
||||
// load indexes
|
||||
size_t global_a_off = ((grid_m * 64) * K) + ((threads % 8) * 8) + ((threads / 8) * K);
|
||||
size_t global_b_off = (grid_n * 128) + ((threads % 16) * 8) + ((threads / 16) * N);
|
||||
|
||||
// swizzled smem store offsets - columns of smem are swizzled
|
||||
// here's a link to a description of the triton: https://github.com/triton-lang/triton/discussions/2026#discussioncomment-6746579
|
||||
// see also the thunderkittens impl: https://github.com/HazyResearch/ThunderKittens/blob/main/include/types/shared/st.cuh
|
||||
size_t store_smem_a_off = ((threads / 8) * 64) + (((threads * 8) ^ threads) & 56); // r15
|
||||
size_t store_smem_b_off = ((threads / 16) * 128) + (((threads / 16) * 8) ^ ((threads % 16) * 8)); // r19\
|
||||
|
||||
// ldmatrix indices
|
||||
// threads 0-7 are row starts for A, 8-15 for B, 16-23 for C, 24-31 for D
|
||||
// [ A | C ]
|
||||
// [ - + - ]
|
||||
// [ B | D ]
|
||||
|
||||
// swizzled ldmatrix
|
||||
size_t load_smem_a_row = ((wg_m * 16) + (threads % 16)) * 64; // r293
|
||||
size_t load_smem_a_phase = (threads / 16) % 2; // r4
|
||||
|
||||
size_t load_smem_b_row = (threads % 16) * 128; // r299
|
||||
size_t load_smem_b_phase = (wg_n * 2) + (((threads / 16) % 2)); // r297 -- this differs from the generated triton kernel (swapped order)
|
||||
|
||||
size_t load_smem_a_0_k_0 = load_smem_a_row + (((load_smem_a_phase + 0) ^ (threads % 8)) * 8); // r38
|
||||
size_t load_smem_a_1_k_0 = load_smem_a_0_k_0 + (32 * 64);
|
||||
size_t load_smem_b_0_k_0 = load_smem_b_row + (((load_smem_b_phase + 0) ^ (threads % 8)) * 8);
|
||||
size_t load_smem_b_1_k_0 = load_smem_b_row + (((load_smem_b_phase + 4) ^ (threads % 8)) * 8);
|
||||
size_t load_smem_b_2_k_0 = load_smem_b_row + (((load_smem_b_phase + 8) ^ (threads % 8)) * 8);
|
||||
size_t load_smem_b_3_k_0 = load_smem_b_row + (((load_smem_b_phase + 12) ^ (threads % 8)) * 8);
|
||||
|
||||
size_t load_smem_a_0_k_1 = load_smem_a_row + (((load_smem_a_phase + 2) ^ (threads % 8)) * 8); // r58 = r293 + r316;
|
||||
size_t load_smem_a_1_k_1 = load_smem_a_0_k_1 + (32 * 64);
|
||||
size_t load_smem_b_0_k_1 = load_smem_b_0_k_0 + (16 * 128);
|
||||
size_t load_smem_b_1_k_1 = load_smem_b_1_k_0 + (16 * 128);
|
||||
size_t load_smem_b_2_k_1 = load_smem_b_2_k_0 + (16 * 128);
|
||||
size_t load_smem_b_3_k_1 = load_smem_b_3_k_0 + (16 * 128);
|
||||
|
||||
size_t load_smem_a_0_k_2 = load_smem_a_row + (((load_smem_a_phase + 4) ^ (threads % 8)) * 8); // r59 = r293 + r319;
|
||||
size_t load_smem_a_1_k_2 = load_smem_a_0_k_2 + (32 * 64);
|
||||
size_t load_smem_b_0_k_2 = load_smem_b_0_k_0 + (32 * 128);
|
||||
size_t load_smem_b_1_k_2 = load_smem_b_1_k_0 + (32 * 128);
|
||||
size_t load_smem_b_2_k_2 = load_smem_b_2_k_0 + (32 * 128);
|
||||
size_t load_smem_b_3_k_2 = load_smem_b_3_k_0 + (32 * 128);
|
||||
|
||||
size_t load_smem_a_0_k_3 = load_smem_a_row + (((load_smem_a_phase + 6) ^ (threads % 8)) * 8); // r60 = r293 + r322;
|
||||
size_t load_smem_a_1_k_3 = load_smem_a_0_k_3 + (32 * 64);
|
||||
size_t load_smem_b_0_k_3 = load_smem_b_0_k_0 + (48 * 128);
|
||||
size_t load_smem_b_1_k_3 = load_smem_b_1_k_0 + (48 * 128);
|
||||
size_t load_smem_b_2_k_3 = load_smem_b_2_k_0 + (48 * 128);
|
||||
size_t load_smem_b_3_k_3 = load_smem_b_3_k_0 + (48 * 128);
|
||||
|
||||
// create shared mem (A_1 8192 bytes, A_2 8192 bytes, B_1 16384 bytes, B2_16384 bytes)
|
||||
__shared__ alignas(16) char smem[49152];
|
||||
|
||||
// create accs (16 WMMAs and 4 output elements each) and zero
|
||||
half4 acc_frag_0_0 = make_half4(0.0f,0.0f,0.0f,0.0f);
|
||||
half4 acc_frag_0_1 = make_half4(0.0f,0.0f,0.0f,0.0f);
|
||||
half4 acc_frag_0_2 = make_half4(0.0f,0.0f,0.0f,0.0f);
|
||||
half4 acc_frag_0_3 = make_half4(0.0f,0.0f,0.0f,0.0f);
|
||||
half4 acc_frag_0_4 = make_half4(0.0f,0.0f,0.0f,0.0f);
|
||||
half4 acc_frag_0_5 = make_half4(0.0f,0.0f,0.0f,0.0f);
|
||||
half4 acc_frag_0_6 = make_half4(0.0f,0.0f,0.0f,0.0f);
|
||||
half4 acc_frag_0_7 = make_half4(0.0f,0.0f,0.0f,0.0f);
|
||||
half4 acc_frag_1_0 = make_half4(0.0f,0.0f,0.0f,0.0f);
|
||||
half4 acc_frag_1_1 = make_half4(0.0f,0.0f,0.0f,0.0f);
|
||||
half4 acc_frag_1_2 = make_half4(0.0f,0.0f,0.0f,0.0f);
|
||||
half4 acc_frag_1_3 = make_half4(0.0f,0.0f,0.0f,0.0f);
|
||||
half4 acc_frag_1_4 = make_half4(0.0f,0.0f,0.0f,0.0f);
|
||||
half4 acc_frag_1_5 = make_half4(0.0f,0.0f,0.0f,0.0f);
|
||||
half4 acc_frag_1_6 = make_half4(0.0f,0.0f,0.0f,0.0f);
|
||||
half4 acc_frag_1_7 = make_half4(0.0f,0.0f,0.0f,0.0f);
|
||||
|
||||
// create registers for block A elements (2)
|
||||
half8 a_frag_0;
|
||||
half8 a_frag_1;
|
||||
|
||||
// create register for block B elements (8)
|
||||
half4 b_frag_0;
|
||||
half4 b_frag_1;
|
||||
half4 b_frag_2;
|
||||
half4 b_frag_3;
|
||||
half4 b_frag_4;
|
||||
half4 b_frag_5;
|
||||
half4 b_frag_6;
|
||||
half4 b_frag_7;
|
||||
|
||||
half *smem_a_even = (half *)(smem);
|
||||
half *smem_a_odd = (half *)(smem + 8192);
|
||||
half *smem_b_even = (half *)(smem + 16384);
|
||||
half *smem_b_odd = (half *)(smem + 32768);
|
||||
|
||||
// https://developer.nvidia.com/blog/controlling-data-movement-to-boost-performance-on-ampere-architecture/
|
||||
// https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#asynchronous-data-copies
|
||||
|
||||
// start first pre-fetch load A
|
||||
__pipeline_memcpy_async(&smem_a_even[store_smem_a_off + ( 0)], &data1[global_a_off + ( 0)], 16);
|
||||
__pipeline_memcpy_async(&smem_a_even[store_smem_a_off + (16*64)], &data1[global_a_off + (16*K)], 16);
|
||||
__pipeline_memcpy_async(&smem_a_even[store_smem_a_off + (32*64)], &data1[global_a_off + (32*K)], 16);
|
||||
__pipeline_memcpy_async(&smem_a_even[store_smem_a_off + (48*64)], &data1[global_a_off + (48*K)], 16);
|
||||
|
||||
// start first pre-fetch load B
|
||||
__pipeline_memcpy_async(&smem_b_even[store_smem_b_off + ( 0)], &data2[global_b_off + ( 0)], 16);
|
||||
__pipeline_memcpy_async(&smem_b_even[store_smem_b_off + ( 8*128)], &data2[global_b_off + ( 8*N)], 16);
|
||||
__pipeline_memcpy_async(&smem_b_even[store_smem_b_off + (16*128)], &data2[global_b_off + (16*N)], 16);
|
||||
__pipeline_memcpy_async(&smem_b_even[store_smem_b_off + (24*128)], &data2[global_b_off + (24*N)], 16);
|
||||
__pipeline_memcpy_async(&smem_b_even[store_smem_b_off + (32*128)], &data2[global_b_off + (32*N)], 16);
|
||||
__pipeline_memcpy_async(&smem_b_even[store_smem_b_off + (40*128)], &data2[global_b_off + (40*N)], 16);
|
||||
__pipeline_memcpy_async(&smem_b_even[store_smem_b_off + (48*128)], &data2[global_b_off + (48*N)], 16);
|
||||
__pipeline_memcpy_async(&smem_b_even[store_smem_b_off + (56*128)], &data2[global_b_off + (56*N)], 16);
|
||||
__pipeline_commit();
|
||||
|
||||
global_a_off += 64;
|
||||
global_b_off += 64 * N;
|
||||
__syncthreads();
|
||||
|
||||
// start second pre-fetch load A
|
||||
__pipeline_memcpy_async(&smem_a_odd[store_smem_a_off + ( 0)], &data1[global_a_off + ( 0)], 16);
|
||||
__pipeline_memcpy_async(&smem_a_odd[store_smem_a_off + (16*64)], &data1[global_a_off + (16*K)], 16);
|
||||
__pipeline_memcpy_async(&smem_a_odd[store_smem_a_off + (32*64)], &data1[global_a_off + (32*K)], 16);
|
||||
__pipeline_memcpy_async(&smem_a_odd[store_smem_a_off + (48*64)], &data1[global_a_off + (48*K)], 16);
|
||||
|
||||
// start second pre-fetch load B
|
||||
__pipeline_memcpy_async(&smem_b_odd[store_smem_b_off + ( 0)], &data2[global_b_off + ( 0)], 16);
|
||||
__pipeline_memcpy_async(&smem_b_odd[store_smem_b_off + ( 8*128)], &data2[global_b_off + ( 8*N)], 16);
|
||||
__pipeline_memcpy_async(&smem_b_odd[store_smem_b_off + (16*128)], &data2[global_b_off + (16*N)], 16);
|
||||
__pipeline_memcpy_async(&smem_b_odd[store_smem_b_off + (24*128)], &data2[global_b_off + (24*N)], 16);
|
||||
__pipeline_memcpy_async(&smem_b_odd[store_smem_b_off + (32*128)], &data2[global_b_off + (32*N)], 16);
|
||||
__pipeline_memcpy_async(&smem_b_odd[store_smem_b_off + (40*128)], &data2[global_b_off + (40*N)], 16);
|
||||
__pipeline_memcpy_async(&smem_b_odd[store_smem_b_off + (48*128)], &data2[global_b_off + (48*N)], 16);
|
||||
__pipeline_memcpy_async(&smem_b_odd[store_smem_b_off + (56*128)], &data2[global_b_off + (56*N)], 16);
|
||||
__pipeline_commit();
|
||||
|
||||
global_a_off += 64;
|
||||
global_b_off += 64 * N;
|
||||
|
||||
// wait on needed prefetch value
|
||||
__pipeline_wait_prior(0); // TODO: this enables fast iterations, but incorrect results with 1 (it shouldn't)
|
||||
__syncthreads();
|
||||
|
||||
for (int block_k = 0; block_k < num_k_blocks; block_k++) {
|
||||
// BLOCK_K==4: unroll 4 iterations of ldmatrix/wmma
|
||||
half *smem_a_curr = (block_k % 2) ? smem_a_even : smem_a_odd;
|
||||
half *smem_b_curr = (block_k % 2) ? smem_b_even : smem_b_odd;
|
||||
|
||||
// first load 16 K elements and 16 WMMAs: BLOCK_M==2 * BLOCK_N==8
|
||||
__ldmatrix_a_elems(&a_frag_0, &smem_a_curr[load_smem_a_0_k_0]);
|
||||
__ldmatrix_a_elems(&a_frag_1, &smem_a_curr[load_smem_a_1_k_0]);
|
||||
__ldmatrix_b_elems(&b_frag_0, &b_frag_1, &smem_b_curr[load_smem_b_0_k_0]);
|
||||
__ldmatrix_b_elems(&b_frag_2, &b_frag_3, &smem_b_curr[load_smem_b_1_k_0]);
|
||||
__ldmatrix_b_elems(&b_frag_4, &b_frag_5, &smem_b_curr[load_smem_b_2_k_0]);
|
||||
__ldmatrix_b_elems(&b_frag_6, &b_frag_7, &smem_b_curr[load_smem_b_3_k_0]);
|
||||
acc_frag_0_0 = __WMMA_8_16_16_half_half(a_frag_0, b_frag_0, acc_frag_0_0);
|
||||
acc_frag_0_1 = __WMMA_8_16_16_half_half(a_frag_0, b_frag_1, acc_frag_0_1);
|
||||
acc_frag_0_2 = __WMMA_8_16_16_half_half(a_frag_0, b_frag_2, acc_frag_0_2);
|
||||
acc_frag_0_3 = __WMMA_8_16_16_half_half(a_frag_0, b_frag_3, acc_frag_0_3);
|
||||
acc_frag_0_4 = __WMMA_8_16_16_half_half(a_frag_0, b_frag_4, acc_frag_0_4);
|
||||
acc_frag_0_5 = __WMMA_8_16_16_half_half(a_frag_0, b_frag_5, acc_frag_0_5);
|
||||
acc_frag_0_6 = __WMMA_8_16_16_half_half(a_frag_0, b_frag_6, acc_frag_0_6);
|
||||
acc_frag_0_7 = __WMMA_8_16_16_half_half(a_frag_0, b_frag_7, acc_frag_0_7);
|
||||
acc_frag_1_0 = __WMMA_8_16_16_half_half(a_frag_1, b_frag_0, acc_frag_1_0);
|
||||
acc_frag_1_1 = __WMMA_8_16_16_half_half(a_frag_1, b_frag_1, acc_frag_1_1);
|
||||
acc_frag_1_2 = __WMMA_8_16_16_half_half(a_frag_1, b_frag_2, acc_frag_1_2);
|
||||
acc_frag_1_3 = __WMMA_8_16_16_half_half(a_frag_1, b_frag_3, acc_frag_1_3);
|
||||
acc_frag_1_4 = __WMMA_8_16_16_half_half(a_frag_1, b_frag_4, acc_frag_1_4);
|
||||
acc_frag_1_5 = __WMMA_8_16_16_half_half(a_frag_1, b_frag_5, acc_frag_1_5);
|
||||
acc_frag_1_6 = __WMMA_8_16_16_half_half(a_frag_1, b_frag_6, acc_frag_1_6);
|
||||
acc_frag_1_7 = __WMMA_8_16_16_half_half(a_frag_1, b_frag_7, acc_frag_1_7);
|
||||
|
||||
// next 16 K elements
|
||||
__ldmatrix_a_elems(&a_frag_0, &smem_a_curr[load_smem_a_0_k_1]);
|
||||
__ldmatrix_a_elems(&a_frag_1, &smem_a_curr[load_smem_a_1_k_1]);
|
||||
__ldmatrix_b_elems(&b_frag_0, &b_frag_1, &smem_b_curr[load_smem_b_0_k_1]);
|
||||
__ldmatrix_b_elems(&b_frag_2, &b_frag_3, &smem_b_curr[load_smem_b_1_k_1]);
|
||||
__ldmatrix_b_elems(&b_frag_4, &b_frag_5, &smem_b_curr[load_smem_b_2_k_1]);
|
||||
__ldmatrix_b_elems(&b_frag_6, &b_frag_7, &smem_b_curr[load_smem_b_3_k_1]);
|
||||
acc_frag_0_0 = __WMMA_8_16_16_half_half(a_frag_0, b_frag_0, acc_frag_0_0);
|
||||
acc_frag_0_1 = __WMMA_8_16_16_half_half(a_frag_0, b_frag_1, acc_frag_0_1);
|
||||
acc_frag_0_2 = __WMMA_8_16_16_half_half(a_frag_0, b_frag_2, acc_frag_0_2);
|
||||
acc_frag_0_3 = __WMMA_8_16_16_half_half(a_frag_0, b_frag_3, acc_frag_0_3);
|
||||
acc_frag_0_4 = __WMMA_8_16_16_half_half(a_frag_0, b_frag_4, acc_frag_0_4);
|
||||
acc_frag_0_5 = __WMMA_8_16_16_half_half(a_frag_0, b_frag_5, acc_frag_0_5);
|
||||
acc_frag_0_6 = __WMMA_8_16_16_half_half(a_frag_0, b_frag_6, acc_frag_0_6);
|
||||
acc_frag_0_7 = __WMMA_8_16_16_half_half(a_frag_0, b_frag_7, acc_frag_0_7);
|
||||
acc_frag_1_0 = __WMMA_8_16_16_half_half(a_frag_1, b_frag_0, acc_frag_1_0);
|
||||
acc_frag_1_1 = __WMMA_8_16_16_half_half(a_frag_1, b_frag_1, acc_frag_1_1);
|
||||
acc_frag_1_2 = __WMMA_8_16_16_half_half(a_frag_1, b_frag_2, acc_frag_1_2);
|
||||
acc_frag_1_3 = __WMMA_8_16_16_half_half(a_frag_1, b_frag_3, acc_frag_1_3);
|
||||
acc_frag_1_4 = __WMMA_8_16_16_half_half(a_frag_1, b_frag_4, acc_frag_1_4);
|
||||
acc_frag_1_5 = __WMMA_8_16_16_half_half(a_frag_1, b_frag_5, acc_frag_1_5);
|
||||
acc_frag_1_6 = __WMMA_8_16_16_half_half(a_frag_1, b_frag_6, acc_frag_1_6);
|
||||
acc_frag_1_7 = __WMMA_8_16_16_half_half(a_frag_1, b_frag_7, acc_frag_1_7);
|
||||
|
||||
// next 16 K elements
|
||||
__ldmatrix_a_elems(&a_frag_0, &smem_a_curr[load_smem_a_0_k_2]);
|
||||
__ldmatrix_a_elems(&a_frag_1, &smem_a_curr[load_smem_a_1_k_2]);
|
||||
__ldmatrix_b_elems(&b_frag_0, &b_frag_1, &smem_b_curr[load_smem_b_0_k_2]);
|
||||
__ldmatrix_b_elems(&b_frag_2, &b_frag_3, &smem_b_curr[load_smem_b_1_k_2]);
|
||||
__ldmatrix_b_elems(&b_frag_4, &b_frag_5, &smem_b_curr[load_smem_b_2_k_2]);
|
||||
__ldmatrix_b_elems(&b_frag_6, &b_frag_7, &smem_b_curr[load_smem_b_3_k_2]);
|
||||
acc_frag_0_0 = __WMMA_8_16_16_half_half(a_frag_0, b_frag_0, acc_frag_0_0);
|
||||
acc_frag_0_1 = __WMMA_8_16_16_half_half(a_frag_0, b_frag_1, acc_frag_0_1);
|
||||
acc_frag_0_2 = __WMMA_8_16_16_half_half(a_frag_0, b_frag_2, acc_frag_0_2);
|
||||
acc_frag_0_3 = __WMMA_8_16_16_half_half(a_frag_0, b_frag_3, acc_frag_0_3);
|
||||
acc_frag_0_4 = __WMMA_8_16_16_half_half(a_frag_0, b_frag_4, acc_frag_0_4);
|
||||
acc_frag_0_5 = __WMMA_8_16_16_half_half(a_frag_0, b_frag_5, acc_frag_0_5);
|
||||
acc_frag_0_6 = __WMMA_8_16_16_half_half(a_frag_0, b_frag_6, acc_frag_0_6);
|
||||
acc_frag_0_7 = __WMMA_8_16_16_half_half(a_frag_0, b_frag_7, acc_frag_0_7);
|
||||
acc_frag_1_0 = __WMMA_8_16_16_half_half(a_frag_1, b_frag_0, acc_frag_1_0);
|
||||
acc_frag_1_1 = __WMMA_8_16_16_half_half(a_frag_1, b_frag_1, acc_frag_1_1);
|
||||
acc_frag_1_2 = __WMMA_8_16_16_half_half(a_frag_1, b_frag_2, acc_frag_1_2);
|
||||
acc_frag_1_3 = __WMMA_8_16_16_half_half(a_frag_1, b_frag_3, acc_frag_1_3);
|
||||
acc_frag_1_4 = __WMMA_8_16_16_half_half(a_frag_1, b_frag_4, acc_frag_1_4);
|
||||
acc_frag_1_5 = __WMMA_8_16_16_half_half(a_frag_1, b_frag_5, acc_frag_1_5);
|
||||
acc_frag_1_6 = __WMMA_8_16_16_half_half(a_frag_1, b_frag_6, acc_frag_1_6);
|
||||
acc_frag_1_7 = __WMMA_8_16_16_half_half(a_frag_1, b_frag_7, acc_frag_1_7);
|
||||
|
||||
// last 16 K elements
|
||||
__ldmatrix_a_elems(&a_frag_0, &smem_a_curr[load_smem_a_0_k_3]);
|
||||
__ldmatrix_a_elems(&a_frag_1, &smem_a_curr[load_smem_a_1_k_3]);
|
||||
__ldmatrix_b_elems(&b_frag_0, &b_frag_1, &smem_b_curr[load_smem_b_0_k_3]);
|
||||
__ldmatrix_b_elems(&b_frag_2, &b_frag_3, &smem_b_curr[load_smem_b_1_k_3]);
|
||||
__ldmatrix_b_elems(&b_frag_4, &b_frag_5, &smem_b_curr[load_smem_b_2_k_3]);
|
||||
__ldmatrix_b_elems(&b_frag_6, &b_frag_7, &smem_b_curr[load_smem_b_3_k_3]);
|
||||
acc_frag_0_0 = __WMMA_8_16_16_half_half(a_frag_0, b_frag_0, acc_frag_0_0);
|
||||
acc_frag_0_1 = __WMMA_8_16_16_half_half(a_frag_0, b_frag_1, acc_frag_0_1);
|
||||
acc_frag_0_2 = __WMMA_8_16_16_half_half(a_frag_0, b_frag_2, acc_frag_0_2);
|
||||
acc_frag_0_3 = __WMMA_8_16_16_half_half(a_frag_0, b_frag_3, acc_frag_0_3);
|
||||
acc_frag_0_4 = __WMMA_8_16_16_half_half(a_frag_0, b_frag_4, acc_frag_0_4);
|
||||
acc_frag_0_5 = __WMMA_8_16_16_half_half(a_frag_0, b_frag_5, acc_frag_0_5);
|
||||
acc_frag_0_6 = __WMMA_8_16_16_half_half(a_frag_0, b_frag_6, acc_frag_0_6);
|
||||
acc_frag_0_7 = __WMMA_8_16_16_half_half(a_frag_0, b_frag_7, acc_frag_0_7);
|
||||
acc_frag_1_0 = __WMMA_8_16_16_half_half(a_frag_1, b_frag_0, acc_frag_1_0);
|
||||
acc_frag_1_1 = __WMMA_8_16_16_half_half(a_frag_1, b_frag_1, acc_frag_1_1);
|
||||
acc_frag_1_2 = __WMMA_8_16_16_half_half(a_frag_1, b_frag_2, acc_frag_1_2);
|
||||
acc_frag_1_3 = __WMMA_8_16_16_half_half(a_frag_1, b_frag_3, acc_frag_1_3);
|
||||
acc_frag_1_4 = __WMMA_8_16_16_half_half(a_frag_1, b_frag_4, acc_frag_1_4);
|
||||
acc_frag_1_5 = __WMMA_8_16_16_half_half(a_frag_1, b_frag_5, acc_frag_1_5);
|
||||
acc_frag_1_6 = __WMMA_8_16_16_half_half(a_frag_1, b_frag_6, acc_frag_1_6);
|
||||
acc_frag_1_7 = __WMMA_8_16_16_half_half(a_frag_1, b_frag_7, acc_frag_1_7);
|
||||
|
||||
// prefetch next iteration if needed
|
||||
__syncthreads();
|
||||
if (block_k < (num_k_blocks-2)) {
|
||||
__pipeline_memcpy_async(&smem_a_curr[store_smem_a_off + ( 0)], &data1[global_a_off + ( 0)], 16);
|
||||
__pipeline_memcpy_async(&smem_a_curr[store_smem_a_off + (16*64)], &data1[global_a_off + (16*K)], 16);
|
||||
__pipeline_memcpy_async(&smem_a_curr[store_smem_a_off + (32*64)], &data1[global_a_off + (32*K)], 16);
|
||||
__pipeline_memcpy_async(&smem_a_curr[store_smem_a_off + (48*64)], &data1[global_a_off + (48*K)], 16);
|
||||
|
||||
__pipeline_memcpy_async(&smem_b_curr[store_smem_b_off + ( 0)], &data2[global_b_off + ( 0)], 16);
|
||||
__pipeline_memcpy_async(&smem_b_curr[store_smem_b_off + ( 8*128)], &data2[global_b_off + ( 8*N)], 16);
|
||||
__pipeline_memcpy_async(&smem_b_curr[store_smem_b_off + (16*128)], &data2[global_b_off + (16*N)], 16);
|
||||
__pipeline_memcpy_async(&smem_b_curr[store_smem_b_off + (24*128)], &data2[global_b_off + (24*N)], 16);
|
||||
__pipeline_memcpy_async(&smem_b_curr[store_smem_b_off + (32*128)], &data2[global_b_off + (32*N)], 16);
|
||||
__pipeline_memcpy_async(&smem_b_curr[store_smem_b_off + (40*128)], &data2[global_b_off + (40*N)], 16);
|
||||
__pipeline_memcpy_async(&smem_b_curr[store_smem_b_off + (48*128)], &data2[global_b_off + (48*N)], 16);
|
||||
__pipeline_memcpy_async(&smem_b_curr[store_smem_b_off + (56*128)], &data2[global_b_off + (56*N)], 16);
|
||||
|
||||
global_a_off += 64;
|
||||
global_b_off += 64 * N;
|
||||
}
|
||||
__pipeline_commit();
|
||||
|
||||
if (block_k < num_k_blocks-1) {
|
||||
__pipeline_wait_prior(1);
|
||||
__syncthreads();
|
||||
}
|
||||
}
|
||||
|
||||
// write accumulators to output
|
||||
__pipeline_wait_prior(0);
|
||||
__syncthreads();
|
||||
|
||||
// // store registers to smem first, then read back to do float4 writes to global
|
||||
// float *smem_d = (float *)(smem);
|
||||
// size_t smem_d_off = (wg_m * 16 * N_PAD) + (wg_n * 16) + ((wg_threads % 4) * 2) + (((wg_threads / 4) % 8) * N_PAD);
|
||||
// smem_d[smem_d_off + 0 + ( 0*8) ] = acc_frag_0_0.x;
|
||||
// smem_d[smem_d_off + 1 + ( 0*8) ] = acc_frag_0_0.y;
|
||||
// smem_d[smem_d_off + 0 + ( 0*8) + (8*N_PAD)] = acc_frag_0_0.z;
|
||||
// smem_d[smem_d_off + 1 + ( 0*8) + (8*N_PAD)] = acc_frag_0_0.w;
|
||||
// smem_d[smem_d_off + 0 + ( 1*8) ] = acc_frag_0_1.x;
|
||||
// smem_d[smem_d_off + 1 + ( 1*8) ] = acc_frag_0_1.y;
|
||||
// smem_d[smem_d_off + 0 + ( 1*8) + (8*N_PAD)] = acc_frag_0_1.z;
|
||||
// smem_d[smem_d_off + 1 + ( 1*8) + (8*N_PAD)] = acc_frag_0_1.w;
|
||||
// smem_d[smem_d_off + 0 + ( 4*8) ] = acc_frag_0_2.x;
|
||||
// smem_d[smem_d_off + 1 + ( 4*8) ] = acc_frag_0_2.y;
|
||||
// smem_d[smem_d_off + 0 + ( 4*8) + (8*N_PAD)] = acc_frag_0_2.z;
|
||||
// smem_d[smem_d_off + 1 + ( 4*8) + (8*N_PAD)] = acc_frag_0_2.w;
|
||||
// smem_d[smem_d_off + 0 + ( 5*8) ] = acc_frag_0_3.x;
|
||||
// smem_d[smem_d_off + 1 + ( 5*8) ] = acc_frag_0_3.y;
|
||||
// smem_d[smem_d_off + 0 + ( 5*8) + (8*N_PAD)] = acc_frag_0_3.z;
|
||||
// smem_d[smem_d_off + 1 + ( 5*8) + (8*N_PAD)] = acc_frag_0_3.w;
|
||||
// smem_d[smem_d_off + 0 + ( 8*8) ] = acc_frag_0_4.x;
|
||||
// smem_d[smem_d_off + 1 + ( 8*8) ] = acc_frag_0_4.y;
|
||||
// smem_d[smem_d_off + 0 + ( 8*8) + (8*N_PAD)] = acc_frag_0_4.z;
|
||||
// smem_d[smem_d_off + 1 + ( 8*8) + (8*N_PAD)] = acc_frag_0_4.w;
|
||||
// smem_d[smem_d_off + 0 + ( 9*8) ] = acc_frag_0_5.x;
|
||||
// smem_d[smem_d_off + 1 + ( 9*8) ] = acc_frag_0_5.y;
|
||||
// smem_d[smem_d_off + 0 + ( 9*8) + (8*N_PAD)] = acc_frag_0_5.z;
|
||||
// smem_d[smem_d_off + 1 + ( 9*8) + (8*N_PAD)] = acc_frag_0_5.w;
|
||||
// smem_d[smem_d_off + 0 + (12*8) ] = acc_frag_0_6.x;
|
||||
// smem_d[smem_d_off + 1 + (12*8) ] = acc_frag_0_6.y;
|
||||
// smem_d[smem_d_off + 0 + (12*8) + (8*N_PAD)] = acc_frag_0_6.z;
|
||||
// smem_d[smem_d_off + 1 + (12*8) + (8*N_PAD)] = acc_frag_0_6.w;
|
||||
// smem_d[smem_d_off + 0 + (13*8) ] = acc_frag_0_7.x;
|
||||
// smem_d[smem_d_off + 1 + (13*8) ] = acc_frag_0_7.y;
|
||||
// smem_d[smem_d_off + 0 + (13*8) + (8*N_PAD)] = acc_frag_0_7.z;
|
||||
// smem_d[smem_d_off + 1 + (13*8) + (8*N_PAD)] = acc_frag_0_7.w;
|
||||
|
||||
// __syncthreads();
|
||||
// size_t load_smem_d_off = ((threads % 32) * 4) + ((threads / 32) * N_PAD);
|
||||
// float4 d_0_0 = *((float4 *)(smem_d + load_smem_d_off + ( 0 * N_PAD)));
|
||||
// float4 d_0_1 = *((float4 *)(smem_d + load_smem_d_off + ( 4 * N_PAD)));
|
||||
// float4 d_0_2 = *((float4 *)(smem_d + load_smem_d_off + ( 8 * N_PAD)));
|
||||
// float4 d_0_3 = *((float4 *)(smem_d + load_smem_d_off + (12 * N_PAD)));
|
||||
// float4 d_0_4 = *((float4 *)(smem_d + load_smem_d_off + (16 * N_PAD)));
|
||||
// float4 d_0_5 = *((float4 *)(smem_d + load_smem_d_off + (20 * N_PAD)));
|
||||
// float4 d_0_6 = *((float4 *)(smem_d + load_smem_d_off + (24 * N_PAD)));
|
||||
// float4 d_0_7 = *((float4 *)(smem_d + load_smem_d_off + (28 * N_PAD)));
|
||||
|
||||
// __syncthreads();
|
||||
// smem_d[smem_d_off + 0 + ( 0*8) ] = acc_frag_1_0.x;
|
||||
// smem_d[smem_d_off + 1 + ( 0*8) ] = acc_frag_1_0.y;
|
||||
// smem_d[smem_d_off + 0 + ( 0*8) + (8*N_PAD)] = acc_frag_1_0.z;
|
||||
// smem_d[smem_d_off + 1 + ( 0*8) + (8*N_PAD)] = acc_frag_1_0.w;
|
||||
// smem_d[smem_d_off + 0 + ( 1*8) ] = acc_frag_1_1.x;
|
||||
// smem_d[smem_d_off + 1 + ( 1*8) ] = acc_frag_1_1.y;
|
||||
// smem_d[smem_d_off + 0 + ( 1*8) + (8*N_PAD)] = acc_frag_1_1.z;
|
||||
// smem_d[smem_d_off + 1 + ( 1*8) + (8*N_PAD)] = acc_frag_1_1.w;
|
||||
// smem_d[smem_d_off + 0 + ( 4*8) ] = acc_frag_1_2.x;
|
||||
// smem_d[smem_d_off + 1 + ( 4*8) ] = acc_frag_1_2.y;
|
||||
// smem_d[smem_d_off + 0 + ( 4*8) + (8*N_PAD)] = acc_frag_1_2.z;
|
||||
// smem_d[smem_d_off + 1 + ( 4*8) + (8*N_PAD)] = acc_frag_1_2.w;
|
||||
// smem_d[smem_d_off + 0 + ( 5*8) ] = acc_frag_1_3.x;
|
||||
// smem_d[smem_d_off + 1 + ( 5*8) ] = acc_frag_1_3.y;
|
||||
// smem_d[smem_d_off + 0 + ( 5*8) + (8*N_PAD)] = acc_frag_1_3.z;
|
||||
// smem_d[smem_d_off + 1 + ( 5*8) + (8*N_PAD)] = acc_frag_1_3.w;
|
||||
// smem_d[smem_d_off + 0 + ( 8*8) ] = acc_frag_1_4.x;
|
||||
// smem_d[smem_d_off + 1 + ( 8*8) ] = acc_frag_1_4.y;
|
||||
// smem_d[smem_d_off + 0 + ( 8*8) + (8*N_PAD)] = acc_frag_1_4.z;
|
||||
// smem_d[smem_d_off + 1 + ( 8*8) + (8*N_PAD)] = acc_frag_1_4.w;
|
||||
// smem_d[smem_d_off + 0 + ( 9*8) ] = acc_frag_1_5.x;
|
||||
// smem_d[smem_d_off + 1 + ( 9*8) ] = acc_frag_1_5.y;
|
||||
// smem_d[smem_d_off + 0 + ( 9*8) + (8*N_PAD)] = acc_frag_1_5.z;
|
||||
// smem_d[smem_d_off + 1 + ( 9*8) + (8*N_PAD)] = acc_frag_1_5.w;
|
||||
// smem_d[smem_d_off + 0 + (12*8) ] = acc_frag_1_6.x;
|
||||
// smem_d[smem_d_off + 1 + (12*8) ] = acc_frag_1_6.y;
|
||||
// smem_d[smem_d_off + 0 + (12*8) + (8*N_PAD)] = acc_frag_1_6.z;
|
||||
// smem_d[smem_d_off + 1 + (12*8) + (8*N_PAD)] = acc_frag_1_6.w;
|
||||
// smem_d[smem_d_off + 0 + (13*8) ] = acc_frag_1_7.x;
|
||||
// smem_d[smem_d_off + 1 + (13*8) ] = acc_frag_1_7.y;
|
||||
// smem_d[smem_d_off + 0 + (13*8) + (8*N_PAD)] = acc_frag_1_7.z;
|
||||
// smem_d[smem_d_off + 1 + (13*8) + (8*N_PAD)] = acc_frag_1_7.w;
|
||||
|
||||
// __syncthreads();
|
||||
// float4 d_1_0 = *((float4 *)(smem_d + load_smem_d_off + ( 0 * N_PAD)));
|
||||
// float4 d_1_1 = *((float4 *)(smem_d + load_smem_d_off + ( 4 * N_PAD)));
|
||||
// float4 d_1_2 = *((float4 *)(smem_d + load_smem_d_off + ( 8 * N_PAD)));
|
||||
// float4 d_1_3 = *((float4 *)(smem_d + load_smem_d_off + (12 * N_PAD)));
|
||||
// float4 d_1_4 = *((float4 *)(smem_d + load_smem_d_off + (16 * N_PAD)));
|
||||
// float4 d_1_5 = *((float4 *)(smem_d + load_smem_d_off + (20 * N_PAD)));
|
||||
// float4 d_1_6 = *((float4 *)(smem_d + load_smem_d_off + (24 * N_PAD)));
|
||||
// float4 d_1_7 = *((float4 *)(smem_d + load_smem_d_off + (28 * N_PAD)));
|
||||
|
||||
// __syncthreads();
|
||||
// float *global_d = &data0[((grid_m * 64) * N) + (grid_n * 128) + ((threads % 32) * 4) + ((threads / 32) * N)];
|
||||
// *((float4 *)(global_d + 0*N)) = d_0_0;
|
||||
// *((float4 *)(global_d + 4*N)) = d_0_1;
|
||||
// *((float4 *)(global_d + 8*N)) = d_0_2;
|
||||
// *((float4 *)(global_d + 12*N)) = d_0_3;
|
||||
// *((float4 *)(global_d + 16*N)) = d_0_4;
|
||||
// *((float4 *)(global_d + 20*N)) = d_0_5;
|
||||
// *((float4 *)(global_d + 24*N)) = d_0_6;
|
||||
// *((float4 *)(global_d + 28*N)) = d_0_7;
|
||||
// *((float4 *)(global_d + 32*N)) = d_1_0;
|
||||
// *((float4 *)(global_d + 36*N)) = d_1_1;
|
||||
// *((float4 *)(global_d + 40*N)) = d_1_2;
|
||||
// *((float4 *)(global_d + 44*N)) = d_1_3;
|
||||
// *((float4 *)(global_d + 48*N)) = d_1_4;
|
||||
// *((float4 *)(global_d + 52*N)) = d_1_5;
|
||||
// *((float4 *)(global_d + 56*N)) = d_1_6;
|
||||
// *((float4 *)(global_d + 60*N)) = d_1_7;
|
||||
|
||||
// slower way: write floats one by one to data0
|
||||
size_t wg_c_off = ((grid_m * 64) * N) + (grid_n * 128) + (wg_m * 16 * N) + (wg_n * 16);
|
||||
size_t thread_c_off = ((wg_threads % 4) * 2) + (((wg_threads / 4) % 8) * N);
|
||||
data0[wg_c_off + thread_c_off + 0 + ( 0*8)] = acc_frag_0_0.x;
|
||||
data0[wg_c_off + thread_c_off + 1 + ( 0*8)] = acc_frag_0_0.y;
|
||||
data0[wg_c_off + thread_c_off + (8 * N) + 0 + ( 0*8)] = acc_frag_0_0.z;
|
||||
data0[wg_c_off + thread_c_off + (8 * N) + 1 + ( 0*8)] = acc_frag_0_0.w;
|
||||
data0[wg_c_off + thread_c_off + 0 + ( 1*8)] = acc_frag_0_1.x;
|
||||
data0[wg_c_off + thread_c_off + 1 + ( 1*8)] = acc_frag_0_1.y;
|
||||
data0[wg_c_off + thread_c_off + (8 * N) + 0 + ( 1*8)] = acc_frag_0_1.z;
|
||||
data0[wg_c_off + thread_c_off + (8 * N) + 1 + ( 1*8)] = acc_frag_0_1.w;
|
||||
data0[wg_c_off + thread_c_off + 0 + ( 4*8)] = acc_frag_0_2.x;
|
||||
data0[wg_c_off + thread_c_off + 1 + ( 4*8)] = acc_frag_0_2.y;
|
||||
data0[wg_c_off + thread_c_off + (8 * N) + 0 + ( 4*8)] = acc_frag_0_2.z;
|
||||
data0[wg_c_off + thread_c_off + (8 * N) + 1 + ( 4*8)] = acc_frag_0_2.w;
|
||||
data0[wg_c_off + thread_c_off + 0 + ( 5*8)] = acc_frag_0_3.x;
|
||||
data0[wg_c_off + thread_c_off + 1 + ( 5*8)] = acc_frag_0_3.y;
|
||||
data0[wg_c_off + thread_c_off + (8 * N) + 0 + ( 5*8)] = acc_frag_0_3.z;
|
||||
data0[wg_c_off + thread_c_off + (8 * N) + 1 + ( 5*8)] = acc_frag_0_3.w;
|
||||
data0[wg_c_off + thread_c_off + 0 + ( 8*8)] = acc_frag_0_4.x;
|
||||
data0[wg_c_off + thread_c_off + 1 + ( 8*8)] = acc_frag_0_4.y;
|
||||
data0[wg_c_off + thread_c_off + (8 * N) + 0 + ( 8*8)] = acc_frag_0_4.z;
|
||||
data0[wg_c_off + thread_c_off + (8 * N) + 1 + ( 8*8)] = acc_frag_0_4.w;
|
||||
data0[wg_c_off + thread_c_off + 0 + ( 9*8)] = acc_frag_0_5.x;
|
||||
data0[wg_c_off + thread_c_off + 1 + ( 9*8)] = acc_frag_0_5.y;
|
||||
data0[wg_c_off + thread_c_off + (8 * N) + 0 + ( 9*8)] = acc_frag_0_5.z;
|
||||
data0[wg_c_off + thread_c_off + (8 * N) + 1 + ( 9*8)] = acc_frag_0_5.w;
|
||||
data0[wg_c_off + thread_c_off + 0 + (12*8)] = acc_frag_0_6.x;
|
||||
data0[wg_c_off + thread_c_off + 1 + (12*8)] = acc_frag_0_6.y;
|
||||
data0[wg_c_off + thread_c_off + (8 * N) + 0 + (12*8)] = acc_frag_0_6.z;
|
||||
data0[wg_c_off + thread_c_off + (8 * N) + 1 + (12*8)] = acc_frag_0_6.w;
|
||||
data0[wg_c_off + thread_c_off + 0 + (13*8)] = acc_frag_0_7.x;
|
||||
data0[wg_c_off + thread_c_off + 1 + (13*8)] = acc_frag_0_7.y;
|
||||
data0[wg_c_off + thread_c_off + (8 * N) + 0 + (13*8)] = acc_frag_0_7.z;
|
||||
data0[wg_c_off + thread_c_off + (8 * N) + 1 + (13*8)] = acc_frag_0_7.w;
|
||||
wg_c_off += 32*N;
|
||||
data0[wg_c_off + thread_c_off + 0 + ( 0*8)] = acc_frag_1_0.x;
|
||||
data0[wg_c_off + thread_c_off + 1 + ( 0*8)] = acc_frag_1_0.y;
|
||||
data0[wg_c_off + thread_c_off + (8 * N) + 0 + ( 0*8)] = acc_frag_1_0.z;
|
||||
data0[wg_c_off + thread_c_off + (8 * N) + 1 + ( 0*8)] = acc_frag_1_0.w;
|
||||
data0[wg_c_off + thread_c_off + 0 + ( 1*8)] = acc_frag_1_1.x;
|
||||
data0[wg_c_off + thread_c_off + 1 + ( 1*8)] = acc_frag_1_1.y;
|
||||
data0[wg_c_off + thread_c_off + (8 * N) + 0 + ( 1*8)] = acc_frag_1_1.z;
|
||||
data0[wg_c_off + thread_c_off + (8 * N) + 1 + ( 1*8)] = acc_frag_1_1.w;
|
||||
data0[wg_c_off + thread_c_off + 0 + ( 4*8)] = acc_frag_1_2.x;
|
||||
data0[wg_c_off + thread_c_off + 1 + ( 4*8)] = acc_frag_1_2.y;
|
||||
data0[wg_c_off + thread_c_off + (8 * N) + 0 + ( 4*8)] = acc_frag_1_2.z;
|
||||
data0[wg_c_off + thread_c_off + (8 * N) + 1 + ( 4*8)] = acc_frag_1_2.w;
|
||||
data0[wg_c_off + thread_c_off + 0 + ( 5*8)] = acc_frag_1_3.x;
|
||||
data0[wg_c_off + thread_c_off + 1 + ( 5*8)] = acc_frag_1_3.y;
|
||||
data0[wg_c_off + thread_c_off + (8 * N) + 0 + ( 5*8)] = acc_frag_1_3.z;
|
||||
data0[wg_c_off + thread_c_off + (8 * N) + 1 + ( 5*8)] = acc_frag_1_3.w;
|
||||
data0[wg_c_off + thread_c_off + 0 + ( 8*8)] = acc_frag_1_4.x;
|
||||
data0[wg_c_off + thread_c_off + 1 + ( 8*8)] = acc_frag_1_4.y;
|
||||
data0[wg_c_off + thread_c_off + (8 * N) + 0 + ( 8*8)] = acc_frag_1_4.z;
|
||||
data0[wg_c_off + thread_c_off + (8 * N) + 1 + ( 8*8)] = acc_frag_1_4.w;
|
||||
data0[wg_c_off + thread_c_off + 0 + ( 9*8)] = acc_frag_1_5.x;
|
||||
data0[wg_c_off + thread_c_off + 1 + ( 9*8)] = acc_frag_1_5.y;
|
||||
data0[wg_c_off + thread_c_off + (8 * N) + 0 + ( 9*8)] = acc_frag_1_5.z;
|
||||
data0[wg_c_off + thread_c_off + (8 * N) + 1 + ( 9*8)] = acc_frag_1_5.w;
|
||||
data0[wg_c_off + thread_c_off + 0 + (12*8)] = acc_frag_1_6.x;
|
||||
data0[wg_c_off + thread_c_off + 1 + (12*8)] = acc_frag_1_6.y;
|
||||
data0[wg_c_off + thread_c_off + (8 * N) + 0 + (12*8)] = acc_frag_1_6.z;
|
||||
data0[wg_c_off + thread_c_off + (8 * N) + 1 + (12*8)] = acc_frag_1_6.w;
|
||||
data0[wg_c_off + thread_c_off + 0 + (13*8)] = acc_frag_1_7.x;
|
||||
data0[wg_c_off + thread_c_off + 1 + (13*8)] = acc_frag_1_7.y;
|
||||
data0[wg_c_off + thread_c_off + (8 * N) + 0 + (13*8)] = acc_frag_1_7.z;
|
||||
data0[wg_c_off + thread_c_off + (8 * N) + 1 + (13*8)] = acc_frag_1_7.w;
|
||||
}
|
||||
465
extra/gemm/max_kernels/nv.fp16_fp16_fp16.3_stage.cu
Normal file
465
extra/gemm/max_kernels/nv.fp16_fp16_fp16.3_stage.cu
Normal file
@@ -0,0 +1,465 @@
|
||||
#define INFINITY (__int_as_float(0x7f800000))
|
||||
#define NAN (__int_as_float(0x7fffffff))
|
||||
#include <cuda_fp16.h>
|
||||
#include <cuda_pipeline.h>
|
||||
#define N_PAD 132
|
||||
|
||||
struct __align__(8) half4 { half x, y, z, w; };
|
||||
__device__ half4 make_half4(half x, half y, half z, half w) { half4 r={x, y, z, w}; return r; }
|
||||
|
||||
struct __align__(16) half8 { half x, y, z, w, a, b, c, d; };
|
||||
__device__ half8 make_half8(half x, half y, half z, half w, half a, half b, half c, half d) { half8 r={x, y, z, w, a, b, c, d}; return r; }
|
||||
|
||||
__device__ void __ldmatrix_a_elems(half8 *regs, half *smem) {
|
||||
uint32_t reg0, reg1, reg2, reg3;
|
||||
asm volatile(
|
||||
"ldmatrix.sync.aligned.m8n8.x4.shared.b16 {%0, %1, %2, %3}, [%4];"
|
||||
: "=r"(reg0), "=r"(reg1), "=r"(reg2), "=r"(reg3)
|
||||
: "l"(__cvta_generic_to_shared(smem))
|
||||
);
|
||||
uint32_t *addr = reinterpret_cast<uint32_t*>(regs);
|
||||
addr[0] = reg0;
|
||||
addr[1] = reg1;
|
||||
addr[2] = reg2;
|
||||
addr[3] = reg3;
|
||||
}
|
||||
|
||||
__device__ void __ldmatrix_b_elems(half4 *regs_lo, half4 *regs_hi, half *smem) {
|
||||
uint32_t reg0, reg1, reg2, reg3;
|
||||
asm volatile(
|
||||
"ldmatrix.sync.aligned.m8n8.x4.trans.shared.b16 {%0, %1, %2, %3}, [%4];"
|
||||
: "=r"(reg0), "=r"(reg1), "=r"(reg2), "=r"(reg3)
|
||||
: "l"(__cvta_generic_to_shared(smem))
|
||||
);
|
||||
uint32_t *addr_lo = reinterpret_cast<uint32_t*>(regs_lo);
|
||||
uint32_t *addr_hi = reinterpret_cast<uint32_t*>(regs_hi);
|
||||
addr_lo[0] = reg0;
|
||||
addr_lo[1] = reg1;
|
||||
addr_hi[0] = reg2;
|
||||
addr_hi[1] = reg3;
|
||||
}
|
||||
|
||||
__device__ half4 __WMMA_8_16_16_half_half(half8 a, half4 b, half4 c) {
|
||||
int *a_pk = (int *) (&a), *b_pk = (int *) (&b), *c_pk = (int *) (&c);
|
||||
asm( "mma.sync.aligned.m16n8k16.row.col.f16.f16.f16.f16 { %0, %1 }, { %2, %3, %4, %5 }, { %6, %7 }, { %0, %1 };"
|
||||
: "+r"(c_pk[0]), "+r"(c_pk[1]): "r"(a_pk[0]), "r"(a_pk[1]), "r"(a_pk[2]), "r"(a_pk[3]), "r"(b_pk[0]), "r"(b_pk[1]) );
|
||||
return c;
|
||||
}
|
||||
|
||||
extern "C" __global__ void __launch_bounds__(256) wmma_example(half* data0, const half* data1, const half* data2, int N, int K) {
|
||||
extern __shared__ char smem[];
|
||||
half *smem_a_0 = (half *)(smem);
|
||||
half *smem_a_1 = (half *)(smem + 16384);
|
||||
half *smem_a_2 = (half *)(smem + 32768);
|
||||
half *smem_b_0 = (half *)(smem + 49152);
|
||||
half *smem_b_1 = (half *)(smem + 57344);
|
||||
half *smem_b_2 = (half *)(smem + 65536);
|
||||
|
||||
int grid_m = blockIdx.x; /* M//256 */
|
||||
int grid_n = blockIdx.y; /* N//128 */
|
||||
int wg_threads = threadIdx.x; // 32
|
||||
int wg_m = threadIdx.y; // 4
|
||||
int wg_n = threadIdx.z; // 2
|
||||
int threads = threadIdx.x + (threadIdx.y * 32) + (threadIdx.z * 128); /* 256 */
|
||||
int num_k_blocks = K / 32;
|
||||
|
||||
// load indexes
|
||||
size_t global_a_off = ((grid_m * 256) * K) + ((threads % 4) * 8) + ((threads / 4) * K);
|
||||
size_t global_b_off = (grid_n * 128) + ((threads % 16) * 8) + ((threads / 16) * N);
|
||||
|
||||
// unswizzed smem store
|
||||
size_t store_smem_a_off = ((threads % 4) * 8) + ((threads / 4) * 32); // 64 rows / 32 cols per copy
|
||||
size_t store_smem_b_off = ((threads % 16) * 8) + ((threads / 16) * 128); // 16 rows / 128 cols per copy
|
||||
|
||||
// ldmatrix indices
|
||||
// threads 0-7 are row starts for A, 8-15 for B, 16-23 for C, 24-31 for D
|
||||
// [ A | C ]
|
||||
// [ - + - ]
|
||||
// [ B | D ]
|
||||
|
||||
// unswizzed ldmatrix
|
||||
size_t load_smem_a_0_k_0 = (wg_m * 16 * 32) + ((wg_threads % 16) * 32) + ((wg_threads / 16) * 8);
|
||||
size_t load_smem_a_1_k_0 = load_smem_a_0_k_0 + ( 64 * 32);
|
||||
size_t load_smem_a_2_k_0 = load_smem_a_0_k_0 + (128 * 32);
|
||||
size_t load_smem_a_3_k_0 = load_smem_a_0_k_0 + (192 * 32);
|
||||
|
||||
size_t load_smem_a_0_k_1 = load_smem_a_0_k_0 + 16;
|
||||
size_t load_smem_a_1_k_1 = load_smem_a_0_k_1 + ( 64 * 32);
|
||||
size_t load_smem_a_2_k_1 = load_smem_a_0_k_1 + (128 * 32);
|
||||
size_t load_smem_a_3_k_1 = load_smem_a_0_k_1 + (192 * 32);
|
||||
|
||||
size_t load_smem_b_0_k_0 = (wg_n * 16) + ((wg_threads % 16) * 128) + ((wg_threads / 16) * 8);
|
||||
size_t load_smem_b_1_k_0 = load_smem_b_0_k_0 + 32;
|
||||
size_t load_smem_b_2_k_0 = load_smem_b_0_k_0 + 64;
|
||||
size_t load_smem_b_3_k_0 = load_smem_b_0_k_0 + 96;
|
||||
|
||||
size_t load_smem_b_0_k_1 = load_smem_b_0_k_0 + (16 * 128);
|
||||
size_t load_smem_b_1_k_1 = load_smem_b_0_k_1 + 32;
|
||||
size_t load_smem_b_2_k_1 = load_smem_b_0_k_1 + 64;
|
||||
size_t load_smem_b_3_k_1 = load_smem_b_0_k_1 + 96;
|
||||
|
||||
// create accs (M=4, N=8)
|
||||
half4 acc_frag_0_0 = make_half4(0.0f,0.0f,0.0f,0.0f);
|
||||
half4 acc_frag_0_1 = make_half4(0.0f,0.0f,0.0f,0.0f);
|
||||
half4 acc_frag_0_2 = make_half4(0.0f,0.0f,0.0f,0.0f);
|
||||
half4 acc_frag_0_3 = make_half4(0.0f,0.0f,0.0f,0.0f);
|
||||
half4 acc_frag_0_4 = make_half4(0.0f,0.0f,0.0f,0.0f);
|
||||
half4 acc_frag_0_5 = make_half4(0.0f,0.0f,0.0f,0.0f);
|
||||
half4 acc_frag_0_6 = make_half4(0.0f,0.0f,0.0f,0.0f);
|
||||
half4 acc_frag_0_7 = make_half4(0.0f,0.0f,0.0f,0.0f);
|
||||
half4 acc_frag_1_0 = make_half4(0.0f,0.0f,0.0f,0.0f);
|
||||
half4 acc_frag_1_1 = make_half4(0.0f,0.0f,0.0f,0.0f);
|
||||
half4 acc_frag_1_2 = make_half4(0.0f,0.0f,0.0f,0.0f);
|
||||
half4 acc_frag_1_3 = make_half4(0.0f,0.0f,0.0f,0.0f);
|
||||
half4 acc_frag_1_4 = make_half4(0.0f,0.0f,0.0f,0.0f);
|
||||
half4 acc_frag_1_5 = make_half4(0.0f,0.0f,0.0f,0.0f);
|
||||
half4 acc_frag_1_6 = make_half4(0.0f,0.0f,0.0f,0.0f);
|
||||
half4 acc_frag_1_7 = make_half4(0.0f,0.0f,0.0f,0.0f);
|
||||
half4 acc_frag_2_0 = make_half4(0.0f,0.0f,0.0f,0.0f);
|
||||
half4 acc_frag_2_1 = make_half4(0.0f,0.0f,0.0f,0.0f);
|
||||
half4 acc_frag_2_2 = make_half4(0.0f,0.0f,0.0f,0.0f);
|
||||
half4 acc_frag_2_3 = make_half4(0.0f,0.0f,0.0f,0.0f);
|
||||
half4 acc_frag_2_4 = make_half4(0.0f,0.0f,0.0f,0.0f);
|
||||
half4 acc_frag_2_5 = make_half4(0.0f,0.0f,0.0f,0.0f);
|
||||
half4 acc_frag_2_6 = make_half4(0.0f,0.0f,0.0f,0.0f);
|
||||
half4 acc_frag_2_7 = make_half4(0.0f,0.0f,0.0f,0.0f);
|
||||
half4 acc_frag_3_0 = make_half4(0.0f,0.0f,0.0f,0.0f);
|
||||
half4 acc_frag_3_1 = make_half4(0.0f,0.0f,0.0f,0.0f);
|
||||
half4 acc_frag_3_2 = make_half4(0.0f,0.0f,0.0f,0.0f);
|
||||
half4 acc_frag_3_3 = make_half4(0.0f,0.0f,0.0f,0.0f);
|
||||
half4 acc_frag_3_4 = make_half4(0.0f,0.0f,0.0f,0.0f);
|
||||
half4 acc_frag_3_5 = make_half4(0.0f,0.0f,0.0f,0.0f);
|
||||
half4 acc_frag_3_6 = make_half4(0.0f,0.0f,0.0f,0.0f);
|
||||
half4 acc_frag_3_7 = make_half4(0.0f,0.0f,0.0f,0.0f);
|
||||
|
||||
// create registers for block A elements
|
||||
half8 a_frag_0_k_0;
|
||||
half8 a_frag_1_k_0;
|
||||
half8 a_frag_2_k_0;
|
||||
half8 a_frag_3_k_0;
|
||||
half8 a_frag_0_k_1;
|
||||
half8 a_frag_1_k_1;
|
||||
half8 a_frag_2_k_1;
|
||||
half8 a_frag_3_k_1;
|
||||
|
||||
// create register for block B elements
|
||||
half4 b_frag_0_k_0;
|
||||
half4 b_frag_1_k_0;
|
||||
half4 b_frag_2_k_0;
|
||||
half4 b_frag_3_k_0;
|
||||
half4 b_frag_4_k_0;
|
||||
half4 b_frag_5_k_0;
|
||||
half4 b_frag_6_k_0;
|
||||
half4 b_frag_7_k_0;
|
||||
half4 b_frag_0_k_1;
|
||||
half4 b_frag_1_k_1;
|
||||
half4 b_frag_2_k_1;
|
||||
half4 b_frag_3_k_1;
|
||||
half4 b_frag_4_k_1;
|
||||
half4 b_frag_5_k_1;
|
||||
half4 b_frag_6_k_1;
|
||||
half4 b_frag_7_k_1;
|
||||
|
||||
__syncthreads();
|
||||
|
||||
// load first tile
|
||||
__pipeline_memcpy_async(&smem_a_0[store_smem_a_off + ( 0)], &data1[global_a_off + ( 0)], 16);
|
||||
__pipeline_memcpy_async(&smem_a_0[store_smem_a_off + ( 64*32)], &data1[global_a_off + ( 64*K)], 16);
|
||||
__pipeline_memcpy_async(&smem_a_0[store_smem_a_off + (128*32)], &data1[global_a_off + (128*K)], 16);
|
||||
__pipeline_memcpy_async(&smem_a_0[store_smem_a_off + (192*32)], &data1[global_a_off + (192*K)], 16);
|
||||
__pipeline_memcpy_async(&smem_b_0[store_smem_b_off + ( 0)], &data2[global_b_off + ( 0)], 16);
|
||||
__pipeline_memcpy_async(&smem_b_0[store_smem_b_off + (16*128)], &data2[global_b_off + ( 16*N)], 16);
|
||||
__pipeline_commit();
|
||||
|
||||
global_a_off += 32;
|
||||
global_b_off += 32 * N;
|
||||
|
||||
// load second tile
|
||||
__pipeline_memcpy_async(&smem_a_1[store_smem_a_off + ( 0)], &data1[global_a_off + ( 0)], 16);
|
||||
__pipeline_memcpy_async(&smem_a_1[store_smem_a_off + ( 64*32)], &data1[global_a_off + ( 64*K)], 16);
|
||||
__pipeline_memcpy_async(&smem_a_1[store_smem_a_off + (128*32)], &data1[global_a_off + (128*K)], 16);
|
||||
__pipeline_memcpy_async(&smem_a_1[store_smem_a_off + (192*32)], &data1[global_a_off + (192*K)], 16);
|
||||
__pipeline_memcpy_async(&smem_b_1[store_smem_b_off + ( 0)], &data2[global_b_off + ( 0)], 16);
|
||||
__pipeline_memcpy_async(&smem_b_1[store_smem_b_off + (16*128)], &data2[global_b_off + ( 16*N)], 16);
|
||||
__pipeline_commit();
|
||||
|
||||
global_a_off += 32;
|
||||
global_b_off += 32 * N;
|
||||
|
||||
// wait on first pre-fetch load
|
||||
__pipeline_wait_prior(1);
|
||||
__syncthreads();
|
||||
|
||||
// load K=0 for the first tile
|
||||
__ldmatrix_a_elems(&a_frag_0_k_0, &smem_a_0[load_smem_a_0_k_0]);
|
||||
__ldmatrix_a_elems(&a_frag_1_k_0, &smem_a_0[load_smem_a_1_k_0]);
|
||||
__ldmatrix_a_elems(&a_frag_2_k_0, &smem_a_0[load_smem_a_2_k_0]);
|
||||
__ldmatrix_a_elems(&a_frag_3_k_0, &smem_a_0[load_smem_a_3_k_0]);
|
||||
__ldmatrix_b_elems(&b_frag_0_k_0, &b_frag_1_k_0, &smem_b_0[load_smem_b_0_k_0]);
|
||||
__ldmatrix_b_elems(&b_frag_2_k_0, &b_frag_3_k_0, &smem_b_0[load_smem_b_1_k_0]);
|
||||
__ldmatrix_b_elems(&b_frag_4_k_0, &b_frag_5_k_0, &smem_b_0[load_smem_b_2_k_0]);
|
||||
__ldmatrix_b_elems(&b_frag_6_k_0, &b_frag_7_k_0, &smem_b_0[load_smem_b_3_k_0]);
|
||||
|
||||
for (int block_k = 0; block_k < num_k_blocks; block_k++) {
|
||||
int phase_k = block_k % 3;
|
||||
half *smem_a_curr = (phase_k == 0) ? smem_a_0 : ((phase_k == 1) ? smem_a_1 : smem_a_2);
|
||||
half *smem_b_curr = (phase_k == 0) ? smem_b_0 : ((phase_k == 1) ? smem_b_1 : smem_b_2);
|
||||
|
||||
int next_phase_k = (block_k+1) % 3;
|
||||
half *smem_a_next = (next_phase_k == 0) ? smem_a_0 : ((next_phase_k == 1) ? smem_a_1 : smem_a_2);
|
||||
half *smem_b_next = (next_phase_k == 0) ? smem_b_0 : ((next_phase_k == 1) ? smem_b_1 : smem_b_2);
|
||||
|
||||
int store_phase_k = (block_k+2) % 3;
|
||||
half *smem_a_store = (store_phase_k == 0) ? smem_a_0 : ((store_phase_k == 1) ? smem_a_1 : smem_a_2);
|
||||
half *smem_b_store = (store_phase_k == 0) ? smem_b_0 : ((store_phase_k == 1) ? smem_b_1 : smem_b_2);
|
||||
|
||||
// load K=1 elements for the current tile
|
||||
__ldmatrix_a_elems(&a_frag_0_k_1, &smem_a_curr[load_smem_a_0_k_1]);
|
||||
__ldmatrix_a_elems(&a_frag_1_k_1, &smem_a_curr[load_smem_a_1_k_1]);
|
||||
__ldmatrix_a_elems(&a_frag_2_k_1, &smem_a_curr[load_smem_a_2_k_1]);
|
||||
__ldmatrix_a_elems(&a_frag_3_k_1, &smem_a_curr[load_smem_a_3_k_1]);
|
||||
__ldmatrix_b_elems(&b_frag_0_k_1, &b_frag_1_k_1, &smem_b_curr[load_smem_b_0_k_1]);
|
||||
__ldmatrix_b_elems(&b_frag_2_k_1, &b_frag_3_k_1, &smem_b_curr[load_smem_b_1_k_1]);
|
||||
__ldmatrix_b_elems(&b_frag_4_k_1, &b_frag_5_k_1, &smem_b_curr[load_smem_b_2_k_1]);
|
||||
__ldmatrix_b_elems(&b_frag_6_k_1, &b_frag_7_k_1, &smem_b_curr[load_smem_b_3_k_1]);
|
||||
|
||||
// MMA K=0, (M=4 x N=8)
|
||||
acc_frag_0_0 = __WMMA_8_16_16_half_half(a_frag_0_k_0, b_frag_0_k_0, acc_frag_0_0);
|
||||
acc_frag_0_1 = __WMMA_8_16_16_half_half(a_frag_0_k_0, b_frag_1_k_0, acc_frag_0_1);
|
||||
acc_frag_0_2 = __WMMA_8_16_16_half_half(a_frag_0_k_0, b_frag_2_k_0, acc_frag_0_2);
|
||||
acc_frag_0_3 = __WMMA_8_16_16_half_half(a_frag_0_k_0, b_frag_3_k_0, acc_frag_0_3);
|
||||
acc_frag_0_4 = __WMMA_8_16_16_half_half(a_frag_0_k_0, b_frag_4_k_0, acc_frag_0_4);
|
||||
acc_frag_0_5 = __WMMA_8_16_16_half_half(a_frag_0_k_0, b_frag_5_k_0, acc_frag_0_5);
|
||||
acc_frag_0_6 = __WMMA_8_16_16_half_half(a_frag_0_k_0, b_frag_6_k_0, acc_frag_0_6);
|
||||
acc_frag_0_7 = __WMMA_8_16_16_half_half(a_frag_0_k_0, b_frag_7_k_0, acc_frag_0_7);
|
||||
acc_frag_1_0 = __WMMA_8_16_16_half_half(a_frag_1_k_0, b_frag_0_k_0, acc_frag_1_0);
|
||||
acc_frag_1_1 = __WMMA_8_16_16_half_half(a_frag_1_k_0, b_frag_1_k_0, acc_frag_1_1);
|
||||
acc_frag_1_2 = __WMMA_8_16_16_half_half(a_frag_1_k_0, b_frag_2_k_0, acc_frag_1_2);
|
||||
acc_frag_1_3 = __WMMA_8_16_16_half_half(a_frag_1_k_0, b_frag_3_k_0, acc_frag_1_3);
|
||||
acc_frag_1_4 = __WMMA_8_16_16_half_half(a_frag_1_k_0, b_frag_4_k_0, acc_frag_1_4);
|
||||
acc_frag_1_5 = __WMMA_8_16_16_half_half(a_frag_1_k_0, b_frag_5_k_0, acc_frag_1_5);
|
||||
acc_frag_1_6 = __WMMA_8_16_16_half_half(a_frag_1_k_0, b_frag_6_k_0, acc_frag_1_6);
|
||||
acc_frag_1_7 = __WMMA_8_16_16_half_half(a_frag_1_k_0, b_frag_7_k_0, acc_frag_1_7);
|
||||
acc_frag_2_0 = __WMMA_8_16_16_half_half(a_frag_2_k_0, b_frag_0_k_0, acc_frag_2_0);
|
||||
acc_frag_2_1 = __WMMA_8_16_16_half_half(a_frag_2_k_0, b_frag_1_k_0, acc_frag_2_1);
|
||||
acc_frag_2_2 = __WMMA_8_16_16_half_half(a_frag_2_k_0, b_frag_2_k_0, acc_frag_2_2);
|
||||
acc_frag_2_3 = __WMMA_8_16_16_half_half(a_frag_2_k_0, b_frag_3_k_0, acc_frag_2_3);
|
||||
acc_frag_2_4 = __WMMA_8_16_16_half_half(a_frag_2_k_0, b_frag_4_k_0, acc_frag_2_4);
|
||||
acc_frag_2_5 = __WMMA_8_16_16_half_half(a_frag_2_k_0, b_frag_5_k_0, acc_frag_2_5);
|
||||
acc_frag_2_6 = __WMMA_8_16_16_half_half(a_frag_2_k_0, b_frag_6_k_0, acc_frag_2_6);
|
||||
acc_frag_2_7 = __WMMA_8_16_16_half_half(a_frag_2_k_0, b_frag_7_k_0, acc_frag_2_7);
|
||||
acc_frag_3_0 = __WMMA_8_16_16_half_half(a_frag_3_k_0, b_frag_0_k_0, acc_frag_3_0);
|
||||
acc_frag_3_1 = __WMMA_8_16_16_half_half(a_frag_3_k_0, b_frag_1_k_0, acc_frag_3_1);
|
||||
acc_frag_3_2 = __WMMA_8_16_16_half_half(a_frag_3_k_0, b_frag_2_k_0, acc_frag_3_2);
|
||||
acc_frag_3_3 = __WMMA_8_16_16_half_half(a_frag_3_k_0, b_frag_3_k_0, acc_frag_3_3);
|
||||
acc_frag_3_4 = __WMMA_8_16_16_half_half(a_frag_3_k_0, b_frag_4_k_0, acc_frag_3_4);
|
||||
acc_frag_3_5 = __WMMA_8_16_16_half_half(a_frag_3_k_0, b_frag_5_k_0, acc_frag_3_5);
|
||||
acc_frag_3_6 = __WMMA_8_16_16_half_half(a_frag_3_k_0, b_frag_6_k_0, acc_frag_3_6);
|
||||
acc_frag_3_7 = __WMMA_8_16_16_half_half(a_frag_3_k_0, b_frag_7_k_0, acc_frag_3_7);
|
||||
|
||||
// load next tile
|
||||
if (block_k < (num_k_blocks-2)) {
|
||||
__pipeline_memcpy_async(&smem_a_store[store_smem_a_off + ( 0)], &data1[global_a_off + ( 0)], 16);
|
||||
__pipeline_memcpy_async(&smem_a_store[store_smem_a_off + ( 64*32)], &data1[global_a_off + ( 64*K)], 16);
|
||||
__pipeline_memcpy_async(&smem_a_store[store_smem_a_off + (128*32)], &data1[global_a_off + (128*K)], 16);
|
||||
__pipeline_memcpy_async(&smem_a_store[store_smem_a_off + (192*32)], &data1[global_a_off + (192*K)], 16);
|
||||
__pipeline_memcpy_async(&smem_b_store[store_smem_b_off + ( 0)], &data2[global_b_off + ( 0)], 16);
|
||||
__pipeline_memcpy_async(&smem_b_store[store_smem_b_off + (16*128)], &data2[global_b_off + ( 16*N)], 16);
|
||||
global_a_off += 32;
|
||||
global_b_off += 32 * N;
|
||||
}
|
||||
__pipeline_commit();
|
||||
|
||||
// wait next tile
|
||||
__pipeline_wait_prior(1);
|
||||
__syncthreads();
|
||||
|
||||
// load K=0 for the next tile
|
||||
__ldmatrix_a_elems(&a_frag_0_k_0, &smem_a_next[load_smem_a_0_k_0]);
|
||||
__ldmatrix_a_elems(&a_frag_1_k_0, &smem_a_next[load_smem_a_1_k_0]);
|
||||
__ldmatrix_a_elems(&a_frag_2_k_0, &smem_a_next[load_smem_a_2_k_0]);
|
||||
__ldmatrix_a_elems(&a_frag_3_k_0, &smem_a_next[load_smem_a_3_k_0]);
|
||||
__ldmatrix_b_elems(&b_frag_0_k_0, &b_frag_1_k_0, &smem_b_next[load_smem_b_0_k_0]);
|
||||
__ldmatrix_b_elems(&b_frag_2_k_0, &b_frag_3_k_0, &smem_b_next[load_smem_b_1_k_0]);
|
||||
__ldmatrix_b_elems(&b_frag_4_k_0, &b_frag_5_k_0, &smem_b_next[load_smem_b_2_k_0]);
|
||||
__ldmatrix_b_elems(&b_frag_6_k_0, &b_frag_7_k_0, &smem_b_next[load_smem_b_3_k_0]);
|
||||
|
||||
// MMA K=1, (M=4 x N=8)
|
||||
acc_frag_0_0 = __WMMA_8_16_16_half_half(a_frag_0_k_1, b_frag_0_k_1, acc_frag_0_0);
|
||||
acc_frag_0_1 = __WMMA_8_16_16_half_half(a_frag_0_k_1, b_frag_1_k_1, acc_frag_0_1);
|
||||
acc_frag_0_2 = __WMMA_8_16_16_half_half(a_frag_0_k_1, b_frag_2_k_1, acc_frag_0_2);
|
||||
acc_frag_0_3 = __WMMA_8_16_16_half_half(a_frag_0_k_1, b_frag_3_k_1, acc_frag_0_3);
|
||||
acc_frag_0_4 = __WMMA_8_16_16_half_half(a_frag_0_k_1, b_frag_4_k_1, acc_frag_0_4);
|
||||
acc_frag_0_5 = __WMMA_8_16_16_half_half(a_frag_0_k_1, b_frag_5_k_1, acc_frag_0_5);
|
||||
acc_frag_0_6 = __WMMA_8_16_16_half_half(a_frag_0_k_1, b_frag_6_k_1, acc_frag_0_6);
|
||||
acc_frag_0_7 = __WMMA_8_16_16_half_half(a_frag_0_k_1, b_frag_7_k_1, acc_frag_0_7);
|
||||
acc_frag_1_0 = __WMMA_8_16_16_half_half(a_frag_1_k_1, b_frag_0_k_1, acc_frag_1_0);
|
||||
acc_frag_1_1 = __WMMA_8_16_16_half_half(a_frag_1_k_1, b_frag_1_k_1, acc_frag_1_1);
|
||||
acc_frag_1_2 = __WMMA_8_16_16_half_half(a_frag_1_k_1, b_frag_2_k_1, acc_frag_1_2);
|
||||
acc_frag_1_3 = __WMMA_8_16_16_half_half(a_frag_1_k_1, b_frag_3_k_1, acc_frag_1_3);
|
||||
acc_frag_1_4 = __WMMA_8_16_16_half_half(a_frag_1_k_1, b_frag_4_k_1, acc_frag_1_4);
|
||||
acc_frag_1_5 = __WMMA_8_16_16_half_half(a_frag_1_k_1, b_frag_5_k_1, acc_frag_1_5);
|
||||
acc_frag_1_6 = __WMMA_8_16_16_half_half(a_frag_1_k_1, b_frag_6_k_1, acc_frag_1_6);
|
||||
acc_frag_1_7 = __WMMA_8_16_16_half_half(a_frag_1_k_1, b_frag_7_k_1, acc_frag_1_7);
|
||||
acc_frag_2_0 = __WMMA_8_16_16_half_half(a_frag_2_k_1, b_frag_0_k_1, acc_frag_2_0);
|
||||
acc_frag_2_1 = __WMMA_8_16_16_half_half(a_frag_2_k_1, b_frag_1_k_1, acc_frag_2_1);
|
||||
acc_frag_2_2 = __WMMA_8_16_16_half_half(a_frag_2_k_1, b_frag_2_k_1, acc_frag_2_2);
|
||||
acc_frag_2_3 = __WMMA_8_16_16_half_half(a_frag_2_k_1, b_frag_3_k_1, acc_frag_2_3);
|
||||
acc_frag_2_4 = __WMMA_8_16_16_half_half(a_frag_2_k_1, b_frag_4_k_1, acc_frag_2_4);
|
||||
acc_frag_2_5 = __WMMA_8_16_16_half_half(a_frag_2_k_1, b_frag_5_k_1, acc_frag_2_5);
|
||||
acc_frag_2_6 = __WMMA_8_16_16_half_half(a_frag_2_k_1, b_frag_6_k_1, acc_frag_2_6);
|
||||
acc_frag_2_7 = __WMMA_8_16_16_half_half(a_frag_2_k_1, b_frag_7_k_1, acc_frag_2_7);
|
||||
acc_frag_3_0 = __WMMA_8_16_16_half_half(a_frag_3_k_1, b_frag_0_k_1, acc_frag_3_0);
|
||||
acc_frag_3_1 = __WMMA_8_16_16_half_half(a_frag_3_k_1, b_frag_1_k_1, acc_frag_3_1);
|
||||
acc_frag_3_2 = __WMMA_8_16_16_half_half(a_frag_3_k_1, b_frag_2_k_1, acc_frag_3_2);
|
||||
acc_frag_3_3 = __WMMA_8_16_16_half_half(a_frag_3_k_1, b_frag_3_k_1, acc_frag_3_3);
|
||||
acc_frag_3_4 = __WMMA_8_16_16_half_half(a_frag_3_k_1, b_frag_4_k_1, acc_frag_3_4);
|
||||
acc_frag_3_5 = __WMMA_8_16_16_half_half(a_frag_3_k_1, b_frag_5_k_1, acc_frag_3_5);
|
||||
acc_frag_3_6 = __WMMA_8_16_16_half_half(a_frag_3_k_1, b_frag_6_k_1, acc_frag_3_6);
|
||||
acc_frag_3_7 = __WMMA_8_16_16_half_half(a_frag_3_k_1, b_frag_7_k_1, acc_frag_3_7);
|
||||
}
|
||||
|
||||
// write accumulators to output
|
||||
__pipeline_wait_prior(0);
|
||||
__syncthreads();
|
||||
|
||||
// slower way: write accs one by one to data0
|
||||
|
||||
size_t wg_c_off = ((grid_m * 256) * N) + (grid_n * 128) + (wg_m * 16 * N) + (wg_n * 16);
|
||||
size_t thread_c_off = ((wg_threads % 4) * 2) + (((wg_threads / 4) % 8) * N);
|
||||
data0[wg_c_off + thread_c_off + 0 + ( 0*8)] = acc_frag_0_0.x;
|
||||
data0[wg_c_off + thread_c_off + 1 + ( 0*8)] = acc_frag_0_0.y;
|
||||
data0[wg_c_off + thread_c_off + (8 * N) + 0 + ( 0*8)] = acc_frag_0_0.z;
|
||||
data0[wg_c_off + thread_c_off + (8 * N) + 1 + ( 0*8)] = acc_frag_0_0.w;
|
||||
data0[wg_c_off + thread_c_off + 0 + ( 1*8)] = acc_frag_0_1.x;
|
||||
data0[wg_c_off + thread_c_off + 1 + ( 1*8)] = acc_frag_0_1.y;
|
||||
data0[wg_c_off + thread_c_off + (8 * N) + 0 + ( 1*8)] = acc_frag_0_1.z;
|
||||
data0[wg_c_off + thread_c_off + (8 * N) + 1 + ( 1*8)] = acc_frag_0_1.w;
|
||||
data0[wg_c_off + thread_c_off + 0 + ( 4*8)] = acc_frag_0_2.x;
|
||||
data0[wg_c_off + thread_c_off + 1 + ( 4*8)] = acc_frag_0_2.y;
|
||||
data0[wg_c_off + thread_c_off + (8 * N) + 0 + ( 4*8)] = acc_frag_0_2.z;
|
||||
data0[wg_c_off + thread_c_off + (8 * N) + 1 + ( 4*8)] = acc_frag_0_2.w;
|
||||
data0[wg_c_off + thread_c_off + 0 + ( 5*8)] = acc_frag_0_3.x;
|
||||
data0[wg_c_off + thread_c_off + 1 + ( 5*8)] = acc_frag_0_3.y;
|
||||
data0[wg_c_off + thread_c_off + (8 * N) + 0 + ( 5*8)] = acc_frag_0_3.z;
|
||||
data0[wg_c_off + thread_c_off + (8 * N) + 1 + ( 5*8)] = acc_frag_0_3.w;
|
||||
data0[wg_c_off + thread_c_off + 0 + ( 8*8)] = acc_frag_0_4.x;
|
||||
data0[wg_c_off + thread_c_off + 1 + ( 8*8)] = acc_frag_0_4.y;
|
||||
data0[wg_c_off + thread_c_off + (8 * N) + 0 + ( 8*8)] = acc_frag_0_4.z;
|
||||
data0[wg_c_off + thread_c_off + (8 * N) + 1 + ( 8*8)] = acc_frag_0_4.w;
|
||||
data0[wg_c_off + thread_c_off + 0 + ( 9*8)] = acc_frag_0_5.x;
|
||||
data0[wg_c_off + thread_c_off + 1 + ( 9*8)] = acc_frag_0_5.y;
|
||||
data0[wg_c_off + thread_c_off + (8 * N) + 0 + ( 9*8)] = acc_frag_0_5.z;
|
||||
data0[wg_c_off + thread_c_off + (8 * N) + 1 + ( 9*8)] = acc_frag_0_5.w;
|
||||
data0[wg_c_off + thread_c_off + 0 + (12*8)] = acc_frag_0_6.x;
|
||||
data0[wg_c_off + thread_c_off + 1 + (12*8)] = acc_frag_0_6.y;
|
||||
data0[wg_c_off + thread_c_off + (8 * N) + 0 + (12*8)] = acc_frag_0_6.z;
|
||||
data0[wg_c_off + thread_c_off + (8 * N) + 1 + (12*8)] = acc_frag_0_6.w;
|
||||
data0[wg_c_off + thread_c_off + 0 + (13*8)] = acc_frag_0_7.x;
|
||||
data0[wg_c_off + thread_c_off + 1 + (13*8)] = acc_frag_0_7.y;
|
||||
data0[wg_c_off + thread_c_off + (8 * N) + 0 + (13*8)] = acc_frag_0_7.z;
|
||||
data0[wg_c_off + thread_c_off + (8 * N) + 1 + (13*8)] = acc_frag_0_7.w;
|
||||
|
||||
wg_c_off += 64*N;
|
||||
data0[wg_c_off + thread_c_off + 0 + ( 0*8)] = acc_frag_1_0.x;
|
||||
data0[wg_c_off + thread_c_off + 1 + ( 0*8)] = acc_frag_1_0.y;
|
||||
data0[wg_c_off + thread_c_off + (8 * N) + 0 + ( 0*8)] = acc_frag_1_0.z;
|
||||
data0[wg_c_off + thread_c_off + (8 * N) + 1 + ( 0*8)] = acc_frag_1_0.w;
|
||||
data0[wg_c_off + thread_c_off + 0 + ( 1*8)] = acc_frag_1_1.x;
|
||||
data0[wg_c_off + thread_c_off + 1 + ( 1*8)] = acc_frag_1_1.y;
|
||||
data0[wg_c_off + thread_c_off + (8 * N) + 0 + ( 1*8)] = acc_frag_1_1.z;
|
||||
data0[wg_c_off + thread_c_off + (8 * N) + 1 + ( 1*8)] = acc_frag_1_1.w;
|
||||
data0[wg_c_off + thread_c_off + 0 + ( 4*8)] = acc_frag_1_2.x;
|
||||
data0[wg_c_off + thread_c_off + 1 + ( 4*8)] = acc_frag_1_2.y;
|
||||
data0[wg_c_off + thread_c_off + (8 * N) + 0 + ( 4*8)] = acc_frag_1_2.z;
|
||||
data0[wg_c_off + thread_c_off + (8 * N) + 1 + ( 4*8)] = acc_frag_1_2.w;
|
||||
data0[wg_c_off + thread_c_off + 0 + ( 5*8)] = acc_frag_1_3.x;
|
||||
data0[wg_c_off + thread_c_off + 1 + ( 5*8)] = acc_frag_1_3.y;
|
||||
data0[wg_c_off + thread_c_off + (8 * N) + 0 + ( 5*8)] = acc_frag_1_3.z;
|
||||
data0[wg_c_off + thread_c_off + (8 * N) + 1 + ( 5*8)] = acc_frag_1_3.w;
|
||||
data0[wg_c_off + thread_c_off + 0 + ( 8*8)] = acc_frag_1_4.x;
|
||||
data0[wg_c_off + thread_c_off + 1 + ( 8*8)] = acc_frag_1_4.y;
|
||||
data0[wg_c_off + thread_c_off + (8 * N) + 0 + ( 8*8)] = acc_frag_1_4.z;
|
||||
data0[wg_c_off + thread_c_off + (8 * N) + 1 + ( 8*8)] = acc_frag_1_4.w;
|
||||
data0[wg_c_off + thread_c_off + 0 + ( 9*8)] = acc_frag_1_5.x;
|
||||
data0[wg_c_off + thread_c_off + 1 + ( 9*8)] = acc_frag_1_5.y;
|
||||
data0[wg_c_off + thread_c_off + (8 * N) + 0 + ( 9*8)] = acc_frag_1_5.z;
|
||||
data0[wg_c_off + thread_c_off + (8 * N) + 1 + ( 9*8)] = acc_frag_1_5.w;
|
||||
data0[wg_c_off + thread_c_off + 0 + (12*8)] = acc_frag_1_6.x;
|
||||
data0[wg_c_off + thread_c_off + 1 + (12*8)] = acc_frag_1_6.y;
|
||||
data0[wg_c_off + thread_c_off + (8 * N) + 0 + (12*8)] = acc_frag_1_6.z;
|
||||
data0[wg_c_off + thread_c_off + (8 * N) + 1 + (12*8)] = acc_frag_1_6.w;
|
||||
data0[wg_c_off + thread_c_off + 0 + (13*8)] = acc_frag_1_7.x;
|
||||
data0[wg_c_off + thread_c_off + 1 + (13*8)] = acc_frag_1_7.y;
|
||||
data0[wg_c_off + thread_c_off + (8 * N) + 0 + (13*8)] = acc_frag_1_7.z;
|
||||
data0[wg_c_off + thread_c_off + (8 * N) + 1 + (13*8)] = acc_frag_1_7.w;
|
||||
|
||||
wg_c_off += 64*N;
|
||||
data0[wg_c_off + thread_c_off + 0 + ( 0*8)] = acc_frag_2_0.x;
|
||||
data0[wg_c_off + thread_c_off + 1 + ( 0*8)] = acc_frag_2_0.y;
|
||||
data0[wg_c_off + thread_c_off + (8 * N) + 0 + ( 0*8)] = acc_frag_2_0.z;
|
||||
data0[wg_c_off + thread_c_off + (8 * N) + 1 + ( 0*8)] = acc_frag_2_0.w;
|
||||
data0[wg_c_off + thread_c_off + 0 + ( 1*8)] = acc_frag_2_1.x;
|
||||
data0[wg_c_off + thread_c_off + 1 + ( 1*8)] = acc_frag_2_1.y;
|
||||
data0[wg_c_off + thread_c_off + (8 * N) + 0 + ( 1*8)] = acc_frag_2_1.z;
|
||||
data0[wg_c_off + thread_c_off + (8 * N) + 1 + ( 1*8)] = acc_frag_2_1.w;
|
||||
data0[wg_c_off + thread_c_off + 0 + ( 4*8)] = acc_frag_2_2.x;
|
||||
data0[wg_c_off + thread_c_off + 1 + ( 4*8)] = acc_frag_2_2.y;
|
||||
data0[wg_c_off + thread_c_off + (8 * N) + 0 + ( 4*8)] = acc_frag_2_2.z;
|
||||
data0[wg_c_off + thread_c_off + (8 * N) + 1 + ( 4*8)] = acc_frag_2_2.w;
|
||||
data0[wg_c_off + thread_c_off + 0 + ( 5*8)] = acc_frag_2_3.x;
|
||||
data0[wg_c_off + thread_c_off + 1 + ( 5*8)] = acc_frag_2_3.y;
|
||||
data0[wg_c_off + thread_c_off + (8 * N) + 0 + ( 5*8)] = acc_frag_2_3.z;
|
||||
data0[wg_c_off + thread_c_off + (8 * N) + 1 + ( 5*8)] = acc_frag_2_3.w;
|
||||
data0[wg_c_off + thread_c_off + 0 + ( 8*8)] = acc_frag_2_4.x;
|
||||
data0[wg_c_off + thread_c_off + 1 + ( 8*8)] = acc_frag_2_4.y;
|
||||
data0[wg_c_off + thread_c_off + (8 * N) + 0 + ( 8*8)] = acc_frag_2_4.z;
|
||||
data0[wg_c_off + thread_c_off + (8 * N) + 1 + ( 8*8)] = acc_frag_2_4.w;
|
||||
data0[wg_c_off + thread_c_off + 0 + ( 9*8)] = acc_frag_2_5.x;
|
||||
data0[wg_c_off + thread_c_off + 1 + ( 9*8)] = acc_frag_2_5.y;
|
||||
data0[wg_c_off + thread_c_off + (8 * N) + 0 + ( 9*8)] = acc_frag_2_5.z;
|
||||
data0[wg_c_off + thread_c_off + (8 * N) + 1 + ( 9*8)] = acc_frag_2_5.w;
|
||||
data0[wg_c_off + thread_c_off + 0 + (12*8)] = acc_frag_2_6.x;
|
||||
data0[wg_c_off + thread_c_off + 1 + (12*8)] = acc_frag_2_6.y;
|
||||
data0[wg_c_off + thread_c_off + (8 * N) + 0 + (12*8)] = acc_frag_2_6.z;
|
||||
data0[wg_c_off + thread_c_off + (8 * N) + 1 + (12*8)] = acc_frag_2_6.w;
|
||||
data0[wg_c_off + thread_c_off + 0 + (13*8)] = acc_frag_2_7.x;
|
||||
data0[wg_c_off + thread_c_off + 1 + (13*8)] = acc_frag_2_7.y;
|
||||
data0[wg_c_off + thread_c_off + (8 * N) + 0 + (13*8)] = acc_frag_2_7.z;
|
||||
data0[wg_c_off + thread_c_off + (8 * N) + 1 + (13*8)] = acc_frag_2_7.w;
|
||||
|
||||
wg_c_off += 64*N;
|
||||
data0[wg_c_off + thread_c_off + 0 + ( 0*8)] = acc_frag_3_0.x;
|
||||
data0[wg_c_off + thread_c_off + 1 + ( 0*8)] = acc_frag_3_0.y;
|
||||
data0[wg_c_off + thread_c_off + (8 * N) + 0 + ( 0*8)] = acc_frag_3_0.z;
|
||||
data0[wg_c_off + thread_c_off + (8 * N) + 1 + ( 0*8)] = acc_frag_3_0.w;
|
||||
data0[wg_c_off + thread_c_off + 0 + ( 1*8)] = acc_frag_3_1.x;
|
||||
data0[wg_c_off + thread_c_off + 1 + ( 1*8)] = acc_frag_3_1.y;
|
||||
data0[wg_c_off + thread_c_off + (8 * N) + 0 + ( 1*8)] = acc_frag_3_1.z;
|
||||
data0[wg_c_off + thread_c_off + (8 * N) + 1 + ( 1*8)] = acc_frag_3_1.w;
|
||||
data0[wg_c_off + thread_c_off + 0 + ( 4*8)] = acc_frag_3_2.x;
|
||||
data0[wg_c_off + thread_c_off + 1 + ( 4*8)] = acc_frag_3_2.y;
|
||||
data0[wg_c_off + thread_c_off + (8 * N) + 0 + ( 4*8)] = acc_frag_3_2.z;
|
||||
data0[wg_c_off + thread_c_off + (8 * N) + 1 + ( 4*8)] = acc_frag_3_2.w;
|
||||
data0[wg_c_off + thread_c_off + 0 + ( 5*8)] = acc_frag_3_3.x;
|
||||
data0[wg_c_off + thread_c_off + 1 + ( 5*8)] = acc_frag_3_3.y;
|
||||
data0[wg_c_off + thread_c_off + (8 * N) + 0 + ( 5*8)] = acc_frag_3_3.z;
|
||||
data0[wg_c_off + thread_c_off + (8 * N) + 1 + ( 5*8)] = acc_frag_3_3.w;
|
||||
data0[wg_c_off + thread_c_off + 0 + ( 8*8)] = acc_frag_3_4.x;
|
||||
data0[wg_c_off + thread_c_off + 1 + ( 8*8)] = acc_frag_3_4.y;
|
||||
data0[wg_c_off + thread_c_off + (8 * N) + 0 + ( 8*8)] = acc_frag_3_4.z;
|
||||
data0[wg_c_off + thread_c_off + (8 * N) + 1 + ( 8*8)] = acc_frag_3_4.w;
|
||||
data0[wg_c_off + thread_c_off + 0 + ( 9*8)] = acc_frag_3_5.x;
|
||||
data0[wg_c_off + thread_c_off + 1 + ( 9*8)] = acc_frag_3_5.y;
|
||||
data0[wg_c_off + thread_c_off + (8 * N) + 0 + ( 9*8)] = acc_frag_3_5.z;
|
||||
data0[wg_c_off + thread_c_off + (8 * N) + 1 + ( 9*8)] = acc_frag_3_5.w;
|
||||
data0[wg_c_off + thread_c_off + 0 + (12*8)] = acc_frag_3_6.x;
|
||||
data0[wg_c_off + thread_c_off + 1 + (12*8)] = acc_frag_3_6.y;
|
||||
data0[wg_c_off + thread_c_off + (8 * N) + 0 + (12*8)] = acc_frag_3_6.z;
|
||||
data0[wg_c_off + thread_c_off + (8 * N) + 1 + (12*8)] = acc_frag_3_6.w;
|
||||
data0[wg_c_off + thread_c_off + 0 + (13*8)] = acc_frag_3_7.x;
|
||||
data0[wg_c_off + thread_c_off + 1 + (13*8)] = acc_frag_3_7.y;
|
||||
data0[wg_c_off + thread_c_off + (8 * N) + 0 + (13*8)] = acc_frag_3_7.z;
|
||||
data0[wg_c_off + thread_c_off + (8 * N) + 1 + (13*8)] = acc_frag_3_7.w;
|
||||
|
||||
}
|
||||
517
extra/gemm/max_kernels/nv.fp16_fp16_fp16.3_stage_swizzled.cu
Normal file
517
extra/gemm/max_kernels/nv.fp16_fp16_fp16.3_stage_swizzled.cu
Normal file
@@ -0,0 +1,517 @@
|
||||
#define INFINITY (__int_as_float(0x7f800000))
|
||||
#define NAN (__int_as_float(0x7fffffff))
|
||||
#include <cuda_fp16.h>
|
||||
#include <cuda_pipeline.h>
|
||||
#define N_PAD 132
|
||||
|
||||
struct __align__(8) half4 { half x, y, z, w; };
|
||||
__device__ half4 make_half4(half x, half y, half z, half w) { half4 r={x, y, z, w}; return r; }
|
||||
|
||||
struct __align__(16) half8 { half x, y, z, w, a, b, c, d; };
|
||||
__device__ half8 make_half8(half x, half y, half z, half w, half a, half b, half c, half d) { half8 r={x, y, z, w, a, b, c, d}; return r; }
|
||||
|
||||
__device__ void __ldmatrix_a_elems(half8 *regs, half *smem) {
|
||||
uint32_t reg0, reg1, reg2, reg3;
|
||||
asm volatile(
|
||||
"ldmatrix.sync.aligned.m8n8.x4.shared.b16 {%0, %1, %2, %3}, [%4];"
|
||||
: "=r"(reg0), "=r"(reg1), "=r"(reg2), "=r"(reg3)
|
||||
: "l"(__cvta_generic_to_shared(smem))
|
||||
);
|
||||
uint32_t *addr = reinterpret_cast<uint32_t*>(regs);
|
||||
addr[0] = reg0;
|
||||
addr[1] = reg1;
|
||||
addr[2] = reg2;
|
||||
addr[3] = reg3;
|
||||
}
|
||||
|
||||
__device__ void __ldmatrix_b_elems(half4 *regs_lo, half4 *regs_hi, half *smem) {
|
||||
uint32_t reg0, reg1, reg2, reg3;
|
||||
asm volatile(
|
||||
"ldmatrix.sync.aligned.m8n8.x4.trans.shared.b16 {%0, %1, %2, %3}, [%4];"
|
||||
: "=r"(reg0), "=r"(reg1), "=r"(reg2), "=r"(reg3)
|
||||
: "l"(__cvta_generic_to_shared(smem))
|
||||
);
|
||||
uint32_t *addr_lo = reinterpret_cast<uint32_t*>(regs_lo);
|
||||
uint32_t *addr_hi = reinterpret_cast<uint32_t*>(regs_hi);
|
||||
addr_lo[0] = reg0;
|
||||
addr_lo[1] = reg1;
|
||||
addr_hi[0] = reg2;
|
||||
addr_hi[1] = reg3;
|
||||
}
|
||||
|
||||
__device__ half4 __WMMA_8_16_16_half_half(half8 a, half4 b, half4 c) {
|
||||
int *a_pk = (int *) (&a), *b_pk = (int *) (&b), *c_pk = (int *) (&c);
|
||||
asm( "mma.sync.aligned.m16n8k16.row.col.f16.f16.f16.f16 { %0, %1 }, { %2, %3, %4, %5 }, { %6, %7 }, { %0, %1 };"
|
||||
: "+r"(c_pk[0]), "+r"(c_pk[1]): "r"(a_pk[0]), "r"(a_pk[1]), "r"(a_pk[2]), "r"(a_pk[3]), "r"(b_pk[0]), "r"(b_pk[1]) );
|
||||
return c;
|
||||
}
|
||||
|
||||
extern "C" __global__ void __launch_bounds__(256) wmma_example(half* data0, const half* data1, const half* data2, int N, int K) {
|
||||
extern __shared__ char smem[];
|
||||
half *smem_a_0 = (half *)(smem);
|
||||
half *smem_a_1 = (half *)(smem + 16384);
|
||||
half *smem_a_2 = (half *)(smem + 32768);
|
||||
half *smem_b_0 = (half *)(smem + 49152);
|
||||
half *smem_b_1 = (half *)(smem + 57344);
|
||||
half *smem_b_2 = (half *)(smem + 65536);
|
||||
|
||||
int grid_m = blockIdx.x; /* M//256 */
|
||||
int grid_n = blockIdx.y; /* N//128 */
|
||||
int wg_threads = threadIdx.x; // 32
|
||||
int wg_m = threadIdx.y; // 4
|
||||
int wg_n = threadIdx.z; // 2
|
||||
int threads = threadIdx.x + (threadIdx.y * 32) + (threadIdx.z * 128); /* 256 */
|
||||
int num_k_blocks = K / 32;
|
||||
|
||||
// ldmatrix indices
|
||||
// threads 0-7 are row starts for A, 8-15 for B, 16-23 for C, 24-31 for D
|
||||
// [ A | C ]
|
||||
// [ - + - ]
|
||||
// [ B | D ]
|
||||
|
||||
// unswizzled A - SMEM_A is 256 rows x 32 cols
|
||||
// size_t global_a_off = ((grid_m * 256) * K) + ((threads % 4) * 8) + ((threads / 4) * K);
|
||||
// size_t store_smem_a_off = ((threads % 4) * 8) + ((threads / 4) * 32); // 64 rows / 32 cols per copy
|
||||
// size_t load_smem_a_0_k_0 = (wg_m * 16 * 32) + ((wg_threads % 16) * 32) + ((wg_threads / 16) * 8);
|
||||
// size_t load_smem_a_1_k_0 = load_smem_a_0_k_0 + ( 64 * 32);
|
||||
// size_t load_smem_a_2_k_0 = load_smem_a_0_k_0 + (128 * 32);
|
||||
// size_t load_smem_a_3_k_0 = load_smem_a_0_k_0 + (192 * 32);
|
||||
// size_t load_smem_a_0_k_1 = load_smem_a_0_k_0 + 16;
|
||||
// size_t load_smem_a_1_k_1 = load_smem_a_0_k_1 + ( 64 * 32);
|
||||
// size_t load_smem_a_2_k_1 = load_smem_a_0_k_1 + (128 * 32);
|
||||
// size_t load_smem_a_3_k_1 = load_smem_a_0_k_1 + (192 * 32);
|
||||
|
||||
// unswizzled reshaped A - SMEM_A is 128 rows x 64 cols, [ (M=0, K=0), (M=0, K=1), (M=8, K=0), (M=8, K=1) ], etc.
|
||||
// size_t global_a_off = ((grid_m * 256) * K) + ((threads % 4) * 8) + (((threads / 4) % 2) * 8 * 16 * K) + ((threads / 8) * K);
|
||||
// size_t store_smem_a_off = ((threads % 8) * 8) + ((threads / 8) * 64); // 32 rows / 64 cols per copy
|
||||
// size_t load_smem_a_0_k_0 = (wg_m * 16 * 64) + ((wg_threads % 16) * 64) + ((wg_threads / 16) * 8);
|
||||
// size_t load_smem_a_1_k_0 = load_smem_a_0_k_0 + (64 * 64);
|
||||
// size_t load_smem_a_2_k_0 = load_smem_a_0_k_0 + + 32;
|
||||
// size_t load_smem_a_3_k_0 = load_smem_a_0_k_0 + (64 * 64) + 32;
|
||||
// size_t load_smem_a_0_k_1 = load_smem_a_0_k_0 + 16;
|
||||
// size_t load_smem_a_1_k_1 = load_smem_a_1_k_0 + 16;
|
||||
// size_t load_smem_a_2_k_1 = load_smem_a_2_k_0 + 16;
|
||||
// size_t load_smem_a_3_k_1 = load_smem_a_3_k_0 + 16;
|
||||
|
||||
// swizzled A
|
||||
size_t global_a_off = ((grid_m * 256) * K) + ((threads % 4) * 8) + (((threads / 4) % 2) * 8 * 16 * K) + ((threads / 8) * K);
|
||||
size_t store_smem_a_off = ((threads / 8) * 64) + (((threads * 8) ^ threads) & 56); // 32 rows / 64 cols per copy
|
||||
size_t load_smem_a_row = ((wg_m * 16) + (threads % 16)) * 64;
|
||||
size_t load_smem_a_phase = (threads / 16) % 2;
|
||||
size_t load_smem_a_0_k_0 = load_smem_a_row + ( 0 * 64) + (((load_smem_a_phase + 0) ^ (threads % 8)) * 8);
|
||||
size_t load_smem_a_1_k_0 = load_smem_a_row + (64 * 64) + (((load_smem_a_phase + 0) ^ (threads % 8)) * 8);
|
||||
size_t load_smem_a_2_k_0 = load_smem_a_row + ( 0 * 64) + (((load_smem_a_phase + 4) ^ (threads % 8)) * 8);
|
||||
size_t load_smem_a_3_k_0 = load_smem_a_row + (64 * 64) + (((load_smem_a_phase + 4) ^ (threads % 8)) * 8);
|
||||
size_t load_smem_a_0_k_1 = load_smem_a_row + ( 0 * 64) + (((load_smem_a_phase + 2) ^ (threads % 8)) * 8);
|
||||
size_t load_smem_a_1_k_1 = load_smem_a_row + (64 * 64) + (((load_smem_a_phase + 2) ^ (threads % 8)) * 8);
|
||||
size_t load_smem_a_2_k_1 = load_smem_a_row + ( 0 * 64) + (((load_smem_a_phase + 6) ^ (threads % 8)) * 8);
|
||||
size_t load_smem_a_3_k_1 = load_smem_a_row + (64 * 64) + (((load_smem_a_phase + 6) ^ (threads % 8)) * 8);
|
||||
|
||||
// unswizzed B
|
||||
// size_t global_b_off = (grid_n * 128) + ((threads % 16) * 8) + ((threads / 16) * N);
|
||||
// size_t store_smem_b_off = ((threads % 16) * 8) + ((threads / 16) * 128); // 16 rows / 128 cols per copy
|
||||
// size_t load_smem_b_0_k_0 = (wg_n * 16) + ((wg_threads % 16) * 128) + ((wg_threads / 16) * 8);
|
||||
// size_t load_smem_b_1_k_0 = load_smem_b_0_k_0 + 32;
|
||||
// size_t load_smem_b_2_k_0 = load_smem_b_0_k_0 + 64;
|
||||
// size_t load_smem_b_3_k_0 = load_smem_b_0_k_0 + 96;
|
||||
// size_t load_smem_b_0_k_1 = load_smem_b_0_k_0 + (16 * 128);
|
||||
// size_t load_smem_b_1_k_1 = load_smem_b_0_k_1 + 32;
|
||||
// size_t load_smem_b_2_k_1 = load_smem_b_0_k_1 + 64;
|
||||
// size_t load_smem_b_3_k_1 = load_smem_b_0_k_1 + 96;
|
||||
|
||||
// swizzled B
|
||||
size_t global_b_off = (grid_n * 128) + ((threads % 16) * 8) + ((threads / 16) * N);
|
||||
size_t store_smem_b_off = ((threads / 16) * 128) + ((((threads / 16) % 8) * 8) ^ ((threads % 16) * 8)); // 16 rows / 128 cols per copy
|
||||
size_t load_smem_b_row = (threads % 16) * 128;
|
||||
size_t load_smem_b_phase = (wg_n * 2) + (wg_threads / 16);
|
||||
size_t load_smem_b_0_k_0 = load_smem_b_row + (((load_smem_b_phase + 0) ^ (threads % 8)) * 8);
|
||||
size_t load_smem_b_1_k_0 = load_smem_b_row + (((load_smem_b_phase + 4) ^ (threads % 8)) * 8);
|
||||
size_t load_smem_b_2_k_0 = load_smem_b_row + (((load_smem_b_phase + 8) ^ (threads % 8)) * 8);
|
||||
size_t load_smem_b_3_k_0 = load_smem_b_row + (((load_smem_b_phase + 12) ^ (threads % 8)) * 8);
|
||||
size_t load_smem_b_0_k_1 = load_smem_b_0_k_0 + (16 * 128);
|
||||
size_t load_smem_b_1_k_1 = load_smem_b_1_k_0 + (16 * 128);
|
||||
size_t load_smem_b_2_k_1 = load_smem_b_2_k_0 + (16 * 128);
|
||||
size_t load_smem_b_3_k_1 = load_smem_b_3_k_0 + (16 * 128);
|
||||
|
||||
// create accs (M=4, N=8)
|
||||
half4 acc_frag_0_0 = make_half4(0.0f,0.0f,0.0f,0.0f);
|
||||
half4 acc_frag_0_1 = make_half4(0.0f,0.0f,0.0f,0.0f);
|
||||
half4 acc_frag_0_2 = make_half4(0.0f,0.0f,0.0f,0.0f);
|
||||
half4 acc_frag_0_3 = make_half4(0.0f,0.0f,0.0f,0.0f);
|
||||
half4 acc_frag_0_4 = make_half4(0.0f,0.0f,0.0f,0.0f);
|
||||
half4 acc_frag_0_5 = make_half4(0.0f,0.0f,0.0f,0.0f);
|
||||
half4 acc_frag_0_6 = make_half4(0.0f,0.0f,0.0f,0.0f);
|
||||
half4 acc_frag_0_7 = make_half4(0.0f,0.0f,0.0f,0.0f);
|
||||
half4 acc_frag_1_0 = make_half4(0.0f,0.0f,0.0f,0.0f);
|
||||
half4 acc_frag_1_1 = make_half4(0.0f,0.0f,0.0f,0.0f);
|
||||
half4 acc_frag_1_2 = make_half4(0.0f,0.0f,0.0f,0.0f);
|
||||
half4 acc_frag_1_3 = make_half4(0.0f,0.0f,0.0f,0.0f);
|
||||
half4 acc_frag_1_4 = make_half4(0.0f,0.0f,0.0f,0.0f);
|
||||
half4 acc_frag_1_5 = make_half4(0.0f,0.0f,0.0f,0.0f);
|
||||
half4 acc_frag_1_6 = make_half4(0.0f,0.0f,0.0f,0.0f);
|
||||
half4 acc_frag_1_7 = make_half4(0.0f,0.0f,0.0f,0.0f);
|
||||
half4 acc_frag_2_0 = make_half4(0.0f,0.0f,0.0f,0.0f);
|
||||
half4 acc_frag_2_1 = make_half4(0.0f,0.0f,0.0f,0.0f);
|
||||
half4 acc_frag_2_2 = make_half4(0.0f,0.0f,0.0f,0.0f);
|
||||
half4 acc_frag_2_3 = make_half4(0.0f,0.0f,0.0f,0.0f);
|
||||
half4 acc_frag_2_4 = make_half4(0.0f,0.0f,0.0f,0.0f);
|
||||
half4 acc_frag_2_5 = make_half4(0.0f,0.0f,0.0f,0.0f);
|
||||
half4 acc_frag_2_6 = make_half4(0.0f,0.0f,0.0f,0.0f);
|
||||
half4 acc_frag_2_7 = make_half4(0.0f,0.0f,0.0f,0.0f);
|
||||
half4 acc_frag_3_0 = make_half4(0.0f,0.0f,0.0f,0.0f);
|
||||
half4 acc_frag_3_1 = make_half4(0.0f,0.0f,0.0f,0.0f);
|
||||
half4 acc_frag_3_2 = make_half4(0.0f,0.0f,0.0f,0.0f);
|
||||
half4 acc_frag_3_3 = make_half4(0.0f,0.0f,0.0f,0.0f);
|
||||
half4 acc_frag_3_4 = make_half4(0.0f,0.0f,0.0f,0.0f);
|
||||
half4 acc_frag_3_5 = make_half4(0.0f,0.0f,0.0f,0.0f);
|
||||
half4 acc_frag_3_6 = make_half4(0.0f,0.0f,0.0f,0.0f);
|
||||
half4 acc_frag_3_7 = make_half4(0.0f,0.0f,0.0f,0.0f);
|
||||
|
||||
// create registers for block A elements
|
||||
half8 a_frag_0_k_0;
|
||||
half8 a_frag_1_k_0;
|
||||
half8 a_frag_2_k_0;
|
||||
half8 a_frag_3_k_0;
|
||||
half8 a_frag_0_k_1;
|
||||
half8 a_frag_1_k_1;
|
||||
half8 a_frag_2_k_1;
|
||||
half8 a_frag_3_k_1;
|
||||
|
||||
// create register for block B elements
|
||||
half4 b_frag_0_k_0;
|
||||
half4 b_frag_1_k_0;
|
||||
half4 b_frag_2_k_0;
|
||||
half4 b_frag_3_k_0;
|
||||
half4 b_frag_4_k_0;
|
||||
half4 b_frag_5_k_0;
|
||||
half4 b_frag_6_k_0;
|
||||
half4 b_frag_7_k_0;
|
||||
half4 b_frag_0_k_1;
|
||||
half4 b_frag_1_k_1;
|
||||
half4 b_frag_2_k_1;
|
||||
half4 b_frag_3_k_1;
|
||||
half4 b_frag_4_k_1;
|
||||
half4 b_frag_5_k_1;
|
||||
half4 b_frag_6_k_1;
|
||||
half4 b_frag_7_k_1;
|
||||
|
||||
__syncthreads();
|
||||
|
||||
// load first tile
|
||||
// unswizzled 256 x 32
|
||||
// __pipeline_memcpy_async(&smem_a_0[store_smem_a_off + ( 0)], &data1[global_a_off + ( 0)], 16);
|
||||
// __pipeline_memcpy_async(&smem_a_0[store_smem_a_off + ( 64*32)], &data1[global_a_off + ( 64*K)], 16);
|
||||
// __pipeline_memcpy_async(&smem_a_0[store_smem_a_off + (128*32)], &data1[global_a_off + (128*K)], 16);
|
||||
// __pipeline_memcpy_async(&smem_a_0[store_smem_a_off + (192*32)], &data1[global_a_off + (192*K)], 16);
|
||||
// unswizzled 128 x 64
|
||||
__pipeline_memcpy_async(&smem_a_0[store_smem_a_off + ( 0)], &data1[global_a_off + ( 0)], 16);
|
||||
__pipeline_memcpy_async(&smem_a_0[store_smem_a_off + ( 32*64)], &data1[global_a_off + ( 32*K)], 16);
|
||||
__pipeline_memcpy_async(&smem_a_0[store_smem_a_off + ( 64*64)], &data1[global_a_off + ( 64*K)], 16);
|
||||
__pipeline_memcpy_async(&smem_a_0[store_smem_a_off + ( 96*64)], &data1[global_a_off + ( 96*K)], 16);
|
||||
__pipeline_memcpy_async(&smem_b_0[store_smem_b_off + ( 0)], &data2[global_b_off + ( 0)], 16);
|
||||
__pipeline_memcpy_async(&smem_b_0[store_smem_b_off + (16*128)], &data2[global_b_off + ( 16*N)], 16);
|
||||
__pipeline_commit();
|
||||
global_a_off += 32;
|
||||
global_b_off += 32 * N;
|
||||
|
||||
// load second tile
|
||||
// unswizzled 256 x 32
|
||||
// __pipeline_memcpy_async(&smem_a_1[store_smem_a_off + ( 0)], &data1[global_a_off + ( 0)], 16);
|
||||
// __pipeline_memcpy_async(&smem_a_1[store_smem_a_off + ( 64*32)], &data1[global_a_off + ( 64*K)], 16);
|
||||
// __pipeline_memcpy_async(&smem_a_1[store_smem_a_off + (128*32)], &data1[global_a_off + (128*K)], 16);
|
||||
// __pipeline_memcpy_async(&smem_a_1[store_smem_a_off + (192*32)], &data1[global_a_off + (192*K)], 16);
|
||||
// unswizzled 128 x 64
|
||||
__pipeline_memcpy_async(&smem_a_1[store_smem_a_off + ( 0)], &data1[global_a_off + ( 0)], 16);
|
||||
__pipeline_memcpy_async(&smem_a_1[store_smem_a_off + ( 32*64)], &data1[global_a_off + ( 32*K)], 16);
|
||||
__pipeline_memcpy_async(&smem_a_1[store_smem_a_off + ( 64*64)], &data1[global_a_off + ( 64*K)], 16);
|
||||
__pipeline_memcpy_async(&smem_a_1[store_smem_a_off + ( 96*64)], &data1[global_a_off + ( 96*K)], 16);
|
||||
__pipeline_memcpy_async(&smem_b_1[store_smem_b_off + ( 0)], &data2[global_b_off + ( 0)], 16);
|
||||
__pipeline_memcpy_async(&smem_b_1[store_smem_b_off + (16*128)], &data2[global_b_off + ( 16*N)], 16);
|
||||
__pipeline_commit();
|
||||
|
||||
global_a_off += 32;
|
||||
global_b_off += 32 * N;
|
||||
|
||||
// wait on first pre-fetch load
|
||||
__pipeline_wait_prior(1);
|
||||
__syncthreads();
|
||||
|
||||
// load K=0 for the first tile
|
||||
__ldmatrix_a_elems(&a_frag_0_k_0, &smem_a_0[load_smem_a_0_k_0]);
|
||||
__ldmatrix_a_elems(&a_frag_1_k_0, &smem_a_0[load_smem_a_1_k_0]);
|
||||
__ldmatrix_a_elems(&a_frag_2_k_0, &smem_a_0[load_smem_a_2_k_0]);
|
||||
__ldmatrix_a_elems(&a_frag_3_k_0, &smem_a_0[load_smem_a_3_k_0]);
|
||||
__ldmatrix_b_elems(&b_frag_0_k_0, &b_frag_1_k_0, &smem_b_0[load_smem_b_0_k_0]);
|
||||
__ldmatrix_b_elems(&b_frag_2_k_0, &b_frag_3_k_0, &smem_b_0[load_smem_b_1_k_0]);
|
||||
__ldmatrix_b_elems(&b_frag_4_k_0, &b_frag_5_k_0, &smem_b_0[load_smem_b_2_k_0]);
|
||||
__ldmatrix_b_elems(&b_frag_6_k_0, &b_frag_7_k_0, &smem_b_0[load_smem_b_3_k_0]);
|
||||
|
||||
for (int block_k = 0; block_k < num_k_blocks; block_k++) {
|
||||
int phase_k = block_k % 3;
|
||||
half *smem_a_curr = (phase_k == 0) ? smem_a_0 : ((phase_k == 1) ? smem_a_1 : smem_a_2);
|
||||
half *smem_b_curr = (phase_k == 0) ? smem_b_0 : ((phase_k == 1) ? smem_b_1 : smem_b_2);
|
||||
|
||||
int next_phase_k = (block_k+1) % 3;
|
||||
half *smem_a_next = (next_phase_k == 0) ? smem_a_0 : ((next_phase_k == 1) ? smem_a_1 : smem_a_2);
|
||||
half *smem_b_next = (next_phase_k == 0) ? smem_b_0 : ((next_phase_k == 1) ? smem_b_1 : smem_b_2);
|
||||
|
||||
int store_phase_k = (block_k+2) % 3;
|
||||
half *smem_a_store = (store_phase_k == 0) ? smem_a_0 : ((store_phase_k == 1) ? smem_a_1 : smem_a_2);
|
||||
half *smem_b_store = (store_phase_k == 0) ? smem_b_0 : ((store_phase_k == 1) ? smem_b_1 : smem_b_2);
|
||||
|
||||
// load K=1 elements for the current tile
|
||||
__ldmatrix_a_elems(&a_frag_0_k_1, &smem_a_curr[load_smem_a_0_k_1]);
|
||||
__ldmatrix_a_elems(&a_frag_1_k_1, &smem_a_curr[load_smem_a_1_k_1]);
|
||||
__ldmatrix_a_elems(&a_frag_2_k_1, &smem_a_curr[load_smem_a_2_k_1]);
|
||||
__ldmatrix_a_elems(&a_frag_3_k_1, &smem_a_curr[load_smem_a_3_k_1]);
|
||||
__ldmatrix_b_elems(&b_frag_0_k_1, &b_frag_1_k_1, &smem_b_curr[load_smem_b_0_k_1]);
|
||||
__ldmatrix_b_elems(&b_frag_2_k_1, &b_frag_3_k_1, &smem_b_curr[load_smem_b_1_k_1]);
|
||||
__ldmatrix_b_elems(&b_frag_4_k_1, &b_frag_5_k_1, &smem_b_curr[load_smem_b_2_k_1]);
|
||||
__ldmatrix_b_elems(&b_frag_6_k_1, &b_frag_7_k_1, &smem_b_curr[load_smem_b_3_k_1]);
|
||||
|
||||
// MMA K=0, (M=4 x N=8)
|
||||
acc_frag_0_0 = __WMMA_8_16_16_half_half(a_frag_0_k_0, b_frag_0_k_0, acc_frag_0_0);
|
||||
acc_frag_0_1 = __WMMA_8_16_16_half_half(a_frag_0_k_0, b_frag_1_k_0, acc_frag_0_1);
|
||||
acc_frag_0_2 = __WMMA_8_16_16_half_half(a_frag_0_k_0, b_frag_2_k_0, acc_frag_0_2);
|
||||
acc_frag_0_3 = __WMMA_8_16_16_half_half(a_frag_0_k_0, b_frag_3_k_0, acc_frag_0_3);
|
||||
acc_frag_0_4 = __WMMA_8_16_16_half_half(a_frag_0_k_0, b_frag_4_k_0, acc_frag_0_4);
|
||||
acc_frag_0_5 = __WMMA_8_16_16_half_half(a_frag_0_k_0, b_frag_5_k_0, acc_frag_0_5);
|
||||
acc_frag_0_6 = __WMMA_8_16_16_half_half(a_frag_0_k_0, b_frag_6_k_0, acc_frag_0_6);
|
||||
acc_frag_0_7 = __WMMA_8_16_16_half_half(a_frag_0_k_0, b_frag_7_k_0, acc_frag_0_7);
|
||||
acc_frag_1_0 = __WMMA_8_16_16_half_half(a_frag_1_k_0, b_frag_0_k_0, acc_frag_1_0);
|
||||
acc_frag_1_1 = __WMMA_8_16_16_half_half(a_frag_1_k_0, b_frag_1_k_0, acc_frag_1_1);
|
||||
acc_frag_1_2 = __WMMA_8_16_16_half_half(a_frag_1_k_0, b_frag_2_k_0, acc_frag_1_2);
|
||||
acc_frag_1_3 = __WMMA_8_16_16_half_half(a_frag_1_k_0, b_frag_3_k_0, acc_frag_1_3);
|
||||
acc_frag_1_4 = __WMMA_8_16_16_half_half(a_frag_1_k_0, b_frag_4_k_0, acc_frag_1_4);
|
||||
acc_frag_1_5 = __WMMA_8_16_16_half_half(a_frag_1_k_0, b_frag_5_k_0, acc_frag_1_5);
|
||||
acc_frag_1_6 = __WMMA_8_16_16_half_half(a_frag_1_k_0, b_frag_6_k_0, acc_frag_1_6);
|
||||
acc_frag_1_7 = __WMMA_8_16_16_half_half(a_frag_1_k_0, b_frag_7_k_0, acc_frag_1_7);
|
||||
acc_frag_2_0 = __WMMA_8_16_16_half_half(a_frag_2_k_0, b_frag_0_k_0, acc_frag_2_0);
|
||||
acc_frag_2_1 = __WMMA_8_16_16_half_half(a_frag_2_k_0, b_frag_1_k_0, acc_frag_2_1);
|
||||
acc_frag_2_2 = __WMMA_8_16_16_half_half(a_frag_2_k_0, b_frag_2_k_0, acc_frag_2_2);
|
||||
acc_frag_2_3 = __WMMA_8_16_16_half_half(a_frag_2_k_0, b_frag_3_k_0, acc_frag_2_3);
|
||||
acc_frag_2_4 = __WMMA_8_16_16_half_half(a_frag_2_k_0, b_frag_4_k_0, acc_frag_2_4);
|
||||
acc_frag_2_5 = __WMMA_8_16_16_half_half(a_frag_2_k_0, b_frag_5_k_0, acc_frag_2_5);
|
||||
acc_frag_2_6 = __WMMA_8_16_16_half_half(a_frag_2_k_0, b_frag_6_k_0, acc_frag_2_6);
|
||||
acc_frag_2_7 = __WMMA_8_16_16_half_half(a_frag_2_k_0, b_frag_7_k_0, acc_frag_2_7);
|
||||
acc_frag_3_0 = __WMMA_8_16_16_half_half(a_frag_3_k_0, b_frag_0_k_0, acc_frag_3_0);
|
||||
acc_frag_3_1 = __WMMA_8_16_16_half_half(a_frag_3_k_0, b_frag_1_k_0, acc_frag_3_1);
|
||||
acc_frag_3_2 = __WMMA_8_16_16_half_half(a_frag_3_k_0, b_frag_2_k_0, acc_frag_3_2);
|
||||
acc_frag_3_3 = __WMMA_8_16_16_half_half(a_frag_3_k_0, b_frag_3_k_0, acc_frag_3_3);
|
||||
acc_frag_3_4 = __WMMA_8_16_16_half_half(a_frag_3_k_0, b_frag_4_k_0, acc_frag_3_4);
|
||||
acc_frag_3_5 = __WMMA_8_16_16_half_half(a_frag_3_k_0, b_frag_5_k_0, acc_frag_3_5);
|
||||
acc_frag_3_6 = __WMMA_8_16_16_half_half(a_frag_3_k_0, b_frag_6_k_0, acc_frag_3_6);
|
||||
acc_frag_3_7 = __WMMA_8_16_16_half_half(a_frag_3_k_0, b_frag_7_k_0, acc_frag_3_7);
|
||||
|
||||
// load next tile
|
||||
if (block_k < (num_k_blocks-2)) {
|
||||
// unswizzled 256 x 32
|
||||
// __pipeline_memcpy_async(&smem_a_store[store_smem_a_off + ( 0)], &data1[global_a_off + ( 0)], 16);
|
||||
// __pipeline_memcpy_async(&smem_a_store[store_smem_a_off + ( 64*32)], &data1[global_a_off + ( 64*K)], 16);
|
||||
// __pipeline_memcpy_async(&smem_a_store[store_smem_a_off + (128*32)], &data1[global_a_off + (128*K)], 16);
|
||||
// __pipeline_memcpy_async(&smem_a_store[store_smem_a_off + (192*32)], &data1[global_a_off + (192*K)], 16);
|
||||
// unswizzled 128 x 64
|
||||
__pipeline_memcpy_async(&smem_a_store[store_smem_a_off + ( 0)], &data1[global_a_off + ( 0)], 16);
|
||||
__pipeline_memcpy_async(&smem_a_store[store_smem_a_off + ( 32*64)], &data1[global_a_off + ( 32*K)], 16);
|
||||
__pipeline_memcpy_async(&smem_a_store[store_smem_a_off + ( 64*64)], &data1[global_a_off + ( 64*K)], 16);
|
||||
__pipeline_memcpy_async(&smem_a_store[store_smem_a_off + ( 96*64)], &data1[global_a_off + ( 96*K)], 16);
|
||||
__pipeline_memcpy_async(&smem_b_store[store_smem_b_off + ( 0)], &data2[global_b_off + ( 0)], 16);
|
||||
__pipeline_memcpy_async(&smem_b_store[store_smem_b_off + (16*128)], &data2[global_b_off + ( 16*N)], 16);
|
||||
global_a_off += 32;
|
||||
global_b_off += 32 * N;
|
||||
}
|
||||
__pipeline_commit();
|
||||
|
||||
// wait next tile
|
||||
__pipeline_wait_prior(1);
|
||||
__syncthreads();
|
||||
|
||||
// load K=0 for the next tile
|
||||
__ldmatrix_a_elems(&a_frag_0_k_0, &smem_a_next[load_smem_a_0_k_0]);
|
||||
__ldmatrix_a_elems(&a_frag_1_k_0, &smem_a_next[load_smem_a_1_k_0]);
|
||||
__ldmatrix_a_elems(&a_frag_2_k_0, &smem_a_next[load_smem_a_2_k_0]);
|
||||
__ldmatrix_a_elems(&a_frag_3_k_0, &smem_a_next[load_smem_a_3_k_0]);
|
||||
__ldmatrix_b_elems(&b_frag_0_k_0, &b_frag_1_k_0, &smem_b_next[load_smem_b_0_k_0]);
|
||||
__ldmatrix_b_elems(&b_frag_2_k_0, &b_frag_3_k_0, &smem_b_next[load_smem_b_1_k_0]);
|
||||
__ldmatrix_b_elems(&b_frag_4_k_0, &b_frag_5_k_0, &smem_b_next[load_smem_b_2_k_0]);
|
||||
__ldmatrix_b_elems(&b_frag_6_k_0, &b_frag_7_k_0, &smem_b_next[load_smem_b_3_k_0]);
|
||||
|
||||
// MMA K=1, (M=4 x N=8)
|
||||
acc_frag_0_0 = __WMMA_8_16_16_half_half(a_frag_0_k_1, b_frag_0_k_1, acc_frag_0_0);
|
||||
acc_frag_0_1 = __WMMA_8_16_16_half_half(a_frag_0_k_1, b_frag_1_k_1, acc_frag_0_1);
|
||||
acc_frag_0_2 = __WMMA_8_16_16_half_half(a_frag_0_k_1, b_frag_2_k_1, acc_frag_0_2);
|
||||
acc_frag_0_3 = __WMMA_8_16_16_half_half(a_frag_0_k_1, b_frag_3_k_1, acc_frag_0_3);
|
||||
acc_frag_0_4 = __WMMA_8_16_16_half_half(a_frag_0_k_1, b_frag_4_k_1, acc_frag_0_4);
|
||||
acc_frag_0_5 = __WMMA_8_16_16_half_half(a_frag_0_k_1, b_frag_5_k_1, acc_frag_0_5);
|
||||
acc_frag_0_6 = __WMMA_8_16_16_half_half(a_frag_0_k_1, b_frag_6_k_1, acc_frag_0_6);
|
||||
acc_frag_0_7 = __WMMA_8_16_16_half_half(a_frag_0_k_1, b_frag_7_k_1, acc_frag_0_7);
|
||||
acc_frag_1_0 = __WMMA_8_16_16_half_half(a_frag_1_k_1, b_frag_0_k_1, acc_frag_1_0);
|
||||
acc_frag_1_1 = __WMMA_8_16_16_half_half(a_frag_1_k_1, b_frag_1_k_1, acc_frag_1_1);
|
||||
acc_frag_1_2 = __WMMA_8_16_16_half_half(a_frag_1_k_1, b_frag_2_k_1, acc_frag_1_2);
|
||||
acc_frag_1_3 = __WMMA_8_16_16_half_half(a_frag_1_k_1, b_frag_3_k_1, acc_frag_1_3);
|
||||
acc_frag_1_4 = __WMMA_8_16_16_half_half(a_frag_1_k_1, b_frag_4_k_1, acc_frag_1_4);
|
||||
acc_frag_1_5 = __WMMA_8_16_16_half_half(a_frag_1_k_1, b_frag_5_k_1, acc_frag_1_5);
|
||||
acc_frag_1_6 = __WMMA_8_16_16_half_half(a_frag_1_k_1, b_frag_6_k_1, acc_frag_1_6);
|
||||
acc_frag_1_7 = __WMMA_8_16_16_half_half(a_frag_1_k_1, b_frag_7_k_1, acc_frag_1_7);
|
||||
acc_frag_2_0 = __WMMA_8_16_16_half_half(a_frag_2_k_1, b_frag_0_k_1, acc_frag_2_0);
|
||||
acc_frag_2_1 = __WMMA_8_16_16_half_half(a_frag_2_k_1, b_frag_1_k_1, acc_frag_2_1);
|
||||
acc_frag_2_2 = __WMMA_8_16_16_half_half(a_frag_2_k_1, b_frag_2_k_1, acc_frag_2_2);
|
||||
acc_frag_2_3 = __WMMA_8_16_16_half_half(a_frag_2_k_1, b_frag_3_k_1, acc_frag_2_3);
|
||||
acc_frag_2_4 = __WMMA_8_16_16_half_half(a_frag_2_k_1, b_frag_4_k_1, acc_frag_2_4);
|
||||
acc_frag_2_5 = __WMMA_8_16_16_half_half(a_frag_2_k_1, b_frag_5_k_1, acc_frag_2_5);
|
||||
acc_frag_2_6 = __WMMA_8_16_16_half_half(a_frag_2_k_1, b_frag_6_k_1, acc_frag_2_6);
|
||||
acc_frag_2_7 = __WMMA_8_16_16_half_half(a_frag_2_k_1, b_frag_7_k_1, acc_frag_2_7);
|
||||
acc_frag_3_0 = __WMMA_8_16_16_half_half(a_frag_3_k_1, b_frag_0_k_1, acc_frag_3_0);
|
||||
acc_frag_3_1 = __WMMA_8_16_16_half_half(a_frag_3_k_1, b_frag_1_k_1, acc_frag_3_1);
|
||||
acc_frag_3_2 = __WMMA_8_16_16_half_half(a_frag_3_k_1, b_frag_2_k_1, acc_frag_3_2);
|
||||
acc_frag_3_3 = __WMMA_8_16_16_half_half(a_frag_3_k_1, b_frag_3_k_1, acc_frag_3_3);
|
||||
acc_frag_3_4 = __WMMA_8_16_16_half_half(a_frag_3_k_1, b_frag_4_k_1, acc_frag_3_4);
|
||||
acc_frag_3_5 = __WMMA_8_16_16_half_half(a_frag_3_k_1, b_frag_5_k_1, acc_frag_3_5);
|
||||
acc_frag_3_6 = __WMMA_8_16_16_half_half(a_frag_3_k_1, b_frag_6_k_1, acc_frag_3_6);
|
||||
acc_frag_3_7 = __WMMA_8_16_16_half_half(a_frag_3_k_1, b_frag_7_k_1, acc_frag_3_7);
|
||||
}
|
||||
|
||||
// write accumulators to output
|
||||
__pipeline_wait_prior(0);
|
||||
__syncthreads();
|
||||
|
||||
// slower way: write accs one by one to data0
|
||||
|
||||
size_t wg_c_off = ((grid_m * 256) * N) + (grid_n * 128) + (wg_m * 16 * N) + (wg_n * 16);
|
||||
size_t thread_c_off = ((wg_threads % 4) * 2) + (((wg_threads / 4) % 8) * N);
|
||||
data0[wg_c_off + thread_c_off + 0 + ( 0*8)] = acc_frag_0_0.x;
|
||||
data0[wg_c_off + thread_c_off + 1 + ( 0*8)] = acc_frag_0_0.y;
|
||||
data0[wg_c_off + thread_c_off + (8 * N) + 0 + ( 0*8)] = acc_frag_0_0.z;
|
||||
data0[wg_c_off + thread_c_off + (8 * N) + 1 + ( 0*8)] = acc_frag_0_0.w;
|
||||
data0[wg_c_off + thread_c_off + 0 + ( 1*8)] = acc_frag_0_1.x;
|
||||
data0[wg_c_off + thread_c_off + 1 + ( 1*8)] = acc_frag_0_1.y;
|
||||
data0[wg_c_off + thread_c_off + (8 * N) + 0 + ( 1*8)] = acc_frag_0_1.z;
|
||||
data0[wg_c_off + thread_c_off + (8 * N) + 1 + ( 1*8)] = acc_frag_0_1.w;
|
||||
data0[wg_c_off + thread_c_off + 0 + ( 4*8)] = acc_frag_0_2.x;
|
||||
data0[wg_c_off + thread_c_off + 1 + ( 4*8)] = acc_frag_0_2.y;
|
||||
data0[wg_c_off + thread_c_off + (8 * N) + 0 + ( 4*8)] = acc_frag_0_2.z;
|
||||
data0[wg_c_off + thread_c_off + (8 * N) + 1 + ( 4*8)] = acc_frag_0_2.w;
|
||||
data0[wg_c_off + thread_c_off + 0 + ( 5*8)] = acc_frag_0_3.x;
|
||||
data0[wg_c_off + thread_c_off + 1 + ( 5*8)] = acc_frag_0_3.y;
|
||||
data0[wg_c_off + thread_c_off + (8 * N) + 0 + ( 5*8)] = acc_frag_0_3.z;
|
||||
data0[wg_c_off + thread_c_off + (8 * N) + 1 + ( 5*8)] = acc_frag_0_3.w;
|
||||
data0[wg_c_off + thread_c_off + 0 + ( 8*8)] = acc_frag_0_4.x;
|
||||
data0[wg_c_off + thread_c_off + 1 + ( 8*8)] = acc_frag_0_4.y;
|
||||
data0[wg_c_off + thread_c_off + (8 * N) + 0 + ( 8*8)] = acc_frag_0_4.z;
|
||||
data0[wg_c_off + thread_c_off + (8 * N) + 1 + ( 8*8)] = acc_frag_0_4.w;
|
||||
data0[wg_c_off + thread_c_off + 0 + ( 9*8)] = acc_frag_0_5.x;
|
||||
data0[wg_c_off + thread_c_off + 1 + ( 9*8)] = acc_frag_0_5.y;
|
||||
data0[wg_c_off + thread_c_off + (8 * N) + 0 + ( 9*8)] = acc_frag_0_5.z;
|
||||
data0[wg_c_off + thread_c_off + (8 * N) + 1 + ( 9*8)] = acc_frag_0_5.w;
|
||||
data0[wg_c_off + thread_c_off + 0 + (12*8)] = acc_frag_0_6.x;
|
||||
data0[wg_c_off + thread_c_off + 1 + (12*8)] = acc_frag_0_6.y;
|
||||
data0[wg_c_off + thread_c_off + (8 * N) + 0 + (12*8)] = acc_frag_0_6.z;
|
||||
data0[wg_c_off + thread_c_off + (8 * N) + 1 + (12*8)] = acc_frag_0_6.w;
|
||||
data0[wg_c_off + thread_c_off + 0 + (13*8)] = acc_frag_0_7.x;
|
||||
data0[wg_c_off + thread_c_off + 1 + (13*8)] = acc_frag_0_7.y;
|
||||
data0[wg_c_off + thread_c_off + (8 * N) + 0 + (13*8)] = acc_frag_0_7.z;
|
||||
data0[wg_c_off + thread_c_off + (8 * N) + 1 + (13*8)] = acc_frag_0_7.w;
|
||||
|
||||
wg_c_off += 64*N;
|
||||
data0[wg_c_off + thread_c_off + 0 + ( 0*8)] = acc_frag_1_0.x;
|
||||
data0[wg_c_off + thread_c_off + 1 + ( 0*8)] = acc_frag_1_0.y;
|
||||
data0[wg_c_off + thread_c_off + (8 * N) + 0 + ( 0*8)] = acc_frag_1_0.z;
|
||||
data0[wg_c_off + thread_c_off + (8 * N) + 1 + ( 0*8)] = acc_frag_1_0.w;
|
||||
data0[wg_c_off + thread_c_off + 0 + ( 1*8)] = acc_frag_1_1.x;
|
||||
data0[wg_c_off + thread_c_off + 1 + ( 1*8)] = acc_frag_1_1.y;
|
||||
data0[wg_c_off + thread_c_off + (8 * N) + 0 + ( 1*8)] = acc_frag_1_1.z;
|
||||
data0[wg_c_off + thread_c_off + (8 * N) + 1 + ( 1*8)] = acc_frag_1_1.w;
|
||||
data0[wg_c_off + thread_c_off + 0 + ( 4*8)] = acc_frag_1_2.x;
|
||||
data0[wg_c_off + thread_c_off + 1 + ( 4*8)] = acc_frag_1_2.y;
|
||||
data0[wg_c_off + thread_c_off + (8 * N) + 0 + ( 4*8)] = acc_frag_1_2.z;
|
||||
data0[wg_c_off + thread_c_off + (8 * N) + 1 + ( 4*8)] = acc_frag_1_2.w;
|
||||
data0[wg_c_off + thread_c_off + 0 + ( 5*8)] = acc_frag_1_3.x;
|
||||
data0[wg_c_off + thread_c_off + 1 + ( 5*8)] = acc_frag_1_3.y;
|
||||
data0[wg_c_off + thread_c_off + (8 * N) + 0 + ( 5*8)] = acc_frag_1_3.z;
|
||||
data0[wg_c_off + thread_c_off + (8 * N) + 1 + ( 5*8)] = acc_frag_1_3.w;
|
||||
data0[wg_c_off + thread_c_off + 0 + ( 8*8)] = acc_frag_1_4.x;
|
||||
data0[wg_c_off + thread_c_off + 1 + ( 8*8)] = acc_frag_1_4.y;
|
||||
data0[wg_c_off + thread_c_off + (8 * N) + 0 + ( 8*8)] = acc_frag_1_4.z;
|
||||
data0[wg_c_off + thread_c_off + (8 * N) + 1 + ( 8*8)] = acc_frag_1_4.w;
|
||||
data0[wg_c_off + thread_c_off + 0 + ( 9*8)] = acc_frag_1_5.x;
|
||||
data0[wg_c_off + thread_c_off + 1 + ( 9*8)] = acc_frag_1_5.y;
|
||||
data0[wg_c_off + thread_c_off + (8 * N) + 0 + ( 9*8)] = acc_frag_1_5.z;
|
||||
data0[wg_c_off + thread_c_off + (8 * N) + 1 + ( 9*8)] = acc_frag_1_5.w;
|
||||
data0[wg_c_off + thread_c_off + 0 + (12*8)] = acc_frag_1_6.x;
|
||||
data0[wg_c_off + thread_c_off + 1 + (12*8)] = acc_frag_1_6.y;
|
||||
data0[wg_c_off + thread_c_off + (8 * N) + 0 + (12*8)] = acc_frag_1_6.z;
|
||||
data0[wg_c_off + thread_c_off + (8 * N) + 1 + (12*8)] = acc_frag_1_6.w;
|
||||
data0[wg_c_off + thread_c_off + 0 + (13*8)] = acc_frag_1_7.x;
|
||||
data0[wg_c_off + thread_c_off + 1 + (13*8)] = acc_frag_1_7.y;
|
||||
data0[wg_c_off + thread_c_off + (8 * N) + 0 + (13*8)] = acc_frag_1_7.z;
|
||||
data0[wg_c_off + thread_c_off + (8 * N) + 1 + (13*8)] = acc_frag_1_7.w;
|
||||
|
||||
wg_c_off += 64*N;
|
||||
data0[wg_c_off + thread_c_off + 0 + ( 0*8)] = acc_frag_2_0.x;
|
||||
data0[wg_c_off + thread_c_off + 1 + ( 0*8)] = acc_frag_2_0.y;
|
||||
data0[wg_c_off + thread_c_off + (8 * N) + 0 + ( 0*8)] = acc_frag_2_0.z;
|
||||
data0[wg_c_off + thread_c_off + (8 * N) + 1 + ( 0*8)] = acc_frag_2_0.w;
|
||||
data0[wg_c_off + thread_c_off + 0 + ( 1*8)] = acc_frag_2_1.x;
|
||||
data0[wg_c_off + thread_c_off + 1 + ( 1*8)] = acc_frag_2_1.y;
|
||||
data0[wg_c_off + thread_c_off + (8 * N) + 0 + ( 1*8)] = acc_frag_2_1.z;
|
||||
data0[wg_c_off + thread_c_off + (8 * N) + 1 + ( 1*8)] = acc_frag_2_1.w;
|
||||
data0[wg_c_off + thread_c_off + 0 + ( 4*8)] = acc_frag_2_2.x;
|
||||
data0[wg_c_off + thread_c_off + 1 + ( 4*8)] = acc_frag_2_2.y;
|
||||
data0[wg_c_off + thread_c_off + (8 * N) + 0 + ( 4*8)] = acc_frag_2_2.z;
|
||||
data0[wg_c_off + thread_c_off + (8 * N) + 1 + ( 4*8)] = acc_frag_2_2.w;
|
||||
data0[wg_c_off + thread_c_off + 0 + ( 5*8)] = acc_frag_2_3.x;
|
||||
data0[wg_c_off + thread_c_off + 1 + ( 5*8)] = acc_frag_2_3.y;
|
||||
data0[wg_c_off + thread_c_off + (8 * N) + 0 + ( 5*8)] = acc_frag_2_3.z;
|
||||
data0[wg_c_off + thread_c_off + (8 * N) + 1 + ( 5*8)] = acc_frag_2_3.w;
|
||||
data0[wg_c_off + thread_c_off + 0 + ( 8*8)] = acc_frag_2_4.x;
|
||||
data0[wg_c_off + thread_c_off + 1 + ( 8*8)] = acc_frag_2_4.y;
|
||||
data0[wg_c_off + thread_c_off + (8 * N) + 0 + ( 8*8)] = acc_frag_2_4.z;
|
||||
data0[wg_c_off + thread_c_off + (8 * N) + 1 + ( 8*8)] = acc_frag_2_4.w;
|
||||
data0[wg_c_off + thread_c_off + 0 + ( 9*8)] = acc_frag_2_5.x;
|
||||
data0[wg_c_off + thread_c_off + 1 + ( 9*8)] = acc_frag_2_5.y;
|
||||
data0[wg_c_off + thread_c_off + (8 * N) + 0 + ( 9*8)] = acc_frag_2_5.z;
|
||||
data0[wg_c_off + thread_c_off + (8 * N) + 1 + ( 9*8)] = acc_frag_2_5.w;
|
||||
data0[wg_c_off + thread_c_off + 0 + (12*8)] = acc_frag_2_6.x;
|
||||
data0[wg_c_off + thread_c_off + 1 + (12*8)] = acc_frag_2_6.y;
|
||||
data0[wg_c_off + thread_c_off + (8 * N) + 0 + (12*8)] = acc_frag_2_6.z;
|
||||
data0[wg_c_off + thread_c_off + (8 * N) + 1 + (12*8)] = acc_frag_2_6.w;
|
||||
data0[wg_c_off + thread_c_off + 0 + (13*8)] = acc_frag_2_7.x;
|
||||
data0[wg_c_off + thread_c_off + 1 + (13*8)] = acc_frag_2_7.y;
|
||||
data0[wg_c_off + thread_c_off + (8 * N) + 0 + (13*8)] = acc_frag_2_7.z;
|
||||
data0[wg_c_off + thread_c_off + (8 * N) + 1 + (13*8)] = acc_frag_2_7.w;
|
||||
|
||||
wg_c_off += 64*N;
|
||||
data0[wg_c_off + thread_c_off + 0 + ( 0*8)] = acc_frag_3_0.x;
|
||||
data0[wg_c_off + thread_c_off + 1 + ( 0*8)] = acc_frag_3_0.y;
|
||||
data0[wg_c_off + thread_c_off + (8 * N) + 0 + ( 0*8)] = acc_frag_3_0.z;
|
||||
data0[wg_c_off + thread_c_off + (8 * N) + 1 + ( 0*8)] = acc_frag_3_0.w;
|
||||
data0[wg_c_off + thread_c_off + 0 + ( 1*8)] = acc_frag_3_1.x;
|
||||
data0[wg_c_off + thread_c_off + 1 + ( 1*8)] = acc_frag_3_1.y;
|
||||
data0[wg_c_off + thread_c_off + (8 * N) + 0 + ( 1*8)] = acc_frag_3_1.z;
|
||||
data0[wg_c_off + thread_c_off + (8 * N) + 1 + ( 1*8)] = acc_frag_3_1.w;
|
||||
data0[wg_c_off + thread_c_off + 0 + ( 4*8)] = acc_frag_3_2.x;
|
||||
data0[wg_c_off + thread_c_off + 1 + ( 4*8)] = acc_frag_3_2.y;
|
||||
data0[wg_c_off + thread_c_off + (8 * N) + 0 + ( 4*8)] = acc_frag_3_2.z;
|
||||
data0[wg_c_off + thread_c_off + (8 * N) + 1 + ( 4*8)] = acc_frag_3_2.w;
|
||||
data0[wg_c_off + thread_c_off + 0 + ( 5*8)] = acc_frag_3_3.x;
|
||||
data0[wg_c_off + thread_c_off + 1 + ( 5*8)] = acc_frag_3_3.y;
|
||||
data0[wg_c_off + thread_c_off + (8 * N) + 0 + ( 5*8)] = acc_frag_3_3.z;
|
||||
data0[wg_c_off + thread_c_off + (8 * N) + 1 + ( 5*8)] = acc_frag_3_3.w;
|
||||
data0[wg_c_off + thread_c_off + 0 + ( 8*8)] = acc_frag_3_4.x;
|
||||
data0[wg_c_off + thread_c_off + 1 + ( 8*8)] = acc_frag_3_4.y;
|
||||
data0[wg_c_off + thread_c_off + (8 * N) + 0 + ( 8*8)] = acc_frag_3_4.z;
|
||||
data0[wg_c_off + thread_c_off + (8 * N) + 1 + ( 8*8)] = acc_frag_3_4.w;
|
||||
data0[wg_c_off + thread_c_off + 0 + ( 9*8)] = acc_frag_3_5.x;
|
||||
data0[wg_c_off + thread_c_off + 1 + ( 9*8)] = acc_frag_3_5.y;
|
||||
data0[wg_c_off + thread_c_off + (8 * N) + 0 + ( 9*8)] = acc_frag_3_5.z;
|
||||
data0[wg_c_off + thread_c_off + (8 * N) + 1 + ( 9*8)] = acc_frag_3_5.w;
|
||||
data0[wg_c_off + thread_c_off + 0 + (12*8)] = acc_frag_3_6.x;
|
||||
data0[wg_c_off + thread_c_off + 1 + (12*8)] = acc_frag_3_6.y;
|
||||
data0[wg_c_off + thread_c_off + (8 * N) + 0 + (12*8)] = acc_frag_3_6.z;
|
||||
data0[wg_c_off + thread_c_off + (8 * N) + 1 + (12*8)] = acc_frag_3_6.w;
|
||||
data0[wg_c_off + thread_c_off + 0 + (13*8)] = acc_frag_3_7.x;
|
||||
data0[wg_c_off + thread_c_off + 1 + (13*8)] = acc_frag_3_7.y;
|
||||
data0[wg_c_off + thread_c_off + (8 * N) + 0 + (13*8)] = acc_frag_3_7.z;
|
||||
data0[wg_c_off + thread_c_off + (8 * N) + 1 + (13*8)] = acc_frag_3_7.w;
|
||||
|
||||
}
|
||||
482
extra/gemm/max_kernels/nv.fp16_fp16_fp16.max.cu
Normal file
482
extra/gemm/max_kernels/nv.fp16_fp16_fp16.max.cu
Normal file
@@ -0,0 +1,482 @@
|
||||
#define INFINITY (__int_as_float(0x7f800000))
|
||||
#define NAN (__int_as_float(0x7fffffff))
|
||||
#include <cuda_fp16.h>
|
||||
#include <cuda_pipeline.h>
|
||||
#define SMEM_N_WIDTH 136
|
||||
|
||||
struct __align__(8) half4 { half x, y, z, w; };
|
||||
__device__ half4 make_half4(half x, half y, half z, half w) { half4 r={x, y, z, w}; return r; }
|
||||
|
||||
struct __align__(16) half8 { half x, y, z, w, a, b, c, d; };
|
||||
__device__ half8 make_half8(half x, half y, half z, half w, half a, half b, half c, half d) { half8 r={x, y, z, w, a, b, c, d}; return r; }
|
||||
|
||||
__device__ void __ldmatrix_a_elems(half8 *regs, half *smem) {
|
||||
uint32_t reg0, reg1, reg2, reg3;
|
||||
asm volatile(
|
||||
"ldmatrix.sync.aligned.m8n8.x4.shared.b16 {%0, %1, %2, %3}, [%4];"
|
||||
: "=r"(reg0), "=r"(reg1), "=r"(reg2), "=r"(reg3)
|
||||
: "l"(__cvta_generic_to_shared(smem))
|
||||
);
|
||||
uint32_t *addr = reinterpret_cast<uint32_t*>(regs);
|
||||
addr[0] = reg0;
|
||||
addr[1] = reg1;
|
||||
addr[2] = reg2;
|
||||
addr[3] = reg3;
|
||||
}
|
||||
|
||||
__device__ void __ldmatrix_b_elems(half4 *regs_lo, half4 *regs_hi, half *smem) {
|
||||
uint32_t reg0, reg1, reg2, reg3;
|
||||
asm volatile(
|
||||
"ldmatrix.sync.aligned.m8n8.x4.trans.shared.b16 {%0, %1, %2, %3}, [%4];"
|
||||
: "=r"(reg0), "=r"(reg1), "=r"(reg2), "=r"(reg3)
|
||||
: "l"(__cvta_generic_to_shared(smem))
|
||||
);
|
||||
uint32_t *addr_lo = reinterpret_cast<uint32_t*>(regs_lo);
|
||||
uint32_t *addr_hi = reinterpret_cast<uint32_t*>(regs_hi);
|
||||
addr_lo[0] = reg0;
|
||||
addr_lo[1] = reg1;
|
||||
addr_hi[0] = reg2;
|
||||
addr_hi[1] = reg3;
|
||||
}
|
||||
|
||||
__device__ half4 __WMMA_8_16_16_half_half(half8 a, half4 b, half4 c) {
|
||||
int *a_pk = (int *) (&a), *b_pk = (int *) (&b), *c_pk = (int *) (&c);
|
||||
asm( "mma.sync.aligned.m16n8k16.row.col.f16.f16.f16.f16 { %0, %1 }, { %2, %3, %4, %5 }, { %6, %7 }, { %0, %1 };"
|
||||
: "+r"(c_pk[0]), "+r"(c_pk[1]): "r"(a_pk[0]), "r"(a_pk[1]), "r"(a_pk[2]), "r"(a_pk[3]), "r"(b_pk[0]), "r"(b_pk[1]) );
|
||||
return c;
|
||||
}
|
||||
|
||||
extern "C" __global__ void __launch_bounds__(256) wmma_example(half* data0, const half* data1, const half* data2, int N, int K) {
|
||||
extern __shared__ char smem[];
|
||||
half *smem_a_0 = (half *)(smem);
|
||||
half *smem_a_1 = (half *)(smem + 16384);
|
||||
half *smem_a_2 = (half *)(smem + 32768);
|
||||
half *smem_b_0 = (half *)(smem + 49152);
|
||||
half *smem_b_1 = (half *)(smem + 57344);
|
||||
half *smem_b_2 = (half *)(smem + 65536);
|
||||
|
||||
int grid_m = blockIdx.x; /* M//256 */
|
||||
int grid_n = blockIdx.y; /* N//128 */
|
||||
int wg_threads = threadIdx.x; // 32
|
||||
int wg_m = threadIdx.y; // 4
|
||||
int wg_n = threadIdx.z; // 2
|
||||
int threads = threadIdx.x + (threadIdx.y * 32) + (threadIdx.z * 128); /* 256 */
|
||||
int num_k_blocks = K / 32;
|
||||
|
||||
// ldmatrix indices - 4x loads of 8x8 matrices by 32 threads
|
||||
// threads 0-7 are row starts for A, 8-15 for B, 16-23 for C, 24-31 for D
|
||||
// [ A | C ]
|
||||
// [ - + - ]
|
||||
// [ B | D ]
|
||||
|
||||
// swizzled A - SMEM_A is 128 rows x 64 cols
|
||||
size_t global_a_off = ((grid_m * 256) * K) + ((threads % 4) * 8) + (((threads / 4) % 2) * 8 * 16 * K) + ((threads / 8) * K);
|
||||
size_t store_smem_a_off = ((threads / 8) * 64) + (((threads * 8) ^ threads) & 56); // 32 rows / 64 cols per copy
|
||||
size_t load_smem_a_row = ((wg_m * 16) + (threads % 16)) * 64;
|
||||
size_t load_smem_a_phase = (threads / 16) % 2;
|
||||
size_t load_smem_a_0_k_0 = load_smem_a_row + ( 0 * 64) + (((load_smem_a_phase + 0) ^ (threads % 8)) * 8);
|
||||
size_t load_smem_a_1_k_0 = load_smem_a_row + (64 * 64) + (((load_smem_a_phase + 0) ^ (threads % 8)) * 8);
|
||||
size_t load_smem_a_2_k_0 = load_smem_a_row + ( 0 * 64) + (((load_smem_a_phase + 4) ^ (threads % 8)) * 8);
|
||||
size_t load_smem_a_3_k_0 = load_smem_a_row + (64 * 64) + (((load_smem_a_phase + 4) ^ (threads % 8)) * 8);
|
||||
size_t load_smem_a_0_k_1 = load_smem_a_row + ( 0 * 64) + (((load_smem_a_phase + 2) ^ (threads % 8)) * 8);
|
||||
size_t load_smem_a_1_k_1 = load_smem_a_row + (64 * 64) + (((load_smem_a_phase + 2) ^ (threads % 8)) * 8);
|
||||
size_t load_smem_a_2_k_1 = load_smem_a_row + ( 0 * 64) + (((load_smem_a_phase + 6) ^ (threads % 8)) * 8);
|
||||
size_t load_smem_a_3_k_1 = load_smem_a_row + (64 * 64) + (((load_smem_a_phase + 6) ^ (threads % 8)) * 8);
|
||||
|
||||
// swizzled B - SMEM_B is 32 rows x 128 cols
|
||||
size_t global_b_off = (grid_n * 128) + ((threads % 16) * 8) + ((threads / 16) * N);
|
||||
size_t store_smem_b_off = ((threads / 16) * 128) + ((((threads / 16) % 8) * 8) ^ ((threads % 16) * 8)); // 16 rows / 128 cols per copy
|
||||
size_t load_smem_b_row = (threads % 16) * 128;
|
||||
size_t load_smem_b_phase = (wg_n * 2) + (wg_threads / 16);
|
||||
size_t load_smem_b_0_k_0 = load_smem_b_row + (((load_smem_b_phase + 0) ^ (threads % 8)) * 8);
|
||||
size_t load_smem_b_1_k_0 = load_smem_b_row + (((load_smem_b_phase + 4) ^ (threads % 8)) * 8);
|
||||
size_t load_smem_b_2_k_0 = load_smem_b_row + (((load_smem_b_phase + 8) ^ (threads % 8)) * 8);
|
||||
size_t load_smem_b_3_k_0 = load_smem_b_row + (((load_smem_b_phase + 12) ^ (threads % 8)) * 8);
|
||||
size_t load_smem_b_0_k_1 = load_smem_b_0_k_0 + (16 * 128);
|
||||
size_t load_smem_b_1_k_1 = load_smem_b_1_k_0 + (16 * 128);
|
||||
size_t load_smem_b_2_k_1 = load_smem_b_2_k_0 + (16 * 128);
|
||||
size_t load_smem_b_3_k_1 = load_smem_b_3_k_0 + (16 * 128);
|
||||
|
||||
// create accs (M=4, N=8)
|
||||
half4 acc_frag_0_0 = make_half4(0.0f,0.0f,0.0f,0.0f);
|
||||
half4 acc_frag_0_1 = make_half4(0.0f,0.0f,0.0f,0.0f);
|
||||
half4 acc_frag_0_2 = make_half4(0.0f,0.0f,0.0f,0.0f);
|
||||
half4 acc_frag_0_3 = make_half4(0.0f,0.0f,0.0f,0.0f);
|
||||
half4 acc_frag_0_4 = make_half4(0.0f,0.0f,0.0f,0.0f);
|
||||
half4 acc_frag_0_5 = make_half4(0.0f,0.0f,0.0f,0.0f);
|
||||
half4 acc_frag_0_6 = make_half4(0.0f,0.0f,0.0f,0.0f);
|
||||
half4 acc_frag_0_7 = make_half4(0.0f,0.0f,0.0f,0.0f);
|
||||
half4 acc_frag_1_0 = make_half4(0.0f,0.0f,0.0f,0.0f);
|
||||
half4 acc_frag_1_1 = make_half4(0.0f,0.0f,0.0f,0.0f);
|
||||
half4 acc_frag_1_2 = make_half4(0.0f,0.0f,0.0f,0.0f);
|
||||
half4 acc_frag_1_3 = make_half4(0.0f,0.0f,0.0f,0.0f);
|
||||
half4 acc_frag_1_4 = make_half4(0.0f,0.0f,0.0f,0.0f);
|
||||
half4 acc_frag_1_5 = make_half4(0.0f,0.0f,0.0f,0.0f);
|
||||
half4 acc_frag_1_6 = make_half4(0.0f,0.0f,0.0f,0.0f);
|
||||
half4 acc_frag_1_7 = make_half4(0.0f,0.0f,0.0f,0.0f);
|
||||
half4 acc_frag_2_0 = make_half4(0.0f,0.0f,0.0f,0.0f);
|
||||
half4 acc_frag_2_1 = make_half4(0.0f,0.0f,0.0f,0.0f);
|
||||
half4 acc_frag_2_2 = make_half4(0.0f,0.0f,0.0f,0.0f);
|
||||
half4 acc_frag_2_3 = make_half4(0.0f,0.0f,0.0f,0.0f);
|
||||
half4 acc_frag_2_4 = make_half4(0.0f,0.0f,0.0f,0.0f);
|
||||
half4 acc_frag_2_5 = make_half4(0.0f,0.0f,0.0f,0.0f);
|
||||
half4 acc_frag_2_6 = make_half4(0.0f,0.0f,0.0f,0.0f);
|
||||
half4 acc_frag_2_7 = make_half4(0.0f,0.0f,0.0f,0.0f);
|
||||
half4 acc_frag_3_0 = make_half4(0.0f,0.0f,0.0f,0.0f);
|
||||
half4 acc_frag_3_1 = make_half4(0.0f,0.0f,0.0f,0.0f);
|
||||
half4 acc_frag_3_2 = make_half4(0.0f,0.0f,0.0f,0.0f);
|
||||
half4 acc_frag_3_3 = make_half4(0.0f,0.0f,0.0f,0.0f);
|
||||
half4 acc_frag_3_4 = make_half4(0.0f,0.0f,0.0f,0.0f);
|
||||
half4 acc_frag_3_5 = make_half4(0.0f,0.0f,0.0f,0.0f);
|
||||
half4 acc_frag_3_6 = make_half4(0.0f,0.0f,0.0f,0.0f);
|
||||
half4 acc_frag_3_7 = make_half4(0.0f,0.0f,0.0f,0.0f);
|
||||
|
||||
// create registers for block A elements
|
||||
half8 a_frag_0_k_0;
|
||||
half8 a_frag_1_k_0;
|
||||
half8 a_frag_2_k_0;
|
||||
half8 a_frag_3_k_0;
|
||||
half8 a_frag_0_k_1;
|
||||
half8 a_frag_1_k_1;
|
||||
half8 a_frag_2_k_1;
|
||||
half8 a_frag_3_k_1;
|
||||
|
||||
// create register for block B elements
|
||||
half4 b_frag_0_k_0;
|
||||
half4 b_frag_1_k_0;
|
||||
half4 b_frag_2_k_0;
|
||||
half4 b_frag_3_k_0;
|
||||
half4 b_frag_4_k_0;
|
||||
half4 b_frag_5_k_0;
|
||||
half4 b_frag_6_k_0;
|
||||
half4 b_frag_7_k_0;
|
||||
half4 b_frag_0_k_1;
|
||||
half4 b_frag_1_k_1;
|
||||
half4 b_frag_2_k_1;
|
||||
half4 b_frag_3_k_1;
|
||||
half4 b_frag_4_k_1;
|
||||
half4 b_frag_5_k_1;
|
||||
half4 b_frag_6_k_1;
|
||||
half4 b_frag_7_k_1;
|
||||
|
||||
__syncthreads();
|
||||
|
||||
// load first tile
|
||||
__pipeline_memcpy_async(&smem_a_0[store_smem_a_off + ( 0)], &data1[global_a_off + ( 0)], 16);
|
||||
__pipeline_memcpy_async(&smem_a_0[store_smem_a_off + ( 32*64)], &data1[global_a_off + ( 32*K)], 16);
|
||||
__pipeline_memcpy_async(&smem_a_0[store_smem_a_off + ( 64*64)], &data1[global_a_off + ( 64*K)], 16);
|
||||
__pipeline_memcpy_async(&smem_a_0[store_smem_a_off + ( 96*64)], &data1[global_a_off + ( 96*K)], 16);
|
||||
__pipeline_memcpy_async(&smem_b_0[store_smem_b_off + ( 0)], &data2[global_b_off + ( 0)], 16);
|
||||
__pipeline_memcpy_async(&smem_b_0[store_smem_b_off + (16*128)], &data2[global_b_off + ( 16*N)], 16);
|
||||
__pipeline_commit();
|
||||
global_a_off += 32;
|
||||
global_b_off += 32 * N;
|
||||
|
||||
// load second tile
|
||||
__pipeline_memcpy_async(&smem_a_1[store_smem_a_off + ( 0)], &data1[global_a_off + ( 0)], 16);
|
||||
__pipeline_memcpy_async(&smem_a_1[store_smem_a_off + ( 32*64)], &data1[global_a_off + ( 32*K)], 16);
|
||||
__pipeline_memcpy_async(&smem_a_1[store_smem_a_off + ( 64*64)], &data1[global_a_off + ( 64*K)], 16);
|
||||
__pipeline_memcpy_async(&smem_a_1[store_smem_a_off + ( 96*64)], &data1[global_a_off + ( 96*K)], 16);
|
||||
__pipeline_memcpy_async(&smem_b_1[store_smem_b_off + ( 0)], &data2[global_b_off + ( 0)], 16);
|
||||
__pipeline_memcpy_async(&smem_b_1[store_smem_b_off + (16*128)], &data2[global_b_off + ( 16*N)], 16);
|
||||
__pipeline_commit();
|
||||
|
||||
global_a_off += 32;
|
||||
global_b_off += 32 * N;
|
||||
|
||||
// wait on first pre-fetch load
|
||||
__pipeline_wait_prior(1);
|
||||
__syncthreads();
|
||||
|
||||
// load K=0 elements for the first tile
|
||||
__ldmatrix_a_elems(&a_frag_0_k_0, &smem_a_0[load_smem_a_0_k_0]);
|
||||
__ldmatrix_a_elems(&a_frag_1_k_0, &smem_a_0[load_smem_a_1_k_0]);
|
||||
__ldmatrix_a_elems(&a_frag_2_k_0, &smem_a_0[load_smem_a_2_k_0]);
|
||||
__ldmatrix_a_elems(&a_frag_3_k_0, &smem_a_0[load_smem_a_3_k_0]);
|
||||
__ldmatrix_b_elems(&b_frag_0_k_0, &b_frag_1_k_0, &smem_b_0[load_smem_b_0_k_0]);
|
||||
__ldmatrix_b_elems(&b_frag_2_k_0, &b_frag_3_k_0, &smem_b_0[load_smem_b_1_k_0]);
|
||||
__ldmatrix_b_elems(&b_frag_4_k_0, &b_frag_5_k_0, &smem_b_0[load_smem_b_2_k_0]);
|
||||
__ldmatrix_b_elems(&b_frag_6_k_0, &b_frag_7_k_0, &smem_b_0[load_smem_b_3_k_0]);
|
||||
|
||||
for (int block_k = 0; block_k < num_k_blocks; block_k++) {
|
||||
int phase_k = block_k % 3;
|
||||
half *smem_a_curr = (phase_k == 0) ? smem_a_0 : ((phase_k == 1) ? smem_a_1 : smem_a_2);
|
||||
half *smem_b_curr = (phase_k == 0) ? smem_b_0 : ((phase_k == 1) ? smem_b_1 : smem_b_2);
|
||||
|
||||
int next_phase_k = (block_k+1) % 3;
|
||||
half *smem_a_next = (next_phase_k == 0) ? smem_a_0 : ((next_phase_k == 1) ? smem_a_1 : smem_a_2);
|
||||
half *smem_b_next = (next_phase_k == 0) ? smem_b_0 : ((next_phase_k == 1) ? smem_b_1 : smem_b_2);
|
||||
|
||||
int store_phase_k = (block_k+2) % 3;
|
||||
half *smem_a_store = (store_phase_k == 0) ? smem_a_0 : ((store_phase_k == 1) ? smem_a_1 : smem_a_2);
|
||||
half *smem_b_store = (store_phase_k == 0) ? smem_b_0 : ((store_phase_k == 1) ? smem_b_1 : smem_b_2);
|
||||
|
||||
// load K=1 elements for the current tile
|
||||
__ldmatrix_a_elems(&a_frag_0_k_1, &smem_a_curr[load_smem_a_0_k_1]);
|
||||
__ldmatrix_a_elems(&a_frag_1_k_1, &smem_a_curr[load_smem_a_1_k_1]);
|
||||
__ldmatrix_a_elems(&a_frag_2_k_1, &smem_a_curr[load_smem_a_2_k_1]);
|
||||
__ldmatrix_a_elems(&a_frag_3_k_1, &smem_a_curr[load_smem_a_3_k_1]);
|
||||
__ldmatrix_b_elems(&b_frag_0_k_1, &b_frag_1_k_1, &smem_b_curr[load_smem_b_0_k_1]);
|
||||
__ldmatrix_b_elems(&b_frag_2_k_1, &b_frag_3_k_1, &smem_b_curr[load_smem_b_1_k_1]);
|
||||
__ldmatrix_b_elems(&b_frag_4_k_1, &b_frag_5_k_1, &smem_b_curr[load_smem_b_2_k_1]);
|
||||
__ldmatrix_b_elems(&b_frag_6_k_1, &b_frag_7_k_1, &smem_b_curr[load_smem_b_3_k_1]);
|
||||
|
||||
// MMA K=0, (M=4 x N=8)
|
||||
acc_frag_0_0 = __WMMA_8_16_16_half_half(a_frag_0_k_0, b_frag_0_k_0, acc_frag_0_0);
|
||||
acc_frag_0_1 = __WMMA_8_16_16_half_half(a_frag_0_k_0, b_frag_1_k_0, acc_frag_0_1);
|
||||
acc_frag_0_2 = __WMMA_8_16_16_half_half(a_frag_0_k_0, b_frag_2_k_0, acc_frag_0_2);
|
||||
acc_frag_0_3 = __WMMA_8_16_16_half_half(a_frag_0_k_0, b_frag_3_k_0, acc_frag_0_3);
|
||||
acc_frag_0_4 = __WMMA_8_16_16_half_half(a_frag_0_k_0, b_frag_4_k_0, acc_frag_0_4);
|
||||
acc_frag_0_5 = __WMMA_8_16_16_half_half(a_frag_0_k_0, b_frag_5_k_0, acc_frag_0_5);
|
||||
acc_frag_0_6 = __WMMA_8_16_16_half_half(a_frag_0_k_0, b_frag_6_k_0, acc_frag_0_6);
|
||||
acc_frag_0_7 = __WMMA_8_16_16_half_half(a_frag_0_k_0, b_frag_7_k_0, acc_frag_0_7);
|
||||
acc_frag_1_0 = __WMMA_8_16_16_half_half(a_frag_1_k_0, b_frag_0_k_0, acc_frag_1_0);
|
||||
acc_frag_1_1 = __WMMA_8_16_16_half_half(a_frag_1_k_0, b_frag_1_k_0, acc_frag_1_1);
|
||||
acc_frag_1_2 = __WMMA_8_16_16_half_half(a_frag_1_k_0, b_frag_2_k_0, acc_frag_1_2);
|
||||
acc_frag_1_3 = __WMMA_8_16_16_half_half(a_frag_1_k_0, b_frag_3_k_0, acc_frag_1_3);
|
||||
acc_frag_1_4 = __WMMA_8_16_16_half_half(a_frag_1_k_0, b_frag_4_k_0, acc_frag_1_4);
|
||||
acc_frag_1_5 = __WMMA_8_16_16_half_half(a_frag_1_k_0, b_frag_5_k_0, acc_frag_1_5);
|
||||
acc_frag_1_6 = __WMMA_8_16_16_half_half(a_frag_1_k_0, b_frag_6_k_0, acc_frag_1_6);
|
||||
acc_frag_1_7 = __WMMA_8_16_16_half_half(a_frag_1_k_0, b_frag_7_k_0, acc_frag_1_7);
|
||||
acc_frag_2_0 = __WMMA_8_16_16_half_half(a_frag_2_k_0, b_frag_0_k_0, acc_frag_2_0);
|
||||
acc_frag_2_1 = __WMMA_8_16_16_half_half(a_frag_2_k_0, b_frag_1_k_0, acc_frag_2_1);
|
||||
acc_frag_2_2 = __WMMA_8_16_16_half_half(a_frag_2_k_0, b_frag_2_k_0, acc_frag_2_2);
|
||||
acc_frag_2_3 = __WMMA_8_16_16_half_half(a_frag_2_k_0, b_frag_3_k_0, acc_frag_2_3);
|
||||
acc_frag_2_4 = __WMMA_8_16_16_half_half(a_frag_2_k_0, b_frag_4_k_0, acc_frag_2_4);
|
||||
acc_frag_2_5 = __WMMA_8_16_16_half_half(a_frag_2_k_0, b_frag_5_k_0, acc_frag_2_5);
|
||||
acc_frag_2_6 = __WMMA_8_16_16_half_half(a_frag_2_k_0, b_frag_6_k_0, acc_frag_2_6);
|
||||
acc_frag_2_7 = __WMMA_8_16_16_half_half(a_frag_2_k_0, b_frag_7_k_0, acc_frag_2_7);
|
||||
acc_frag_3_0 = __WMMA_8_16_16_half_half(a_frag_3_k_0, b_frag_0_k_0, acc_frag_3_0);
|
||||
acc_frag_3_1 = __WMMA_8_16_16_half_half(a_frag_3_k_0, b_frag_1_k_0, acc_frag_3_1);
|
||||
acc_frag_3_2 = __WMMA_8_16_16_half_half(a_frag_3_k_0, b_frag_2_k_0, acc_frag_3_2);
|
||||
acc_frag_3_3 = __WMMA_8_16_16_half_half(a_frag_3_k_0, b_frag_3_k_0, acc_frag_3_3);
|
||||
acc_frag_3_4 = __WMMA_8_16_16_half_half(a_frag_3_k_0, b_frag_4_k_0, acc_frag_3_4);
|
||||
acc_frag_3_5 = __WMMA_8_16_16_half_half(a_frag_3_k_0, b_frag_5_k_0, acc_frag_3_5);
|
||||
acc_frag_3_6 = __WMMA_8_16_16_half_half(a_frag_3_k_0, b_frag_6_k_0, acc_frag_3_6);
|
||||
acc_frag_3_7 = __WMMA_8_16_16_half_half(a_frag_3_k_0, b_frag_7_k_0, acc_frag_3_7);
|
||||
|
||||
// load next tile if needed
|
||||
if (block_k < (num_k_blocks-2)) {
|
||||
__pipeline_memcpy_async(&smem_a_store[store_smem_a_off + ( 0)], &data1[global_a_off + ( 0)], 16);
|
||||
__pipeline_memcpy_async(&smem_a_store[store_smem_a_off + ( 32*64)], &data1[global_a_off + ( 32*K)], 16);
|
||||
__pipeline_memcpy_async(&smem_a_store[store_smem_a_off + ( 64*64)], &data1[global_a_off + ( 64*K)], 16);
|
||||
__pipeline_memcpy_async(&smem_a_store[store_smem_a_off + ( 96*64)], &data1[global_a_off + ( 96*K)], 16);
|
||||
__pipeline_memcpy_async(&smem_b_store[store_smem_b_off + ( 0)], &data2[global_b_off + ( 0)], 16);
|
||||
__pipeline_memcpy_async(&smem_b_store[store_smem_b_off + (16*128)], &data2[global_b_off + ( 16*N)], 16);
|
||||
global_a_off += 32;
|
||||
global_b_off += 32 * N;
|
||||
}
|
||||
__pipeline_commit();
|
||||
|
||||
// wait next tile
|
||||
__pipeline_wait_prior(1);
|
||||
__syncthreads();
|
||||
|
||||
// load K=0 elements for the next tile
|
||||
__ldmatrix_a_elems(&a_frag_0_k_0, &smem_a_next[load_smem_a_0_k_0]);
|
||||
__ldmatrix_a_elems(&a_frag_1_k_0, &smem_a_next[load_smem_a_1_k_0]);
|
||||
__ldmatrix_a_elems(&a_frag_2_k_0, &smem_a_next[load_smem_a_2_k_0]);
|
||||
__ldmatrix_a_elems(&a_frag_3_k_0, &smem_a_next[load_smem_a_3_k_0]);
|
||||
__ldmatrix_b_elems(&b_frag_0_k_0, &b_frag_1_k_0, &smem_b_next[load_smem_b_0_k_0]);
|
||||
__ldmatrix_b_elems(&b_frag_2_k_0, &b_frag_3_k_0, &smem_b_next[load_smem_b_1_k_0]);
|
||||
__ldmatrix_b_elems(&b_frag_4_k_0, &b_frag_5_k_0, &smem_b_next[load_smem_b_2_k_0]);
|
||||
__ldmatrix_b_elems(&b_frag_6_k_0, &b_frag_7_k_0, &smem_b_next[load_smem_b_3_k_0]);
|
||||
|
||||
// MMA K=1, (M=4 x N=8)
|
||||
acc_frag_0_0 = __WMMA_8_16_16_half_half(a_frag_0_k_1, b_frag_0_k_1, acc_frag_0_0);
|
||||
acc_frag_0_1 = __WMMA_8_16_16_half_half(a_frag_0_k_1, b_frag_1_k_1, acc_frag_0_1);
|
||||
acc_frag_0_2 = __WMMA_8_16_16_half_half(a_frag_0_k_1, b_frag_2_k_1, acc_frag_0_2);
|
||||
acc_frag_0_3 = __WMMA_8_16_16_half_half(a_frag_0_k_1, b_frag_3_k_1, acc_frag_0_3);
|
||||
acc_frag_0_4 = __WMMA_8_16_16_half_half(a_frag_0_k_1, b_frag_4_k_1, acc_frag_0_4);
|
||||
acc_frag_0_5 = __WMMA_8_16_16_half_half(a_frag_0_k_1, b_frag_5_k_1, acc_frag_0_5);
|
||||
acc_frag_0_6 = __WMMA_8_16_16_half_half(a_frag_0_k_1, b_frag_6_k_1, acc_frag_0_6);
|
||||
acc_frag_0_7 = __WMMA_8_16_16_half_half(a_frag_0_k_1, b_frag_7_k_1, acc_frag_0_7);
|
||||
acc_frag_1_0 = __WMMA_8_16_16_half_half(a_frag_1_k_1, b_frag_0_k_1, acc_frag_1_0);
|
||||
acc_frag_1_1 = __WMMA_8_16_16_half_half(a_frag_1_k_1, b_frag_1_k_1, acc_frag_1_1);
|
||||
acc_frag_1_2 = __WMMA_8_16_16_half_half(a_frag_1_k_1, b_frag_2_k_1, acc_frag_1_2);
|
||||
acc_frag_1_3 = __WMMA_8_16_16_half_half(a_frag_1_k_1, b_frag_3_k_1, acc_frag_1_3);
|
||||
acc_frag_1_4 = __WMMA_8_16_16_half_half(a_frag_1_k_1, b_frag_4_k_1, acc_frag_1_4);
|
||||
acc_frag_1_5 = __WMMA_8_16_16_half_half(a_frag_1_k_1, b_frag_5_k_1, acc_frag_1_5);
|
||||
acc_frag_1_6 = __WMMA_8_16_16_half_half(a_frag_1_k_1, b_frag_6_k_1, acc_frag_1_6);
|
||||
acc_frag_1_7 = __WMMA_8_16_16_half_half(a_frag_1_k_1, b_frag_7_k_1, acc_frag_1_7);
|
||||
acc_frag_2_0 = __WMMA_8_16_16_half_half(a_frag_2_k_1, b_frag_0_k_1, acc_frag_2_0);
|
||||
acc_frag_2_1 = __WMMA_8_16_16_half_half(a_frag_2_k_1, b_frag_1_k_1, acc_frag_2_1);
|
||||
acc_frag_2_2 = __WMMA_8_16_16_half_half(a_frag_2_k_1, b_frag_2_k_1, acc_frag_2_2);
|
||||
acc_frag_2_3 = __WMMA_8_16_16_half_half(a_frag_2_k_1, b_frag_3_k_1, acc_frag_2_3);
|
||||
acc_frag_2_4 = __WMMA_8_16_16_half_half(a_frag_2_k_1, b_frag_4_k_1, acc_frag_2_4);
|
||||
acc_frag_2_5 = __WMMA_8_16_16_half_half(a_frag_2_k_1, b_frag_5_k_1, acc_frag_2_5);
|
||||
acc_frag_2_6 = __WMMA_8_16_16_half_half(a_frag_2_k_1, b_frag_6_k_1, acc_frag_2_6);
|
||||
acc_frag_2_7 = __WMMA_8_16_16_half_half(a_frag_2_k_1, b_frag_7_k_1, acc_frag_2_7);
|
||||
acc_frag_3_0 = __WMMA_8_16_16_half_half(a_frag_3_k_1, b_frag_0_k_1, acc_frag_3_0);
|
||||
acc_frag_3_1 = __WMMA_8_16_16_half_half(a_frag_3_k_1, b_frag_1_k_1, acc_frag_3_1);
|
||||
acc_frag_3_2 = __WMMA_8_16_16_half_half(a_frag_3_k_1, b_frag_2_k_1, acc_frag_3_2);
|
||||
acc_frag_3_3 = __WMMA_8_16_16_half_half(a_frag_3_k_1, b_frag_3_k_1, acc_frag_3_3);
|
||||
acc_frag_3_4 = __WMMA_8_16_16_half_half(a_frag_3_k_1, b_frag_4_k_1, acc_frag_3_4);
|
||||
acc_frag_3_5 = __WMMA_8_16_16_half_half(a_frag_3_k_1, b_frag_5_k_1, acc_frag_3_5);
|
||||
acc_frag_3_6 = __WMMA_8_16_16_half_half(a_frag_3_k_1, b_frag_6_k_1, acc_frag_3_6);
|
||||
acc_frag_3_7 = __WMMA_8_16_16_half_half(a_frag_3_k_1, b_frag_7_k_1, acc_frag_3_7);
|
||||
}
|
||||
|
||||
// write accumulators to output
|
||||
__pipeline_wait_prior(0);
|
||||
__syncthreads();
|
||||
|
||||
// faster epilogue: write each 8x8 TC accs to SMEM first
|
||||
// - SMEM_N_WIDTH 8 larger than 128 required to deconflict bank access
|
||||
// - around 14 micros
|
||||
// - check bank conflict with in sudo with: "PYTHONPATH=. CUDA=1 GEMM_VARIATION="max" DTYPE_IN=half DTYPE_OUT=half DTYPE_ACC=half CNT=8 INPUT=ONES /usr/local/cuda/bin/ncu --section MemoryWorkloadAnalysis --metrics l1tex__data_bank_conflicts_pipe_lsu_mem_shared_op_ld.sum,l1tex__data_bank_conflicts_pipe_lsu_mem_shared_op_st.sum python3 ./extra/gemm/max_matmul.py"
|
||||
|
||||
// epilogue chunk with 256 threads / WG_M=4 / WG_N=2: split into 8 chunks (hi/lo for each in TC M)
|
||||
// 1) write 32 rows of 128 cols (rows 0-7, 16-23, 32-39, 48-53 in acc_frag_0.lo, then acc_frag_0.hi, etc.)
|
||||
// 2) read/write 16 rows of 128 elements in 8 elem (16B) chunks
|
||||
|
||||
half2 *smem32_d = (half2 *)(smem);
|
||||
half8 *smem128_d = (half8 *)(smem);
|
||||
half8 *out128_d = (half8 *)(data0);
|
||||
size_t smem32_d_write_off = (wg_m * 8 * (SMEM_N_WIDTH / 2)) + (wg_n * (16 / 2));
|
||||
size_t smem32_d_thread_off = ((wg_threads / 4) * (SMEM_N_WIDTH / 2)) + (wg_threads % 4);
|
||||
size_t smem128_d_read_off = ((threads / 16) * (SMEM_N_WIDTH / 8)) + (threads % 16);
|
||||
size_t out128_d_off = ((grid_m * 256) * (N / 8)) + (grid_n * (128 / 8)) +
|
||||
((threads / 128) * 16 * (N / 8)) + (((threads / 16) % 8) * (N / 8)) + (threads % 16);
|
||||
|
||||
// write acc_frag_0_*
|
||||
|
||||
// write 32 rows of 128 N elements to SMEM
|
||||
__syncthreads();
|
||||
smem32_d[smem32_d_write_off + smem32_d_thread_off + ( 0*4)] = half2(acc_frag_0_0.x, acc_frag_0_0.y);
|
||||
smem32_d[smem32_d_write_off + smem32_d_thread_off + ( 1*4)] = half2(acc_frag_0_1.x, acc_frag_0_1.y);
|
||||
smem32_d[smem32_d_write_off + smem32_d_thread_off + ( 4*4)] = half2(acc_frag_0_2.x, acc_frag_0_2.y);
|
||||
smem32_d[smem32_d_write_off + smem32_d_thread_off + ( 5*4)] = half2(acc_frag_0_3.x, acc_frag_0_3.y);
|
||||
smem32_d[smem32_d_write_off + smem32_d_thread_off + ( 8*4)] = half2(acc_frag_0_4.x, acc_frag_0_4.y);
|
||||
smem32_d[smem32_d_write_off + smem32_d_thread_off + ( 9*4)] = half2(acc_frag_0_5.x, acc_frag_0_5.y);
|
||||
smem32_d[smem32_d_write_off + smem32_d_thread_off + (12*4)] = half2(acc_frag_0_6.x, acc_frag_0_6.y);
|
||||
smem32_d[smem32_d_write_off + smem32_d_thread_off + (13*4)] = half2(acc_frag_0_7.x, acc_frag_0_7.y);
|
||||
|
||||
// each thread reads and writes two 8 element chunks
|
||||
__syncthreads();
|
||||
out128_d[out128_d_off + ( 0 * (N / 8))] = smem128_d[smem128_d_read_off];
|
||||
out128_d[out128_d_off + (32 * (N / 8))] = smem128_d[smem128_d_read_off + (16 * (SMEM_N_WIDTH / 8))];
|
||||
|
||||
// write 32 rows of 128 N elements to SMEM
|
||||
__syncthreads();
|
||||
smem32_d[smem32_d_write_off + smem32_d_thread_off + ( 0*4)] = half2(acc_frag_0_0.z, acc_frag_0_0.w);
|
||||
smem32_d[smem32_d_write_off + smem32_d_thread_off + ( 1*4)] = half2(acc_frag_0_1.z, acc_frag_0_1.w);
|
||||
smem32_d[smem32_d_write_off + smem32_d_thread_off + ( 4*4)] = half2(acc_frag_0_2.z, acc_frag_0_2.w);
|
||||
smem32_d[smem32_d_write_off + smem32_d_thread_off + ( 5*4)] = half2(acc_frag_0_3.z, acc_frag_0_3.w);
|
||||
smem32_d[smem32_d_write_off + smem32_d_thread_off + ( 8*4)] = half2(acc_frag_0_4.z, acc_frag_0_4.w);
|
||||
smem32_d[smem32_d_write_off + smem32_d_thread_off + ( 9*4)] = half2(acc_frag_0_5.z, acc_frag_0_5.w);
|
||||
smem32_d[smem32_d_write_off + smem32_d_thread_off + (12*4)] = half2(acc_frag_0_6.z, acc_frag_0_6.w);
|
||||
smem32_d[smem32_d_write_off + smem32_d_thread_off + (13*4)] = half2(acc_frag_0_7.z, acc_frag_0_7.w);
|
||||
|
||||
// each thread reads and writes two 8 element chunks
|
||||
__syncthreads();
|
||||
out128_d[out128_d_off + ( 8 * (N / 8))] = smem128_d[smem128_d_read_off];
|
||||
out128_d[out128_d_off + (40 * (N / 8))] = smem128_d[smem128_d_read_off + (16 * (SMEM_N_WIDTH / 8))];
|
||||
|
||||
// write acc_frag_1_*
|
||||
out128_d_off += (64 * (N / 8));
|
||||
|
||||
// write 32 rows of 128 N elements to SMEM
|
||||
__syncthreads();
|
||||
smem32_d[smem32_d_write_off + smem32_d_thread_off + ( 0*4)] = half2(acc_frag_1_0.x, acc_frag_1_0.y);
|
||||
smem32_d[smem32_d_write_off + smem32_d_thread_off + ( 1*4)] = half2(acc_frag_1_1.x, acc_frag_1_1.y);
|
||||
smem32_d[smem32_d_write_off + smem32_d_thread_off + ( 4*4)] = half2(acc_frag_1_2.x, acc_frag_1_2.y);
|
||||
smem32_d[smem32_d_write_off + smem32_d_thread_off + ( 5*4)] = half2(acc_frag_1_3.x, acc_frag_1_3.y);
|
||||
smem32_d[smem32_d_write_off + smem32_d_thread_off + ( 8*4)] = half2(acc_frag_1_4.x, acc_frag_1_4.y);
|
||||
smem32_d[smem32_d_write_off + smem32_d_thread_off + ( 9*4)] = half2(acc_frag_1_5.x, acc_frag_1_5.y);
|
||||
smem32_d[smem32_d_write_off + smem32_d_thread_off + (12*4)] = half2(acc_frag_1_6.x, acc_frag_1_6.y);
|
||||
smem32_d[smem32_d_write_off + smem32_d_thread_off + (13*4)] = half2(acc_frag_1_7.x, acc_frag_1_7.y);
|
||||
|
||||
// each thread reads and writes two 8 element chunks
|
||||
__syncthreads();
|
||||
out128_d[out128_d_off + ( 0 * (N / 8))] = smem128_d[smem128_d_read_off];
|
||||
out128_d[out128_d_off + (32 * (N / 8))] = smem128_d[smem128_d_read_off + (16 * (SMEM_N_WIDTH / 8))];
|
||||
|
||||
// write 32 rows of 128 N elements to SMEM
|
||||
__syncthreads();
|
||||
smem32_d[smem32_d_write_off + smem32_d_thread_off + ( 0*4)] = half2(acc_frag_1_0.z, acc_frag_1_0.w);
|
||||
smem32_d[smem32_d_write_off + smem32_d_thread_off + ( 1*4)] = half2(acc_frag_1_1.z, acc_frag_1_1.w);
|
||||
smem32_d[smem32_d_write_off + smem32_d_thread_off + ( 4*4)] = half2(acc_frag_1_2.z, acc_frag_1_2.w);
|
||||
smem32_d[smem32_d_write_off + smem32_d_thread_off + ( 5*4)] = half2(acc_frag_1_3.z, acc_frag_1_3.w);
|
||||
smem32_d[smem32_d_write_off + smem32_d_thread_off + ( 8*4)] = half2(acc_frag_1_4.z, acc_frag_1_4.w);
|
||||
smem32_d[smem32_d_write_off + smem32_d_thread_off + ( 9*4)] = half2(acc_frag_1_5.z, acc_frag_1_5.w);
|
||||
smem32_d[smem32_d_write_off + smem32_d_thread_off + (12*4)] = half2(acc_frag_1_6.z, acc_frag_1_6.w);
|
||||
smem32_d[smem32_d_write_off + smem32_d_thread_off + (13*4)] = half2(acc_frag_1_7.z, acc_frag_1_7.w);
|
||||
|
||||
// each thread reads and writes two 8 element chunks
|
||||
__syncthreads();
|
||||
out128_d[out128_d_off + ( 8 * (N / 8))] = smem128_d[smem128_d_read_off];
|
||||
out128_d[out128_d_off + (40 * (N / 8))] = smem128_d[smem128_d_read_off + (16 * (SMEM_N_WIDTH / 8))];
|
||||
|
||||
// write acc_frag_2_*
|
||||
out128_d_off += (64 * (N / 8));
|
||||
|
||||
// write 32 rows of 128 N elements to SMEM
|
||||
__syncthreads();
|
||||
smem32_d[smem32_d_write_off + smem32_d_thread_off + ( 0*4)] = half2(acc_frag_2_0.x, acc_frag_2_0.y);
|
||||
smem32_d[smem32_d_write_off + smem32_d_thread_off + ( 1*4)] = half2(acc_frag_2_1.x, acc_frag_2_1.y);
|
||||
smem32_d[smem32_d_write_off + smem32_d_thread_off + ( 4*4)] = half2(acc_frag_2_2.x, acc_frag_2_2.y);
|
||||
smem32_d[smem32_d_write_off + smem32_d_thread_off + ( 5*4)] = half2(acc_frag_2_3.x, acc_frag_2_3.y);
|
||||
smem32_d[smem32_d_write_off + smem32_d_thread_off + ( 8*4)] = half2(acc_frag_2_4.x, acc_frag_2_4.y);
|
||||
smem32_d[smem32_d_write_off + smem32_d_thread_off + ( 9*4)] = half2(acc_frag_2_5.x, acc_frag_2_5.y);
|
||||
smem32_d[smem32_d_write_off + smem32_d_thread_off + (12*4)] = half2(acc_frag_2_6.x, acc_frag_2_6.y);
|
||||
smem32_d[smem32_d_write_off + smem32_d_thread_off + (13*4)] = half2(acc_frag_2_7.x, acc_frag_2_7.y);
|
||||
|
||||
// each thread reads and writes two 8 element chunks
|
||||
__syncthreads();
|
||||
out128_d[out128_d_off + ( 0 * (N / 8))] = smem128_d[smem128_d_read_off];
|
||||
out128_d[out128_d_off + (32 * (N / 8))] = smem128_d[smem128_d_read_off + (16 * (SMEM_N_WIDTH / 8))];
|
||||
|
||||
// write 32 rows of 128 N elements to SMEM
|
||||
__syncthreads();
|
||||
smem32_d[smem32_d_write_off + smem32_d_thread_off + ( 0*4)] = half2(acc_frag_2_0.z, acc_frag_2_0.w);
|
||||
smem32_d[smem32_d_write_off + smem32_d_thread_off + ( 1*4)] = half2(acc_frag_2_1.z, acc_frag_2_1.w);
|
||||
smem32_d[smem32_d_write_off + smem32_d_thread_off + ( 4*4)] = half2(acc_frag_2_2.z, acc_frag_2_2.w);
|
||||
smem32_d[smem32_d_write_off + smem32_d_thread_off + ( 5*4)] = half2(acc_frag_2_3.z, acc_frag_2_3.w);
|
||||
smem32_d[smem32_d_write_off + smem32_d_thread_off + ( 8*4)] = half2(acc_frag_2_4.z, acc_frag_2_4.w);
|
||||
smem32_d[smem32_d_write_off + smem32_d_thread_off + ( 9*4)] = half2(acc_frag_2_5.z, acc_frag_2_5.w);
|
||||
smem32_d[smem32_d_write_off + smem32_d_thread_off + (12*4)] = half2(acc_frag_2_6.z, acc_frag_2_6.w);
|
||||
smem32_d[smem32_d_write_off + smem32_d_thread_off + (13*4)] = half2(acc_frag_2_7.z, acc_frag_2_7.w);
|
||||
|
||||
// each thread reads and writes two 8 element chunks
|
||||
__syncthreads();
|
||||
out128_d[out128_d_off + ( 8 * (N / 8))] = smem128_d[smem128_d_read_off];
|
||||
out128_d[out128_d_off + (40 * (N / 8))] = smem128_d[smem128_d_read_off + (16 * (SMEM_N_WIDTH / 8))];
|
||||
|
||||
// write acc_frag_3_*
|
||||
out128_d_off += (64 * (N / 8));
|
||||
|
||||
// write 32 rows of 128 N elements to SMEM
|
||||
__syncthreads();
|
||||
smem32_d[smem32_d_write_off + smem32_d_thread_off + ( 0*4)] = half2(acc_frag_3_0.x, acc_frag_3_0.y);
|
||||
smem32_d[smem32_d_write_off + smem32_d_thread_off + ( 1*4)] = half2(acc_frag_3_1.x, acc_frag_3_1.y);
|
||||
smem32_d[smem32_d_write_off + smem32_d_thread_off + ( 4*4)] = half2(acc_frag_3_2.x, acc_frag_3_2.y);
|
||||
smem32_d[smem32_d_write_off + smem32_d_thread_off + ( 5*4)] = half2(acc_frag_3_3.x, acc_frag_3_3.y);
|
||||
smem32_d[smem32_d_write_off + smem32_d_thread_off + ( 8*4)] = half2(acc_frag_3_4.x, acc_frag_3_4.y);
|
||||
smem32_d[smem32_d_write_off + smem32_d_thread_off + ( 9*4)] = half2(acc_frag_3_5.x, acc_frag_3_5.y);
|
||||
smem32_d[smem32_d_write_off + smem32_d_thread_off + (12*4)] = half2(acc_frag_3_6.x, acc_frag_3_6.y);
|
||||
smem32_d[smem32_d_write_off + smem32_d_thread_off + (13*4)] = half2(acc_frag_3_7.x, acc_frag_3_7.y);
|
||||
|
||||
// each thread reads and writes two 8 element chunks
|
||||
__syncthreads();
|
||||
out128_d[out128_d_off + ( 0 * (N / 8))] = smem128_d[smem128_d_read_off];
|
||||
out128_d[out128_d_off + (32 * (N / 8))] = smem128_d[smem128_d_read_off + (16 * (SMEM_N_WIDTH / 8))];
|
||||
|
||||
// write 32 rows of 128 N elements to SMEM
|
||||
__syncthreads();
|
||||
smem32_d[smem32_d_write_off + smem32_d_thread_off + ( 0*4)] = half2(acc_frag_3_0.z, acc_frag_3_0.w);
|
||||
smem32_d[smem32_d_write_off + smem32_d_thread_off + ( 1*4)] = half2(acc_frag_3_1.z, acc_frag_3_1.w);
|
||||
smem32_d[smem32_d_write_off + smem32_d_thread_off + ( 4*4)] = half2(acc_frag_3_2.z, acc_frag_3_2.w);
|
||||
smem32_d[smem32_d_write_off + smem32_d_thread_off + ( 5*4)] = half2(acc_frag_3_3.z, acc_frag_3_3.w);
|
||||
smem32_d[smem32_d_write_off + smem32_d_thread_off + ( 8*4)] = half2(acc_frag_3_4.z, acc_frag_3_4.w);
|
||||
smem32_d[smem32_d_write_off + smem32_d_thread_off + ( 9*4)] = half2(acc_frag_3_5.z, acc_frag_3_5.w);
|
||||
smem32_d[smem32_d_write_off + smem32_d_thread_off + (12*4)] = half2(acc_frag_3_6.z, acc_frag_3_6.w);
|
||||
smem32_d[smem32_d_write_off + smem32_d_thread_off + (13*4)] = half2(acc_frag_3_7.z, acc_frag_3_7.w);
|
||||
|
||||
// each thread reads and writes two 8 element chunks
|
||||
__syncthreads();
|
||||
out128_d[out128_d_off + ( 8 * (N / 8))] = smem128_d[smem128_d_read_off];
|
||||
out128_d[out128_d_off + (40 * (N / 8))] = smem128_d[smem128_d_read_off + (16 * (SMEM_N_WIDTH / 8))];
|
||||
|
||||
__syncthreads();
|
||||
}
|
||||
486
extra/gemm/max_kernels/nv.fp16_fp16_fp16.no_xor.cu
Normal file
486
extra/gemm/max_kernels/nv.fp16_fp16_fp16.no_xor.cu
Normal file
@@ -0,0 +1,486 @@
|
||||
#define INFINITY (__int_as_float(0x7f800000))
|
||||
#define NAN (__int_as_float(0x7fffffff))
|
||||
#include <cuda_fp16.h>
|
||||
#include <cuda_pipeline.h>
|
||||
#define SMEM_N_WIDTH 136
|
||||
|
||||
struct __align__(8) half4 { half x, y, z, w; };
|
||||
__device__ half4 make_half4(half x, half y, half z, half w) { half4 r={x, y, z, w}; return r; }
|
||||
|
||||
struct __align__(16) half8 { half x, y, z, w, a, b, c, d; };
|
||||
__device__ half8 make_half8(half x, half y, half z, half w, half a, half b, half c, half d) { half8 r={x, y, z, w, a, b, c, d}; return r; }
|
||||
|
||||
__device__ void __ldmatrix_a_elems(half8 *regs, half *smem) {
|
||||
uint32_t reg0, reg1, reg2, reg3;
|
||||
asm volatile(
|
||||
"ldmatrix.sync.aligned.m8n8.x4.shared.b16 {%0, %1, %2, %3}, [%4];"
|
||||
: "=r"(reg0), "=r"(reg1), "=r"(reg2), "=r"(reg3)
|
||||
: "l"(__cvta_generic_to_shared(smem))
|
||||
);
|
||||
uint32_t *addr = reinterpret_cast<uint32_t*>(regs);
|
||||
addr[0] = reg0;
|
||||
addr[1] = reg1;
|
||||
addr[2] = reg2;
|
||||
addr[3] = reg3;
|
||||
}
|
||||
|
||||
__device__ void __ldmatrix_b_elems(half4 *regs_lo, half4 *regs_hi, half *smem) {
|
||||
uint32_t reg0, reg1, reg2, reg3;
|
||||
asm volatile(
|
||||
"ldmatrix.sync.aligned.m8n8.x4.trans.shared.b16 {%0, %1, %2, %3}, [%4];"
|
||||
: "=r"(reg0), "=r"(reg1), "=r"(reg2), "=r"(reg3)
|
||||
: "l"(__cvta_generic_to_shared(smem))
|
||||
);
|
||||
uint32_t *addr_lo = reinterpret_cast<uint32_t*>(regs_lo);
|
||||
uint32_t *addr_hi = reinterpret_cast<uint32_t*>(regs_hi);
|
||||
addr_lo[0] = reg0;
|
||||
addr_lo[1] = reg1;
|
||||
addr_hi[0] = reg2;
|
||||
addr_hi[1] = reg3;
|
||||
}
|
||||
|
||||
__device__ half4 __WMMA_8_16_16_half_half(half8 a, half4 b, half4 c) {
|
||||
int *a_pk = (int *) (&a), *b_pk = (int *) (&b), *c_pk = (int *) (&c);
|
||||
asm( "mma.sync.aligned.m16n8k16.row.col.f16.f16.f16.f16 { %0, %1 }, { %2, %3, %4, %5 }, { %6, %7 }, { %0, %1 };"
|
||||
: "+r"(c_pk[0]), "+r"(c_pk[1]): "r"(a_pk[0]), "r"(a_pk[1]), "r"(a_pk[2]), "r"(a_pk[3]), "r"(b_pk[0]), "r"(b_pk[1]) );
|
||||
return c;
|
||||
}
|
||||
|
||||
extern "C" __global__ void __launch_bounds__(256) wmma_example(half* data0, const half* data1, const half* data2, int N, int K) {
|
||||
extern __shared__ char smem[];
|
||||
half *smem_a_0 = (half *)(smem);
|
||||
half *smem_a_1 = (half *)(smem + 16384);
|
||||
half *smem_a_2 = (half *)(smem + 32768);
|
||||
half *smem_b_0 = (half *)(smem + 49152);
|
||||
half *smem_b_1 = (half *)(smem + 57344);
|
||||
half *smem_b_2 = (half *)(smem + 65536);
|
||||
|
||||
int grid_m = blockIdx.x; /* M//256 */
|
||||
int grid_n = blockIdx.y; /* N//128 */
|
||||
int wg_threads = threadIdx.x; // 32
|
||||
int wg_m = threadIdx.y; // 4
|
||||
int wg_n = threadIdx.z; // 2
|
||||
int threads = threadIdx.x + (threadIdx.y * 32) + (threadIdx.z * 128); /* 256 */
|
||||
int num_k_blocks = K / 32;
|
||||
|
||||
// ldmatrix indices - 4x loads of 8x8 matrices by 32 threads
|
||||
// threads 0-7 are row starts for A, 8-15 for B, 16-23 for C, 24-31 for D
|
||||
// [ A | C ]
|
||||
// [ - + - ]
|
||||
// [ B | D ]
|
||||
|
||||
// swizzled A - SMEM_A is 128 rows x 64 cols
|
||||
size_t global_a_off = ((grid_m * 256) * K) + ((threads % 4) * 8) + (((threads / 4) % 2) * 8 * 16 * K) + ((threads / 8) * K);
|
||||
size_t store_smem_a_off = ((threads / 8) * 64) + (((threads * 8) ^ threads) & 56); // 32 rows / 64 cols per copy
|
||||
size_t load_smem_a_row = ((wg_m * 16) + (threads % 16)) * 64;
|
||||
size_t load_smem_a_phase = (threads / 16) % 2;
|
||||
size_t load_smem_a_0_k_0 = load_smem_a_row + ( 0 * 64) + (((load_smem_a_phase + 0) ^ (threads % 8)) * 8);
|
||||
size_t load_smem_a_1_k_0 = load_smem_a_row + (64 * 64) + (((load_smem_a_phase + 0) ^ (threads % 8)) * 8);
|
||||
size_t load_smem_a_2_k_0 = load_smem_a_row + ( 0 * 64) + (((load_smem_a_phase + 4) ^ (threads % 8)) * 8);
|
||||
size_t load_smem_a_3_k_0 = load_smem_a_row + (64 * 64) + (((load_smem_a_phase + 4) ^ (threads % 8)) * 8);
|
||||
size_t load_smem_a_0_k_1 = load_smem_a_row + ( 0 * 64) + (((load_smem_a_phase + 2) ^ (threads % 8)) * 8);
|
||||
size_t load_smem_a_1_k_1 = load_smem_a_row + (64 * 64) + (((load_smem_a_phase + 2) ^ (threads % 8)) * 8);
|
||||
size_t load_smem_a_2_k_1 = load_smem_a_row + ( 0 * 64) + (((load_smem_a_phase + 6) ^ (threads % 8)) * 8);
|
||||
size_t load_smem_a_3_k_1 = load_smem_a_row + (64 * 64) + (((load_smem_a_phase + 6) ^ (threads % 8)) * 8);
|
||||
|
||||
// swizzled B - SMEM_B is 64 rows x 64 cols
|
||||
size_t global_b_off = (grid_n * 128) + ((threads % 16) * 8) + ((threads / 16) * N);
|
||||
size_t store_smem_b_off = // 32 rows of 64 cols per copy
|
||||
((threads / 128) * (64)) + // [A,C] vs [B,D] in ldmatrix
|
||||
((threads % 2) * (2 * 64)) + // [A vs C] or [B vs. D]
|
||||
(((threads / 2) % 2) * (4 * 64)) + // WG_N in [0, 1]
|
||||
(((threads / 4) % 4) * (8 * 64)) + // B in [0, 1, 2, 3]
|
||||
(((threads / 16) % 8) * (8)); // cols in SMEM_B i.e. rows of 8x8
|
||||
|
||||
size_t load_smem_b_0_k_0 = (wg_n * 4 * 64) + ((wg_threads / 8) * 64) + ((wg_threads % 8) * 8);
|
||||
size_t load_smem_b_1_k_0 = load_smem_b_0_k_0 + ( 8 * 64);
|
||||
size_t load_smem_b_2_k_0 = load_smem_b_0_k_0 + (16 * 64);
|
||||
size_t load_smem_b_3_k_0 = load_smem_b_0_k_0 + (24 * 64);
|
||||
size_t load_smem_b_0_k_1 = load_smem_b_0_k_0 + (32 * 64);
|
||||
size_t load_smem_b_1_k_1 = load_smem_b_1_k_0 + (32 * 64);
|
||||
size_t load_smem_b_2_k_1 = load_smem_b_2_k_0 + (32 * 64);
|
||||
size_t load_smem_b_3_k_1 = load_smem_b_3_k_0 + (32 * 64);
|
||||
|
||||
// create accs (M=4, N=8)
|
||||
half4 acc_frag_0_0 = make_half4(0.0f,0.0f,0.0f,0.0f);
|
||||
half4 acc_frag_0_1 = make_half4(0.0f,0.0f,0.0f,0.0f);
|
||||
half4 acc_frag_0_2 = make_half4(0.0f,0.0f,0.0f,0.0f);
|
||||
half4 acc_frag_0_3 = make_half4(0.0f,0.0f,0.0f,0.0f);
|
||||
half4 acc_frag_0_4 = make_half4(0.0f,0.0f,0.0f,0.0f);
|
||||
half4 acc_frag_0_5 = make_half4(0.0f,0.0f,0.0f,0.0f);
|
||||
half4 acc_frag_0_6 = make_half4(0.0f,0.0f,0.0f,0.0f);
|
||||
half4 acc_frag_0_7 = make_half4(0.0f,0.0f,0.0f,0.0f);
|
||||
half4 acc_frag_1_0 = make_half4(0.0f,0.0f,0.0f,0.0f);
|
||||
half4 acc_frag_1_1 = make_half4(0.0f,0.0f,0.0f,0.0f);
|
||||
half4 acc_frag_1_2 = make_half4(0.0f,0.0f,0.0f,0.0f);
|
||||
half4 acc_frag_1_3 = make_half4(0.0f,0.0f,0.0f,0.0f);
|
||||
half4 acc_frag_1_4 = make_half4(0.0f,0.0f,0.0f,0.0f);
|
||||
half4 acc_frag_1_5 = make_half4(0.0f,0.0f,0.0f,0.0f);
|
||||
half4 acc_frag_1_6 = make_half4(0.0f,0.0f,0.0f,0.0f);
|
||||
half4 acc_frag_1_7 = make_half4(0.0f,0.0f,0.0f,0.0f);
|
||||
half4 acc_frag_2_0 = make_half4(0.0f,0.0f,0.0f,0.0f);
|
||||
half4 acc_frag_2_1 = make_half4(0.0f,0.0f,0.0f,0.0f);
|
||||
half4 acc_frag_2_2 = make_half4(0.0f,0.0f,0.0f,0.0f);
|
||||
half4 acc_frag_2_3 = make_half4(0.0f,0.0f,0.0f,0.0f);
|
||||
half4 acc_frag_2_4 = make_half4(0.0f,0.0f,0.0f,0.0f);
|
||||
half4 acc_frag_2_5 = make_half4(0.0f,0.0f,0.0f,0.0f);
|
||||
half4 acc_frag_2_6 = make_half4(0.0f,0.0f,0.0f,0.0f);
|
||||
half4 acc_frag_2_7 = make_half4(0.0f,0.0f,0.0f,0.0f);
|
||||
half4 acc_frag_3_0 = make_half4(0.0f,0.0f,0.0f,0.0f);
|
||||
half4 acc_frag_3_1 = make_half4(0.0f,0.0f,0.0f,0.0f);
|
||||
half4 acc_frag_3_2 = make_half4(0.0f,0.0f,0.0f,0.0f);
|
||||
half4 acc_frag_3_3 = make_half4(0.0f,0.0f,0.0f,0.0f);
|
||||
half4 acc_frag_3_4 = make_half4(0.0f,0.0f,0.0f,0.0f);
|
||||
half4 acc_frag_3_5 = make_half4(0.0f,0.0f,0.0f,0.0f);
|
||||
half4 acc_frag_3_6 = make_half4(0.0f,0.0f,0.0f,0.0f);
|
||||
half4 acc_frag_3_7 = make_half4(0.0f,0.0f,0.0f,0.0f);
|
||||
|
||||
// create registers for block A elements
|
||||
half8 a_frag_0_k_0;
|
||||
half8 a_frag_1_k_0;
|
||||
half8 a_frag_2_k_0;
|
||||
half8 a_frag_3_k_0;
|
||||
half8 a_frag_0_k_1;
|
||||
half8 a_frag_1_k_1;
|
||||
half8 a_frag_2_k_1;
|
||||
half8 a_frag_3_k_1;
|
||||
|
||||
// create register for block B elements
|
||||
half4 b_frag_0_k_0;
|
||||
half4 b_frag_1_k_0;
|
||||
half4 b_frag_2_k_0;
|
||||
half4 b_frag_3_k_0;
|
||||
half4 b_frag_4_k_0;
|
||||
half4 b_frag_5_k_0;
|
||||
half4 b_frag_6_k_0;
|
||||
half4 b_frag_7_k_0;
|
||||
half4 b_frag_0_k_1;
|
||||
half4 b_frag_1_k_1;
|
||||
half4 b_frag_2_k_1;
|
||||
half4 b_frag_3_k_1;
|
||||
half4 b_frag_4_k_1;
|
||||
half4 b_frag_5_k_1;
|
||||
half4 b_frag_6_k_1;
|
||||
half4 b_frag_7_k_1;
|
||||
|
||||
__syncthreads();
|
||||
|
||||
// load first tile
|
||||
__pipeline_memcpy_async(&smem_a_0[store_smem_a_off + ( 0)], &data1[global_a_off + ( 0)], 16);
|
||||
__pipeline_memcpy_async(&smem_a_0[store_smem_a_off + ( 32*64)], &data1[global_a_off + ( 32*K)], 16);
|
||||
__pipeline_memcpy_async(&smem_a_0[store_smem_a_off + ( 64*64)], &data1[global_a_off + ( 64*K)], 16);
|
||||
__pipeline_memcpy_async(&smem_a_0[store_smem_a_off + ( 96*64)], &data1[global_a_off + ( 96*K)], 16);
|
||||
__pipeline_memcpy_async(&smem_b_0[store_smem_b_off + ( 0)], &data2[global_b_off + ( 0)], 16);
|
||||
__pipeline_memcpy_async(&smem_b_0[store_smem_b_off + ( 32*64)], &data2[global_b_off + ( 16*N)], 16);
|
||||
__pipeline_commit();
|
||||
global_a_off += 32;
|
||||
global_b_off += 32 * N;
|
||||
|
||||
// load second tile
|
||||
__pipeline_memcpy_async(&smem_a_1[store_smem_a_off + ( 0)], &data1[global_a_off + ( 0)], 16);
|
||||
__pipeline_memcpy_async(&smem_a_1[store_smem_a_off + ( 32*64)], &data1[global_a_off + ( 32*K)], 16);
|
||||
__pipeline_memcpy_async(&smem_a_1[store_smem_a_off + ( 64*64)], &data1[global_a_off + ( 64*K)], 16);
|
||||
__pipeline_memcpy_async(&smem_a_1[store_smem_a_off + ( 96*64)], &data1[global_a_off + ( 96*K)], 16);
|
||||
__pipeline_memcpy_async(&smem_b_1[store_smem_b_off + ( 0)], &data2[global_b_off + ( 0)], 16);
|
||||
__pipeline_memcpy_async(&smem_b_1[store_smem_b_off + ( 32*64)], &data2[global_b_off + ( 16*N)], 16);
|
||||
__pipeline_commit();
|
||||
|
||||
global_a_off += 32;
|
||||
global_b_off += 32 * N;
|
||||
|
||||
// wait on first pre-fetch load
|
||||
__pipeline_wait_prior(1);
|
||||
__syncthreads();
|
||||
|
||||
// load K=0 elements for the first tile
|
||||
__ldmatrix_a_elems(&a_frag_0_k_0, &smem_a_0[load_smem_a_0_k_0]);
|
||||
__ldmatrix_a_elems(&a_frag_1_k_0, &smem_a_0[load_smem_a_1_k_0]);
|
||||
__ldmatrix_a_elems(&a_frag_2_k_0, &smem_a_0[load_smem_a_2_k_0]);
|
||||
__ldmatrix_a_elems(&a_frag_3_k_0, &smem_a_0[load_smem_a_3_k_0]);
|
||||
__ldmatrix_b_elems(&b_frag_0_k_0, &b_frag_1_k_0, &smem_b_0[load_smem_b_0_k_0]);
|
||||
__ldmatrix_b_elems(&b_frag_2_k_0, &b_frag_3_k_0, &smem_b_0[load_smem_b_1_k_0]);
|
||||
__ldmatrix_b_elems(&b_frag_4_k_0, &b_frag_5_k_0, &smem_b_0[load_smem_b_2_k_0]);
|
||||
__ldmatrix_b_elems(&b_frag_6_k_0, &b_frag_7_k_0, &smem_b_0[load_smem_b_3_k_0]);
|
||||
|
||||
for (int block_k = 0; block_k < num_k_blocks; block_k++) {
|
||||
int phase_k = block_k % 3;
|
||||
half *smem_a_curr = (phase_k == 0) ? smem_a_0 : ((phase_k == 1) ? smem_a_1 : smem_a_2);
|
||||
half *smem_b_curr = (phase_k == 0) ? smem_b_0 : ((phase_k == 1) ? smem_b_1 : smem_b_2);
|
||||
|
||||
int next_phase_k = (block_k+1) % 3;
|
||||
half *smem_a_next = (next_phase_k == 0) ? smem_a_0 : ((next_phase_k == 1) ? smem_a_1 : smem_a_2);
|
||||
half *smem_b_next = (next_phase_k == 0) ? smem_b_0 : ((next_phase_k == 1) ? smem_b_1 : smem_b_2);
|
||||
|
||||
int store_phase_k = (block_k+2) % 3;
|
||||
half *smem_a_store = (store_phase_k == 0) ? smem_a_0 : ((store_phase_k == 1) ? smem_a_1 : smem_a_2);
|
||||
half *smem_b_store = (store_phase_k == 0) ? smem_b_0 : ((store_phase_k == 1) ? smem_b_1 : smem_b_2);
|
||||
|
||||
// load K=1 elements for the current tile
|
||||
__ldmatrix_a_elems(&a_frag_0_k_1, &smem_a_curr[load_smem_a_0_k_1]);
|
||||
__ldmatrix_a_elems(&a_frag_1_k_1, &smem_a_curr[load_smem_a_1_k_1]);
|
||||
__ldmatrix_a_elems(&a_frag_2_k_1, &smem_a_curr[load_smem_a_2_k_1]);
|
||||
__ldmatrix_a_elems(&a_frag_3_k_1, &smem_a_curr[load_smem_a_3_k_1]);
|
||||
__ldmatrix_b_elems(&b_frag_0_k_1, &b_frag_1_k_1, &smem_b_curr[load_smem_b_0_k_1]);
|
||||
__ldmatrix_b_elems(&b_frag_2_k_1, &b_frag_3_k_1, &smem_b_curr[load_smem_b_1_k_1]);
|
||||
__ldmatrix_b_elems(&b_frag_4_k_1, &b_frag_5_k_1, &smem_b_curr[load_smem_b_2_k_1]);
|
||||
__ldmatrix_b_elems(&b_frag_6_k_1, &b_frag_7_k_1, &smem_b_curr[load_smem_b_3_k_1]);
|
||||
|
||||
// MMA K=0, (M=4 x N=8)
|
||||
acc_frag_0_0 = __WMMA_8_16_16_half_half(a_frag_0_k_0, b_frag_0_k_0, acc_frag_0_0);
|
||||
acc_frag_0_1 = __WMMA_8_16_16_half_half(a_frag_0_k_0, b_frag_1_k_0, acc_frag_0_1);
|
||||
acc_frag_0_2 = __WMMA_8_16_16_half_half(a_frag_0_k_0, b_frag_2_k_0, acc_frag_0_2);
|
||||
acc_frag_0_3 = __WMMA_8_16_16_half_half(a_frag_0_k_0, b_frag_3_k_0, acc_frag_0_3);
|
||||
acc_frag_0_4 = __WMMA_8_16_16_half_half(a_frag_0_k_0, b_frag_4_k_0, acc_frag_0_4);
|
||||
acc_frag_0_5 = __WMMA_8_16_16_half_half(a_frag_0_k_0, b_frag_5_k_0, acc_frag_0_5);
|
||||
acc_frag_0_6 = __WMMA_8_16_16_half_half(a_frag_0_k_0, b_frag_6_k_0, acc_frag_0_6);
|
||||
acc_frag_0_7 = __WMMA_8_16_16_half_half(a_frag_0_k_0, b_frag_7_k_0, acc_frag_0_7);
|
||||
acc_frag_1_0 = __WMMA_8_16_16_half_half(a_frag_1_k_0, b_frag_0_k_0, acc_frag_1_0);
|
||||
acc_frag_1_1 = __WMMA_8_16_16_half_half(a_frag_1_k_0, b_frag_1_k_0, acc_frag_1_1);
|
||||
acc_frag_1_2 = __WMMA_8_16_16_half_half(a_frag_1_k_0, b_frag_2_k_0, acc_frag_1_2);
|
||||
acc_frag_1_3 = __WMMA_8_16_16_half_half(a_frag_1_k_0, b_frag_3_k_0, acc_frag_1_3);
|
||||
acc_frag_1_4 = __WMMA_8_16_16_half_half(a_frag_1_k_0, b_frag_4_k_0, acc_frag_1_4);
|
||||
acc_frag_1_5 = __WMMA_8_16_16_half_half(a_frag_1_k_0, b_frag_5_k_0, acc_frag_1_5);
|
||||
acc_frag_1_6 = __WMMA_8_16_16_half_half(a_frag_1_k_0, b_frag_6_k_0, acc_frag_1_6);
|
||||
acc_frag_1_7 = __WMMA_8_16_16_half_half(a_frag_1_k_0, b_frag_7_k_0, acc_frag_1_7);
|
||||
acc_frag_2_0 = __WMMA_8_16_16_half_half(a_frag_2_k_0, b_frag_0_k_0, acc_frag_2_0);
|
||||
acc_frag_2_1 = __WMMA_8_16_16_half_half(a_frag_2_k_0, b_frag_1_k_0, acc_frag_2_1);
|
||||
acc_frag_2_2 = __WMMA_8_16_16_half_half(a_frag_2_k_0, b_frag_2_k_0, acc_frag_2_2);
|
||||
acc_frag_2_3 = __WMMA_8_16_16_half_half(a_frag_2_k_0, b_frag_3_k_0, acc_frag_2_3);
|
||||
acc_frag_2_4 = __WMMA_8_16_16_half_half(a_frag_2_k_0, b_frag_4_k_0, acc_frag_2_4);
|
||||
acc_frag_2_5 = __WMMA_8_16_16_half_half(a_frag_2_k_0, b_frag_5_k_0, acc_frag_2_5);
|
||||
acc_frag_2_6 = __WMMA_8_16_16_half_half(a_frag_2_k_0, b_frag_6_k_0, acc_frag_2_6);
|
||||
acc_frag_2_7 = __WMMA_8_16_16_half_half(a_frag_2_k_0, b_frag_7_k_0, acc_frag_2_7);
|
||||
acc_frag_3_0 = __WMMA_8_16_16_half_half(a_frag_3_k_0, b_frag_0_k_0, acc_frag_3_0);
|
||||
acc_frag_3_1 = __WMMA_8_16_16_half_half(a_frag_3_k_0, b_frag_1_k_0, acc_frag_3_1);
|
||||
acc_frag_3_2 = __WMMA_8_16_16_half_half(a_frag_3_k_0, b_frag_2_k_0, acc_frag_3_2);
|
||||
acc_frag_3_3 = __WMMA_8_16_16_half_half(a_frag_3_k_0, b_frag_3_k_0, acc_frag_3_3);
|
||||
acc_frag_3_4 = __WMMA_8_16_16_half_half(a_frag_3_k_0, b_frag_4_k_0, acc_frag_3_4);
|
||||
acc_frag_3_5 = __WMMA_8_16_16_half_half(a_frag_3_k_0, b_frag_5_k_0, acc_frag_3_5);
|
||||
acc_frag_3_6 = __WMMA_8_16_16_half_half(a_frag_3_k_0, b_frag_6_k_0, acc_frag_3_6);
|
||||
acc_frag_3_7 = __WMMA_8_16_16_half_half(a_frag_3_k_0, b_frag_7_k_0, acc_frag_3_7);
|
||||
|
||||
// load next tile if needed
|
||||
if (block_k < (num_k_blocks-2)) {
|
||||
__pipeline_memcpy_async(&smem_a_store[store_smem_a_off + ( 0)], &data1[global_a_off + ( 0)], 16);
|
||||
__pipeline_memcpy_async(&smem_a_store[store_smem_a_off + ( 32*64)], &data1[global_a_off + ( 32*K)], 16);
|
||||
__pipeline_memcpy_async(&smem_a_store[store_smem_a_off + ( 64*64)], &data1[global_a_off + ( 64*K)], 16);
|
||||
__pipeline_memcpy_async(&smem_a_store[store_smem_a_off + ( 96*64)], &data1[global_a_off + ( 96*K)], 16);
|
||||
__pipeline_memcpy_async(&smem_b_store[store_smem_b_off + ( 0)], &data2[global_b_off + ( 0)], 16);
|
||||
__pipeline_memcpy_async(&smem_b_store[store_smem_b_off + ( 32*64)], &data2[global_b_off + ( 16*N)], 16);
|
||||
global_a_off += 32;
|
||||
global_b_off += 32 * N;
|
||||
}
|
||||
__pipeline_commit();
|
||||
|
||||
// wait next tile
|
||||
__pipeline_wait_prior(1);
|
||||
__syncthreads();
|
||||
|
||||
// load K=0 elements for the next tile
|
||||
__ldmatrix_a_elems(&a_frag_0_k_0, &smem_a_next[load_smem_a_0_k_0]);
|
||||
__ldmatrix_a_elems(&a_frag_1_k_0, &smem_a_next[load_smem_a_1_k_0]);
|
||||
__ldmatrix_a_elems(&a_frag_2_k_0, &smem_a_next[load_smem_a_2_k_0]);
|
||||
__ldmatrix_a_elems(&a_frag_3_k_0, &smem_a_next[load_smem_a_3_k_0]);
|
||||
__ldmatrix_b_elems(&b_frag_0_k_0, &b_frag_1_k_0, &smem_b_next[load_smem_b_0_k_0]);
|
||||
__ldmatrix_b_elems(&b_frag_2_k_0, &b_frag_3_k_0, &smem_b_next[load_smem_b_1_k_0]);
|
||||
__ldmatrix_b_elems(&b_frag_4_k_0, &b_frag_5_k_0, &smem_b_next[load_smem_b_2_k_0]);
|
||||
__ldmatrix_b_elems(&b_frag_6_k_0, &b_frag_7_k_0, &smem_b_next[load_smem_b_3_k_0]);
|
||||
|
||||
// MMA K=1, (M=4 x N=8)
|
||||
acc_frag_0_0 = __WMMA_8_16_16_half_half(a_frag_0_k_1, b_frag_0_k_1, acc_frag_0_0);
|
||||
acc_frag_0_1 = __WMMA_8_16_16_half_half(a_frag_0_k_1, b_frag_1_k_1, acc_frag_0_1);
|
||||
acc_frag_0_2 = __WMMA_8_16_16_half_half(a_frag_0_k_1, b_frag_2_k_1, acc_frag_0_2);
|
||||
acc_frag_0_3 = __WMMA_8_16_16_half_half(a_frag_0_k_1, b_frag_3_k_1, acc_frag_0_3);
|
||||
acc_frag_0_4 = __WMMA_8_16_16_half_half(a_frag_0_k_1, b_frag_4_k_1, acc_frag_0_4);
|
||||
acc_frag_0_5 = __WMMA_8_16_16_half_half(a_frag_0_k_1, b_frag_5_k_1, acc_frag_0_5);
|
||||
acc_frag_0_6 = __WMMA_8_16_16_half_half(a_frag_0_k_1, b_frag_6_k_1, acc_frag_0_6);
|
||||
acc_frag_0_7 = __WMMA_8_16_16_half_half(a_frag_0_k_1, b_frag_7_k_1, acc_frag_0_7);
|
||||
acc_frag_1_0 = __WMMA_8_16_16_half_half(a_frag_1_k_1, b_frag_0_k_1, acc_frag_1_0);
|
||||
acc_frag_1_1 = __WMMA_8_16_16_half_half(a_frag_1_k_1, b_frag_1_k_1, acc_frag_1_1);
|
||||
acc_frag_1_2 = __WMMA_8_16_16_half_half(a_frag_1_k_1, b_frag_2_k_1, acc_frag_1_2);
|
||||
acc_frag_1_3 = __WMMA_8_16_16_half_half(a_frag_1_k_1, b_frag_3_k_1, acc_frag_1_3);
|
||||
acc_frag_1_4 = __WMMA_8_16_16_half_half(a_frag_1_k_1, b_frag_4_k_1, acc_frag_1_4);
|
||||
acc_frag_1_5 = __WMMA_8_16_16_half_half(a_frag_1_k_1, b_frag_5_k_1, acc_frag_1_5);
|
||||
acc_frag_1_6 = __WMMA_8_16_16_half_half(a_frag_1_k_1, b_frag_6_k_1, acc_frag_1_6);
|
||||
acc_frag_1_7 = __WMMA_8_16_16_half_half(a_frag_1_k_1, b_frag_7_k_1, acc_frag_1_7);
|
||||
acc_frag_2_0 = __WMMA_8_16_16_half_half(a_frag_2_k_1, b_frag_0_k_1, acc_frag_2_0);
|
||||
acc_frag_2_1 = __WMMA_8_16_16_half_half(a_frag_2_k_1, b_frag_1_k_1, acc_frag_2_1);
|
||||
acc_frag_2_2 = __WMMA_8_16_16_half_half(a_frag_2_k_1, b_frag_2_k_1, acc_frag_2_2);
|
||||
acc_frag_2_3 = __WMMA_8_16_16_half_half(a_frag_2_k_1, b_frag_3_k_1, acc_frag_2_3);
|
||||
acc_frag_2_4 = __WMMA_8_16_16_half_half(a_frag_2_k_1, b_frag_4_k_1, acc_frag_2_4);
|
||||
acc_frag_2_5 = __WMMA_8_16_16_half_half(a_frag_2_k_1, b_frag_5_k_1, acc_frag_2_5);
|
||||
acc_frag_2_6 = __WMMA_8_16_16_half_half(a_frag_2_k_1, b_frag_6_k_1, acc_frag_2_6);
|
||||
acc_frag_2_7 = __WMMA_8_16_16_half_half(a_frag_2_k_1, b_frag_7_k_1, acc_frag_2_7);
|
||||
acc_frag_3_0 = __WMMA_8_16_16_half_half(a_frag_3_k_1, b_frag_0_k_1, acc_frag_3_0);
|
||||
acc_frag_3_1 = __WMMA_8_16_16_half_half(a_frag_3_k_1, b_frag_1_k_1, acc_frag_3_1);
|
||||
acc_frag_3_2 = __WMMA_8_16_16_half_half(a_frag_3_k_1, b_frag_2_k_1, acc_frag_3_2);
|
||||
acc_frag_3_3 = __WMMA_8_16_16_half_half(a_frag_3_k_1, b_frag_3_k_1, acc_frag_3_3);
|
||||
acc_frag_3_4 = __WMMA_8_16_16_half_half(a_frag_3_k_1, b_frag_4_k_1, acc_frag_3_4);
|
||||
acc_frag_3_5 = __WMMA_8_16_16_half_half(a_frag_3_k_1, b_frag_5_k_1, acc_frag_3_5);
|
||||
acc_frag_3_6 = __WMMA_8_16_16_half_half(a_frag_3_k_1, b_frag_6_k_1, acc_frag_3_6);
|
||||
acc_frag_3_7 = __WMMA_8_16_16_half_half(a_frag_3_k_1, b_frag_7_k_1, acc_frag_3_7);
|
||||
}
|
||||
|
||||
// write accumulators to output
|
||||
__pipeline_wait_prior(0);
|
||||
__syncthreads();
|
||||
|
||||
// faster epilogue: write each 8x8 TC accs to SMEM first
|
||||
// - SMEM_N_WIDTH 8 larger than 128 required to deconflict bank access
|
||||
// - around 14 micros
|
||||
// - check bank conflict with in sudo with: "PYTHONPATH=. CUDA=1 GEMM_VARIATION="max" DTYPE_IN=half DTYPE_OUT=half DTYPE_ACC=half CNT=8 INPUT=ONES /usr/local/cuda/bin/ncu --section MemoryWorkloadAnalysis --metrics l1tex__data_bank_conflicts_pipe_lsu_mem_shared_op_ld.sum,l1tex__data_bank_conflicts_pipe_lsu_mem_shared_op_st.sum python3 ./extra/gemm/max_matmul.py"
|
||||
|
||||
// epilogue chunk with 256 threads / WG_M=4 / WG_N=2: split into 8 chunks (hi/lo for each in TC M)
|
||||
// 1) write 32 rows of 128 cols (rows 0-7, 16-23, 32-39, 48-53 in acc_frag_0.lo, then acc_frag_0.hi, etc.)
|
||||
// 2) read/write 16 rows of 128 elements in 8 elem (16B) chunks
|
||||
|
||||
half2 *smem32_d = (half2 *)(smem);
|
||||
half8 *smem128_d = (half8 *)(smem);
|
||||
half8 *out128_d = (half8 *)(data0);
|
||||
size_t smem32_d_write_off = (wg_m * 8 * (SMEM_N_WIDTH / 2)) + (wg_n * (16 / 2));
|
||||
size_t smem32_d_thread_off = ((wg_threads / 4) * (SMEM_N_WIDTH / 2)) + (wg_threads % 4);
|
||||
size_t smem128_d_read_off = ((threads / 16) * (SMEM_N_WIDTH / 8)) + (threads % 16);
|
||||
size_t out128_d_off = ((grid_m * 256) * (N / 8)) + (grid_n * (128 / 8)) +
|
||||
((threads / 128) * 16 * (N / 8)) + (((threads / 16) % 8) * (N / 8)) + (threads % 16);
|
||||
|
||||
// write acc_frag_0_*
|
||||
|
||||
// write 32 rows of 128 N elements to SMEM
|
||||
__syncthreads();
|
||||
smem32_d[smem32_d_write_off + smem32_d_thread_off + ( 0*4)] = half2(acc_frag_0_0.x, acc_frag_0_0.y);
|
||||
smem32_d[smem32_d_write_off + smem32_d_thread_off + ( 1*4)] = half2(acc_frag_0_1.x, acc_frag_0_1.y);
|
||||
smem32_d[smem32_d_write_off + smem32_d_thread_off + ( 4*4)] = half2(acc_frag_0_2.x, acc_frag_0_2.y);
|
||||
smem32_d[smem32_d_write_off + smem32_d_thread_off + ( 5*4)] = half2(acc_frag_0_3.x, acc_frag_0_3.y);
|
||||
smem32_d[smem32_d_write_off + smem32_d_thread_off + ( 8*4)] = half2(acc_frag_0_4.x, acc_frag_0_4.y);
|
||||
smem32_d[smem32_d_write_off + smem32_d_thread_off + ( 9*4)] = half2(acc_frag_0_5.x, acc_frag_0_5.y);
|
||||
smem32_d[smem32_d_write_off + smem32_d_thread_off + (12*4)] = half2(acc_frag_0_6.x, acc_frag_0_6.y);
|
||||
smem32_d[smem32_d_write_off + smem32_d_thread_off + (13*4)] = half2(acc_frag_0_7.x, acc_frag_0_7.y);
|
||||
|
||||
// each thread reads and writes two 8 element chunks
|
||||
__syncthreads();
|
||||
out128_d[out128_d_off + ( 0 * (N / 8))] = smem128_d[smem128_d_read_off];
|
||||
out128_d[out128_d_off + (32 * (N / 8))] = smem128_d[smem128_d_read_off + (16 * (SMEM_N_WIDTH / 8))];
|
||||
|
||||
// write 32 rows of 128 N elements to SMEM
|
||||
__syncthreads();
|
||||
smem32_d[smem32_d_write_off + smem32_d_thread_off + ( 0*4)] = half2(acc_frag_0_0.z, acc_frag_0_0.w);
|
||||
smem32_d[smem32_d_write_off + smem32_d_thread_off + ( 1*4)] = half2(acc_frag_0_1.z, acc_frag_0_1.w);
|
||||
smem32_d[smem32_d_write_off + smem32_d_thread_off + ( 4*4)] = half2(acc_frag_0_2.z, acc_frag_0_2.w);
|
||||
smem32_d[smem32_d_write_off + smem32_d_thread_off + ( 5*4)] = half2(acc_frag_0_3.z, acc_frag_0_3.w);
|
||||
smem32_d[smem32_d_write_off + smem32_d_thread_off + ( 8*4)] = half2(acc_frag_0_4.z, acc_frag_0_4.w);
|
||||
smem32_d[smem32_d_write_off + smem32_d_thread_off + ( 9*4)] = half2(acc_frag_0_5.z, acc_frag_0_5.w);
|
||||
smem32_d[smem32_d_write_off + smem32_d_thread_off + (12*4)] = half2(acc_frag_0_6.z, acc_frag_0_6.w);
|
||||
smem32_d[smem32_d_write_off + smem32_d_thread_off + (13*4)] = half2(acc_frag_0_7.z, acc_frag_0_7.w);
|
||||
|
||||
// each thread reads and writes two 8 element chunks
|
||||
__syncthreads();
|
||||
out128_d[out128_d_off + ( 8 * (N / 8))] = smem128_d[smem128_d_read_off];
|
||||
out128_d[out128_d_off + (40 * (N / 8))] = smem128_d[smem128_d_read_off + (16 * (SMEM_N_WIDTH / 8))];
|
||||
|
||||
// write acc_frag_1_*
|
||||
out128_d_off += (64 * (N / 8));
|
||||
|
||||
// write 32 rows of 128 N elements to SMEM
|
||||
__syncthreads();
|
||||
smem32_d[smem32_d_write_off + smem32_d_thread_off + ( 0*4)] = half2(acc_frag_1_0.x, acc_frag_1_0.y);
|
||||
smem32_d[smem32_d_write_off + smem32_d_thread_off + ( 1*4)] = half2(acc_frag_1_1.x, acc_frag_1_1.y);
|
||||
smem32_d[smem32_d_write_off + smem32_d_thread_off + ( 4*4)] = half2(acc_frag_1_2.x, acc_frag_1_2.y);
|
||||
smem32_d[smem32_d_write_off + smem32_d_thread_off + ( 5*4)] = half2(acc_frag_1_3.x, acc_frag_1_3.y);
|
||||
smem32_d[smem32_d_write_off + smem32_d_thread_off + ( 8*4)] = half2(acc_frag_1_4.x, acc_frag_1_4.y);
|
||||
smem32_d[smem32_d_write_off + smem32_d_thread_off + ( 9*4)] = half2(acc_frag_1_5.x, acc_frag_1_5.y);
|
||||
smem32_d[smem32_d_write_off + smem32_d_thread_off + (12*4)] = half2(acc_frag_1_6.x, acc_frag_1_6.y);
|
||||
smem32_d[smem32_d_write_off + smem32_d_thread_off + (13*4)] = half2(acc_frag_1_7.x, acc_frag_1_7.y);
|
||||
|
||||
// each thread reads and writes two 8 element chunks
|
||||
__syncthreads();
|
||||
out128_d[out128_d_off + ( 0 * (N / 8))] = smem128_d[smem128_d_read_off];
|
||||
out128_d[out128_d_off + (32 * (N / 8))] = smem128_d[smem128_d_read_off + (16 * (SMEM_N_WIDTH / 8))];
|
||||
|
||||
// write 32 rows of 128 N elements to SMEM
|
||||
__syncthreads();
|
||||
smem32_d[smem32_d_write_off + smem32_d_thread_off + ( 0*4)] = half2(acc_frag_1_0.z, acc_frag_1_0.w);
|
||||
smem32_d[smem32_d_write_off + smem32_d_thread_off + ( 1*4)] = half2(acc_frag_1_1.z, acc_frag_1_1.w);
|
||||
smem32_d[smem32_d_write_off + smem32_d_thread_off + ( 4*4)] = half2(acc_frag_1_2.z, acc_frag_1_2.w);
|
||||
smem32_d[smem32_d_write_off + smem32_d_thread_off + ( 5*4)] = half2(acc_frag_1_3.z, acc_frag_1_3.w);
|
||||
smem32_d[smem32_d_write_off + smem32_d_thread_off + ( 8*4)] = half2(acc_frag_1_4.z, acc_frag_1_4.w);
|
||||
smem32_d[smem32_d_write_off + smem32_d_thread_off + ( 9*4)] = half2(acc_frag_1_5.z, acc_frag_1_5.w);
|
||||
smem32_d[smem32_d_write_off + smem32_d_thread_off + (12*4)] = half2(acc_frag_1_6.z, acc_frag_1_6.w);
|
||||
smem32_d[smem32_d_write_off + smem32_d_thread_off + (13*4)] = half2(acc_frag_1_7.z, acc_frag_1_7.w);
|
||||
|
||||
// each thread reads and writes two 8 element chunks
|
||||
__syncthreads();
|
||||
out128_d[out128_d_off + ( 8 * (N / 8))] = smem128_d[smem128_d_read_off];
|
||||
out128_d[out128_d_off + (40 * (N / 8))] = smem128_d[smem128_d_read_off + (16 * (SMEM_N_WIDTH / 8))];
|
||||
|
||||
// write acc_frag_2_*
|
||||
out128_d_off += (64 * (N / 8));
|
||||
|
||||
// write 32 rows of 128 N elements to SMEM
|
||||
__syncthreads();
|
||||
smem32_d[smem32_d_write_off + smem32_d_thread_off + ( 0*4)] = half2(acc_frag_2_0.x, acc_frag_2_0.y);
|
||||
smem32_d[smem32_d_write_off + smem32_d_thread_off + ( 1*4)] = half2(acc_frag_2_1.x, acc_frag_2_1.y);
|
||||
smem32_d[smem32_d_write_off + smem32_d_thread_off + ( 4*4)] = half2(acc_frag_2_2.x, acc_frag_2_2.y);
|
||||
smem32_d[smem32_d_write_off + smem32_d_thread_off + ( 5*4)] = half2(acc_frag_2_3.x, acc_frag_2_3.y);
|
||||
smem32_d[smem32_d_write_off + smem32_d_thread_off + ( 8*4)] = half2(acc_frag_2_4.x, acc_frag_2_4.y);
|
||||
smem32_d[smem32_d_write_off + smem32_d_thread_off + ( 9*4)] = half2(acc_frag_2_5.x, acc_frag_2_5.y);
|
||||
smem32_d[smem32_d_write_off + smem32_d_thread_off + (12*4)] = half2(acc_frag_2_6.x, acc_frag_2_6.y);
|
||||
smem32_d[smem32_d_write_off + smem32_d_thread_off + (13*4)] = half2(acc_frag_2_7.x, acc_frag_2_7.y);
|
||||
|
||||
// each thread reads and writes two 8 element chunks
|
||||
__syncthreads();
|
||||
out128_d[out128_d_off + ( 0 * (N / 8))] = smem128_d[smem128_d_read_off];
|
||||
out128_d[out128_d_off + (32 * (N / 8))] = smem128_d[smem128_d_read_off + (16 * (SMEM_N_WIDTH / 8))];
|
||||
|
||||
// write 32 rows of 128 N elements to SMEM
|
||||
__syncthreads();
|
||||
smem32_d[smem32_d_write_off + smem32_d_thread_off + ( 0*4)] = half2(acc_frag_2_0.z, acc_frag_2_0.w);
|
||||
smem32_d[smem32_d_write_off + smem32_d_thread_off + ( 1*4)] = half2(acc_frag_2_1.z, acc_frag_2_1.w);
|
||||
smem32_d[smem32_d_write_off + smem32_d_thread_off + ( 4*4)] = half2(acc_frag_2_2.z, acc_frag_2_2.w);
|
||||
smem32_d[smem32_d_write_off + smem32_d_thread_off + ( 5*4)] = half2(acc_frag_2_3.z, acc_frag_2_3.w);
|
||||
smem32_d[smem32_d_write_off + smem32_d_thread_off + ( 8*4)] = half2(acc_frag_2_4.z, acc_frag_2_4.w);
|
||||
smem32_d[smem32_d_write_off + smem32_d_thread_off + ( 9*4)] = half2(acc_frag_2_5.z, acc_frag_2_5.w);
|
||||
smem32_d[smem32_d_write_off + smem32_d_thread_off + (12*4)] = half2(acc_frag_2_6.z, acc_frag_2_6.w);
|
||||
smem32_d[smem32_d_write_off + smem32_d_thread_off + (13*4)] = half2(acc_frag_2_7.z, acc_frag_2_7.w);
|
||||
|
||||
// each thread reads and writes two 8 element chunks
|
||||
__syncthreads();
|
||||
out128_d[out128_d_off + ( 8 * (N / 8))] = smem128_d[smem128_d_read_off];
|
||||
out128_d[out128_d_off + (40 * (N / 8))] = smem128_d[smem128_d_read_off + (16 * (SMEM_N_WIDTH / 8))];
|
||||
|
||||
// write acc_frag_3_*
|
||||
out128_d_off += (64 * (N / 8));
|
||||
|
||||
// write 32 rows of 128 N elements to SMEM
|
||||
__syncthreads();
|
||||
smem32_d[smem32_d_write_off + smem32_d_thread_off + ( 0*4)] = half2(acc_frag_3_0.x, acc_frag_3_0.y);
|
||||
smem32_d[smem32_d_write_off + smem32_d_thread_off + ( 1*4)] = half2(acc_frag_3_1.x, acc_frag_3_1.y);
|
||||
smem32_d[smem32_d_write_off + smem32_d_thread_off + ( 4*4)] = half2(acc_frag_3_2.x, acc_frag_3_2.y);
|
||||
smem32_d[smem32_d_write_off + smem32_d_thread_off + ( 5*4)] = half2(acc_frag_3_3.x, acc_frag_3_3.y);
|
||||
smem32_d[smem32_d_write_off + smem32_d_thread_off + ( 8*4)] = half2(acc_frag_3_4.x, acc_frag_3_4.y);
|
||||
smem32_d[smem32_d_write_off + smem32_d_thread_off + ( 9*4)] = half2(acc_frag_3_5.x, acc_frag_3_5.y);
|
||||
smem32_d[smem32_d_write_off + smem32_d_thread_off + (12*4)] = half2(acc_frag_3_6.x, acc_frag_3_6.y);
|
||||
smem32_d[smem32_d_write_off + smem32_d_thread_off + (13*4)] = half2(acc_frag_3_7.x, acc_frag_3_7.y);
|
||||
|
||||
// each thread reads and writes two 8 element chunks
|
||||
__syncthreads();
|
||||
out128_d[out128_d_off + ( 0 * (N / 8))] = smem128_d[smem128_d_read_off];
|
||||
out128_d[out128_d_off + (32 * (N / 8))] = smem128_d[smem128_d_read_off + (16 * (SMEM_N_WIDTH / 8))];
|
||||
|
||||
// write 32 rows of 128 N elements to SMEM
|
||||
__syncthreads();
|
||||
smem32_d[smem32_d_write_off + smem32_d_thread_off + ( 0*4)] = half2(acc_frag_3_0.z, acc_frag_3_0.w);
|
||||
smem32_d[smem32_d_write_off + smem32_d_thread_off + ( 1*4)] = half2(acc_frag_3_1.z, acc_frag_3_1.w);
|
||||
smem32_d[smem32_d_write_off + smem32_d_thread_off + ( 4*4)] = half2(acc_frag_3_2.z, acc_frag_3_2.w);
|
||||
smem32_d[smem32_d_write_off + smem32_d_thread_off + ( 5*4)] = half2(acc_frag_3_3.z, acc_frag_3_3.w);
|
||||
smem32_d[smem32_d_write_off + smem32_d_thread_off + ( 8*4)] = half2(acc_frag_3_4.z, acc_frag_3_4.w);
|
||||
smem32_d[smem32_d_write_off + smem32_d_thread_off + ( 9*4)] = half2(acc_frag_3_5.z, acc_frag_3_5.w);
|
||||
smem32_d[smem32_d_write_off + smem32_d_thread_off + (12*4)] = half2(acc_frag_3_6.z, acc_frag_3_6.w);
|
||||
smem32_d[smem32_d_write_off + smem32_d_thread_off + (13*4)] = half2(acc_frag_3_7.z, acc_frag_3_7.w);
|
||||
|
||||
// each thread reads and writes two 8 element chunks
|
||||
__syncthreads();
|
||||
out128_d[out128_d_off + ( 8 * (N / 8))] = smem128_d[smem128_d_read_off];
|
||||
out128_d[out128_d_off + (40 * (N / 8))] = smem128_d[smem128_d_read_off + (16 * (SMEM_N_WIDTH / 8))];
|
||||
|
||||
__syncthreads();
|
||||
}
|
||||
157
extra/gemm/max_kernels/nv.fp16_fp32_fp16.hcopt.cu
Normal file
157
extra/gemm/max_kernels/nv.fp16_fp32_fp16.hcopt.cu
Normal file
@@ -0,0 +1,157 @@
|
||||
#define INFINITY (__int_as_float(0x7f800000))
|
||||
#define NAN (__int_as_float(0x7fffffff))
|
||||
#include <cuda_fp16.h>
|
||||
struct __align__(8) half4 { half x, y, z, w; }; __device__ half4 make_half4(half x, half y, half z, half w) { half4 r={x, y, z, w}; return r; }
|
||||
struct __align__(16) half8 { half x, y, z, w, a, b, c, d; }; __device__ half8 make_half8(half x, half y, half z, half w, half a, half b, half c, half d) { half8 r={x, y, z, w, a, b, c, d}; return r; }
|
||||
__device__ float4 __WMMA_8_16_16_half_float(half8 a, half4 b, float4 c) { int *a_pk = (int *) (&a), *b_pk = (int *) (&b);
|
||||
asm( "mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 { %0, %1, %2, %3 }, { %4, %5, %6, %7 }, { %8, %9 }, { %0, %1, %2, %3 };"
|
||||
: "+f"(c.x), "+f"(c.y), "+f"(c.z), "+f"(c.w) : "r"(a_pk[0]), "r"(a_pk[1]), "r"(a_pk[2]), "r"(a_pk[3]), "r"(b_pk[0]), "r"(b_pk[1]) );
|
||||
return c;}
|
||||
extern "C" __global__ void __launch_bounds__(128) wmma_example(half* data0, const half* data1, const half* data2) {
|
||||
int gidx0 = blockIdx.x; /* 32 */
|
||||
int gidx1 = blockIdx.y; /* 64 */
|
||||
int lidx0 = threadIdx.x; /* 16 */
|
||||
int lidx1 = threadIdx.y; /* 2 */
|
||||
int lidx2 = threadIdx.z; /* 4 */
|
||||
float4 cast0 = make_float4(0.0f,0.0f,0.0f,0.0f);
|
||||
int alu0 = (gidx0*128);
|
||||
int alu1 = (gidx1*262144);
|
||||
int alu2 = (lidx1*32768);
|
||||
int alu3 = (lidx2*32);
|
||||
int alu4 = (lidx0/8);
|
||||
int alu5 = (alu4*16384);
|
||||
int alu6 = (lidx0%2);
|
||||
int alu7 = (alu6*2);
|
||||
int alu8 = ((lidx0/2)%2);
|
||||
int alu9 = (alu8*4);
|
||||
int alu10 = ((lidx0/4)%2);
|
||||
int alu11 = (alu10*8192);
|
||||
int alu12 = (alu1+alu0+alu7+alu9+alu11+alu5+alu2+alu3);
|
||||
int alu13 = (alu1+alu7+alu9+alu11+alu5+alu2);
|
||||
float4 acc0 = cast0;
|
||||
float4 acc1 = cast0;
|
||||
float4 acc2 = cast0;
|
||||
float4 acc3 = cast0;
|
||||
float4 acc4 = cast0;
|
||||
float4 acc5 = cast0;
|
||||
float4 acc6 = cast0;
|
||||
float4 acc7 = cast0;
|
||||
float4 acc8 = cast0;
|
||||
float4 acc9 = cast0;
|
||||
float4 acc10 = cast0;
|
||||
float4 acc11 = cast0;
|
||||
float4 acc12 = cast0;
|
||||
float4 acc13 = cast0;
|
||||
float4 acc14 = cast0;
|
||||
float4 acc15 = cast0;
|
||||
for (int ridx0 = 0; ridx0 < 256; ridx0++) {
|
||||
int alu14 = (ridx0*16);
|
||||
int alu15 = (alu13+alu14);
|
||||
int alu16 = (alu14+alu13);
|
||||
int alu17 = (alu0+(alu6*8192)+(alu8*16384)+alu10+(alu4*2)+(lidx1*4)+alu3+(ridx0*65536));
|
||||
half val0 = data2[alu17+8];
|
||||
half val1 = data2[alu17+16];
|
||||
half val2 = data2[alu17+24];
|
||||
half val3 = data2[alu17+4096];
|
||||
half val4 = data2[alu17+4104];
|
||||
half val5 = data2[alu17+4112];
|
||||
half val6 = data2[alu17+4120];
|
||||
half val7 = data2[alu17+32768];
|
||||
half val8 = data2[alu17+32776];
|
||||
half val9 = data2[alu17+32784];
|
||||
half val10 = data2[alu17+32792];
|
||||
half val11 = data2[alu17+36864];
|
||||
half val12 = data2[alu17+36872];
|
||||
half4 cast1 = make_half4(val0,val4,val8,val12);
|
||||
half val13 = data2[alu17+36880];
|
||||
half4 cast2 = make_half4(val1,val5,val9,val13);
|
||||
half val14 = data2[alu17+36888];
|
||||
half4 cast3 = make_half4(val2,val6,val10,val14);
|
||||
half val15 = data2[alu17];
|
||||
half4 cast4 = make_half4(val15,val3,val7,val11);
|
||||
half2 val16 = *((half2*)(data1+alu15+4096));
|
||||
half2 val17 = *((half2*)(data1+alu15+65536));
|
||||
half2 val18 = *((half2*)(data1+alu15+69632));
|
||||
half2 val19 = *((half2*)(data1+alu15+131072));
|
||||
half2 val20 = *((half2*)(data1+alu15+135168));
|
||||
half2 val21 = *((half2*)(data1+alu15+196608));
|
||||
half2 val22 = *((half2*)(data1+alu15+200704));
|
||||
half2 val23 = *((half2*)(data1+alu15));
|
||||
half2 val24 = *((half2*)(data1+alu16+8));
|
||||
half2 val25 = *((half2*)(data1+alu16+4104));
|
||||
half8 cast5 = make_half8(val23.x,val23.y,val16.x,val16.y,val24.x,val24.y,val25.x,val25.y);
|
||||
float4 wmma0 = __WMMA_8_16_16_half_float(cast5, cast1, acc1);
|
||||
float4 wmma1 = __WMMA_8_16_16_half_float(cast5, cast2, acc2);
|
||||
float4 wmma2 = __WMMA_8_16_16_half_float(cast5, cast3, acc3);
|
||||
float4 wmma3 = __WMMA_8_16_16_half_float(cast5, cast4, acc0);
|
||||
half2 val26 = *((half2*)(data1+alu16+65544));
|
||||
half2 val27 = *((half2*)(data1+alu16+69640));
|
||||
half8 cast6 = make_half8(val17.x,val17.y,val18.x,val18.y,val26.x,val26.y,val27.x,val27.y);
|
||||
float4 wmma4 = __WMMA_8_16_16_half_float(cast6, cast1, acc5);
|
||||
float4 wmma5 = __WMMA_8_16_16_half_float(cast6, cast2, acc6);
|
||||
float4 wmma6 = __WMMA_8_16_16_half_float(cast6, cast3, acc7);
|
||||
float4 wmma7 = __WMMA_8_16_16_half_float(cast6, cast4, acc4);
|
||||
half2 val28 = *((half2*)(data1+alu16+131080));
|
||||
half2 val29 = *((half2*)(data1+alu16+135176));
|
||||
half8 cast7 = make_half8(val19.x,val19.y,val20.x,val20.y,val28.x,val28.y,val29.x,val29.y);
|
||||
float4 wmma8 = __WMMA_8_16_16_half_float(cast7, cast1, acc9);
|
||||
float4 wmma9 = __WMMA_8_16_16_half_float(cast7, cast2, acc10);
|
||||
float4 wmma10 = __WMMA_8_16_16_half_float(cast7, cast3, acc11);
|
||||
float4 wmma11 = __WMMA_8_16_16_half_float(cast7, cast4, acc8);
|
||||
half2 val30 = *((half2*)(data1+alu16+196616));
|
||||
half2 val31 = *((half2*)(data1+alu16+200712));
|
||||
half8 cast8 = make_half8(val21.x,val21.y,val22.x,val22.y,val30.x,val30.y,val31.x,val31.y);
|
||||
float4 wmma12 = __WMMA_8_16_16_half_float(cast8, cast1, acc13);
|
||||
float4 wmma13 = __WMMA_8_16_16_half_float(cast8, cast2, acc14);
|
||||
float4 wmma14 = __WMMA_8_16_16_half_float(cast8, cast3, acc15);
|
||||
float4 wmma15 = __WMMA_8_16_16_half_float(cast8, cast4, acc12);
|
||||
acc0 = wmma3;
|
||||
acc1 = wmma0;
|
||||
acc2 = wmma1;
|
||||
acc3 = wmma2;
|
||||
acc4 = wmma7;
|
||||
acc5 = wmma4;
|
||||
acc6 = wmma5;
|
||||
acc7 = wmma6;
|
||||
acc8 = wmma11;
|
||||
acc9 = wmma8;
|
||||
acc10 = wmma9;
|
||||
acc11 = wmma10;
|
||||
acc12 = wmma15;
|
||||
acc13 = wmma12;
|
||||
acc14 = wmma13;
|
||||
acc15 = wmma14;
|
||||
}
|
||||
*((half2*)(data0+alu12+8)) = make_half2((half)(acc1.x),(half)(acc1.y));
|
||||
*((half2*)(data0+alu12+16)) = make_half2((half)(acc2.x),(half)(acc2.y));
|
||||
*((half2*)(data0+alu12+24)) = make_half2((half)(acc3.x),(half)(acc3.y));
|
||||
*((half2*)(data0+alu12+4096)) = make_half2((half)(acc0.z),(half)(acc0.w));
|
||||
*((half2*)(data0+alu12+4104)) = make_half2((half)(acc1.z),(half)(acc1.w));
|
||||
*((half2*)(data0+alu12+4112)) = make_half2((half)(acc2.z),(half)(acc2.w));
|
||||
*((half2*)(data0+alu12+4120)) = make_half2((half)(acc3.z),(half)(acc3.w));
|
||||
*((half2*)(data0+alu12+65536)) = make_half2((half)(acc4.x),(half)(acc4.y));
|
||||
*((half2*)(data0+alu12+65544)) = make_half2((half)(acc5.x),(half)(acc5.y));
|
||||
*((half2*)(data0+alu12+65552)) = make_half2((half)(acc6.x),(half)(acc6.y));
|
||||
*((half2*)(data0+alu12+65560)) = make_half2((half)(acc7.x),(half)(acc7.y));
|
||||
*((half2*)(data0+alu12+69632)) = make_half2((half)(acc4.z),(half)(acc4.w));
|
||||
*((half2*)(data0+alu12+69640)) = make_half2((half)(acc5.z),(half)(acc5.w));
|
||||
*((half2*)(data0+alu12+69648)) = make_half2((half)(acc6.z),(half)(acc6.w));
|
||||
*((half2*)(data0+alu12+69656)) = make_half2((half)(acc7.z),(half)(acc7.w));
|
||||
*((half2*)(data0+alu12+131072)) = make_half2((half)(acc8.x),(half)(acc8.y));
|
||||
*((half2*)(data0+alu12+131080)) = make_half2((half)(acc9.x),(half)(acc9.y));
|
||||
*((half2*)(data0+alu12+131088)) = make_half2((half)(acc10.x),(half)(acc10.y));
|
||||
*((half2*)(data0+alu12+131096)) = make_half2((half)(acc11.x),(half)(acc11.y));
|
||||
*((half2*)(data0+alu12+135168)) = make_half2((half)(acc8.z),(half)(acc8.w));
|
||||
*((half2*)(data0+alu12+135176)) = make_half2((half)(acc9.z),(half)(acc9.w));
|
||||
*((half2*)(data0+alu12+135184)) = make_half2((half)(acc10.z),(half)(acc10.w));
|
||||
*((half2*)(data0+alu12+135192)) = make_half2((half)(acc11.z),(half)(acc11.w));
|
||||
*((half2*)(data0+alu12+196608)) = make_half2((half)(acc12.x),(half)(acc12.y));
|
||||
*((half2*)(data0+alu12+196616)) = make_half2((half)(acc13.x),(half)(acc13.y));
|
||||
*((half2*)(data0+alu12+196624)) = make_half2((half)(acc14.x),(half)(acc14.y));
|
||||
*((half2*)(data0+alu12+196632)) = make_half2((half)(acc15.x),(half)(acc15.y));
|
||||
*((half2*)(data0+alu12+200704)) = make_half2((half)(acc12.z),(half)(acc12.w));
|
||||
*((half2*)(data0+alu12+200712)) = make_half2((half)(acc13.z),(half)(acc13.w));
|
||||
*((half2*)(data0+alu12+200720)) = make_half2((half)(acc14.z),(half)(acc14.w));
|
||||
*((half2*)(data0+alu12+200728)) = make_half2((half)(acc15.z),(half)(acc15.w));
|
||||
*((half2*)(data0+alu12)) = make_half2((half)(acc0.x),(half)(acc0.y));
|
||||
}
|
||||
@@ -0,0 +1,398 @@
|
||||
#define INFINITY (__int_as_float(0x7f800000))
|
||||
#define NAN (__int_as_float(0x7fffffff))
|
||||
#include <cuda_fp16.h>
|
||||
#include <cuda_pipeline.h>
|
||||
#define N_PAD 132
|
||||
|
||||
struct __align__(8) half4 { half x, y, z, w; };
|
||||
__device__ half4 make_half4(half x, half y, half z, half w) { half4 r={x, y, z, w}; return r; }
|
||||
|
||||
struct __align__(16) half8 { half x, y, z, w, a, b, c, d; };
|
||||
__device__ half8 make_half8(half x, half y, half z, half w, half a, half b, half c, half d) { half8 r={x, y, z, w, a, b, c, d}; return r; }
|
||||
|
||||
__device__ void __ldmatrix_a_elems(half8 *regs, half *smem) {
|
||||
uint32_t reg0, reg1, reg2, reg3;
|
||||
asm volatile(
|
||||
"ldmatrix.sync.aligned.m8n8.x4.shared.b16 {%0, %1, %2, %3}, [%4];"
|
||||
: "=r"(reg0), "=r"(reg1), "=r"(reg2), "=r"(reg3)
|
||||
: "l"(__cvta_generic_to_shared(smem))
|
||||
);
|
||||
uint32_t *addr = reinterpret_cast<uint32_t*>(regs);
|
||||
addr[0] = reg0;
|
||||
addr[1] = reg1;
|
||||
addr[2] = reg2;
|
||||
addr[3] = reg3;
|
||||
}
|
||||
|
||||
__device__ void __ldmatrix_b_elems(half4 *regs_lo, half4 *regs_hi, half *smem) {
|
||||
uint32_t reg0, reg1, reg2, reg3;
|
||||
asm volatile(
|
||||
"ldmatrix.sync.aligned.m8n8.x4.trans.shared.b16 {%0, %1, %2, %3}, [%4];"
|
||||
: "=r"(reg0), "=r"(reg1), "=r"(reg2), "=r"(reg3)
|
||||
: "l"(__cvta_generic_to_shared(smem))
|
||||
);
|
||||
uint32_t *addr_lo = reinterpret_cast<uint32_t*>(regs_lo);
|
||||
uint32_t *addr_hi = reinterpret_cast<uint32_t*>(regs_hi);
|
||||
addr_lo[0] = reg0;
|
||||
addr_lo[1] = reg1;
|
||||
addr_hi[0] = reg2;
|
||||
addr_hi[1] = reg3;
|
||||
}
|
||||
|
||||
__device__ float4 __WMMA_8_16_16_half_float(half8 a, half4 b, float4 c) {
|
||||
int *a_pk = (int *) (&a), *b_pk = (int *) (&b);
|
||||
asm( "mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 { %0, %1, %2, %3 }, { %4, %5, %6, %7 }, { %8, %9 }, { %0, %1, %2, %3 };"
|
||||
: "+f"(c.x), "+f"(c.y), "+f"(c.z), "+f"(c.w) : "r"(a_pk[0]), "r"(a_pk[1]), "r"(a_pk[2]), "r"(a_pk[3]), "r"(b_pk[0]), "r"(b_pk[1]) );
|
||||
return c;
|
||||
}
|
||||
|
||||
extern "C" __global__ void __launch_bounds__(128) wmma_example(float* data0, const half* data1, const half* data2, int N, int K) {
|
||||
int grid_m = blockIdx.x; /* M//64 */
|
||||
int grid_n = blockIdx.y; /* N//128 */
|
||||
int threads = threadIdx.x; /* 128 */
|
||||
int wg_m = (threads/64); // 0 or 1 for 1st and 3rd blocks of b_m=16xb_k=16 vs 2nd and 4th blocks
|
||||
int wg_n = (threads/32)%2; // 0 or 1 for 1st, 3rd, 5th, 7th blocks of b_n=16xb_k=16 vs 2nd, 4th, 6th, 8th blocks - differs from triton
|
||||
int wg_threads = threads%32;
|
||||
int num_k_blocks = K / 64;
|
||||
|
||||
// load indexes
|
||||
size_t global_a_off = ((grid_m * 64) * K) + ((threads % 8) * 8) + ((threads / 8) * K);
|
||||
size_t global_b_off = (grid_n * 128) + ((threads % 16) * 8) + ((threads / 16) * N);
|
||||
|
||||
// swizzled smem store offsets - columns of smem are swizzled
|
||||
// here's a link to a description of the triton: https://github.com/triton-lang/triton/discussions/2026#discussioncomment-6746579
|
||||
// see also the thunderkittens impl: https://github.com/HazyResearch/ThunderKittens/blob/main/include/types/shared/st.cuh
|
||||
size_t store_smem_a_off = ((threads / 8) * 64) + (((threads * 8) ^ threads) & 56); // r15
|
||||
size_t store_smem_b_off = ((threads / 16) * 128) + (((threads / 16) * 8) ^ ((threads % 16) * 8)); // r19
|
||||
|
||||
// ldmatrix indices
|
||||
// threads 0-7 are row starts for A, 8-15 for B, 16-23 for C, 24-31 for D
|
||||
// [ A | C ]
|
||||
// [ - + - ]
|
||||
// [ B | D ]
|
||||
|
||||
// swizzled ldmatrix
|
||||
size_t load_smem_a_row = ((wg_m * 16) + (threads % 16)) * 64; // r293
|
||||
size_t load_smem_a_phase = (threads / 16) % 2; // r4
|
||||
|
||||
size_t load_smem_b_row = (threads % 16) * 128; // r299
|
||||
size_t load_smem_b_phase = (wg_n * 2) + (((threads / 16) % 2)); // r297 -- this differs from the generated triton kernel (swapped order)
|
||||
|
||||
size_t load_smem_a_0_k_0 = load_smem_a_row + (((load_smem_a_phase + 0) ^ (threads % 8)) * 8); // r38
|
||||
size_t load_smem_a_1_k_0 = load_smem_a_0_k_0 + (32 * 64);
|
||||
size_t load_smem_b_0_k_0 = load_smem_b_row + (((load_smem_b_phase + 0) ^ (threads % 8)) * 8);
|
||||
size_t load_smem_b_1_k_0 = load_smem_b_row + (((load_smem_b_phase + 4) ^ (threads % 8)) * 8);
|
||||
size_t load_smem_b_2_k_0 = load_smem_b_row + (((load_smem_b_phase + 8) ^ (threads % 8)) * 8);
|
||||
size_t load_smem_b_3_k_0 = load_smem_b_row + (((load_smem_b_phase + 12) ^ (threads % 8)) * 8);
|
||||
|
||||
size_t load_smem_a_0_k_1 = load_smem_a_row + (((load_smem_a_phase + 2) ^ (threads % 8)) * 8); // r58 = r293 + r316;
|
||||
size_t load_smem_a_1_k_1 = load_smem_a_0_k_1 + (32 * 64);
|
||||
size_t load_smem_b_0_k_1 = load_smem_b_0_k_0 + (16 * 128);
|
||||
size_t load_smem_b_1_k_1 = load_smem_b_1_k_0 + (16 * 128);
|
||||
size_t load_smem_b_2_k_1 = load_smem_b_2_k_0 + (16 * 128);
|
||||
size_t load_smem_b_3_k_1 = load_smem_b_3_k_0 + (16 * 128);
|
||||
|
||||
size_t load_smem_a_0_k_2 = load_smem_a_row + (((load_smem_a_phase + 4) ^ (threads % 8)) * 8); // r59 = r293 + r319;
|
||||
size_t load_smem_a_1_k_2 = load_smem_a_0_k_2 + (32 * 64);
|
||||
size_t load_smem_b_0_k_2 = load_smem_b_0_k_0 + (32 * 128);
|
||||
size_t load_smem_b_1_k_2 = load_smem_b_1_k_0 + (32 * 128);
|
||||
size_t load_smem_b_2_k_2 = load_smem_b_2_k_0 + (32 * 128);
|
||||
size_t load_smem_b_3_k_2 = load_smem_b_3_k_0 + (32 * 128);
|
||||
|
||||
size_t load_smem_a_0_k_3 = load_smem_a_row + (((load_smem_a_phase + 6) ^ (threads % 8)) * 8); // r60 = r293 + r322;
|
||||
size_t load_smem_a_1_k_3 = load_smem_a_0_k_3 + (32 * 64);
|
||||
size_t load_smem_b_0_k_3 = load_smem_b_0_k_0 + (48 * 128);
|
||||
size_t load_smem_b_1_k_3 = load_smem_b_1_k_0 + (48 * 128);
|
||||
size_t load_smem_b_2_k_3 = load_smem_b_2_k_0 + (48 * 128);
|
||||
size_t load_smem_b_3_k_3 = load_smem_b_3_k_0 + (48 * 128);
|
||||
|
||||
// create shared mem (A_1 8192 bytes, A_2 8192 bytes, B_1 16384 bytes, B2_16384 bytes)
|
||||
__shared__ alignas(16) char smem[49152];
|
||||
|
||||
// create accs (16 WMMAs and 4 output elements each) and zero
|
||||
float4 acc_frag_0_0 = make_float4(0.0f,0.0f,0.0f,0.0f);
|
||||
float4 acc_frag_0_1 = make_float4(0.0f,0.0f,0.0f,0.0f);
|
||||
float4 acc_frag_0_2 = make_float4(0.0f,0.0f,0.0f,0.0f);
|
||||
float4 acc_frag_0_3 = make_float4(0.0f,0.0f,0.0f,0.0f);
|
||||
float4 acc_frag_0_4 = make_float4(0.0f,0.0f,0.0f,0.0f);
|
||||
float4 acc_frag_0_5 = make_float4(0.0f,0.0f,0.0f,0.0f);
|
||||
float4 acc_frag_0_6 = make_float4(0.0f,0.0f,0.0f,0.0f);
|
||||
float4 acc_frag_0_7 = make_float4(0.0f,0.0f,0.0f,0.0f);
|
||||
float4 acc_frag_1_0 = make_float4(0.0f,0.0f,0.0f,0.0f);
|
||||
float4 acc_frag_1_1 = make_float4(0.0f,0.0f,0.0f,0.0f);
|
||||
float4 acc_frag_1_2 = make_float4(0.0f,0.0f,0.0f,0.0f);
|
||||
float4 acc_frag_1_3 = make_float4(0.0f,0.0f,0.0f,0.0f);
|
||||
float4 acc_frag_1_4 = make_float4(0.0f,0.0f,0.0f,0.0f);
|
||||
float4 acc_frag_1_5 = make_float4(0.0f,0.0f,0.0f,0.0f);
|
||||
float4 acc_frag_1_6 = make_float4(0.0f,0.0f,0.0f,0.0f);
|
||||
float4 acc_frag_1_7 = make_float4(0.0f,0.0f,0.0f,0.0f);
|
||||
|
||||
// create registers for block A elements (2)
|
||||
half8 a_frag_0;
|
||||
half8 a_frag_1;
|
||||
|
||||
// create register for block B elements (8)
|
||||
half4 b_frag_0;
|
||||
half4 b_frag_1;
|
||||
half4 b_frag_2;
|
||||
half4 b_frag_3;
|
||||
half4 b_frag_4;
|
||||
half4 b_frag_5;
|
||||
half4 b_frag_6;
|
||||
half4 b_frag_7;
|
||||
|
||||
half *smem_a_even = (half *)(smem);
|
||||
half *smem_a_odd = (half *)(smem + 8192);
|
||||
half *smem_b_even = (half *)(smem + 16384);
|
||||
half *smem_b_odd = (half *)(smem + 32768);
|
||||
|
||||
// https://developer.nvidia.com/blog/controlling-data-movement-to-boost-performance-on-ampere-architecture/
|
||||
// https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#asynchronous-data-copies
|
||||
|
||||
// start first pre-fetch load A
|
||||
__pipeline_memcpy_async(&smem_a_even[store_smem_a_off + ( 0)], &data1[global_a_off + ( 0)], 16);
|
||||
__pipeline_memcpy_async(&smem_a_even[store_smem_a_off + (16*64)], &data1[global_a_off + (16*K)], 16);
|
||||
__pipeline_memcpy_async(&smem_a_even[store_smem_a_off + (32*64)], &data1[global_a_off + (32*K)], 16);
|
||||
__pipeline_memcpy_async(&smem_a_even[store_smem_a_off + (48*64)], &data1[global_a_off + (48*K)], 16);
|
||||
|
||||
// start first pre-fetch load B
|
||||
__pipeline_memcpy_async(&smem_b_even[store_smem_b_off + ( 0)], &data2[global_b_off + ( 0)], 16);
|
||||
__pipeline_memcpy_async(&smem_b_even[store_smem_b_off + ( 8*128)], &data2[global_b_off + ( 8*N)], 16);
|
||||
__pipeline_memcpy_async(&smem_b_even[store_smem_b_off + (16*128)], &data2[global_b_off + (16*N)], 16);
|
||||
__pipeline_memcpy_async(&smem_b_even[store_smem_b_off + (24*128)], &data2[global_b_off + (24*N)], 16);
|
||||
__pipeline_memcpy_async(&smem_b_even[store_smem_b_off + (32*128)], &data2[global_b_off + (32*N)], 16);
|
||||
__pipeline_memcpy_async(&smem_b_even[store_smem_b_off + (40*128)], &data2[global_b_off + (40*N)], 16);
|
||||
__pipeline_memcpy_async(&smem_b_even[store_smem_b_off + (48*128)], &data2[global_b_off + (48*N)], 16);
|
||||
__pipeline_memcpy_async(&smem_b_even[store_smem_b_off + (56*128)], &data2[global_b_off + (56*N)], 16);
|
||||
__pipeline_commit();
|
||||
|
||||
global_a_off += 64;
|
||||
global_b_off += 64 * N;
|
||||
__syncthreads();
|
||||
|
||||
// start second pre-fetch load A
|
||||
__pipeline_memcpy_async(&smem_a_odd[store_smem_a_off + ( 0)], &data1[global_a_off + ( 0)], 16);
|
||||
__pipeline_memcpy_async(&smem_a_odd[store_smem_a_off + (16*64)], &data1[global_a_off + (16*K)], 16);
|
||||
__pipeline_memcpy_async(&smem_a_odd[store_smem_a_off + (32*64)], &data1[global_a_off + (32*K)], 16);
|
||||
__pipeline_memcpy_async(&smem_a_odd[store_smem_a_off + (48*64)], &data1[global_a_off + (48*K)], 16);
|
||||
|
||||
// start second pre-fetch load B
|
||||
__pipeline_memcpy_async(&smem_b_odd[store_smem_b_off + ( 0)], &data2[global_b_off + ( 0)], 16);
|
||||
__pipeline_memcpy_async(&smem_b_odd[store_smem_b_off + ( 8*128)], &data2[global_b_off + ( 8*N)], 16);
|
||||
__pipeline_memcpy_async(&smem_b_odd[store_smem_b_off + (16*128)], &data2[global_b_off + (16*N)], 16);
|
||||
__pipeline_memcpy_async(&smem_b_odd[store_smem_b_off + (24*128)], &data2[global_b_off + (24*N)], 16);
|
||||
__pipeline_memcpy_async(&smem_b_odd[store_smem_b_off + (32*128)], &data2[global_b_off + (32*N)], 16);
|
||||
__pipeline_memcpy_async(&smem_b_odd[store_smem_b_off + (40*128)], &data2[global_b_off + (40*N)], 16);
|
||||
__pipeline_memcpy_async(&smem_b_odd[store_smem_b_off + (48*128)], &data2[global_b_off + (48*N)], 16);
|
||||
__pipeline_memcpy_async(&smem_b_odd[store_smem_b_off + (56*128)], &data2[global_b_off + (56*N)], 16);
|
||||
__pipeline_commit();
|
||||
|
||||
global_a_off += 64;
|
||||
global_b_off += 64 * N;
|
||||
|
||||
// wait on needed prefetch value
|
||||
__pipeline_wait_prior(0); // TODO: this enables fast iterations, but incorrect results with 1 (it shouldn't)
|
||||
__syncthreads();
|
||||
|
||||
for (int block_k = 0; block_k < num_k_blocks; block_k++) {
|
||||
// BLOCK_K==4: unroll 4 iterations of ldmatrix/wmma
|
||||
half *smem_a_curr = (block_k % 2) ? smem_a_even : smem_a_odd;
|
||||
half *smem_b_curr = (block_k % 2) ? smem_b_even : smem_b_odd;
|
||||
|
||||
// first load 16 K elements and 16 WMMAs: BLOCK_M==2 * BLOCK_N==8
|
||||
__ldmatrix_a_elems(&a_frag_0, &smem_a_curr[load_smem_a_0_k_0]);
|
||||
__ldmatrix_a_elems(&a_frag_1, &smem_a_curr[load_smem_a_1_k_0]);
|
||||
__ldmatrix_b_elems(&b_frag_0, &b_frag_1, &smem_b_curr[load_smem_b_0_k_0]);
|
||||
__ldmatrix_b_elems(&b_frag_2, &b_frag_3, &smem_b_curr[load_smem_b_1_k_0]);
|
||||
__ldmatrix_b_elems(&b_frag_4, &b_frag_5, &smem_b_curr[load_smem_b_2_k_0]);
|
||||
__ldmatrix_b_elems(&b_frag_6, &b_frag_7, &smem_b_curr[load_smem_b_3_k_0]);
|
||||
acc_frag_0_0 = __WMMA_8_16_16_half_float(a_frag_0, b_frag_0, acc_frag_0_0);
|
||||
acc_frag_0_1 = __WMMA_8_16_16_half_float(a_frag_0, b_frag_1, acc_frag_0_1);
|
||||
acc_frag_0_2 = __WMMA_8_16_16_half_float(a_frag_0, b_frag_2, acc_frag_0_2);
|
||||
acc_frag_0_3 = __WMMA_8_16_16_half_float(a_frag_0, b_frag_3, acc_frag_0_3);
|
||||
acc_frag_0_4 = __WMMA_8_16_16_half_float(a_frag_0, b_frag_4, acc_frag_0_4);
|
||||
acc_frag_0_5 = __WMMA_8_16_16_half_float(a_frag_0, b_frag_5, acc_frag_0_5);
|
||||
acc_frag_0_6 = __WMMA_8_16_16_half_float(a_frag_0, b_frag_6, acc_frag_0_6);
|
||||
acc_frag_0_7 = __WMMA_8_16_16_half_float(a_frag_0, b_frag_7, acc_frag_0_7);
|
||||
acc_frag_1_0 = __WMMA_8_16_16_half_float(a_frag_1, b_frag_0, acc_frag_1_0);
|
||||
acc_frag_1_1 = __WMMA_8_16_16_half_float(a_frag_1, b_frag_1, acc_frag_1_1);
|
||||
acc_frag_1_2 = __WMMA_8_16_16_half_float(a_frag_1, b_frag_2, acc_frag_1_2);
|
||||
acc_frag_1_3 = __WMMA_8_16_16_half_float(a_frag_1, b_frag_3, acc_frag_1_3);
|
||||
acc_frag_1_4 = __WMMA_8_16_16_half_float(a_frag_1, b_frag_4, acc_frag_1_4);
|
||||
acc_frag_1_5 = __WMMA_8_16_16_half_float(a_frag_1, b_frag_5, acc_frag_1_5);
|
||||
acc_frag_1_6 = __WMMA_8_16_16_half_float(a_frag_1, b_frag_6, acc_frag_1_6);
|
||||
acc_frag_1_7 = __WMMA_8_16_16_half_float(a_frag_1, b_frag_7, acc_frag_1_7);
|
||||
|
||||
// next 16 K elements
|
||||
__ldmatrix_a_elems(&a_frag_0, &smem_a_curr[load_smem_a_0_k_1]);
|
||||
__ldmatrix_a_elems(&a_frag_1, &smem_a_curr[load_smem_a_1_k_1]);
|
||||
__ldmatrix_b_elems(&b_frag_0, &b_frag_1, &smem_b_curr[load_smem_b_0_k_1]);
|
||||
__ldmatrix_b_elems(&b_frag_2, &b_frag_3, &smem_b_curr[load_smem_b_1_k_1]);
|
||||
__ldmatrix_b_elems(&b_frag_4, &b_frag_5, &smem_b_curr[load_smem_b_2_k_1]);
|
||||
__ldmatrix_b_elems(&b_frag_6, &b_frag_7, &smem_b_curr[load_smem_b_3_k_1]);
|
||||
acc_frag_0_0 = __WMMA_8_16_16_half_float(a_frag_0, b_frag_0, acc_frag_0_0);
|
||||
acc_frag_0_1 = __WMMA_8_16_16_half_float(a_frag_0, b_frag_1, acc_frag_0_1);
|
||||
acc_frag_0_2 = __WMMA_8_16_16_half_float(a_frag_0, b_frag_2, acc_frag_0_2);
|
||||
acc_frag_0_3 = __WMMA_8_16_16_half_float(a_frag_0, b_frag_3, acc_frag_0_3);
|
||||
acc_frag_0_4 = __WMMA_8_16_16_half_float(a_frag_0, b_frag_4, acc_frag_0_4);
|
||||
acc_frag_0_5 = __WMMA_8_16_16_half_float(a_frag_0, b_frag_5, acc_frag_0_5);
|
||||
acc_frag_0_6 = __WMMA_8_16_16_half_float(a_frag_0, b_frag_6, acc_frag_0_6);
|
||||
acc_frag_0_7 = __WMMA_8_16_16_half_float(a_frag_0, b_frag_7, acc_frag_0_7);
|
||||
acc_frag_1_0 = __WMMA_8_16_16_half_float(a_frag_1, b_frag_0, acc_frag_1_0);
|
||||
acc_frag_1_1 = __WMMA_8_16_16_half_float(a_frag_1, b_frag_1, acc_frag_1_1);
|
||||
acc_frag_1_2 = __WMMA_8_16_16_half_float(a_frag_1, b_frag_2, acc_frag_1_2);
|
||||
acc_frag_1_3 = __WMMA_8_16_16_half_float(a_frag_1, b_frag_3, acc_frag_1_3);
|
||||
acc_frag_1_4 = __WMMA_8_16_16_half_float(a_frag_1, b_frag_4, acc_frag_1_4);
|
||||
acc_frag_1_5 = __WMMA_8_16_16_half_float(a_frag_1, b_frag_5, acc_frag_1_5);
|
||||
acc_frag_1_6 = __WMMA_8_16_16_half_float(a_frag_1, b_frag_6, acc_frag_1_6);
|
||||
acc_frag_1_7 = __WMMA_8_16_16_half_float(a_frag_1, b_frag_7, acc_frag_1_7);
|
||||
|
||||
// next 16 K elements
|
||||
__ldmatrix_a_elems(&a_frag_0, &smem_a_curr[load_smem_a_0_k_2]);
|
||||
__ldmatrix_a_elems(&a_frag_1, &smem_a_curr[load_smem_a_1_k_2]);
|
||||
__ldmatrix_b_elems(&b_frag_0, &b_frag_1, &smem_b_curr[load_smem_b_0_k_2]);
|
||||
__ldmatrix_b_elems(&b_frag_2, &b_frag_3, &smem_b_curr[load_smem_b_1_k_2]);
|
||||
__ldmatrix_b_elems(&b_frag_4, &b_frag_5, &smem_b_curr[load_smem_b_2_k_2]);
|
||||
__ldmatrix_b_elems(&b_frag_6, &b_frag_7, &smem_b_curr[load_smem_b_3_k_2]);
|
||||
acc_frag_0_0 = __WMMA_8_16_16_half_float(a_frag_0, b_frag_0, acc_frag_0_0);
|
||||
acc_frag_0_1 = __WMMA_8_16_16_half_float(a_frag_0, b_frag_1, acc_frag_0_1);
|
||||
acc_frag_0_2 = __WMMA_8_16_16_half_float(a_frag_0, b_frag_2, acc_frag_0_2);
|
||||
acc_frag_0_3 = __WMMA_8_16_16_half_float(a_frag_0, b_frag_3, acc_frag_0_3);
|
||||
acc_frag_0_4 = __WMMA_8_16_16_half_float(a_frag_0, b_frag_4, acc_frag_0_4);
|
||||
acc_frag_0_5 = __WMMA_8_16_16_half_float(a_frag_0, b_frag_5, acc_frag_0_5);
|
||||
acc_frag_0_6 = __WMMA_8_16_16_half_float(a_frag_0, b_frag_6, acc_frag_0_6);
|
||||
acc_frag_0_7 = __WMMA_8_16_16_half_float(a_frag_0, b_frag_7, acc_frag_0_7);
|
||||
acc_frag_1_0 = __WMMA_8_16_16_half_float(a_frag_1, b_frag_0, acc_frag_1_0);
|
||||
acc_frag_1_1 = __WMMA_8_16_16_half_float(a_frag_1, b_frag_1, acc_frag_1_1);
|
||||
acc_frag_1_2 = __WMMA_8_16_16_half_float(a_frag_1, b_frag_2, acc_frag_1_2);
|
||||
acc_frag_1_3 = __WMMA_8_16_16_half_float(a_frag_1, b_frag_3, acc_frag_1_3);
|
||||
acc_frag_1_4 = __WMMA_8_16_16_half_float(a_frag_1, b_frag_4, acc_frag_1_4);
|
||||
acc_frag_1_5 = __WMMA_8_16_16_half_float(a_frag_1, b_frag_5, acc_frag_1_5);
|
||||
acc_frag_1_6 = __WMMA_8_16_16_half_float(a_frag_1, b_frag_6, acc_frag_1_6);
|
||||
acc_frag_1_7 = __WMMA_8_16_16_half_float(a_frag_1, b_frag_7, acc_frag_1_7);
|
||||
|
||||
// last 16 K elements
|
||||
__ldmatrix_a_elems(&a_frag_0, &smem_a_curr[load_smem_a_0_k_3]);
|
||||
__ldmatrix_a_elems(&a_frag_1, &smem_a_curr[load_smem_a_1_k_3]);
|
||||
__ldmatrix_b_elems(&b_frag_0, &b_frag_1, &smem_b_curr[load_smem_b_0_k_3]);
|
||||
__ldmatrix_b_elems(&b_frag_2, &b_frag_3, &smem_b_curr[load_smem_b_1_k_3]);
|
||||
__ldmatrix_b_elems(&b_frag_4, &b_frag_5, &smem_b_curr[load_smem_b_2_k_3]);
|
||||
__ldmatrix_b_elems(&b_frag_6, &b_frag_7, &smem_b_curr[load_smem_b_3_k_3]);
|
||||
acc_frag_0_0 = __WMMA_8_16_16_half_float(a_frag_0, b_frag_0, acc_frag_0_0);
|
||||
acc_frag_0_1 = __WMMA_8_16_16_half_float(a_frag_0, b_frag_1, acc_frag_0_1);
|
||||
acc_frag_0_2 = __WMMA_8_16_16_half_float(a_frag_0, b_frag_2, acc_frag_0_2);
|
||||
acc_frag_0_3 = __WMMA_8_16_16_half_float(a_frag_0, b_frag_3, acc_frag_0_3);
|
||||
acc_frag_0_4 = __WMMA_8_16_16_half_float(a_frag_0, b_frag_4, acc_frag_0_4);
|
||||
acc_frag_0_5 = __WMMA_8_16_16_half_float(a_frag_0, b_frag_5, acc_frag_0_5);
|
||||
acc_frag_0_6 = __WMMA_8_16_16_half_float(a_frag_0, b_frag_6, acc_frag_0_6);
|
||||
acc_frag_0_7 = __WMMA_8_16_16_half_float(a_frag_0, b_frag_7, acc_frag_0_7);
|
||||
acc_frag_1_0 = __WMMA_8_16_16_half_float(a_frag_1, b_frag_0, acc_frag_1_0);
|
||||
acc_frag_1_1 = __WMMA_8_16_16_half_float(a_frag_1, b_frag_1, acc_frag_1_1);
|
||||
acc_frag_1_2 = __WMMA_8_16_16_half_float(a_frag_1, b_frag_2, acc_frag_1_2);
|
||||
acc_frag_1_3 = __WMMA_8_16_16_half_float(a_frag_1, b_frag_3, acc_frag_1_3);
|
||||
acc_frag_1_4 = __WMMA_8_16_16_half_float(a_frag_1, b_frag_4, acc_frag_1_4);
|
||||
acc_frag_1_5 = __WMMA_8_16_16_half_float(a_frag_1, b_frag_5, acc_frag_1_5);
|
||||
acc_frag_1_6 = __WMMA_8_16_16_half_float(a_frag_1, b_frag_6, acc_frag_1_6);
|
||||
acc_frag_1_7 = __WMMA_8_16_16_half_float(a_frag_1, b_frag_7, acc_frag_1_7);
|
||||
|
||||
// prefetch next iteration if needed
|
||||
__syncthreads();
|
||||
if (block_k < (num_k_blocks-2)) {
|
||||
__pipeline_memcpy_async(&smem_a_curr[store_smem_a_off + ( 0)], &data1[global_a_off + ( 0)], 16);
|
||||
__pipeline_memcpy_async(&smem_a_curr[store_smem_a_off + (16*64)], &data1[global_a_off + (16*K)], 16);
|
||||
__pipeline_memcpy_async(&smem_a_curr[store_smem_a_off + (32*64)], &data1[global_a_off + (32*K)], 16);
|
||||
__pipeline_memcpy_async(&smem_a_curr[store_smem_a_off + (48*64)], &data1[global_a_off + (48*K)], 16);
|
||||
|
||||
__pipeline_memcpy_async(&smem_b_curr[store_smem_b_off + ( 0)], &data2[global_b_off + ( 0)], 16);
|
||||
__pipeline_memcpy_async(&smem_b_curr[store_smem_b_off + ( 8*128)], &data2[global_b_off + ( 8*N)], 16);
|
||||
__pipeline_memcpy_async(&smem_b_curr[store_smem_b_off + (16*128)], &data2[global_b_off + (16*N)], 16);
|
||||
__pipeline_memcpy_async(&smem_b_curr[store_smem_b_off + (24*128)], &data2[global_b_off + (24*N)], 16);
|
||||
__pipeline_memcpy_async(&smem_b_curr[store_smem_b_off + (32*128)], &data2[global_b_off + (32*N)], 16);
|
||||
__pipeline_memcpy_async(&smem_b_curr[store_smem_b_off + (40*128)], &data2[global_b_off + (40*N)], 16);
|
||||
__pipeline_memcpy_async(&smem_b_curr[store_smem_b_off + (48*128)], &data2[global_b_off + (48*N)], 16);
|
||||
__pipeline_memcpy_async(&smem_b_curr[store_smem_b_off + (56*128)], &data2[global_b_off + (56*N)], 16);
|
||||
|
||||
global_a_off += 64;
|
||||
global_b_off += 64 * N;
|
||||
}
|
||||
__pipeline_commit();
|
||||
|
||||
if (block_k < num_k_blocks-1) {
|
||||
__pipeline_wait_prior(1);
|
||||
__syncthreads();
|
||||
}
|
||||
}
|
||||
|
||||
// write accumulators to output
|
||||
__pipeline_wait_prior(0);
|
||||
__syncthreads();
|
||||
|
||||
// slower way: write floats one by one to data0
|
||||
size_t wg_c_off = ((grid_m * 64) * N) + (grid_n * 128) + (wg_m * 16 * N) + (wg_n * 16);
|
||||
size_t thread_c_off = ((wg_threads % 4) * 2) + (((wg_threads / 4) % 8) * N);
|
||||
data0[wg_c_off + thread_c_off + 0 + ( 0*8)] = acc_frag_0_0.x;
|
||||
data0[wg_c_off + thread_c_off + 1 + ( 0*8)] = acc_frag_0_0.y;
|
||||
data0[wg_c_off + thread_c_off + (8 * N) + 0 + ( 0*8)] = acc_frag_0_0.z;
|
||||
data0[wg_c_off + thread_c_off + (8 * N) + 1 + ( 0*8)] = acc_frag_0_0.w;
|
||||
data0[wg_c_off + thread_c_off + 0 + ( 1*8)] = acc_frag_0_1.x;
|
||||
data0[wg_c_off + thread_c_off + 1 + ( 1*8)] = acc_frag_0_1.y;
|
||||
data0[wg_c_off + thread_c_off + (8 * N) + 0 + ( 1*8)] = acc_frag_0_1.z;
|
||||
data0[wg_c_off + thread_c_off + (8 * N) + 1 + ( 1*8)] = acc_frag_0_1.w;
|
||||
data0[wg_c_off + thread_c_off + 0 + ( 4*8)] = acc_frag_0_2.x;
|
||||
data0[wg_c_off + thread_c_off + 1 + ( 4*8)] = acc_frag_0_2.y;
|
||||
data0[wg_c_off + thread_c_off + (8 * N) + 0 + ( 4*8)] = acc_frag_0_2.z;
|
||||
data0[wg_c_off + thread_c_off + (8 * N) + 1 + ( 4*8)] = acc_frag_0_2.w;
|
||||
data0[wg_c_off + thread_c_off + 0 + ( 5*8)] = acc_frag_0_3.x;
|
||||
data0[wg_c_off + thread_c_off + 1 + ( 5*8)] = acc_frag_0_3.y;
|
||||
data0[wg_c_off + thread_c_off + (8 * N) + 0 + ( 5*8)] = acc_frag_0_3.z;
|
||||
data0[wg_c_off + thread_c_off + (8 * N) + 1 + ( 5*8)] = acc_frag_0_3.w;
|
||||
data0[wg_c_off + thread_c_off + 0 + ( 8*8)] = acc_frag_0_4.x;
|
||||
data0[wg_c_off + thread_c_off + 1 + ( 8*8)] = acc_frag_0_4.y;
|
||||
data0[wg_c_off + thread_c_off + (8 * N) + 0 + ( 8*8)] = acc_frag_0_4.z;
|
||||
data0[wg_c_off + thread_c_off + (8 * N) + 1 + ( 8*8)] = acc_frag_0_4.w;
|
||||
data0[wg_c_off + thread_c_off + 0 + ( 9*8)] = acc_frag_0_5.x;
|
||||
data0[wg_c_off + thread_c_off + 1 + ( 9*8)] = acc_frag_0_5.y;
|
||||
data0[wg_c_off + thread_c_off + (8 * N) + 0 + ( 9*8)] = acc_frag_0_5.z;
|
||||
data0[wg_c_off + thread_c_off + (8 * N) + 1 + ( 9*8)] = acc_frag_0_5.w;
|
||||
data0[wg_c_off + thread_c_off + 0 + (12*8)] = acc_frag_0_6.x;
|
||||
data0[wg_c_off + thread_c_off + 1 + (12*8)] = acc_frag_0_6.y;
|
||||
data0[wg_c_off + thread_c_off + (8 * N) + 0 + (12*8)] = acc_frag_0_6.z;
|
||||
data0[wg_c_off + thread_c_off + (8 * N) + 1 + (12*8)] = acc_frag_0_6.w;
|
||||
data0[wg_c_off + thread_c_off + 0 + (13*8)] = acc_frag_0_7.x;
|
||||
data0[wg_c_off + thread_c_off + 1 + (13*8)] = acc_frag_0_7.y;
|
||||
data0[wg_c_off + thread_c_off + (8 * N) + 0 + (13*8)] = acc_frag_0_7.z;
|
||||
data0[wg_c_off + thread_c_off + (8 * N) + 1 + (13*8)] = acc_frag_0_7.w;
|
||||
wg_c_off += 32*N;
|
||||
data0[wg_c_off + thread_c_off + 0 + ( 0*8)] = acc_frag_1_0.x;
|
||||
data0[wg_c_off + thread_c_off + 1 + ( 0*8)] = acc_frag_1_0.y;
|
||||
data0[wg_c_off + thread_c_off + (8 * N) + 0 + ( 0*8)] = acc_frag_1_0.z;
|
||||
data0[wg_c_off + thread_c_off + (8 * N) + 1 + ( 0*8)] = acc_frag_1_0.w;
|
||||
data0[wg_c_off + thread_c_off + 0 + ( 1*8)] = acc_frag_1_1.x;
|
||||
data0[wg_c_off + thread_c_off + 1 + ( 1*8)] = acc_frag_1_1.y;
|
||||
data0[wg_c_off + thread_c_off + (8 * N) + 0 + ( 1*8)] = acc_frag_1_1.z;
|
||||
data0[wg_c_off + thread_c_off + (8 * N) + 1 + ( 1*8)] = acc_frag_1_1.w;
|
||||
data0[wg_c_off + thread_c_off + 0 + ( 4*8)] = acc_frag_1_2.x;
|
||||
data0[wg_c_off + thread_c_off + 1 + ( 4*8)] = acc_frag_1_2.y;
|
||||
data0[wg_c_off + thread_c_off + (8 * N) + 0 + ( 4*8)] = acc_frag_1_2.z;
|
||||
data0[wg_c_off + thread_c_off + (8 * N) + 1 + ( 4*8)] = acc_frag_1_2.w;
|
||||
data0[wg_c_off + thread_c_off + 0 + ( 5*8)] = acc_frag_1_3.x;
|
||||
data0[wg_c_off + thread_c_off + 1 + ( 5*8)] = acc_frag_1_3.y;
|
||||
data0[wg_c_off + thread_c_off + (8 * N) + 0 + ( 5*8)] = acc_frag_1_3.z;
|
||||
data0[wg_c_off + thread_c_off + (8 * N) + 1 + ( 5*8)] = acc_frag_1_3.w;
|
||||
data0[wg_c_off + thread_c_off + 0 + ( 8*8)] = acc_frag_1_4.x;
|
||||
data0[wg_c_off + thread_c_off + 1 + ( 8*8)] = acc_frag_1_4.y;
|
||||
data0[wg_c_off + thread_c_off + (8 * N) + 0 + ( 8*8)] = acc_frag_1_4.z;
|
||||
data0[wg_c_off + thread_c_off + (8 * N) + 1 + ( 8*8)] = acc_frag_1_4.w;
|
||||
data0[wg_c_off + thread_c_off + 0 + ( 9*8)] = acc_frag_1_5.x;
|
||||
data0[wg_c_off + thread_c_off + 1 + ( 9*8)] = acc_frag_1_5.y;
|
||||
data0[wg_c_off + thread_c_off + (8 * N) + 0 + ( 9*8)] = acc_frag_1_5.z;
|
||||
data0[wg_c_off + thread_c_off + (8 * N) + 1 + ( 9*8)] = acc_frag_1_5.w;
|
||||
data0[wg_c_off + thread_c_off + 0 + (12*8)] = acc_frag_1_6.x;
|
||||
data0[wg_c_off + thread_c_off + 1 + (12*8)] = acc_frag_1_6.y;
|
||||
data0[wg_c_off + thread_c_off + (8 * N) + 0 + (12*8)] = acc_frag_1_6.z;
|
||||
data0[wg_c_off + thread_c_off + (8 * N) + 1 + (12*8)] = acc_frag_1_6.w;
|
||||
data0[wg_c_off + thread_c_off + 0 + (13*8)] = acc_frag_1_7.x;
|
||||
data0[wg_c_off + thread_c_off + 1 + (13*8)] = acc_frag_1_7.y;
|
||||
data0[wg_c_off + thread_c_off + (8 * N) + 0 + (13*8)] = acc_frag_1_7.z;
|
||||
data0[wg_c_off + thread_c_off + (8 * N) + 1 + (13*8)] = acc_frag_1_7.w;
|
||||
}
|
||||
363
extra/gemm/max_kernels/nv.fp16_fp32_fp32.flat_smem_input.cu
Normal file
363
extra/gemm/max_kernels/nv.fp16_fp32_fp32.flat_smem_input.cu
Normal file
@@ -0,0 +1,363 @@
|
||||
#define INFINITY (__int_as_float(0x7f800000))
|
||||
#define NAN (__int_as_float(0x7fffffff))
|
||||
#include <cuda_fp16.h>
|
||||
#include <cuda_pipeline.h>
|
||||
#define N_PAD 132
|
||||
|
||||
struct __align__(8) half4 { half x, y, z, w; };
|
||||
__device__ half4 make_half4(half x, half y, half z, half w) { half4 r={x, y, z, w}; return r; }
|
||||
|
||||
struct __align__(16) half8 { half x, y, z, w, a, b, c, d; };
|
||||
__device__ half8 make_half8(half x, half y, half z, half w, half a, half b, half c, half d) { half8 r={x, y, z, w, a, b, c, d}; return r; }
|
||||
|
||||
__device__ void __ldmatrix_a_elems(half8 *regs, half *smem) {
|
||||
uint32_t reg0, reg1, reg2, reg3;
|
||||
asm volatile(
|
||||
"ldmatrix.sync.aligned.m8n8.x4.shared.b16 {%0, %1, %2, %3}, [%4];"
|
||||
: "=r"(reg0), "=r"(reg1), "=r"(reg2), "=r"(reg3)
|
||||
: "l"(__cvta_generic_to_shared(smem))
|
||||
);
|
||||
uint32_t *addr = reinterpret_cast<uint32_t*>(regs);
|
||||
addr[0] = reg0;
|
||||
addr[1] = reg1;
|
||||
addr[2] = reg2;
|
||||
addr[3] = reg3;
|
||||
}
|
||||
|
||||
__device__ void __ldmatrix_b_elems(half4 *regs_lo, half4 *regs_hi, half *smem) {
|
||||
uint32_t reg0, reg1, reg2, reg3;
|
||||
asm volatile(
|
||||
"ldmatrix.sync.aligned.m8n8.x4.trans.shared.b16 {%0, %1, %2, %3}, [%4];"
|
||||
: "=r"(reg0), "=r"(reg1), "=r"(reg2), "=r"(reg3)
|
||||
: "l"(__cvta_generic_to_shared(smem))
|
||||
);
|
||||
uint32_t *addr_lo = reinterpret_cast<uint32_t*>(regs_lo);
|
||||
uint32_t *addr_hi = reinterpret_cast<uint32_t*>(regs_hi);
|
||||
addr_lo[0] = reg0;
|
||||
addr_lo[1] = reg1;
|
||||
addr_hi[0] = reg2;
|
||||
addr_hi[1] = reg3;
|
||||
}
|
||||
|
||||
__device__ float4 __WMMA_8_16_16_half_float(half8 a, half4 b, float4 c) {
|
||||
int *a_pk = (int *) (&a), *b_pk = (int *) (&b);
|
||||
asm( "mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 { %0, %1, %2, %3 }, { %4, %5, %6, %7 }, { %8, %9 }, { %0, %1, %2, %3 };"
|
||||
: "+f"(c.x), "+f"(c.y), "+f"(c.z), "+f"(c.w) : "r"(a_pk[0]), "r"(a_pk[1]), "r"(a_pk[2]), "r"(a_pk[3]), "r"(b_pk[0]), "r"(b_pk[1]) );
|
||||
return c;
|
||||
}
|
||||
|
||||
extern "C" __global__ void __launch_bounds__(128) wmma_example(float* data0, const half* data1, const half* data2, int N, int K) {
|
||||
int grid_m = blockIdx.x; /* M//64 */
|
||||
int grid_n = blockIdx.y; /* N//128 */
|
||||
int threads = threadIdx.x; /* 128 */
|
||||
int wg_m = (threads/64); // 0 or 1 for 1st and 3rd blocks of b_m=16xb_k=16 vs 2nd and 4th blocks
|
||||
int wg_n = (threads/32)%2; // 0 or 1 for 1st, 3rd, 5th, 7th blocks of b_n=16xb_k=16 vs 2nd, 4th, 6th, 8th blocks - differs from triton
|
||||
int wg_threads = threads%32;
|
||||
int num_k_blocks = K / 64;
|
||||
|
||||
// load indexes
|
||||
size_t global_a_off = ((grid_m * 64) * K) + ((threads % 8) * 8) + ((threads / 8) * K);
|
||||
size_t global_b_off = (grid_n * 128) + ((threads % 16) * 8) + ((threads / 16) * N);
|
||||
|
||||
// non-swizzled - should work slowly with bank conflicts
|
||||
size_t store_smem_a_off = ((threads % 8) * 8) + ((threads / 8) * 64);
|
||||
size_t store_smem_b_off = ((threads % 16) * 8) + ((threads / 16) * 128);
|
||||
|
||||
// ldmatrix indices
|
||||
// threads 0-7 are row starts for A, 8-15 for B, 16-23 for C, 24-31 for D
|
||||
// [ A | C ]
|
||||
// [ - + - ]
|
||||
// [ B | D ]
|
||||
|
||||
// unswizzled ldmatrix
|
||||
size_t load_smem_a_0_k_0 = (wg_m * 16 * 64) + ((wg_threads % 8) * 64) + (((wg_threads / 8) % 2) * 64 * 8) + ((wg_threads / 16) * 8);
|
||||
size_t load_smem_a_1_k_0 = load_smem_a_0_k_0 + (32*64);
|
||||
size_t load_smem_b_0_k_0 = (wg_n * 16) + ((wg_threads % 8) * 128) + (((wg_threads / 8) % 2) * 128 * 8) + ((wg_threads / 16) * 8);
|
||||
size_t load_smem_b_1_k_0 = load_smem_b_0_k_0 + 32;
|
||||
size_t load_smem_b_2_k_0 = load_smem_b_0_k_0 + 64;
|
||||
size_t load_smem_b_3_k_0 = load_smem_b_0_k_0 + 96;
|
||||
|
||||
size_t load_smem_a_0_k_1 = load_smem_a_0_k_0 + 16;
|
||||
size_t load_smem_a_1_k_1 = load_smem_a_1_k_0 + 16;
|
||||
size_t load_smem_b_0_k_1 = load_smem_b_0_k_0 + (16 * 128);
|
||||
size_t load_smem_b_1_k_1 = load_smem_b_1_k_0 + (16 * 128);
|
||||
size_t load_smem_b_2_k_1 = load_smem_b_2_k_0 + (16 * 128);
|
||||
size_t load_smem_b_3_k_1 = load_smem_b_3_k_0 + (16 * 128);
|
||||
|
||||
size_t load_smem_a_0_k_2 = load_smem_a_0_k_0 + 32;
|
||||
size_t load_smem_a_1_k_2 = load_smem_a_1_k_0 + 32;
|
||||
size_t load_smem_b_0_k_2 = load_smem_b_0_k_0 + (32 * 128);
|
||||
size_t load_smem_b_1_k_2 = load_smem_b_1_k_0 + (32 * 128);
|
||||
size_t load_smem_b_2_k_2 = load_smem_b_2_k_0 + (32 * 128);
|
||||
size_t load_smem_b_3_k_2 = load_smem_b_3_k_0 + (32 * 128);
|
||||
|
||||
size_t load_smem_a_0_k_3 = load_smem_a_0_k_0 + 48;
|
||||
size_t load_smem_a_1_k_3 = load_smem_a_1_k_0 + 48;
|
||||
size_t load_smem_b_0_k_3 = load_smem_b_0_k_0 + (48 * 128);
|
||||
size_t load_smem_b_1_k_3 = load_smem_b_1_k_0 + (48 * 128);
|
||||
size_t load_smem_b_2_k_3 = load_smem_b_2_k_0 + (48 * 128);
|
||||
size_t load_smem_b_3_k_3 = load_smem_b_3_k_0 + (48 * 128);
|
||||
|
||||
// create shared mem (A 8192 bytes, B 16384 bytes)
|
||||
__shared__ alignas(16) char smem[24576];
|
||||
|
||||
// create accs (16 WMMAs and 4 output elements each) and zero
|
||||
float4 acc_frag_0_0 = make_float4(0.0f,0.0f,0.0f,0.0f);
|
||||
float4 acc_frag_0_1 = make_float4(0.0f,0.0f,0.0f,0.0f);
|
||||
float4 acc_frag_0_2 = make_float4(0.0f,0.0f,0.0f,0.0f);
|
||||
float4 acc_frag_0_3 = make_float4(0.0f,0.0f,0.0f,0.0f);
|
||||
float4 acc_frag_0_4 = make_float4(0.0f,0.0f,0.0f,0.0f);
|
||||
float4 acc_frag_0_5 = make_float4(0.0f,0.0f,0.0f,0.0f);
|
||||
float4 acc_frag_0_6 = make_float4(0.0f,0.0f,0.0f,0.0f);
|
||||
float4 acc_frag_0_7 = make_float4(0.0f,0.0f,0.0f,0.0f);
|
||||
float4 acc_frag_1_0 = make_float4(0.0f,0.0f,0.0f,0.0f);
|
||||
float4 acc_frag_1_1 = make_float4(0.0f,0.0f,0.0f,0.0f);
|
||||
float4 acc_frag_1_2 = make_float4(0.0f,0.0f,0.0f,0.0f);
|
||||
float4 acc_frag_1_3 = make_float4(0.0f,0.0f,0.0f,0.0f);
|
||||
float4 acc_frag_1_4 = make_float4(0.0f,0.0f,0.0f,0.0f);
|
||||
float4 acc_frag_1_5 = make_float4(0.0f,0.0f,0.0f,0.0f);
|
||||
float4 acc_frag_1_6 = make_float4(0.0f,0.0f,0.0f,0.0f);
|
||||
float4 acc_frag_1_7 = make_float4(0.0f,0.0f,0.0f,0.0f);
|
||||
|
||||
// create registers for block A elements (2)
|
||||
half8 a_frag_0;
|
||||
half8 a_frag_1;
|
||||
|
||||
// create register for block B elements (8)
|
||||
half4 b_frag_0;
|
||||
half4 b_frag_1;
|
||||
half4 b_frag_2;
|
||||
half4 b_frag_3;
|
||||
half4 b_frag_4;
|
||||
half4 b_frag_5;
|
||||
half4 b_frag_6;
|
||||
half4 b_frag_7;
|
||||
|
||||
half *smem_a = (half *)(smem);
|
||||
half *smem_b = (half *)(smem + 8192);
|
||||
|
||||
// https://developer.nvidia.com/blog/controlling-data-movement-to-boost-performance-on-ampere-architecture/
|
||||
// https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#asynchronous-data-copies
|
||||
|
||||
// start first pre-fetch load A
|
||||
__pipeline_memcpy_async(&smem_a[store_smem_a_off + ( 0)], &data1[global_a_off + ( 0)], 16);
|
||||
__pipeline_memcpy_async(&smem_a[store_smem_a_off + (16*64)], &data1[global_a_off + (16*K)], 16);
|
||||
__pipeline_memcpy_async(&smem_a[store_smem_a_off + (32*64)], &data1[global_a_off + (32*K)], 16);
|
||||
__pipeline_memcpy_async(&smem_a[store_smem_a_off + (48*64)], &data1[global_a_off + (48*K)], 16);
|
||||
|
||||
// start first pre-fetch load B
|
||||
__pipeline_memcpy_async(&smem_b[store_smem_b_off + ( 0)], &data2[global_b_off + ( 0)], 16);
|
||||
__pipeline_memcpy_async(&smem_b[store_smem_b_off + ( 8*128)], &data2[global_b_off + ( 8*N)], 16);
|
||||
__pipeline_memcpy_async(&smem_b[store_smem_b_off + (16*128)], &data2[global_b_off + (16*N)], 16);
|
||||
__pipeline_memcpy_async(&smem_b[store_smem_b_off + (24*128)], &data2[global_b_off + (24*N)], 16);
|
||||
__pipeline_memcpy_async(&smem_b[store_smem_b_off + (32*128)], &data2[global_b_off + (32*N)], 16);
|
||||
__pipeline_memcpy_async(&smem_b[store_smem_b_off + (40*128)], &data2[global_b_off + (40*N)], 16);
|
||||
__pipeline_memcpy_async(&smem_b[store_smem_b_off + (48*128)], &data2[global_b_off + (48*N)], 16);
|
||||
__pipeline_memcpy_async(&smem_b[store_smem_b_off + (56*128)], &data2[global_b_off + (56*N)], 16);
|
||||
__pipeline_commit();
|
||||
|
||||
global_a_off += 64;
|
||||
global_b_off += 64 * N;
|
||||
__syncthreads();
|
||||
|
||||
for (int block_k = 0; block_k < num_k_blocks; block_k++) {
|
||||
// wait on needed prefetch value
|
||||
__pipeline_wait_prior(0);
|
||||
__syncthreads();
|
||||
|
||||
// BLOCK_K==4: unroll 4 iterations of ldmatrix/wmma
|
||||
half *smem_a_curr = smem_a;
|
||||
half *smem_b_curr = smem_b;
|
||||
|
||||
// first load 16 K elements and 16 WMMAs: BLOCK_M==2 * BLOCK_N==8
|
||||
__ldmatrix_a_elems(&a_frag_0, &smem_a_curr[load_smem_a_0_k_0]);
|
||||
__ldmatrix_a_elems(&a_frag_1, &smem_a_curr[load_smem_a_1_k_0]);
|
||||
__ldmatrix_b_elems(&b_frag_0, &b_frag_1, &smem_b_curr[load_smem_b_0_k_0]);
|
||||
__ldmatrix_b_elems(&b_frag_2, &b_frag_3, &smem_b_curr[load_smem_b_1_k_0]);
|
||||
__ldmatrix_b_elems(&b_frag_4, &b_frag_5, &smem_b_curr[load_smem_b_2_k_0]);
|
||||
__ldmatrix_b_elems(&b_frag_6, &b_frag_7, &smem_b_curr[load_smem_b_3_k_0]);
|
||||
acc_frag_0_0 = __WMMA_8_16_16_half_float(a_frag_0, b_frag_0, acc_frag_0_0);
|
||||
acc_frag_0_1 = __WMMA_8_16_16_half_float(a_frag_0, b_frag_1, acc_frag_0_1);
|
||||
acc_frag_0_2 = __WMMA_8_16_16_half_float(a_frag_0, b_frag_2, acc_frag_0_2);
|
||||
acc_frag_0_3 = __WMMA_8_16_16_half_float(a_frag_0, b_frag_3, acc_frag_0_3);
|
||||
acc_frag_0_4 = __WMMA_8_16_16_half_float(a_frag_0, b_frag_4, acc_frag_0_4);
|
||||
acc_frag_0_5 = __WMMA_8_16_16_half_float(a_frag_0, b_frag_5, acc_frag_0_5);
|
||||
acc_frag_0_6 = __WMMA_8_16_16_half_float(a_frag_0, b_frag_6, acc_frag_0_6);
|
||||
acc_frag_0_7 = __WMMA_8_16_16_half_float(a_frag_0, b_frag_7, acc_frag_0_7);
|
||||
acc_frag_1_0 = __WMMA_8_16_16_half_float(a_frag_1, b_frag_0, acc_frag_1_0);
|
||||
acc_frag_1_1 = __WMMA_8_16_16_half_float(a_frag_1, b_frag_1, acc_frag_1_1);
|
||||
acc_frag_1_2 = __WMMA_8_16_16_half_float(a_frag_1, b_frag_2, acc_frag_1_2);
|
||||
acc_frag_1_3 = __WMMA_8_16_16_half_float(a_frag_1, b_frag_3, acc_frag_1_3);
|
||||
acc_frag_1_4 = __WMMA_8_16_16_half_float(a_frag_1, b_frag_4, acc_frag_1_4);
|
||||
acc_frag_1_5 = __WMMA_8_16_16_half_float(a_frag_1, b_frag_5, acc_frag_1_5);
|
||||
acc_frag_1_6 = __WMMA_8_16_16_half_float(a_frag_1, b_frag_6, acc_frag_1_6);
|
||||
acc_frag_1_7 = __WMMA_8_16_16_half_float(a_frag_1, b_frag_7, acc_frag_1_7);
|
||||
|
||||
// next 16 K elements
|
||||
__ldmatrix_a_elems(&a_frag_0, &smem_a_curr[load_smem_a_0_k_1]);
|
||||
__ldmatrix_a_elems(&a_frag_1, &smem_a_curr[load_smem_a_1_k_1]);
|
||||
__ldmatrix_b_elems(&b_frag_0, &b_frag_1, &smem_b_curr[load_smem_b_0_k_1]);
|
||||
__ldmatrix_b_elems(&b_frag_2, &b_frag_3, &smem_b_curr[load_smem_b_1_k_1]);
|
||||
__ldmatrix_b_elems(&b_frag_4, &b_frag_5, &smem_b_curr[load_smem_b_2_k_1]);
|
||||
__ldmatrix_b_elems(&b_frag_6, &b_frag_7, &smem_b_curr[load_smem_b_3_k_1]);
|
||||
acc_frag_0_0 = __WMMA_8_16_16_half_float(a_frag_0, b_frag_0, acc_frag_0_0);
|
||||
acc_frag_0_1 = __WMMA_8_16_16_half_float(a_frag_0, b_frag_1, acc_frag_0_1);
|
||||
acc_frag_0_2 = __WMMA_8_16_16_half_float(a_frag_0, b_frag_2, acc_frag_0_2);
|
||||
acc_frag_0_3 = __WMMA_8_16_16_half_float(a_frag_0, b_frag_3, acc_frag_0_3);
|
||||
acc_frag_0_4 = __WMMA_8_16_16_half_float(a_frag_0, b_frag_4, acc_frag_0_4);
|
||||
acc_frag_0_5 = __WMMA_8_16_16_half_float(a_frag_0, b_frag_5, acc_frag_0_5);
|
||||
acc_frag_0_6 = __WMMA_8_16_16_half_float(a_frag_0, b_frag_6, acc_frag_0_6);
|
||||
acc_frag_0_7 = __WMMA_8_16_16_half_float(a_frag_0, b_frag_7, acc_frag_0_7);
|
||||
acc_frag_1_0 = __WMMA_8_16_16_half_float(a_frag_1, b_frag_0, acc_frag_1_0);
|
||||
acc_frag_1_1 = __WMMA_8_16_16_half_float(a_frag_1, b_frag_1, acc_frag_1_1);
|
||||
acc_frag_1_2 = __WMMA_8_16_16_half_float(a_frag_1, b_frag_2, acc_frag_1_2);
|
||||
acc_frag_1_3 = __WMMA_8_16_16_half_float(a_frag_1, b_frag_3, acc_frag_1_3);
|
||||
acc_frag_1_4 = __WMMA_8_16_16_half_float(a_frag_1, b_frag_4, acc_frag_1_4);
|
||||
acc_frag_1_5 = __WMMA_8_16_16_half_float(a_frag_1, b_frag_5, acc_frag_1_5);
|
||||
acc_frag_1_6 = __WMMA_8_16_16_half_float(a_frag_1, b_frag_6, acc_frag_1_6);
|
||||
acc_frag_1_7 = __WMMA_8_16_16_half_float(a_frag_1, b_frag_7, acc_frag_1_7);
|
||||
|
||||
// next 16 K elements
|
||||
__ldmatrix_a_elems(&a_frag_0, &smem_a_curr[load_smem_a_0_k_2]);
|
||||
__ldmatrix_a_elems(&a_frag_1, &smem_a_curr[load_smem_a_1_k_2]);
|
||||
__ldmatrix_b_elems(&b_frag_0, &b_frag_1, &smem_b_curr[load_smem_b_0_k_2]);
|
||||
__ldmatrix_b_elems(&b_frag_2, &b_frag_3, &smem_b_curr[load_smem_b_1_k_2]);
|
||||
__ldmatrix_b_elems(&b_frag_4, &b_frag_5, &smem_b_curr[load_smem_b_2_k_2]);
|
||||
__ldmatrix_b_elems(&b_frag_6, &b_frag_7, &smem_b_curr[load_smem_b_3_k_2]);
|
||||
acc_frag_0_0 = __WMMA_8_16_16_half_float(a_frag_0, b_frag_0, acc_frag_0_0);
|
||||
acc_frag_0_1 = __WMMA_8_16_16_half_float(a_frag_0, b_frag_1, acc_frag_0_1);
|
||||
acc_frag_0_2 = __WMMA_8_16_16_half_float(a_frag_0, b_frag_2, acc_frag_0_2);
|
||||
acc_frag_0_3 = __WMMA_8_16_16_half_float(a_frag_0, b_frag_3, acc_frag_0_3);
|
||||
acc_frag_0_4 = __WMMA_8_16_16_half_float(a_frag_0, b_frag_4, acc_frag_0_4);
|
||||
acc_frag_0_5 = __WMMA_8_16_16_half_float(a_frag_0, b_frag_5, acc_frag_0_5);
|
||||
acc_frag_0_6 = __WMMA_8_16_16_half_float(a_frag_0, b_frag_6, acc_frag_0_6);
|
||||
acc_frag_0_7 = __WMMA_8_16_16_half_float(a_frag_0, b_frag_7, acc_frag_0_7);
|
||||
acc_frag_1_0 = __WMMA_8_16_16_half_float(a_frag_1, b_frag_0, acc_frag_1_0);
|
||||
acc_frag_1_1 = __WMMA_8_16_16_half_float(a_frag_1, b_frag_1, acc_frag_1_1);
|
||||
acc_frag_1_2 = __WMMA_8_16_16_half_float(a_frag_1, b_frag_2, acc_frag_1_2);
|
||||
acc_frag_1_3 = __WMMA_8_16_16_half_float(a_frag_1, b_frag_3, acc_frag_1_3);
|
||||
acc_frag_1_4 = __WMMA_8_16_16_half_float(a_frag_1, b_frag_4, acc_frag_1_4);
|
||||
acc_frag_1_5 = __WMMA_8_16_16_half_float(a_frag_1, b_frag_5, acc_frag_1_5);
|
||||
acc_frag_1_6 = __WMMA_8_16_16_half_float(a_frag_1, b_frag_6, acc_frag_1_6);
|
||||
acc_frag_1_7 = __WMMA_8_16_16_half_float(a_frag_1, b_frag_7, acc_frag_1_7);
|
||||
|
||||
// last 16 K elements
|
||||
__ldmatrix_a_elems(&a_frag_0, &smem_a_curr[load_smem_a_0_k_3]);
|
||||
__ldmatrix_a_elems(&a_frag_1, &smem_a_curr[load_smem_a_1_k_3]);
|
||||
__ldmatrix_b_elems(&b_frag_0, &b_frag_1, &smem_b_curr[load_smem_b_0_k_3]);
|
||||
__ldmatrix_b_elems(&b_frag_2, &b_frag_3, &smem_b_curr[load_smem_b_1_k_3]);
|
||||
__ldmatrix_b_elems(&b_frag_4, &b_frag_5, &smem_b_curr[load_smem_b_2_k_3]);
|
||||
__ldmatrix_b_elems(&b_frag_6, &b_frag_7, &smem_b_curr[load_smem_b_3_k_3]);
|
||||
acc_frag_0_0 = __WMMA_8_16_16_half_float(a_frag_0, b_frag_0, acc_frag_0_0);
|
||||
acc_frag_0_1 = __WMMA_8_16_16_half_float(a_frag_0, b_frag_1, acc_frag_0_1);
|
||||
acc_frag_0_2 = __WMMA_8_16_16_half_float(a_frag_0, b_frag_2, acc_frag_0_2);
|
||||
acc_frag_0_3 = __WMMA_8_16_16_half_float(a_frag_0, b_frag_3, acc_frag_0_3);
|
||||
acc_frag_0_4 = __WMMA_8_16_16_half_float(a_frag_0, b_frag_4, acc_frag_0_4);
|
||||
acc_frag_0_5 = __WMMA_8_16_16_half_float(a_frag_0, b_frag_5, acc_frag_0_5);
|
||||
acc_frag_0_6 = __WMMA_8_16_16_half_float(a_frag_0, b_frag_6, acc_frag_0_6);
|
||||
acc_frag_0_7 = __WMMA_8_16_16_half_float(a_frag_0, b_frag_7, acc_frag_0_7);
|
||||
acc_frag_1_0 = __WMMA_8_16_16_half_float(a_frag_1, b_frag_0, acc_frag_1_0);
|
||||
acc_frag_1_1 = __WMMA_8_16_16_half_float(a_frag_1, b_frag_1, acc_frag_1_1);
|
||||
acc_frag_1_2 = __WMMA_8_16_16_half_float(a_frag_1, b_frag_2, acc_frag_1_2);
|
||||
acc_frag_1_3 = __WMMA_8_16_16_half_float(a_frag_1, b_frag_3, acc_frag_1_3);
|
||||
acc_frag_1_4 = __WMMA_8_16_16_half_float(a_frag_1, b_frag_4, acc_frag_1_4);
|
||||
acc_frag_1_5 = __WMMA_8_16_16_half_float(a_frag_1, b_frag_5, acc_frag_1_5);
|
||||
acc_frag_1_6 = __WMMA_8_16_16_half_float(a_frag_1, b_frag_6, acc_frag_1_6);
|
||||
acc_frag_1_7 = __WMMA_8_16_16_half_float(a_frag_1, b_frag_7, acc_frag_1_7);
|
||||
|
||||
// prefetch next iteration if needed
|
||||
__syncthreads();
|
||||
if (block_k < (num_k_blocks-1)) {
|
||||
__pipeline_memcpy_async(&smem_a_curr[store_smem_a_off + ( 0)], &data1[global_a_off + ( 0)], 16);
|
||||
__pipeline_memcpy_async(&smem_a_curr[store_smem_a_off + (16*64)], &data1[global_a_off + (16*K)], 16);
|
||||
__pipeline_memcpy_async(&smem_a_curr[store_smem_a_off + (32*64)], &data1[global_a_off + (32*K)], 16);
|
||||
__pipeline_memcpy_async(&smem_a_curr[store_smem_a_off + (48*64)], &data1[global_a_off + (48*K)], 16);
|
||||
|
||||
__pipeline_memcpy_async(&smem_b_curr[store_smem_b_off + ( 0)], &data2[global_b_off + ( 0)], 16);
|
||||
__pipeline_memcpy_async(&smem_b_curr[store_smem_b_off + ( 8*128)], &data2[global_b_off + ( 8*N)], 16);
|
||||
__pipeline_memcpy_async(&smem_b_curr[store_smem_b_off + (16*128)], &data2[global_b_off + (16*N)], 16);
|
||||
__pipeline_memcpy_async(&smem_b_curr[store_smem_b_off + (24*128)], &data2[global_b_off + (24*N)], 16);
|
||||
__pipeline_memcpy_async(&smem_b_curr[store_smem_b_off + (32*128)], &data2[global_b_off + (32*N)], 16);
|
||||
__pipeline_memcpy_async(&smem_b_curr[store_smem_b_off + (40*128)], &data2[global_b_off + (40*N)], 16);
|
||||
__pipeline_memcpy_async(&smem_b_curr[store_smem_b_off + (48*128)], &data2[global_b_off + (48*N)], 16);
|
||||
__pipeline_memcpy_async(&smem_b_curr[store_smem_b_off + (56*128)], &data2[global_b_off + (56*N)], 16);
|
||||
|
||||
global_a_off += 64;
|
||||
global_b_off += 64 * N;
|
||||
}
|
||||
__pipeline_commit();
|
||||
}
|
||||
|
||||
// write accumulators to output
|
||||
__pipeline_wait_prior(0);
|
||||
__syncthreads();
|
||||
|
||||
// slower way: write floats one by one to data0
|
||||
size_t wg_c_off = ((grid_m * 64) * N) + (grid_n * 128) + (wg_m * 16 * N) + (wg_n * 16);
|
||||
size_t thread_c_off = ((wg_threads % 4) * 2) + (((wg_threads / 4) % 8) * N);
|
||||
data0[wg_c_off + thread_c_off + 0 + ( 0*8)] = acc_frag_0_0.x;
|
||||
data0[wg_c_off + thread_c_off + 1 + ( 0*8)] = acc_frag_0_0.y;
|
||||
data0[wg_c_off + thread_c_off + (8 * N) + 0 + ( 0*8)] = acc_frag_0_0.z;
|
||||
data0[wg_c_off + thread_c_off + (8 * N) + 1 + ( 0*8)] = acc_frag_0_0.w;
|
||||
data0[wg_c_off + thread_c_off + 0 + ( 1*8)] = acc_frag_0_1.x;
|
||||
data0[wg_c_off + thread_c_off + 1 + ( 1*8)] = acc_frag_0_1.y;
|
||||
data0[wg_c_off + thread_c_off + (8 * N) + 0 + ( 1*8)] = acc_frag_0_1.z;
|
||||
data0[wg_c_off + thread_c_off + (8 * N) + 1 + ( 1*8)] = acc_frag_0_1.w;
|
||||
data0[wg_c_off + thread_c_off + 0 + ( 4*8)] = acc_frag_0_2.x;
|
||||
data0[wg_c_off + thread_c_off + 1 + ( 4*8)] = acc_frag_0_2.y;
|
||||
data0[wg_c_off + thread_c_off + (8 * N) + 0 + ( 4*8)] = acc_frag_0_2.z;
|
||||
data0[wg_c_off + thread_c_off + (8 * N) + 1 + ( 4*8)] = acc_frag_0_2.w;
|
||||
data0[wg_c_off + thread_c_off + 0 + ( 5*8)] = acc_frag_0_3.x;
|
||||
data0[wg_c_off + thread_c_off + 1 + ( 5*8)] = acc_frag_0_3.y;
|
||||
data0[wg_c_off + thread_c_off + (8 * N) + 0 + ( 5*8)] = acc_frag_0_3.z;
|
||||
data0[wg_c_off + thread_c_off + (8 * N) + 1 + ( 5*8)] = acc_frag_0_3.w;
|
||||
data0[wg_c_off + thread_c_off + 0 + ( 8*8)] = acc_frag_0_4.x;
|
||||
data0[wg_c_off + thread_c_off + 1 + ( 8*8)] = acc_frag_0_4.y;
|
||||
data0[wg_c_off + thread_c_off + (8 * N) + 0 + ( 8*8)] = acc_frag_0_4.z;
|
||||
data0[wg_c_off + thread_c_off + (8 * N) + 1 + ( 8*8)] = acc_frag_0_4.w;
|
||||
data0[wg_c_off + thread_c_off + 0 + ( 9*8)] = acc_frag_0_5.x;
|
||||
data0[wg_c_off + thread_c_off + 1 + ( 9*8)] = acc_frag_0_5.y;
|
||||
data0[wg_c_off + thread_c_off + (8 * N) + 0 + ( 9*8)] = acc_frag_0_5.z;
|
||||
data0[wg_c_off + thread_c_off + (8 * N) + 1 + ( 9*8)] = acc_frag_0_5.w;
|
||||
data0[wg_c_off + thread_c_off + 0 + (12*8)] = acc_frag_0_6.x;
|
||||
data0[wg_c_off + thread_c_off + 1 + (12*8)] = acc_frag_0_6.y;
|
||||
data0[wg_c_off + thread_c_off + (8 * N) + 0 + (12*8)] = acc_frag_0_6.z;
|
||||
data0[wg_c_off + thread_c_off + (8 * N) + 1 + (12*8)] = acc_frag_0_6.w;
|
||||
data0[wg_c_off + thread_c_off + 0 + (13*8)] = acc_frag_0_7.x;
|
||||
data0[wg_c_off + thread_c_off + 1 + (13*8)] = acc_frag_0_7.y;
|
||||
data0[wg_c_off + thread_c_off + (8 * N) + 0 + (13*8)] = acc_frag_0_7.z;
|
||||
data0[wg_c_off + thread_c_off + (8 * N) + 1 + (13*8)] = acc_frag_0_7.w;
|
||||
wg_c_off += 32*N;
|
||||
data0[wg_c_off + thread_c_off + 0 + ( 0*8)] = acc_frag_1_0.x;
|
||||
data0[wg_c_off + thread_c_off + 1 + ( 0*8)] = acc_frag_1_0.y;
|
||||
data0[wg_c_off + thread_c_off + (8 * N) + 0 + ( 0*8)] = acc_frag_1_0.z;
|
||||
data0[wg_c_off + thread_c_off + (8 * N) + 1 + ( 0*8)] = acc_frag_1_0.w;
|
||||
data0[wg_c_off + thread_c_off + 0 + ( 1*8)] = acc_frag_1_1.x;
|
||||
data0[wg_c_off + thread_c_off + 1 + ( 1*8)] = acc_frag_1_1.y;
|
||||
data0[wg_c_off + thread_c_off + (8 * N) + 0 + ( 1*8)] = acc_frag_1_1.z;
|
||||
data0[wg_c_off + thread_c_off + (8 * N) + 1 + ( 1*8)] = acc_frag_1_1.w;
|
||||
data0[wg_c_off + thread_c_off + 0 + ( 4*8)] = acc_frag_1_2.x;
|
||||
data0[wg_c_off + thread_c_off + 1 + ( 4*8)] = acc_frag_1_2.y;
|
||||
data0[wg_c_off + thread_c_off + (8 * N) + 0 + ( 4*8)] = acc_frag_1_2.z;
|
||||
data0[wg_c_off + thread_c_off + (8 * N) + 1 + ( 4*8)] = acc_frag_1_2.w;
|
||||
data0[wg_c_off + thread_c_off + 0 + ( 5*8)] = acc_frag_1_3.x;
|
||||
data0[wg_c_off + thread_c_off + 1 + ( 5*8)] = acc_frag_1_3.y;
|
||||
data0[wg_c_off + thread_c_off + (8 * N) + 0 + ( 5*8)] = acc_frag_1_3.z;
|
||||
data0[wg_c_off + thread_c_off + (8 * N) + 1 + ( 5*8)] = acc_frag_1_3.w;
|
||||
data0[wg_c_off + thread_c_off + 0 + ( 8*8)] = acc_frag_1_4.x;
|
||||
data0[wg_c_off + thread_c_off + 1 + ( 8*8)] = acc_frag_1_4.y;
|
||||
data0[wg_c_off + thread_c_off + (8 * N) + 0 + ( 8*8)] = acc_frag_1_4.z;
|
||||
data0[wg_c_off + thread_c_off + (8 * N) + 1 + ( 8*8)] = acc_frag_1_4.w;
|
||||
data0[wg_c_off + thread_c_off + 0 + ( 9*8)] = acc_frag_1_5.x;
|
||||
data0[wg_c_off + thread_c_off + 1 + ( 9*8)] = acc_frag_1_5.y;
|
||||
data0[wg_c_off + thread_c_off + (8 * N) + 0 + ( 9*8)] = acc_frag_1_5.z;
|
||||
data0[wg_c_off + thread_c_off + (8 * N) + 1 + ( 9*8)] = acc_frag_1_5.w;
|
||||
data0[wg_c_off + thread_c_off + 0 + (12*8)] = acc_frag_1_6.x;
|
||||
data0[wg_c_off + thread_c_off + 1 + (12*8)] = acc_frag_1_6.y;
|
||||
data0[wg_c_off + thread_c_off + (8 * N) + 0 + (12*8)] = acc_frag_1_6.z;
|
||||
data0[wg_c_off + thread_c_off + (8 * N) + 1 + (12*8)] = acc_frag_1_6.w;
|
||||
data0[wg_c_off + thread_c_off + 0 + (13*8)] = acc_frag_1_7.x;
|
||||
data0[wg_c_off + thread_c_off + 1 + (13*8)] = acc_frag_1_7.y;
|
||||
data0[wg_c_off + thread_c_off + (8 * N) + 0 + (13*8)] = acc_frag_1_7.z;
|
||||
data0[wg_c_off + thread_c_off + (8 * N) + 1 + (13*8)] = acc_frag_1_7.w;
|
||||
}
|
||||
439
extra/gemm/max_kernels/nv.fp16_fp32_fp32.max.cu
Normal file
439
extra/gemm/max_kernels/nv.fp16_fp32_fp32.max.cu
Normal file
@@ -0,0 +1,439 @@
|
||||
#define INFINITY (__int_as_float(0x7f800000))
|
||||
#define NAN (__int_as_float(0x7fffffff))
|
||||
#include <cuda_fp16.h>
|
||||
#include <cuda_pipeline.h>
|
||||
#define N_PAD 132
|
||||
|
||||
struct __align__(8) half4 { half x, y, z, w; };
|
||||
__device__ half4 make_half4(half x, half y, half z, half w) { half4 r={x, y, z, w}; return r; }
|
||||
|
||||
struct __align__(16) half8 { half x, y, z, w, a, b, c, d; };
|
||||
__device__ half8 make_half8(half x, half y, half z, half w, half a, half b, half c, half d) { half8 r={x, y, z, w, a, b, c, d}; return r; }
|
||||
|
||||
__device__ void __ldmatrix_a_elems(half8 *regs, half *smem) {
|
||||
uint32_t reg0, reg1, reg2, reg3;
|
||||
asm volatile(
|
||||
"ldmatrix.sync.aligned.m8n8.x4.shared.b16 {%0, %1, %2, %3}, [%4];"
|
||||
: "=r"(reg0), "=r"(reg1), "=r"(reg2), "=r"(reg3)
|
||||
: "l"(__cvta_generic_to_shared(smem))
|
||||
);
|
||||
uint32_t *addr = reinterpret_cast<uint32_t*>(regs);
|
||||
addr[0] = reg0;
|
||||
addr[1] = reg1;
|
||||
addr[2] = reg2;
|
||||
addr[3] = reg3;
|
||||
}
|
||||
|
||||
__device__ void __ldmatrix_b_elems(half4 *regs_lo, half4 *regs_hi, half *smem) {
|
||||
uint32_t reg0, reg1, reg2, reg3;
|
||||
asm volatile(
|
||||
"ldmatrix.sync.aligned.m8n8.x4.trans.shared.b16 {%0, %1, %2, %3}, [%4];"
|
||||
: "=r"(reg0), "=r"(reg1), "=r"(reg2), "=r"(reg3)
|
||||
: "l"(__cvta_generic_to_shared(smem))
|
||||
);
|
||||
uint32_t *addr_lo = reinterpret_cast<uint32_t*>(regs_lo);
|
||||
uint32_t *addr_hi = reinterpret_cast<uint32_t*>(regs_hi);
|
||||
addr_lo[0] = reg0;
|
||||
addr_lo[1] = reg1;
|
||||
addr_hi[0] = reg2;
|
||||
addr_hi[1] = reg3;
|
||||
}
|
||||
|
||||
__device__ float4 __WMMA_8_16_16_half_float(half8 a, half4 b, float4 c) {
|
||||
int *a_pk = (int *) (&a), *b_pk = (int *) (&b);
|
||||
asm( "mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 { %0, %1, %2, %3 }, { %4, %5, %6, %7 }, { %8, %9 }, { %0, %1, %2, %3 };"
|
||||
: "+f"(c.x), "+f"(c.y), "+f"(c.z), "+f"(c.w) : "r"(a_pk[0]), "r"(a_pk[1]), "r"(a_pk[2]), "r"(a_pk[3]), "r"(b_pk[0]), "r"(b_pk[1]) );
|
||||
return c;
|
||||
}
|
||||
|
||||
extern "C" __global__ void __launch_bounds__(128) wmma_example(float* data0, const half* data1, const half* data2, int N, int K) {
|
||||
int grid_m = blockIdx.x; /* M//64 */
|
||||
int grid_n = blockIdx.y; /* N//128 */
|
||||
int threads = threadIdx.x; /* 128 */
|
||||
int wg_m = (threads/64); // 0 or 1 for 1st and 3rd blocks of b_m=16xb_k=16 vs 2nd and 4th blocks
|
||||
int wg_n = (threads/32)%2; // 0 or 1 for 1st, 3rd, 5th, 7th blocks of b_n=16xb_k=16 vs 2nd, 4th, 6th, 8th blocks - differs from triton
|
||||
int wg_threads = threads%32;
|
||||
int num_k_blocks = K / 64;
|
||||
|
||||
// load indexes
|
||||
size_t global_a_off = ((grid_m * 64) * K) + ((threads % 8) * 8) + ((threads / 8) * K);
|
||||
size_t global_b_off = (grid_n * 128) + ((threads % 16) * 8) + ((threads / 16) * N);
|
||||
|
||||
// swizzled smem store offsets - columns of smem are swizzled
|
||||
// here's a link to a description of the triton: https://github.com/triton-lang/triton/discussions/2026#discussioncomment-6746579
|
||||
// see also the thunderkittens impl: https://github.com/HazyResearch/ThunderKittens/blob/main/include/types/shared/st.cuh
|
||||
size_t store_smem_a_off = ((threads / 8) * 64) + (((threads * 8) ^ threads) & 56); // r15
|
||||
size_t store_smem_b_off = ((threads / 16) * 128) + (((threads / 16) * 8) ^ ((threads % 16) * 8)); // r19
|
||||
|
||||
// ldmatrix indices - 4x loads of 8x8 matrices by 32 threads
|
||||
// threads 0-7 are row starts for A, 8-15 for B, 16-23 for C, 24-31 for D
|
||||
// [ A | C ]
|
||||
// [ - + - ]
|
||||
// [ B | D ]
|
||||
|
||||
// swizzled ldmatrix
|
||||
size_t load_smem_a_row = ((wg_m * 16) + (threads % 16)) * 64; // r293
|
||||
size_t load_smem_a_phase = (threads / 16) % 2; // r4
|
||||
|
||||
size_t load_smem_b_row = (threads % 16) * 128; // r299
|
||||
size_t load_smem_b_phase = (wg_n * 2) + (((threads / 16) % 2)); // r297 -- this differs from the generated triton kernel (swapped order)
|
||||
|
||||
size_t load_smem_a_0_k_0 = load_smem_a_row + (((load_smem_a_phase + 0) ^ (threads % 8)) * 8); // r38
|
||||
size_t load_smem_a_1_k_0 = load_smem_a_0_k_0 + (32 * 64);
|
||||
size_t load_smem_b_0_k_0 = load_smem_b_row + (((load_smem_b_phase + 0) ^ (threads % 8)) * 8);
|
||||
size_t load_smem_b_1_k_0 = load_smem_b_row + (((load_smem_b_phase + 4) ^ (threads % 8)) * 8);
|
||||
size_t load_smem_b_2_k_0 = load_smem_b_row + (((load_smem_b_phase + 8) ^ (threads % 8)) * 8);
|
||||
size_t load_smem_b_3_k_0 = load_smem_b_row + (((load_smem_b_phase + 12) ^ (threads % 8)) * 8);
|
||||
|
||||
size_t load_smem_a_0_k_1 = load_smem_a_row + (((load_smem_a_phase + 2) ^ (threads % 8)) * 8); // r58 = r293 + r316;
|
||||
size_t load_smem_a_1_k_1 = load_smem_a_0_k_1 + (32 * 64);
|
||||
size_t load_smem_b_0_k_1 = load_smem_b_0_k_0 + (16 * 128);
|
||||
size_t load_smem_b_1_k_1 = load_smem_b_1_k_0 + (16 * 128);
|
||||
size_t load_smem_b_2_k_1 = load_smem_b_2_k_0 + (16 * 128);
|
||||
size_t load_smem_b_3_k_1 = load_smem_b_3_k_0 + (16 * 128);
|
||||
|
||||
size_t load_smem_a_0_k_2 = load_smem_a_row + (((load_smem_a_phase + 4) ^ (threads % 8)) * 8); // r59 = r293 + r319;
|
||||
size_t load_smem_a_1_k_2 = load_smem_a_0_k_2 + (32 * 64);
|
||||
size_t load_smem_b_0_k_2 = load_smem_b_0_k_0 + (32 * 128);
|
||||
size_t load_smem_b_1_k_2 = load_smem_b_1_k_0 + (32 * 128);
|
||||
size_t load_smem_b_2_k_2 = load_smem_b_2_k_0 + (32 * 128);
|
||||
size_t load_smem_b_3_k_2 = load_smem_b_3_k_0 + (32 * 128);
|
||||
|
||||
size_t load_smem_a_0_k_3 = load_smem_a_row + (((load_smem_a_phase + 6) ^ (threads % 8)) * 8); // r60 = r293 + r322;
|
||||
size_t load_smem_a_1_k_3 = load_smem_a_0_k_3 + (32 * 64);
|
||||
size_t load_smem_b_0_k_3 = load_smem_b_0_k_0 + (48 * 128);
|
||||
size_t load_smem_b_1_k_3 = load_smem_b_1_k_0 + (48 * 128);
|
||||
size_t load_smem_b_2_k_3 = load_smem_b_2_k_0 + (48 * 128);
|
||||
size_t load_smem_b_3_k_3 = load_smem_b_3_k_0 + (48 * 128);
|
||||
|
||||
// create shared mem (A_1 8192 bytes, A_2 8192 bytes, B_1 16384 bytes, B2_16384 bytes)
|
||||
__shared__ alignas(16) char smem[49152];
|
||||
|
||||
// create accs (16 WMMAs and 4 output elements each) and zero
|
||||
float4 acc_frag_0_0 = make_float4(0.0f,0.0f,0.0f,0.0f);
|
||||
float4 acc_frag_0_1 = make_float4(0.0f,0.0f,0.0f,0.0f);
|
||||
float4 acc_frag_0_2 = make_float4(0.0f,0.0f,0.0f,0.0f);
|
||||
float4 acc_frag_0_3 = make_float4(0.0f,0.0f,0.0f,0.0f);
|
||||
float4 acc_frag_0_4 = make_float4(0.0f,0.0f,0.0f,0.0f);
|
||||
float4 acc_frag_0_5 = make_float4(0.0f,0.0f,0.0f,0.0f);
|
||||
float4 acc_frag_0_6 = make_float4(0.0f,0.0f,0.0f,0.0f);
|
||||
float4 acc_frag_0_7 = make_float4(0.0f,0.0f,0.0f,0.0f);
|
||||
float4 acc_frag_1_0 = make_float4(0.0f,0.0f,0.0f,0.0f);
|
||||
float4 acc_frag_1_1 = make_float4(0.0f,0.0f,0.0f,0.0f);
|
||||
float4 acc_frag_1_2 = make_float4(0.0f,0.0f,0.0f,0.0f);
|
||||
float4 acc_frag_1_3 = make_float4(0.0f,0.0f,0.0f,0.0f);
|
||||
float4 acc_frag_1_4 = make_float4(0.0f,0.0f,0.0f,0.0f);
|
||||
float4 acc_frag_1_5 = make_float4(0.0f,0.0f,0.0f,0.0f);
|
||||
float4 acc_frag_1_6 = make_float4(0.0f,0.0f,0.0f,0.0f);
|
||||
float4 acc_frag_1_7 = make_float4(0.0f,0.0f,0.0f,0.0f);
|
||||
|
||||
// create registers for block A elements (2)
|
||||
half8 a_frag_0;
|
||||
half8 a_frag_1;
|
||||
|
||||
// create register for block B elements (8)
|
||||
half4 b_frag_0;
|
||||
half4 b_frag_1;
|
||||
half4 b_frag_2;
|
||||
half4 b_frag_3;
|
||||
half4 b_frag_4;
|
||||
half4 b_frag_5;
|
||||
half4 b_frag_6;
|
||||
half4 b_frag_7;
|
||||
|
||||
half *smem_a_even = (half *)(smem);
|
||||
half *smem_a_odd = (half *)(smem + 8192);
|
||||
half *smem_b_even = (half *)(smem + 16384);
|
||||
half *smem_b_odd = (half *)(smem + 32768);
|
||||
|
||||
// https://developer.nvidia.com/blog/controlling-data-movement-to-boost-performance-on-ampere-architecture/
|
||||
// https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#asynchronous-data-copies
|
||||
|
||||
// start first pre-fetch load A
|
||||
__pipeline_memcpy_async(&smem_a_even[store_smem_a_off + ( 0)], &data1[global_a_off + ( 0)], 16);
|
||||
__pipeline_memcpy_async(&smem_a_even[store_smem_a_off + (16*64)], &data1[global_a_off + (16*K)], 16);
|
||||
__pipeline_memcpy_async(&smem_a_even[store_smem_a_off + (32*64)], &data1[global_a_off + (32*K)], 16);
|
||||
__pipeline_memcpy_async(&smem_a_even[store_smem_a_off + (48*64)], &data1[global_a_off + (48*K)], 16);
|
||||
|
||||
// start first pre-fetch load B
|
||||
__pipeline_memcpy_async(&smem_b_even[store_smem_b_off + ( 0)], &data2[global_b_off + ( 0)], 16);
|
||||
__pipeline_memcpy_async(&smem_b_even[store_smem_b_off + ( 8*128)], &data2[global_b_off + ( 8*N)], 16);
|
||||
__pipeline_memcpy_async(&smem_b_even[store_smem_b_off + (16*128)], &data2[global_b_off + (16*N)], 16);
|
||||
__pipeline_memcpy_async(&smem_b_even[store_smem_b_off + (24*128)], &data2[global_b_off + (24*N)], 16);
|
||||
__pipeline_memcpy_async(&smem_b_even[store_smem_b_off + (32*128)], &data2[global_b_off + (32*N)], 16);
|
||||
__pipeline_memcpy_async(&smem_b_even[store_smem_b_off + (40*128)], &data2[global_b_off + (40*N)], 16);
|
||||
__pipeline_memcpy_async(&smem_b_even[store_smem_b_off + (48*128)], &data2[global_b_off + (48*N)], 16);
|
||||
__pipeline_memcpy_async(&smem_b_even[store_smem_b_off + (56*128)], &data2[global_b_off + (56*N)], 16);
|
||||
__pipeline_commit();
|
||||
|
||||
global_a_off += 64;
|
||||
global_b_off += 64 * N;
|
||||
__syncthreads();
|
||||
|
||||
// start second pre-fetch load A
|
||||
__pipeline_memcpy_async(&smem_a_odd[store_smem_a_off + ( 0)], &data1[global_a_off + ( 0)], 16);
|
||||
__pipeline_memcpy_async(&smem_a_odd[store_smem_a_off + (16*64)], &data1[global_a_off + (16*K)], 16);
|
||||
__pipeline_memcpy_async(&smem_a_odd[store_smem_a_off + (32*64)], &data1[global_a_off + (32*K)], 16);
|
||||
__pipeline_memcpy_async(&smem_a_odd[store_smem_a_off + (48*64)], &data1[global_a_off + (48*K)], 16);
|
||||
|
||||
// start second pre-fetch load B
|
||||
__pipeline_memcpy_async(&smem_b_odd[store_smem_b_off + ( 0)], &data2[global_b_off + ( 0)], 16);
|
||||
__pipeline_memcpy_async(&smem_b_odd[store_smem_b_off + ( 8*128)], &data2[global_b_off + ( 8*N)], 16);
|
||||
__pipeline_memcpy_async(&smem_b_odd[store_smem_b_off + (16*128)], &data2[global_b_off + (16*N)], 16);
|
||||
__pipeline_memcpy_async(&smem_b_odd[store_smem_b_off + (24*128)], &data2[global_b_off + (24*N)], 16);
|
||||
__pipeline_memcpy_async(&smem_b_odd[store_smem_b_off + (32*128)], &data2[global_b_off + (32*N)], 16);
|
||||
__pipeline_memcpy_async(&smem_b_odd[store_smem_b_off + (40*128)], &data2[global_b_off + (40*N)], 16);
|
||||
__pipeline_memcpy_async(&smem_b_odd[store_smem_b_off + (48*128)], &data2[global_b_off + (48*N)], 16);
|
||||
__pipeline_memcpy_async(&smem_b_odd[store_smem_b_off + (56*128)], &data2[global_b_off + (56*N)], 16);
|
||||
__pipeline_commit();
|
||||
|
||||
global_a_off += 64;
|
||||
global_b_off += 64 * N;
|
||||
|
||||
// wait on needed prefetch value
|
||||
__pipeline_wait_prior(0); // TODO: this enables fast iterations, but incorrect results with 1 (it shouldn't)
|
||||
__syncthreads();
|
||||
|
||||
for (int block_k = 0; block_k < num_k_blocks; block_k++) {
|
||||
// BLOCK_K==4: unroll 4 iterations of ldmatrix/wmma
|
||||
half *smem_a_curr = (block_k % 2) ? smem_a_even : smem_a_odd;
|
||||
half *smem_b_curr = (block_k % 2) ? smem_b_even : smem_b_odd;
|
||||
|
||||
// first load 16 K elements and 16 WMMAs: BLOCK_M==2 * BLOCK_N==8
|
||||
__ldmatrix_a_elems(&a_frag_0, &smem_a_curr[load_smem_a_0_k_0]);
|
||||
__ldmatrix_a_elems(&a_frag_1, &smem_a_curr[load_smem_a_1_k_0]);
|
||||
__ldmatrix_b_elems(&b_frag_0, &b_frag_1, &smem_b_curr[load_smem_b_0_k_0]);
|
||||
__ldmatrix_b_elems(&b_frag_2, &b_frag_3, &smem_b_curr[load_smem_b_1_k_0]);
|
||||
__ldmatrix_b_elems(&b_frag_4, &b_frag_5, &smem_b_curr[load_smem_b_2_k_0]);
|
||||
__ldmatrix_b_elems(&b_frag_6, &b_frag_7, &smem_b_curr[load_smem_b_3_k_0]);
|
||||
acc_frag_0_0 = __WMMA_8_16_16_half_float(a_frag_0, b_frag_0, acc_frag_0_0);
|
||||
acc_frag_0_1 = __WMMA_8_16_16_half_float(a_frag_0, b_frag_1, acc_frag_0_1);
|
||||
acc_frag_0_2 = __WMMA_8_16_16_half_float(a_frag_0, b_frag_2, acc_frag_0_2);
|
||||
acc_frag_0_3 = __WMMA_8_16_16_half_float(a_frag_0, b_frag_3, acc_frag_0_3);
|
||||
acc_frag_0_4 = __WMMA_8_16_16_half_float(a_frag_0, b_frag_4, acc_frag_0_4);
|
||||
acc_frag_0_5 = __WMMA_8_16_16_half_float(a_frag_0, b_frag_5, acc_frag_0_5);
|
||||
acc_frag_0_6 = __WMMA_8_16_16_half_float(a_frag_0, b_frag_6, acc_frag_0_6);
|
||||
acc_frag_0_7 = __WMMA_8_16_16_half_float(a_frag_0, b_frag_7, acc_frag_0_7);
|
||||
acc_frag_1_0 = __WMMA_8_16_16_half_float(a_frag_1, b_frag_0, acc_frag_1_0);
|
||||
acc_frag_1_1 = __WMMA_8_16_16_half_float(a_frag_1, b_frag_1, acc_frag_1_1);
|
||||
acc_frag_1_2 = __WMMA_8_16_16_half_float(a_frag_1, b_frag_2, acc_frag_1_2);
|
||||
acc_frag_1_3 = __WMMA_8_16_16_half_float(a_frag_1, b_frag_3, acc_frag_1_3);
|
||||
acc_frag_1_4 = __WMMA_8_16_16_half_float(a_frag_1, b_frag_4, acc_frag_1_4);
|
||||
acc_frag_1_5 = __WMMA_8_16_16_half_float(a_frag_1, b_frag_5, acc_frag_1_5);
|
||||
acc_frag_1_6 = __WMMA_8_16_16_half_float(a_frag_1, b_frag_6, acc_frag_1_6);
|
||||
acc_frag_1_7 = __WMMA_8_16_16_half_float(a_frag_1, b_frag_7, acc_frag_1_7);
|
||||
|
||||
// next 16 K elements
|
||||
__ldmatrix_a_elems(&a_frag_0, &smem_a_curr[load_smem_a_0_k_1]);
|
||||
__ldmatrix_a_elems(&a_frag_1, &smem_a_curr[load_smem_a_1_k_1]);
|
||||
__ldmatrix_b_elems(&b_frag_0, &b_frag_1, &smem_b_curr[load_smem_b_0_k_1]);
|
||||
__ldmatrix_b_elems(&b_frag_2, &b_frag_3, &smem_b_curr[load_smem_b_1_k_1]);
|
||||
__ldmatrix_b_elems(&b_frag_4, &b_frag_5, &smem_b_curr[load_smem_b_2_k_1]);
|
||||
__ldmatrix_b_elems(&b_frag_6, &b_frag_7, &smem_b_curr[load_smem_b_3_k_1]);
|
||||
acc_frag_0_0 = __WMMA_8_16_16_half_float(a_frag_0, b_frag_0, acc_frag_0_0);
|
||||
acc_frag_0_1 = __WMMA_8_16_16_half_float(a_frag_0, b_frag_1, acc_frag_0_1);
|
||||
acc_frag_0_2 = __WMMA_8_16_16_half_float(a_frag_0, b_frag_2, acc_frag_0_2);
|
||||
acc_frag_0_3 = __WMMA_8_16_16_half_float(a_frag_0, b_frag_3, acc_frag_0_3);
|
||||
acc_frag_0_4 = __WMMA_8_16_16_half_float(a_frag_0, b_frag_4, acc_frag_0_4);
|
||||
acc_frag_0_5 = __WMMA_8_16_16_half_float(a_frag_0, b_frag_5, acc_frag_0_5);
|
||||
acc_frag_0_6 = __WMMA_8_16_16_half_float(a_frag_0, b_frag_6, acc_frag_0_6);
|
||||
acc_frag_0_7 = __WMMA_8_16_16_half_float(a_frag_0, b_frag_7, acc_frag_0_7);
|
||||
acc_frag_1_0 = __WMMA_8_16_16_half_float(a_frag_1, b_frag_0, acc_frag_1_0);
|
||||
acc_frag_1_1 = __WMMA_8_16_16_half_float(a_frag_1, b_frag_1, acc_frag_1_1);
|
||||
acc_frag_1_2 = __WMMA_8_16_16_half_float(a_frag_1, b_frag_2, acc_frag_1_2);
|
||||
acc_frag_1_3 = __WMMA_8_16_16_half_float(a_frag_1, b_frag_3, acc_frag_1_3);
|
||||
acc_frag_1_4 = __WMMA_8_16_16_half_float(a_frag_1, b_frag_4, acc_frag_1_4);
|
||||
acc_frag_1_5 = __WMMA_8_16_16_half_float(a_frag_1, b_frag_5, acc_frag_1_5);
|
||||
acc_frag_1_6 = __WMMA_8_16_16_half_float(a_frag_1, b_frag_6, acc_frag_1_6);
|
||||
acc_frag_1_7 = __WMMA_8_16_16_half_float(a_frag_1, b_frag_7, acc_frag_1_7);
|
||||
|
||||
// next 16 K elements
|
||||
__ldmatrix_a_elems(&a_frag_0, &smem_a_curr[load_smem_a_0_k_2]);
|
||||
__ldmatrix_a_elems(&a_frag_1, &smem_a_curr[load_smem_a_1_k_2]);
|
||||
__ldmatrix_b_elems(&b_frag_0, &b_frag_1, &smem_b_curr[load_smem_b_0_k_2]);
|
||||
__ldmatrix_b_elems(&b_frag_2, &b_frag_3, &smem_b_curr[load_smem_b_1_k_2]);
|
||||
__ldmatrix_b_elems(&b_frag_4, &b_frag_5, &smem_b_curr[load_smem_b_2_k_2]);
|
||||
__ldmatrix_b_elems(&b_frag_6, &b_frag_7, &smem_b_curr[load_smem_b_3_k_2]);
|
||||
acc_frag_0_0 = __WMMA_8_16_16_half_float(a_frag_0, b_frag_0, acc_frag_0_0);
|
||||
acc_frag_0_1 = __WMMA_8_16_16_half_float(a_frag_0, b_frag_1, acc_frag_0_1);
|
||||
acc_frag_0_2 = __WMMA_8_16_16_half_float(a_frag_0, b_frag_2, acc_frag_0_2);
|
||||
acc_frag_0_3 = __WMMA_8_16_16_half_float(a_frag_0, b_frag_3, acc_frag_0_3);
|
||||
acc_frag_0_4 = __WMMA_8_16_16_half_float(a_frag_0, b_frag_4, acc_frag_0_4);
|
||||
acc_frag_0_5 = __WMMA_8_16_16_half_float(a_frag_0, b_frag_5, acc_frag_0_5);
|
||||
acc_frag_0_6 = __WMMA_8_16_16_half_float(a_frag_0, b_frag_6, acc_frag_0_6);
|
||||
acc_frag_0_7 = __WMMA_8_16_16_half_float(a_frag_0, b_frag_7, acc_frag_0_7);
|
||||
acc_frag_1_0 = __WMMA_8_16_16_half_float(a_frag_1, b_frag_0, acc_frag_1_0);
|
||||
acc_frag_1_1 = __WMMA_8_16_16_half_float(a_frag_1, b_frag_1, acc_frag_1_1);
|
||||
acc_frag_1_2 = __WMMA_8_16_16_half_float(a_frag_1, b_frag_2, acc_frag_1_2);
|
||||
acc_frag_1_3 = __WMMA_8_16_16_half_float(a_frag_1, b_frag_3, acc_frag_1_3);
|
||||
acc_frag_1_4 = __WMMA_8_16_16_half_float(a_frag_1, b_frag_4, acc_frag_1_4);
|
||||
acc_frag_1_5 = __WMMA_8_16_16_half_float(a_frag_1, b_frag_5, acc_frag_1_5);
|
||||
acc_frag_1_6 = __WMMA_8_16_16_half_float(a_frag_1, b_frag_6, acc_frag_1_6);
|
||||
acc_frag_1_7 = __WMMA_8_16_16_half_float(a_frag_1, b_frag_7, acc_frag_1_7);
|
||||
|
||||
// last 16 K elements
|
||||
__ldmatrix_a_elems(&a_frag_0, &smem_a_curr[load_smem_a_0_k_3]);
|
||||
__ldmatrix_a_elems(&a_frag_1, &smem_a_curr[load_smem_a_1_k_3]);
|
||||
__ldmatrix_b_elems(&b_frag_0, &b_frag_1, &smem_b_curr[load_smem_b_0_k_3]);
|
||||
__ldmatrix_b_elems(&b_frag_2, &b_frag_3, &smem_b_curr[load_smem_b_1_k_3]);
|
||||
__ldmatrix_b_elems(&b_frag_4, &b_frag_5, &smem_b_curr[load_smem_b_2_k_3]);
|
||||
__ldmatrix_b_elems(&b_frag_6, &b_frag_7, &smem_b_curr[load_smem_b_3_k_3]);
|
||||
acc_frag_0_0 = __WMMA_8_16_16_half_float(a_frag_0, b_frag_0, acc_frag_0_0);
|
||||
acc_frag_0_1 = __WMMA_8_16_16_half_float(a_frag_0, b_frag_1, acc_frag_0_1);
|
||||
acc_frag_0_2 = __WMMA_8_16_16_half_float(a_frag_0, b_frag_2, acc_frag_0_2);
|
||||
acc_frag_0_3 = __WMMA_8_16_16_half_float(a_frag_0, b_frag_3, acc_frag_0_3);
|
||||
acc_frag_0_4 = __WMMA_8_16_16_half_float(a_frag_0, b_frag_4, acc_frag_0_4);
|
||||
acc_frag_0_5 = __WMMA_8_16_16_half_float(a_frag_0, b_frag_5, acc_frag_0_5);
|
||||
acc_frag_0_6 = __WMMA_8_16_16_half_float(a_frag_0, b_frag_6, acc_frag_0_6);
|
||||
acc_frag_0_7 = __WMMA_8_16_16_half_float(a_frag_0, b_frag_7, acc_frag_0_7);
|
||||
acc_frag_1_0 = __WMMA_8_16_16_half_float(a_frag_1, b_frag_0, acc_frag_1_0);
|
||||
acc_frag_1_1 = __WMMA_8_16_16_half_float(a_frag_1, b_frag_1, acc_frag_1_1);
|
||||
acc_frag_1_2 = __WMMA_8_16_16_half_float(a_frag_1, b_frag_2, acc_frag_1_2);
|
||||
acc_frag_1_3 = __WMMA_8_16_16_half_float(a_frag_1, b_frag_3, acc_frag_1_3);
|
||||
acc_frag_1_4 = __WMMA_8_16_16_half_float(a_frag_1, b_frag_4, acc_frag_1_4);
|
||||
acc_frag_1_5 = __WMMA_8_16_16_half_float(a_frag_1, b_frag_5, acc_frag_1_5);
|
||||
acc_frag_1_6 = __WMMA_8_16_16_half_float(a_frag_1, b_frag_6, acc_frag_1_6);
|
||||
acc_frag_1_7 = __WMMA_8_16_16_half_float(a_frag_1, b_frag_7, acc_frag_1_7);
|
||||
|
||||
// prefetch next iteration if needed
|
||||
__syncthreads();
|
||||
if (block_k < (num_k_blocks-2)) {
|
||||
__pipeline_memcpy_async(&smem_a_curr[store_smem_a_off + ( 0)], &data1[global_a_off + ( 0)], 16);
|
||||
__pipeline_memcpy_async(&smem_a_curr[store_smem_a_off + (16*64)], &data1[global_a_off + (16*K)], 16);
|
||||
__pipeline_memcpy_async(&smem_a_curr[store_smem_a_off + (32*64)], &data1[global_a_off + (32*K)], 16);
|
||||
__pipeline_memcpy_async(&smem_a_curr[store_smem_a_off + (48*64)], &data1[global_a_off + (48*K)], 16);
|
||||
|
||||
__pipeline_memcpy_async(&smem_b_curr[store_smem_b_off + ( 0)], &data2[global_b_off + ( 0)], 16);
|
||||
__pipeline_memcpy_async(&smem_b_curr[store_smem_b_off + ( 8*128)], &data2[global_b_off + ( 8*N)], 16);
|
||||
__pipeline_memcpy_async(&smem_b_curr[store_smem_b_off + (16*128)], &data2[global_b_off + (16*N)], 16);
|
||||
__pipeline_memcpy_async(&smem_b_curr[store_smem_b_off + (24*128)], &data2[global_b_off + (24*N)], 16);
|
||||
__pipeline_memcpy_async(&smem_b_curr[store_smem_b_off + (32*128)], &data2[global_b_off + (32*N)], 16);
|
||||
__pipeline_memcpy_async(&smem_b_curr[store_smem_b_off + (40*128)], &data2[global_b_off + (40*N)], 16);
|
||||
__pipeline_memcpy_async(&smem_b_curr[store_smem_b_off + (48*128)], &data2[global_b_off + (48*N)], 16);
|
||||
__pipeline_memcpy_async(&smem_b_curr[store_smem_b_off + (56*128)], &data2[global_b_off + (56*N)], 16);
|
||||
|
||||
global_a_off += 64;
|
||||
global_b_off += 64 * N;
|
||||
}
|
||||
__pipeline_commit();
|
||||
|
||||
if (block_k < num_k_blocks-1) {
|
||||
__pipeline_wait_prior(1);
|
||||
__syncthreads();
|
||||
}
|
||||
}
|
||||
|
||||
// write accumulators to output
|
||||
__pipeline_wait_prior(0);
|
||||
__syncthreads();
|
||||
|
||||
// store registers to smem first, then read back to do float4 writes to global
|
||||
float *smem_d = (float *)(smem);
|
||||
size_t smem_d_off = (wg_m * 16 * N_PAD) + (wg_n * 16) + ((wg_threads % 4) * 2) + (((wg_threads / 4) % 8) * N_PAD);
|
||||
smem_d[smem_d_off + 0 + ( 0*8) ] = acc_frag_0_0.x;
|
||||
smem_d[smem_d_off + 1 + ( 0*8) ] = acc_frag_0_0.y;
|
||||
smem_d[smem_d_off + 0 + ( 0*8) + (8*N_PAD)] = acc_frag_0_0.z;
|
||||
smem_d[smem_d_off + 1 + ( 0*8) + (8*N_PAD)] = acc_frag_0_0.w;
|
||||
smem_d[smem_d_off + 0 + ( 1*8) ] = acc_frag_0_1.x;
|
||||
smem_d[smem_d_off + 1 + ( 1*8) ] = acc_frag_0_1.y;
|
||||
smem_d[smem_d_off + 0 + ( 1*8) + (8*N_PAD)] = acc_frag_0_1.z;
|
||||
smem_d[smem_d_off + 1 + ( 1*8) + (8*N_PAD)] = acc_frag_0_1.w;
|
||||
smem_d[smem_d_off + 0 + ( 4*8) ] = acc_frag_0_2.x;
|
||||
smem_d[smem_d_off + 1 + ( 4*8) ] = acc_frag_0_2.y;
|
||||
smem_d[smem_d_off + 0 + ( 4*8) + (8*N_PAD)] = acc_frag_0_2.z;
|
||||
smem_d[smem_d_off + 1 + ( 4*8) + (8*N_PAD)] = acc_frag_0_2.w;
|
||||
smem_d[smem_d_off + 0 + ( 5*8) ] = acc_frag_0_3.x;
|
||||
smem_d[smem_d_off + 1 + ( 5*8) ] = acc_frag_0_3.y;
|
||||
smem_d[smem_d_off + 0 + ( 5*8) + (8*N_PAD)] = acc_frag_0_3.z;
|
||||
smem_d[smem_d_off + 1 + ( 5*8) + (8*N_PAD)] = acc_frag_0_3.w;
|
||||
smem_d[smem_d_off + 0 + ( 8*8) ] = acc_frag_0_4.x;
|
||||
smem_d[smem_d_off + 1 + ( 8*8) ] = acc_frag_0_4.y;
|
||||
smem_d[smem_d_off + 0 + ( 8*8) + (8*N_PAD)] = acc_frag_0_4.z;
|
||||
smem_d[smem_d_off + 1 + ( 8*8) + (8*N_PAD)] = acc_frag_0_4.w;
|
||||
smem_d[smem_d_off + 0 + ( 9*8) ] = acc_frag_0_5.x;
|
||||
smem_d[smem_d_off + 1 + ( 9*8) ] = acc_frag_0_5.y;
|
||||
smem_d[smem_d_off + 0 + ( 9*8) + (8*N_PAD)] = acc_frag_0_5.z;
|
||||
smem_d[smem_d_off + 1 + ( 9*8) + (8*N_PAD)] = acc_frag_0_5.w;
|
||||
smem_d[smem_d_off + 0 + (12*8) ] = acc_frag_0_6.x;
|
||||
smem_d[smem_d_off + 1 + (12*8) ] = acc_frag_0_6.y;
|
||||
smem_d[smem_d_off + 0 + (12*8) + (8*N_PAD)] = acc_frag_0_6.z;
|
||||
smem_d[smem_d_off + 1 + (12*8) + (8*N_PAD)] = acc_frag_0_6.w;
|
||||
smem_d[smem_d_off + 0 + (13*8) ] = acc_frag_0_7.x;
|
||||
smem_d[smem_d_off + 1 + (13*8) ] = acc_frag_0_7.y;
|
||||
smem_d[smem_d_off + 0 + (13*8) + (8*N_PAD)] = acc_frag_0_7.z;
|
||||
smem_d[smem_d_off + 1 + (13*8) + (8*N_PAD)] = acc_frag_0_7.w;
|
||||
|
||||
__syncthreads();
|
||||
size_t load_smem_d_off = ((threads % 32) * 4) + ((threads / 32) * N_PAD);
|
||||
float4 d_0_0 = *((float4 *)(smem_d + load_smem_d_off + ( 0 * N_PAD)));
|
||||
float4 d_0_1 = *((float4 *)(smem_d + load_smem_d_off + ( 4 * N_PAD)));
|
||||
float4 d_0_2 = *((float4 *)(smem_d + load_smem_d_off + ( 8 * N_PAD)));
|
||||
float4 d_0_3 = *((float4 *)(smem_d + load_smem_d_off + (12 * N_PAD)));
|
||||
float4 d_0_4 = *((float4 *)(smem_d + load_smem_d_off + (16 * N_PAD)));
|
||||
float4 d_0_5 = *((float4 *)(smem_d + load_smem_d_off + (20 * N_PAD)));
|
||||
float4 d_0_6 = *((float4 *)(smem_d + load_smem_d_off + (24 * N_PAD)));
|
||||
float4 d_0_7 = *((float4 *)(smem_d + load_smem_d_off + (28 * N_PAD)));
|
||||
|
||||
__syncthreads();
|
||||
smem_d[smem_d_off + 0 + ( 0*8) ] = acc_frag_1_0.x;
|
||||
smem_d[smem_d_off + 1 + ( 0*8) ] = acc_frag_1_0.y;
|
||||
smem_d[smem_d_off + 0 + ( 0*8) + (8*N_PAD)] = acc_frag_1_0.z;
|
||||
smem_d[smem_d_off + 1 + ( 0*8) + (8*N_PAD)] = acc_frag_1_0.w;
|
||||
smem_d[smem_d_off + 0 + ( 1*8) ] = acc_frag_1_1.x;
|
||||
smem_d[smem_d_off + 1 + ( 1*8) ] = acc_frag_1_1.y;
|
||||
smem_d[smem_d_off + 0 + ( 1*8) + (8*N_PAD)] = acc_frag_1_1.z;
|
||||
smem_d[smem_d_off + 1 + ( 1*8) + (8*N_PAD)] = acc_frag_1_1.w;
|
||||
smem_d[smem_d_off + 0 + ( 4*8) ] = acc_frag_1_2.x;
|
||||
smem_d[smem_d_off + 1 + ( 4*8) ] = acc_frag_1_2.y;
|
||||
smem_d[smem_d_off + 0 + ( 4*8) + (8*N_PAD)] = acc_frag_1_2.z;
|
||||
smem_d[smem_d_off + 1 + ( 4*8) + (8*N_PAD)] = acc_frag_1_2.w;
|
||||
smem_d[smem_d_off + 0 + ( 5*8) ] = acc_frag_1_3.x;
|
||||
smem_d[smem_d_off + 1 + ( 5*8) ] = acc_frag_1_3.y;
|
||||
smem_d[smem_d_off + 0 + ( 5*8) + (8*N_PAD)] = acc_frag_1_3.z;
|
||||
smem_d[smem_d_off + 1 + ( 5*8) + (8*N_PAD)] = acc_frag_1_3.w;
|
||||
smem_d[smem_d_off + 0 + ( 8*8) ] = acc_frag_1_4.x;
|
||||
smem_d[smem_d_off + 1 + ( 8*8) ] = acc_frag_1_4.y;
|
||||
smem_d[smem_d_off + 0 + ( 8*8) + (8*N_PAD)] = acc_frag_1_4.z;
|
||||
smem_d[smem_d_off + 1 + ( 8*8) + (8*N_PAD)] = acc_frag_1_4.w;
|
||||
smem_d[smem_d_off + 0 + ( 9*8) ] = acc_frag_1_5.x;
|
||||
smem_d[smem_d_off + 1 + ( 9*8) ] = acc_frag_1_5.y;
|
||||
smem_d[smem_d_off + 0 + ( 9*8) + (8*N_PAD)] = acc_frag_1_5.z;
|
||||
smem_d[smem_d_off + 1 + ( 9*8) + (8*N_PAD)] = acc_frag_1_5.w;
|
||||
smem_d[smem_d_off + 0 + (12*8) ] = acc_frag_1_6.x;
|
||||
smem_d[smem_d_off + 1 + (12*8) ] = acc_frag_1_6.y;
|
||||
smem_d[smem_d_off + 0 + (12*8) + (8*N_PAD)] = acc_frag_1_6.z;
|
||||
smem_d[smem_d_off + 1 + (12*8) + (8*N_PAD)] = acc_frag_1_6.w;
|
||||
smem_d[smem_d_off + 0 + (13*8) ] = acc_frag_1_7.x;
|
||||
smem_d[smem_d_off + 1 + (13*8) ] = acc_frag_1_7.y;
|
||||
smem_d[smem_d_off + 0 + (13*8) + (8*N_PAD)] = acc_frag_1_7.z;
|
||||
smem_d[smem_d_off + 1 + (13*8) + (8*N_PAD)] = acc_frag_1_7.w;
|
||||
|
||||
__syncthreads();
|
||||
float4 d_1_0 = *((float4 *)(smem_d + load_smem_d_off + ( 0 * N_PAD)));
|
||||
float4 d_1_1 = *((float4 *)(smem_d + load_smem_d_off + ( 4 * N_PAD)));
|
||||
float4 d_1_2 = *((float4 *)(smem_d + load_smem_d_off + ( 8 * N_PAD)));
|
||||
float4 d_1_3 = *((float4 *)(smem_d + load_smem_d_off + (12 * N_PAD)));
|
||||
float4 d_1_4 = *((float4 *)(smem_d + load_smem_d_off + (16 * N_PAD)));
|
||||
float4 d_1_5 = *((float4 *)(smem_d + load_smem_d_off + (20 * N_PAD)));
|
||||
float4 d_1_6 = *((float4 *)(smem_d + load_smem_d_off + (24 * N_PAD)));
|
||||
float4 d_1_7 = *((float4 *)(smem_d + load_smem_d_off + (28 * N_PAD)));
|
||||
|
||||
__syncthreads();
|
||||
float *global_d = &data0[((grid_m * 64) * N) + (grid_n * 128) + ((threads % 32) * 4) + ((threads / 32) * N)];
|
||||
*((float4 *)(global_d + 0*N)) = d_0_0;
|
||||
*((float4 *)(global_d + 4*N)) = d_0_1;
|
||||
*((float4 *)(global_d + 8*N)) = d_0_2;
|
||||
*((float4 *)(global_d + 12*N)) = d_0_3;
|
||||
*((float4 *)(global_d + 16*N)) = d_0_4;
|
||||
*((float4 *)(global_d + 20*N)) = d_0_5;
|
||||
*((float4 *)(global_d + 24*N)) = d_0_6;
|
||||
*((float4 *)(global_d + 28*N)) = d_0_7;
|
||||
*((float4 *)(global_d + 32*N)) = d_1_0;
|
||||
*((float4 *)(global_d + 36*N)) = d_1_1;
|
||||
*((float4 *)(global_d + 40*N)) = d_1_2;
|
||||
*((float4 *)(global_d + 44*N)) = d_1_3;
|
||||
*((float4 *)(global_d + 48*N)) = d_1_4;
|
||||
*((float4 *)(global_d + 52*N)) = d_1_5;
|
||||
*((float4 *)(global_d + 56*N)) = d_1_6;
|
||||
*((float4 *)(global_d + 60*N)) = d_1_7;
|
||||
}
|
||||
371
extra/gemm/max_kernels/nv.fp16_fp32_fp32.swizzled_smem_input.cu
Normal file
371
extra/gemm/max_kernels/nv.fp16_fp32_fp32.swizzled_smem_input.cu
Normal file
@@ -0,0 +1,371 @@
|
||||
#define INFINITY (__int_as_float(0x7f800000))
|
||||
#define NAN (__int_as_float(0x7fffffff))
|
||||
#include <cuda_fp16.h>
|
||||
#include <cuda_pipeline.h>
|
||||
#define N_PAD 132
|
||||
|
||||
struct __align__(8) half4 { half x, y, z, w; };
|
||||
__device__ half4 make_half4(half x, half y, half z, half w) { half4 r={x, y, z, w}; return r; }
|
||||
|
||||
struct __align__(16) half8 { half x, y, z, w, a, b, c, d; };
|
||||
__device__ half8 make_half8(half x, half y, half z, half w, half a, half b, half c, half d) { half8 r={x, y, z, w, a, b, c, d}; return r; }
|
||||
|
||||
__device__ void __ldmatrix_a_elems(half8 *regs, half *smem) {
|
||||
uint32_t reg0, reg1, reg2, reg3;
|
||||
asm volatile(
|
||||
"ldmatrix.sync.aligned.m8n8.x4.shared.b16 {%0, %1, %2, %3}, [%4];"
|
||||
: "=r"(reg0), "=r"(reg1), "=r"(reg2), "=r"(reg3)
|
||||
: "l"(__cvta_generic_to_shared(smem))
|
||||
);
|
||||
uint32_t *addr = reinterpret_cast<uint32_t*>(regs);
|
||||
addr[0] = reg0;
|
||||
addr[1] = reg1;
|
||||
addr[2] = reg2;
|
||||
addr[3] = reg3;
|
||||
}
|
||||
|
||||
__device__ void __ldmatrix_b_elems(half4 *regs_lo, half4 *regs_hi, half *smem) {
|
||||
uint32_t reg0, reg1, reg2, reg3;
|
||||
asm volatile(
|
||||
"ldmatrix.sync.aligned.m8n8.x4.trans.shared.b16 {%0, %1, %2, %3}, [%4];"
|
||||
: "=r"(reg0), "=r"(reg1), "=r"(reg2), "=r"(reg3)
|
||||
: "l"(__cvta_generic_to_shared(smem))
|
||||
);
|
||||
uint32_t *addr_lo = reinterpret_cast<uint32_t*>(regs_lo);
|
||||
uint32_t *addr_hi = reinterpret_cast<uint32_t*>(regs_hi);
|
||||
addr_lo[0] = reg0;
|
||||
addr_lo[1] = reg1;
|
||||
addr_hi[0] = reg2;
|
||||
addr_hi[1] = reg3;
|
||||
}
|
||||
|
||||
__device__ float4 __WMMA_8_16_16_half_float(half8 a, half4 b, float4 c) {
|
||||
int *a_pk = (int *) (&a), *b_pk = (int *) (&b);
|
||||
asm( "mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 { %0, %1, %2, %3 }, { %4, %5, %6, %7 }, { %8, %9 }, { %0, %1, %2, %3 };"
|
||||
: "+f"(c.x), "+f"(c.y), "+f"(c.z), "+f"(c.w) : "r"(a_pk[0]), "r"(a_pk[1]), "r"(a_pk[2]), "r"(a_pk[3]), "r"(b_pk[0]), "r"(b_pk[1]) );
|
||||
return c;
|
||||
}
|
||||
|
||||
extern "C" __global__ void __launch_bounds__(128) wmma_example(float* data0, const half* data1, const half* data2, int N, int K) {
|
||||
int grid_m = blockIdx.x; /* M//64 */
|
||||
int grid_n = blockIdx.y; /* N//128 */
|
||||
int threads = threadIdx.x; /* 128 */
|
||||
int wg_m = (threads/64); // 0 or 1 for 1st and 3rd blocks of b_m=16xb_k=16 vs 2nd and 4th blocks
|
||||
int wg_n = (threads/32)%2; // 0 or 1 for 1st, 3rd, 5th, 7th blocks of b_n=16xb_k=16 vs 2nd, 4th, 6th, 8th blocks - differs from triton
|
||||
int wg_threads = threads%32;
|
||||
int num_k_blocks = K / 64;
|
||||
|
||||
// load indexes
|
||||
size_t global_a_off = ((grid_m * 64) * K) + ((threads % 8) * 8) + ((threads / 8) * K);
|
||||
size_t global_b_off = (grid_n * 128) + ((threads % 16) * 8) + ((threads / 16) * N);
|
||||
|
||||
// swizzled smem store offsets - columns of smem are swizzled
|
||||
// here's a link to a description of the triton: https://github.com/triton-lang/triton/discussions/2026#discussioncomment-6746579
|
||||
// see also the thunderkittens impl: https://github.com/HazyResearch/ThunderKittens/blob/main/include/types/shared/st.cuh
|
||||
size_t store_smem_a_off = ((threads / 8) * 64) + (((threads * 8) ^ threads) & 56); // r15
|
||||
size_t store_smem_b_off = ((threads / 16) * 128) + (((threads / 16) * 8) ^ ((threads % 16) * 8)); // r19
|
||||
|
||||
// ldmatrix indices
|
||||
// threads 0-7 are row starts for A, 8-15 for B, 16-23 for C, 24-31 for D
|
||||
// [ A | C ]
|
||||
// [ - + - ]
|
||||
// [ B | D ]
|
||||
|
||||
// swizzled ldmatrix
|
||||
size_t load_smem_a_row = ((wg_m * 16) + (threads % 16)) * 64; // r293
|
||||
size_t load_smem_a_phase = (threads / 16) % 2; // r4
|
||||
|
||||
size_t load_smem_b_row = (threads % 16) * 128; // r299
|
||||
size_t load_smem_b_phase = (wg_n * 2) + (((threads / 16) % 2)); // r297 -- this differs from the generated triton kernel (swapped order)
|
||||
|
||||
size_t load_smem_a_0_k_0 = load_smem_a_row + (((load_smem_a_phase + 0) ^ (threads % 8)) * 8); // r38
|
||||
size_t load_smem_a_1_k_0 = load_smem_a_0_k_0 + (32 * 64);
|
||||
size_t load_smem_b_0_k_0 = load_smem_b_row + (((load_smem_b_phase + 0) ^ (threads % 8)) * 8);
|
||||
size_t load_smem_b_1_k_0 = load_smem_b_row + (((load_smem_b_phase + 4) ^ (threads % 8)) * 8);
|
||||
size_t load_smem_b_2_k_0 = load_smem_b_row + (((load_smem_b_phase + 8) ^ (threads % 8)) * 8);
|
||||
size_t load_smem_b_3_k_0 = load_smem_b_row + (((load_smem_b_phase + 12) ^ (threads % 8)) * 8);
|
||||
|
||||
size_t load_smem_a_0_k_1 = load_smem_a_row + (((load_smem_a_phase + 2) ^ (threads % 8)) * 8); // r58 = r293 + r316;
|
||||
size_t load_smem_a_1_k_1 = load_smem_a_0_k_1 + (32 * 64);
|
||||
size_t load_smem_b_0_k_1 = load_smem_b_0_k_0 + (16 * 128);
|
||||
size_t load_smem_b_1_k_1 = load_smem_b_1_k_0 + (16 * 128);
|
||||
size_t load_smem_b_2_k_1 = load_smem_b_2_k_0 + (16 * 128);
|
||||
size_t load_smem_b_3_k_1 = load_smem_b_3_k_0 + (16 * 128);
|
||||
|
||||
size_t load_smem_a_0_k_2 = load_smem_a_row + (((load_smem_a_phase + 4) ^ (threads % 8)) * 8); // r59 = r293 + r319;
|
||||
size_t load_smem_a_1_k_2 = load_smem_a_0_k_2 + (32 * 64);
|
||||
size_t load_smem_b_0_k_2 = load_smem_b_0_k_0 + (32 * 128);
|
||||
size_t load_smem_b_1_k_2 = load_smem_b_1_k_0 + (32 * 128);
|
||||
size_t load_smem_b_2_k_2 = load_smem_b_2_k_0 + (32 * 128);
|
||||
size_t load_smem_b_3_k_2 = load_smem_b_3_k_0 + (32 * 128);
|
||||
|
||||
size_t load_smem_a_0_k_3 = load_smem_a_row + (((load_smem_a_phase + 6) ^ (threads % 8)) * 8); // r60 = r293 + r322;
|
||||
size_t load_smem_a_1_k_3 = load_smem_a_0_k_3 + (32 * 64);
|
||||
size_t load_smem_b_0_k_3 = load_smem_b_0_k_0 + (48 * 128);
|
||||
size_t load_smem_b_1_k_3 = load_smem_b_1_k_0 + (48 * 128);
|
||||
size_t load_smem_b_2_k_3 = load_smem_b_2_k_0 + (48 * 128);
|
||||
size_t load_smem_b_3_k_3 = load_smem_b_3_k_0 + (48 * 128);
|
||||
|
||||
// create shared mem (A 8192 bytes, B 16384 bytes)
|
||||
__shared__ alignas(16) char smem[24576];
|
||||
|
||||
// create accs (16 WMMAs and 4 output elements each) and zero
|
||||
float4 acc_frag_0_0 = make_float4(0.0f,0.0f,0.0f,0.0f);
|
||||
float4 acc_frag_0_1 = make_float4(0.0f,0.0f,0.0f,0.0f);
|
||||
float4 acc_frag_0_2 = make_float4(0.0f,0.0f,0.0f,0.0f);
|
||||
float4 acc_frag_0_3 = make_float4(0.0f,0.0f,0.0f,0.0f);
|
||||
float4 acc_frag_0_4 = make_float4(0.0f,0.0f,0.0f,0.0f);
|
||||
float4 acc_frag_0_5 = make_float4(0.0f,0.0f,0.0f,0.0f);
|
||||
float4 acc_frag_0_6 = make_float4(0.0f,0.0f,0.0f,0.0f);
|
||||
float4 acc_frag_0_7 = make_float4(0.0f,0.0f,0.0f,0.0f);
|
||||
float4 acc_frag_1_0 = make_float4(0.0f,0.0f,0.0f,0.0f);
|
||||
float4 acc_frag_1_1 = make_float4(0.0f,0.0f,0.0f,0.0f);
|
||||
float4 acc_frag_1_2 = make_float4(0.0f,0.0f,0.0f,0.0f);
|
||||
float4 acc_frag_1_3 = make_float4(0.0f,0.0f,0.0f,0.0f);
|
||||
float4 acc_frag_1_4 = make_float4(0.0f,0.0f,0.0f,0.0f);
|
||||
float4 acc_frag_1_5 = make_float4(0.0f,0.0f,0.0f,0.0f);
|
||||
float4 acc_frag_1_6 = make_float4(0.0f,0.0f,0.0f,0.0f);
|
||||
float4 acc_frag_1_7 = make_float4(0.0f,0.0f,0.0f,0.0f);
|
||||
|
||||
// create registers for block A elements (2)
|
||||
half8 a_frag_0;
|
||||
half8 a_frag_1;
|
||||
|
||||
// create register for block B elements (8)
|
||||
half4 b_frag_0;
|
||||
half4 b_frag_1;
|
||||
half4 b_frag_2;
|
||||
half4 b_frag_3;
|
||||
half4 b_frag_4;
|
||||
half4 b_frag_5;
|
||||
half4 b_frag_6;
|
||||
half4 b_frag_7;
|
||||
|
||||
half *smem_a = (half *)(smem);
|
||||
half *smem_b = (half *)(smem + 8192);
|
||||
|
||||
// https://developer.nvidia.com/blog/controlling-data-movement-to-boost-performance-on-ampere-architecture/
|
||||
// https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#asynchronous-data-copies
|
||||
|
||||
// start first pre-fetch load A
|
||||
__pipeline_memcpy_async(&smem_a[store_smem_a_off + ( 0)], &data1[global_a_off + ( 0)], 16);
|
||||
__pipeline_memcpy_async(&smem_a[store_smem_a_off + (16*64)], &data1[global_a_off + (16*K)], 16);
|
||||
__pipeline_memcpy_async(&smem_a[store_smem_a_off + (32*64)], &data1[global_a_off + (32*K)], 16);
|
||||
__pipeline_memcpy_async(&smem_a[store_smem_a_off + (48*64)], &data1[global_a_off + (48*K)], 16);
|
||||
|
||||
// start first pre-fetch load B
|
||||
__pipeline_memcpy_async(&smem_b[store_smem_b_off + ( 0)], &data2[global_b_off + ( 0)], 16);
|
||||
__pipeline_memcpy_async(&smem_b[store_smem_b_off + ( 8*128)], &data2[global_b_off + ( 8*N)], 16);
|
||||
__pipeline_memcpy_async(&smem_b[store_smem_b_off + (16*128)], &data2[global_b_off + (16*N)], 16);
|
||||
__pipeline_memcpy_async(&smem_b[store_smem_b_off + (24*128)], &data2[global_b_off + (24*N)], 16);
|
||||
__pipeline_memcpy_async(&smem_b[store_smem_b_off + (32*128)], &data2[global_b_off + (32*N)], 16);
|
||||
__pipeline_memcpy_async(&smem_b[store_smem_b_off + (40*128)], &data2[global_b_off + (40*N)], 16);
|
||||
__pipeline_memcpy_async(&smem_b[store_smem_b_off + (48*128)], &data2[global_b_off + (48*N)], 16);
|
||||
__pipeline_memcpy_async(&smem_b[store_smem_b_off + (56*128)], &data2[global_b_off + (56*N)], 16);
|
||||
__pipeline_commit();
|
||||
|
||||
global_a_off += 64;
|
||||
global_b_off += 64 * N;
|
||||
__syncthreads();
|
||||
|
||||
for (int block_k = 0; block_k < num_k_blocks; block_k++) {
|
||||
// wait on needed prefetch value
|
||||
__pipeline_wait_prior(0);
|
||||
__syncthreads();
|
||||
|
||||
// BLOCK_K==4: unroll 4 iterations of ldmatrix/wmma
|
||||
half *smem_a_curr = smem_a;
|
||||
half *smem_b_curr = smem_b;
|
||||
|
||||
// first load 16 K elements and 16 WMMAs: BLOCK_M==2 * BLOCK_N==8
|
||||
__ldmatrix_a_elems(&a_frag_0, &smem_a_curr[load_smem_a_0_k_0]);
|
||||
__ldmatrix_a_elems(&a_frag_1, &smem_a_curr[load_smem_a_1_k_0]);
|
||||
__ldmatrix_b_elems(&b_frag_0, &b_frag_1, &smem_b_curr[load_smem_b_0_k_0]);
|
||||
__ldmatrix_b_elems(&b_frag_2, &b_frag_3, &smem_b_curr[load_smem_b_1_k_0]);
|
||||
__ldmatrix_b_elems(&b_frag_4, &b_frag_5, &smem_b_curr[load_smem_b_2_k_0]);
|
||||
__ldmatrix_b_elems(&b_frag_6, &b_frag_7, &smem_b_curr[load_smem_b_3_k_0]);
|
||||
acc_frag_0_0 = __WMMA_8_16_16_half_float(a_frag_0, b_frag_0, acc_frag_0_0);
|
||||
acc_frag_0_1 = __WMMA_8_16_16_half_float(a_frag_0, b_frag_1, acc_frag_0_1);
|
||||
acc_frag_0_2 = __WMMA_8_16_16_half_float(a_frag_0, b_frag_2, acc_frag_0_2);
|
||||
acc_frag_0_3 = __WMMA_8_16_16_half_float(a_frag_0, b_frag_3, acc_frag_0_3);
|
||||
acc_frag_0_4 = __WMMA_8_16_16_half_float(a_frag_0, b_frag_4, acc_frag_0_4);
|
||||
acc_frag_0_5 = __WMMA_8_16_16_half_float(a_frag_0, b_frag_5, acc_frag_0_5);
|
||||
acc_frag_0_6 = __WMMA_8_16_16_half_float(a_frag_0, b_frag_6, acc_frag_0_6);
|
||||
acc_frag_0_7 = __WMMA_8_16_16_half_float(a_frag_0, b_frag_7, acc_frag_0_7);
|
||||
acc_frag_1_0 = __WMMA_8_16_16_half_float(a_frag_1, b_frag_0, acc_frag_1_0);
|
||||
acc_frag_1_1 = __WMMA_8_16_16_half_float(a_frag_1, b_frag_1, acc_frag_1_1);
|
||||
acc_frag_1_2 = __WMMA_8_16_16_half_float(a_frag_1, b_frag_2, acc_frag_1_2);
|
||||
acc_frag_1_3 = __WMMA_8_16_16_half_float(a_frag_1, b_frag_3, acc_frag_1_3);
|
||||
acc_frag_1_4 = __WMMA_8_16_16_half_float(a_frag_1, b_frag_4, acc_frag_1_4);
|
||||
acc_frag_1_5 = __WMMA_8_16_16_half_float(a_frag_1, b_frag_5, acc_frag_1_5);
|
||||
acc_frag_1_6 = __WMMA_8_16_16_half_float(a_frag_1, b_frag_6, acc_frag_1_6);
|
||||
acc_frag_1_7 = __WMMA_8_16_16_half_float(a_frag_1, b_frag_7, acc_frag_1_7);
|
||||
|
||||
// next 16 K elements
|
||||
__ldmatrix_a_elems(&a_frag_0, &smem_a_curr[load_smem_a_0_k_1]);
|
||||
__ldmatrix_a_elems(&a_frag_1, &smem_a_curr[load_smem_a_1_k_1]);
|
||||
__ldmatrix_b_elems(&b_frag_0, &b_frag_1, &smem_b_curr[load_smem_b_0_k_1]);
|
||||
__ldmatrix_b_elems(&b_frag_2, &b_frag_3, &smem_b_curr[load_smem_b_1_k_1]);
|
||||
__ldmatrix_b_elems(&b_frag_4, &b_frag_5, &smem_b_curr[load_smem_b_2_k_1]);
|
||||
__ldmatrix_b_elems(&b_frag_6, &b_frag_7, &smem_b_curr[load_smem_b_3_k_1]);
|
||||
acc_frag_0_0 = __WMMA_8_16_16_half_float(a_frag_0, b_frag_0, acc_frag_0_0);
|
||||
acc_frag_0_1 = __WMMA_8_16_16_half_float(a_frag_0, b_frag_1, acc_frag_0_1);
|
||||
acc_frag_0_2 = __WMMA_8_16_16_half_float(a_frag_0, b_frag_2, acc_frag_0_2);
|
||||
acc_frag_0_3 = __WMMA_8_16_16_half_float(a_frag_0, b_frag_3, acc_frag_0_3);
|
||||
acc_frag_0_4 = __WMMA_8_16_16_half_float(a_frag_0, b_frag_4, acc_frag_0_4);
|
||||
acc_frag_0_5 = __WMMA_8_16_16_half_float(a_frag_0, b_frag_5, acc_frag_0_5);
|
||||
acc_frag_0_6 = __WMMA_8_16_16_half_float(a_frag_0, b_frag_6, acc_frag_0_6);
|
||||
acc_frag_0_7 = __WMMA_8_16_16_half_float(a_frag_0, b_frag_7, acc_frag_0_7);
|
||||
acc_frag_1_0 = __WMMA_8_16_16_half_float(a_frag_1, b_frag_0, acc_frag_1_0);
|
||||
acc_frag_1_1 = __WMMA_8_16_16_half_float(a_frag_1, b_frag_1, acc_frag_1_1);
|
||||
acc_frag_1_2 = __WMMA_8_16_16_half_float(a_frag_1, b_frag_2, acc_frag_1_2);
|
||||
acc_frag_1_3 = __WMMA_8_16_16_half_float(a_frag_1, b_frag_3, acc_frag_1_3);
|
||||
acc_frag_1_4 = __WMMA_8_16_16_half_float(a_frag_1, b_frag_4, acc_frag_1_4);
|
||||
acc_frag_1_5 = __WMMA_8_16_16_half_float(a_frag_1, b_frag_5, acc_frag_1_5);
|
||||
acc_frag_1_6 = __WMMA_8_16_16_half_float(a_frag_1, b_frag_6, acc_frag_1_6);
|
||||
acc_frag_1_7 = __WMMA_8_16_16_half_float(a_frag_1, b_frag_7, acc_frag_1_7);
|
||||
|
||||
// next 16 K elements
|
||||
__ldmatrix_a_elems(&a_frag_0, &smem_a_curr[load_smem_a_0_k_2]);
|
||||
__ldmatrix_a_elems(&a_frag_1, &smem_a_curr[load_smem_a_1_k_2]);
|
||||
__ldmatrix_b_elems(&b_frag_0, &b_frag_1, &smem_b_curr[load_smem_b_0_k_2]);
|
||||
__ldmatrix_b_elems(&b_frag_2, &b_frag_3, &smem_b_curr[load_smem_b_1_k_2]);
|
||||
__ldmatrix_b_elems(&b_frag_4, &b_frag_5, &smem_b_curr[load_smem_b_2_k_2]);
|
||||
__ldmatrix_b_elems(&b_frag_6, &b_frag_7, &smem_b_curr[load_smem_b_3_k_2]);
|
||||
acc_frag_0_0 = __WMMA_8_16_16_half_float(a_frag_0, b_frag_0, acc_frag_0_0);
|
||||
acc_frag_0_1 = __WMMA_8_16_16_half_float(a_frag_0, b_frag_1, acc_frag_0_1);
|
||||
acc_frag_0_2 = __WMMA_8_16_16_half_float(a_frag_0, b_frag_2, acc_frag_0_2);
|
||||
acc_frag_0_3 = __WMMA_8_16_16_half_float(a_frag_0, b_frag_3, acc_frag_0_3);
|
||||
acc_frag_0_4 = __WMMA_8_16_16_half_float(a_frag_0, b_frag_4, acc_frag_0_4);
|
||||
acc_frag_0_5 = __WMMA_8_16_16_half_float(a_frag_0, b_frag_5, acc_frag_0_5);
|
||||
acc_frag_0_6 = __WMMA_8_16_16_half_float(a_frag_0, b_frag_6, acc_frag_0_6);
|
||||
acc_frag_0_7 = __WMMA_8_16_16_half_float(a_frag_0, b_frag_7, acc_frag_0_7);
|
||||
acc_frag_1_0 = __WMMA_8_16_16_half_float(a_frag_1, b_frag_0, acc_frag_1_0);
|
||||
acc_frag_1_1 = __WMMA_8_16_16_half_float(a_frag_1, b_frag_1, acc_frag_1_1);
|
||||
acc_frag_1_2 = __WMMA_8_16_16_half_float(a_frag_1, b_frag_2, acc_frag_1_2);
|
||||
acc_frag_1_3 = __WMMA_8_16_16_half_float(a_frag_1, b_frag_3, acc_frag_1_3);
|
||||
acc_frag_1_4 = __WMMA_8_16_16_half_float(a_frag_1, b_frag_4, acc_frag_1_4);
|
||||
acc_frag_1_5 = __WMMA_8_16_16_half_float(a_frag_1, b_frag_5, acc_frag_1_5);
|
||||
acc_frag_1_6 = __WMMA_8_16_16_half_float(a_frag_1, b_frag_6, acc_frag_1_6);
|
||||
acc_frag_1_7 = __WMMA_8_16_16_half_float(a_frag_1, b_frag_7, acc_frag_1_7);
|
||||
|
||||
// last 16 K elements
|
||||
__ldmatrix_a_elems(&a_frag_0, &smem_a_curr[load_smem_a_0_k_3]);
|
||||
__ldmatrix_a_elems(&a_frag_1, &smem_a_curr[load_smem_a_1_k_3]);
|
||||
__ldmatrix_b_elems(&b_frag_0, &b_frag_1, &smem_b_curr[load_smem_b_0_k_3]);
|
||||
__ldmatrix_b_elems(&b_frag_2, &b_frag_3, &smem_b_curr[load_smem_b_1_k_3]);
|
||||
__ldmatrix_b_elems(&b_frag_4, &b_frag_5, &smem_b_curr[load_smem_b_2_k_3]);
|
||||
__ldmatrix_b_elems(&b_frag_6, &b_frag_7, &smem_b_curr[load_smem_b_3_k_3]);
|
||||
acc_frag_0_0 = __WMMA_8_16_16_half_float(a_frag_0, b_frag_0, acc_frag_0_0);
|
||||
acc_frag_0_1 = __WMMA_8_16_16_half_float(a_frag_0, b_frag_1, acc_frag_0_1);
|
||||
acc_frag_0_2 = __WMMA_8_16_16_half_float(a_frag_0, b_frag_2, acc_frag_0_2);
|
||||
acc_frag_0_3 = __WMMA_8_16_16_half_float(a_frag_0, b_frag_3, acc_frag_0_3);
|
||||
acc_frag_0_4 = __WMMA_8_16_16_half_float(a_frag_0, b_frag_4, acc_frag_0_4);
|
||||
acc_frag_0_5 = __WMMA_8_16_16_half_float(a_frag_0, b_frag_5, acc_frag_0_5);
|
||||
acc_frag_0_6 = __WMMA_8_16_16_half_float(a_frag_0, b_frag_6, acc_frag_0_6);
|
||||
acc_frag_0_7 = __WMMA_8_16_16_half_float(a_frag_0, b_frag_7, acc_frag_0_7);
|
||||
acc_frag_1_0 = __WMMA_8_16_16_half_float(a_frag_1, b_frag_0, acc_frag_1_0);
|
||||
acc_frag_1_1 = __WMMA_8_16_16_half_float(a_frag_1, b_frag_1, acc_frag_1_1);
|
||||
acc_frag_1_2 = __WMMA_8_16_16_half_float(a_frag_1, b_frag_2, acc_frag_1_2);
|
||||
acc_frag_1_3 = __WMMA_8_16_16_half_float(a_frag_1, b_frag_3, acc_frag_1_3);
|
||||
acc_frag_1_4 = __WMMA_8_16_16_half_float(a_frag_1, b_frag_4, acc_frag_1_4);
|
||||
acc_frag_1_5 = __WMMA_8_16_16_half_float(a_frag_1, b_frag_5, acc_frag_1_5);
|
||||
acc_frag_1_6 = __WMMA_8_16_16_half_float(a_frag_1, b_frag_6, acc_frag_1_6);
|
||||
acc_frag_1_7 = __WMMA_8_16_16_half_float(a_frag_1, b_frag_7, acc_frag_1_7);
|
||||
|
||||
// prefetch next iteration if needed
|
||||
__syncthreads();
|
||||
if (block_k < (num_k_blocks-1)) {
|
||||
__pipeline_memcpy_async(&smem_a_curr[store_smem_a_off + ( 0)], &data1[global_a_off + ( 0)], 16);
|
||||
__pipeline_memcpy_async(&smem_a_curr[store_smem_a_off + (16*64)], &data1[global_a_off + (16*K)], 16);
|
||||
__pipeline_memcpy_async(&smem_a_curr[store_smem_a_off + (32*64)], &data1[global_a_off + (32*K)], 16);
|
||||
__pipeline_memcpy_async(&smem_a_curr[store_smem_a_off + (48*64)], &data1[global_a_off + (48*K)], 16);
|
||||
|
||||
__pipeline_memcpy_async(&smem_b_curr[store_smem_b_off + ( 0)], &data2[global_b_off + ( 0)], 16);
|
||||
__pipeline_memcpy_async(&smem_b_curr[store_smem_b_off + ( 8*128)], &data2[global_b_off + ( 8*N)], 16);
|
||||
__pipeline_memcpy_async(&smem_b_curr[store_smem_b_off + (16*128)], &data2[global_b_off + (16*N)], 16);
|
||||
__pipeline_memcpy_async(&smem_b_curr[store_smem_b_off + (24*128)], &data2[global_b_off + (24*N)], 16);
|
||||
__pipeline_memcpy_async(&smem_b_curr[store_smem_b_off + (32*128)], &data2[global_b_off + (32*N)], 16);
|
||||
__pipeline_memcpy_async(&smem_b_curr[store_smem_b_off + (40*128)], &data2[global_b_off + (40*N)], 16);
|
||||
__pipeline_memcpy_async(&smem_b_curr[store_smem_b_off + (48*128)], &data2[global_b_off + (48*N)], 16);
|
||||
__pipeline_memcpy_async(&smem_b_curr[store_smem_b_off + (56*128)], &data2[global_b_off + (56*N)], 16);
|
||||
|
||||
global_a_off += 64;
|
||||
global_b_off += 64 * N;
|
||||
}
|
||||
__pipeline_commit();
|
||||
}
|
||||
|
||||
// write accumulators to output
|
||||
__pipeline_wait_prior(0);
|
||||
__syncthreads();
|
||||
|
||||
// slower way: write floats one by one to data0
|
||||
size_t wg_c_off = ((grid_m * 64) * N) + (grid_n * 128) + (wg_m * 16 * N) + (wg_n * 16);
|
||||
size_t thread_c_off = ((wg_threads % 4) * 2) + (((wg_threads / 4) % 8) * N);
|
||||
data0[wg_c_off + thread_c_off + 0 + ( 0*8)] = acc_frag_0_0.x;
|
||||
data0[wg_c_off + thread_c_off + 1 + ( 0*8)] = acc_frag_0_0.y;
|
||||
data0[wg_c_off + thread_c_off + (8 * N) + 0 + ( 0*8)] = acc_frag_0_0.z;
|
||||
data0[wg_c_off + thread_c_off + (8 * N) + 1 + ( 0*8)] = acc_frag_0_0.w;
|
||||
data0[wg_c_off + thread_c_off + 0 + ( 1*8)] = acc_frag_0_1.x;
|
||||
data0[wg_c_off + thread_c_off + 1 + ( 1*8)] = acc_frag_0_1.y;
|
||||
data0[wg_c_off + thread_c_off + (8 * N) + 0 + ( 1*8)] = acc_frag_0_1.z;
|
||||
data0[wg_c_off + thread_c_off + (8 * N) + 1 + ( 1*8)] = acc_frag_0_1.w;
|
||||
data0[wg_c_off + thread_c_off + 0 + ( 4*8)] = acc_frag_0_2.x;
|
||||
data0[wg_c_off + thread_c_off + 1 + ( 4*8)] = acc_frag_0_2.y;
|
||||
data0[wg_c_off + thread_c_off + (8 * N) + 0 + ( 4*8)] = acc_frag_0_2.z;
|
||||
data0[wg_c_off + thread_c_off + (8 * N) + 1 + ( 4*8)] = acc_frag_0_2.w;
|
||||
data0[wg_c_off + thread_c_off + 0 + ( 5*8)] = acc_frag_0_3.x;
|
||||
data0[wg_c_off + thread_c_off + 1 + ( 5*8)] = acc_frag_0_3.y;
|
||||
data0[wg_c_off + thread_c_off + (8 * N) + 0 + ( 5*8)] = acc_frag_0_3.z;
|
||||
data0[wg_c_off + thread_c_off + (8 * N) + 1 + ( 5*8)] = acc_frag_0_3.w;
|
||||
data0[wg_c_off + thread_c_off + 0 + ( 8*8)] = acc_frag_0_4.x;
|
||||
data0[wg_c_off + thread_c_off + 1 + ( 8*8)] = acc_frag_0_4.y;
|
||||
data0[wg_c_off + thread_c_off + (8 * N) + 0 + ( 8*8)] = acc_frag_0_4.z;
|
||||
data0[wg_c_off + thread_c_off + (8 * N) + 1 + ( 8*8)] = acc_frag_0_4.w;
|
||||
data0[wg_c_off + thread_c_off + 0 + ( 9*8)] = acc_frag_0_5.x;
|
||||
data0[wg_c_off + thread_c_off + 1 + ( 9*8)] = acc_frag_0_5.y;
|
||||
data0[wg_c_off + thread_c_off + (8 * N) + 0 + ( 9*8)] = acc_frag_0_5.z;
|
||||
data0[wg_c_off + thread_c_off + (8 * N) + 1 + ( 9*8)] = acc_frag_0_5.w;
|
||||
data0[wg_c_off + thread_c_off + 0 + (12*8)] = acc_frag_0_6.x;
|
||||
data0[wg_c_off + thread_c_off + 1 + (12*8)] = acc_frag_0_6.y;
|
||||
data0[wg_c_off + thread_c_off + (8 * N) + 0 + (12*8)] = acc_frag_0_6.z;
|
||||
data0[wg_c_off + thread_c_off + (8 * N) + 1 + (12*8)] = acc_frag_0_6.w;
|
||||
data0[wg_c_off + thread_c_off + 0 + (13*8)] = acc_frag_0_7.x;
|
||||
data0[wg_c_off + thread_c_off + 1 + (13*8)] = acc_frag_0_7.y;
|
||||
data0[wg_c_off + thread_c_off + (8 * N) + 0 + (13*8)] = acc_frag_0_7.z;
|
||||
data0[wg_c_off + thread_c_off + (8 * N) + 1 + (13*8)] = acc_frag_0_7.w;
|
||||
wg_c_off += 32*N;
|
||||
data0[wg_c_off + thread_c_off + 0 + ( 0*8)] = acc_frag_1_0.x;
|
||||
data0[wg_c_off + thread_c_off + 1 + ( 0*8)] = acc_frag_1_0.y;
|
||||
data0[wg_c_off + thread_c_off + (8 * N) + 0 + ( 0*8)] = acc_frag_1_0.z;
|
||||
data0[wg_c_off + thread_c_off + (8 * N) + 1 + ( 0*8)] = acc_frag_1_0.w;
|
||||
data0[wg_c_off + thread_c_off + 0 + ( 1*8)] = acc_frag_1_1.x;
|
||||
data0[wg_c_off + thread_c_off + 1 + ( 1*8)] = acc_frag_1_1.y;
|
||||
data0[wg_c_off + thread_c_off + (8 * N) + 0 + ( 1*8)] = acc_frag_1_1.z;
|
||||
data0[wg_c_off + thread_c_off + (8 * N) + 1 + ( 1*8)] = acc_frag_1_1.w;
|
||||
data0[wg_c_off + thread_c_off + 0 + ( 4*8)] = acc_frag_1_2.x;
|
||||
data0[wg_c_off + thread_c_off + 1 + ( 4*8)] = acc_frag_1_2.y;
|
||||
data0[wg_c_off + thread_c_off + (8 * N) + 0 + ( 4*8)] = acc_frag_1_2.z;
|
||||
data0[wg_c_off + thread_c_off + (8 * N) + 1 + ( 4*8)] = acc_frag_1_2.w;
|
||||
data0[wg_c_off + thread_c_off + 0 + ( 5*8)] = acc_frag_1_3.x;
|
||||
data0[wg_c_off + thread_c_off + 1 + ( 5*8)] = acc_frag_1_3.y;
|
||||
data0[wg_c_off + thread_c_off + (8 * N) + 0 + ( 5*8)] = acc_frag_1_3.z;
|
||||
data0[wg_c_off + thread_c_off + (8 * N) + 1 + ( 5*8)] = acc_frag_1_3.w;
|
||||
data0[wg_c_off + thread_c_off + 0 + ( 8*8)] = acc_frag_1_4.x;
|
||||
data0[wg_c_off + thread_c_off + 1 + ( 8*8)] = acc_frag_1_4.y;
|
||||
data0[wg_c_off + thread_c_off + (8 * N) + 0 + ( 8*8)] = acc_frag_1_4.z;
|
||||
data0[wg_c_off + thread_c_off + (8 * N) + 1 + ( 8*8)] = acc_frag_1_4.w;
|
||||
data0[wg_c_off + thread_c_off + 0 + ( 9*8)] = acc_frag_1_5.x;
|
||||
data0[wg_c_off + thread_c_off + 1 + ( 9*8)] = acc_frag_1_5.y;
|
||||
data0[wg_c_off + thread_c_off + (8 * N) + 0 + ( 9*8)] = acc_frag_1_5.z;
|
||||
data0[wg_c_off + thread_c_off + (8 * N) + 1 + ( 9*8)] = acc_frag_1_5.w;
|
||||
data0[wg_c_off + thread_c_off + 0 + (12*8)] = acc_frag_1_6.x;
|
||||
data0[wg_c_off + thread_c_off + 1 + (12*8)] = acc_frag_1_6.y;
|
||||
data0[wg_c_off + thread_c_off + (8 * N) + 0 + (12*8)] = acc_frag_1_6.z;
|
||||
data0[wg_c_off + thread_c_off + (8 * N) + 1 + (12*8)] = acc_frag_1_6.w;
|
||||
data0[wg_c_off + thread_c_off + 0 + (13*8)] = acc_frag_1_7.x;
|
||||
data0[wg_c_off + thread_c_off + 1 + (13*8)] = acc_frag_1_7.y;
|
||||
data0[wg_c_off + thread_c_off + (8 * N) + 0 + (13*8)] = acc_frag_1_7.z;
|
||||
data0[wg_c_off + thread_c_off + (8 * N) + 1 + (13*8)] = acc_frag_1_7.w;
|
||||
}
|
||||
232
extra/gemm/max_matmul.py
Normal file
232
extra/gemm/max_matmul.py
Normal file
@@ -0,0 +1,232 @@
|
||||
import numpy as np, os
|
||||
from tinygrad.helpers import getenv, flat_mv
|
||||
from tinygrad import dtypes
|
||||
from typing import Optional, List, Tuple, cast, Dict, Final, DefaultDict, Self
|
||||
|
||||
# for copied uops
|
||||
from tinygrad.codegen.kernel import Kernel, KernelOptError
|
||||
from tinygrad.ops import UOp, Ops, BinaryOps, UnaryOps, TernaryOps, KernelInfo
|
||||
from tinygrad.engine.search import Opt, OptOps
|
||||
from tinygrad import Device, dtypes, Tensor
|
||||
from tinygrad.dtype import PtrDType, DType, DTYPES_DICT
|
||||
from tinygrad.shape.shapetracker import ShapeTracker
|
||||
from tinygrad.shape.view import View
|
||||
|
||||
script_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
|
||||
# problem variations
|
||||
DTYPE_IN = DTYPES_DICT[getenv("DTYPE_IN", "half")]
|
||||
DTYPE_OUT = DTYPES_DICT[getenv("DTYPE_OUT", "half")]
|
||||
DTYPE_ACC = DTYPES_DICT[getenv("DTYPE_ACC", "float")]
|
||||
N = getenv("N", 4096)
|
||||
M = getenv("M", N)
|
||||
K = getenv("K", N)
|
||||
CNT = getenv("CNT", 10)
|
||||
ATOL = getenv("ATOL", 5e-3 if DTYPE_IN == dtypes.float else 1e-2)
|
||||
RTOL = getenv("RTOL", 1e-4 if DTYPE_IN == dtypes.float else 1e-3)
|
||||
FLOPS = M * N * K * 2
|
||||
BW = 2 * ((M*K) + (K*N) + (M*N))
|
||||
|
||||
# algorithm variations
|
||||
INPUT = getenv("INPUT", "RAND")
|
||||
GEMM_VARIATION = getenv("GEMM_VARIATION", "nv_hcopt")
|
||||
|
||||
def randoms():
|
||||
if INPUT == "RAND":
|
||||
na = np.random.default_rng().normal(scale=1.0, size=(M,K)).astype(dtype=np.float32)
|
||||
nb = np.random.default_rng().normal(scale=1.0, size=(K,N)).astype(dtype=np.float32)
|
||||
elif INPUT == "IDENTITY" and M==N==K:
|
||||
na = np.identity(K, dtype=np.float32)
|
||||
nb = np.identity(K, dtype=np.float32)
|
||||
elif INPUT == "OUTPUTONES" and M==K:
|
||||
na = np.identity(K, dtype=np.float32)
|
||||
nb = np.ones((K,N), dtype=np.float32)
|
||||
else:
|
||||
na = np.ones((M,K), dtype=np.float32)
|
||||
nb = np.ones((K,N), dtype=np.float32)
|
||||
nc = np.zeros(M*N, np.float32)
|
||||
if DTYPE_IN != dtypes.float:
|
||||
na = na.astype(np.bfloat16 if DTYPE_IN == dtypes.bfloat16 else np.float16)
|
||||
nb = nb.astype(np.bfloat16 if DTYPE_IN == dtypes.bfloat16 else np.float16)
|
||||
if DTYPE_OUT != dtypes.float:
|
||||
nc = nc.astype(np.bfloat16 if DTYPE_IN == dtypes.bfloat16 else np.float16)
|
||||
return na, nb, nc
|
||||
|
||||
def ast_to_cuda_prog(compiler, ast, opts):
|
||||
k = Kernel(ast)
|
||||
k.required_optimizations()
|
||||
for opt in opts:
|
||||
k.apply_opt(opt)
|
||||
p = k.to_program()
|
||||
return CUDAProgram(device, p.function_name, compiler.compile(p.src))
|
||||
|
||||
if __name__ == "__main__":
|
||||
print(f"gemm variation: {GEMM_VARIATION=} {M=} {N=} {K=} {DTYPE_IN=} {DTYPE_OUT=} {DTYPE_ACC=}")
|
||||
prog, global_size, local_size = None, None, None
|
||||
|
||||
if getenv("CUDA") == 1:
|
||||
from tinygrad.runtime.ops_cuda import CUDAAllocator, CUDADevice, CUDAProgram, CUDACompiler
|
||||
device = CUDADevice("cuda:0")
|
||||
compiler = CUDACompiler(device.arch)
|
||||
cudaalloc = CUDAAllocator(device)
|
||||
|
||||
a = cudaalloc.alloc(M*K*DTYPE_IN.itemsize)
|
||||
b = cudaalloc.alloc(K*N*DTYPE_IN.itemsize)
|
||||
c = cudaalloc.alloc(M*N*DTYPE_OUT.itemsize)
|
||||
|
||||
if GEMM_VARIATION == "max" and (M%64)==0 and (N%128)==0 and (K%64)==0 and DTYPE_IN == dtypes.half and DTYPE_OUT == dtypes.float and DTYPE_ACC == dtypes.float:
|
||||
print("Using CUDA and triton-generated kernel")
|
||||
# See nv_triton_gemm.annotated.ptx for PTX code which was generated from `PYTHONPATH=. DEBUG=6 CUDA=1 PTX=1 python3 extra/gemm/triton_nv_matmul.py`
|
||||
# this kernel with M=N=K=4096 does 162TFLOPS, vs torch at 144TFLOPS and BEAM=8 tinygrad at 138TFLOPS. theo max is 165TFLOPS.
|
||||
|
||||
# WMMA element size is (M, N, K) = (16, 8, 16)
|
||||
# warpgroup size in WMMA tiles is (B_M, B_N, B_K) = (2, 8, 4) so 64 HMMA calls per threadgroup reduce iteration
|
||||
# thread block size is (T_M, T_N, T_K) = (2, 2, 1), i.e. macro blocks in M and N, so 256 HMMA calls per kernel reduce iteration
|
||||
# kernel reduce iteration size in elements = (64, 128, 64)
|
||||
# single iteration SMEM_A = (64 * 64) * (2 bytes / half) = 8192 bytes, SMEM_B = (128 * 64) * (2 bytes / half) = 16384 bytes
|
||||
# double-buffer smem = (8192 + 16384) * 2 = 49152 bytes
|
||||
# reduce for_loop size = [1, 1, (4096 // 16 // 4)==64]
|
||||
# NOTE: T_K > 0 would be group_for_reduce
|
||||
prog = CUDAProgram(device, "wmma_example", compiler.compile(open(os.path.join(script_dir, 'max_kernels/nv.fp16_fp32_fp32.max.cu')).read()))
|
||||
args = (c, a, b)
|
||||
kwargs = {
|
||||
'global_size': [M//64, N//128, 1],
|
||||
'local_size': [128, 1, 1], # 4 warpgroups == (T_M:=2) * (T_N:=2)
|
||||
'wait': True,
|
||||
'vals': (N, K),
|
||||
}
|
||||
elif GEMM_VARIATION == "2_stage_swizzled_smem_input" and (M%64)==0 and (N%128)==0 and (K%64)==0 and DTYPE_IN == dtypes.half and DTYPE_OUT == dtypes.float and DTYPE_ACC == dtypes.float:
|
||||
print("Using CUDA, 2-stage reduce pipeline, swizzled SMEM inputs")
|
||||
prog = CUDAProgram(device, "wmma_example", compiler.compile(open(os.path.join(script_dir, 'max_kernels/nv.fp16_fp32_fp32.2_stage_swizzled_smem_input.cu')).read()))
|
||||
args = (c, a, b)
|
||||
kwargs = {
|
||||
'global_size': [M//64, N//128, 1],
|
||||
'local_size': [128, 1, 1], # 4 warpgroups == (T_M:=2) * (T_N:=2)
|
||||
'wait': True,
|
||||
'vals': (N, K),
|
||||
}
|
||||
elif GEMM_VARIATION == "swizzled_smem_input" and (M%64)==0 and (N%128)==0 and (K%64)==0 and DTYPE_IN == dtypes.half and DTYPE_OUT == dtypes.float and DTYPE_ACC == dtypes.float:
|
||||
print("Using CUDA, swizzled SMEM inputs")
|
||||
prog = CUDAProgram(device, "wmma_example", compiler.compile(open(os.path.join(script_dir, 'max_kernels/nv.fp16_fp32_fp32.swizzled_smem_input.cu')).read()))
|
||||
args = (c, a, b)
|
||||
kwargs = {
|
||||
'global_size': [M//64, N//128, 1],
|
||||
'local_size': [128, 1, 1], # 4 warpgroups == (T_M:=2) * (T_N:=2)
|
||||
'wait': True,
|
||||
'vals': (N, K),
|
||||
}
|
||||
elif GEMM_VARIATION == "flat_smem_input" and (M%64)==0 and (N%128)==0 and (K%64)==0 and DTYPE_IN == dtypes.half and DTYPE_OUT == dtypes.float and DTYPE_ACC == dtypes.float:
|
||||
print("Using CUDA, flat SMEM inputs")
|
||||
prog = CUDAProgram(device, "wmma_example", compiler.compile(open(os.path.join(script_dir, 'max_kernels/nv.fp16_fp32_fp32.flat_smem_input.cu')).read()))
|
||||
args = (c, a, b)
|
||||
kwargs = {
|
||||
'global_size': [M//64, N//128, 1],
|
||||
'local_size': [128, 1, 1], # 4 warpgroups == (T_M:=2) * (T_N:=2)
|
||||
'wait': True,
|
||||
'vals': (N, K),
|
||||
}
|
||||
elif GEMM_VARIATION == "hcopt" and M == N == K == 4096 and DTYPE_IN == dtypes.half and DTYPE_OUT == dtypes.half and DTYPE_ACC == dtypes.float:
|
||||
print("Using CUDA and generated hcopt")
|
||||
# [Opt(op=OptOps.TC, axis=0, amt=0), Opt(op=OptOps.UPCAST, axis=0, amt=4), Opt(op=OptOps.UPCAST, axis=1, amt=4), Opt(op=OptOps.LOCAL, axis=1, amt=4)]
|
||||
prog = CUDAProgram(device, "wmma_example", compiler.compile(open(os.path.join(script_dir, 'max_kernels/nv.fp16_fp32_fp16.hcopt.cu')).read()))
|
||||
args = (c, a, b)
|
||||
kwargs = {
|
||||
'global_size': [32, 64, 1],
|
||||
'local_size': [16, 2, 4], # 16,2 are warp, 4 workgroups upcasted to axis=1
|
||||
'wait': True,
|
||||
}
|
||||
elif GEMM_VARIATION == "2_stage" and (M%64)== 0 and (N%128)==0 and (K%64)==0 and DTYPE_IN == dtypes.half and DTYPE_OUT == dtypes.half and DTYPE_ACC == dtypes.half:
|
||||
print("Using CUDA and un-optimized 2-stage, swizzled SMEM inputs and direct acc to output kernel")
|
||||
prog = CUDAProgram(device, "wmma_example", compiler.compile(open(os.path.join(script_dir, 'max_kernels/nv.fp16_fp16_fp16.2_stage.cu')).read()))
|
||||
args = (c, a, b)
|
||||
kwargs = {
|
||||
'global_size': [M//64, N//128, 1],
|
||||
'local_size': [128, 1, 1], # 4 warpgroups == (T_M:=2) * (T_N:=2)
|
||||
'wait': True,
|
||||
'vals': (N, K),
|
||||
}
|
||||
elif GEMM_VARIATION == "3_stage" and (M%256)== 0 and (N%128)==0 and (K%32)==0 and DTYPE_IN == dtypes.half and DTYPE_OUT == dtypes.half and DTYPE_ACC == dtypes.half:
|
||||
print("Using CUDA and 3-stage (interleave global copies and ldmatrix)")
|
||||
prog = CUDAProgram(device, "wmma_example", compiler.compile(open(os.path.join(script_dir, 'max_kernels/nv.fp16_fp16_fp16.3_stage.cu')).read()), 73728)
|
||||
args = (c, a, b)
|
||||
kwargs = {
|
||||
'global_size': [M//256, N//128, 1],
|
||||
'local_size': [32, 4, 2], # 8 warpgroups, WG_M=4 and WG_N=2
|
||||
'wait': True,
|
||||
'vals': (N, K),
|
||||
}
|
||||
elif GEMM_VARIATION == "3_stage_swizzled" and (M%256)== 0 and (N%128)==0 and (K%32)==0 and DTYPE_IN == dtypes.half and DTYPE_OUT == dtypes.half and DTYPE_ACC == dtypes.half:
|
||||
print("Using CUDA and 3-stage (interleave global copies and ldmatrix) and swizzled SMEM inputs")
|
||||
prog = CUDAProgram(device, "wmma_example", compiler.compile(open(os.path.join(script_dir, 'max_kernels/nv.fp16_fp16_fp16.3_stage_swizzled.cu')).read()), 73728)
|
||||
args = (c, a, b)
|
||||
kwargs = {
|
||||
'global_size': [M//256, N//128, 1],
|
||||
'local_size': [32, 4, 2], # 8 warpgroups, WG_M=4 and WG_N=2
|
||||
'wait': True,
|
||||
'vals': (N, K),
|
||||
}
|
||||
elif GEMM_VARIATION == "max" and (M%256)== 0 and (N%128)==0 and (K%32)==0 and DTYPE_IN == dtypes.half and DTYPE_OUT == dtypes.half and DTYPE_ACC == dtypes.half:
|
||||
print("Using CUDA and 3-stage (interleave global copies and ldmatrix), swizzled SMEM inputs and epilogue")
|
||||
prog = CUDAProgram(device, "wmma_example", compiler.compile(open(os.path.join(script_dir, 'max_kernels/nv.fp16_fp16_fp16.max.cu')).read()), 73728)
|
||||
args = (c, a, b)
|
||||
kwargs = {
|
||||
'global_size': [M//256, N//128, 1],
|
||||
'local_size': [32, 4, 2], # 8 warpgroups, WG_M=4 and WG_N=2
|
||||
'wait': True,
|
||||
'vals': (N, K),
|
||||
}
|
||||
elif GEMM_VARIATION == "no_xor" and (M%256)== 0 and (N%128)==0 and (K%32)==0 and DTYPE_IN == dtypes.half and DTYPE_OUT == dtypes.half and DTYPE_ACC == dtypes.half:
|
||||
print("Using CUDA and 3-stage (interleave global copies and ldmatrix), swizzled SMEM inputs and epilogue")
|
||||
prog = CUDAProgram(device, "wmma_example", compiler.compile(open(os.path.join(script_dir, 'max_kernels/nv.fp16_fp16_fp16.no_xor.cu')).read()), 73728)
|
||||
args = (c, a, b)
|
||||
kwargs = {
|
||||
'global_size': [M//256, N//128, 1],
|
||||
'local_size': [32, 4, 2], # 8 warpgroups, WG_M=4 and WG_N=2
|
||||
'wait': True,
|
||||
'vals': (N, K),
|
||||
}
|
||||
else:
|
||||
raise RuntimeError(f"invalid gemm variation: {GEMM_VARIATION=} {M=} {N=} {K=} {DTYPE_IN=} {DTYPE_OUT=} {DTYPE_ACC=}")
|
||||
|
||||
tms = []
|
||||
na, nb, nc = randoms()
|
||||
cudaalloc.copyin(a, bytearray(na))
|
||||
cudaalloc.copyin(b, bytearray(nb))
|
||||
for i in range(CNT):
|
||||
tms.append(prog(*args, **kwargs))
|
||||
cudaalloc.copyout(flat_mv(nc.data), c)
|
||||
comp = na.astype(np.float32) @ nb.astype(np.float32)
|
||||
result = nc.reshape(M, N).astype(np.float32)
|
||||
|
||||
print(f"{N*N:10d} {min(tms)*1e6:9.2f} us, would be {FLOPS*1e-9/min(tms):9.2f} GFLOPS matmul, {BW*1e-9/min(tms):.2f} GB/s")
|
||||
try:
|
||||
np.testing.assert_allclose(result, comp, atol=ATOL, rtol=RTOL)
|
||||
except AssertionError as e:
|
||||
if getenv("DEBUG_VALUES") > 0:
|
||||
indices = np.where(~np.isclose(result, comp, rtol=RTOL, atol=ATOL))
|
||||
non_matching_elements_result = result[indices]
|
||||
non_matching_elements_comp = comp[indices]
|
||||
print("valid :", np.where(np.isclose(result, comp, rtol=RTOL, atol=ATOL)))
|
||||
print("invalid :", indices)
|
||||
print("result :", non_matching_elements_result)
|
||||
print("ground truth:", non_matching_elements_comp)
|
||||
print("result sum :", np.sum(result))
|
||||
print("ground sum :", np.sum(comp))
|
||||
raise e
|
||||
|
||||
if getenv("DEBUG_VALUES") > 0:
|
||||
print(comp)
|
||||
print("ground sum :", np.sum(comp))
|
||||
print(result)
|
||||
print("result sum :", np.sum(result))
|
||||
|
||||
elif getenv("AMD") == 1:
|
||||
# note: https://hipfft.readthedocs.io/en/rocm-6.1.2/how-to/fine-tuning-llms/optimizing-triton-kernel.html
|
||||
|
||||
# also this is different than the rocblas/tensile approach to GEMM
|
||||
# see: https://github.com/ROCm/Tensile/blob/develop/Tensile/KernelWriterAssembly.py
|
||||
raise RuntimeError("invalid max_matmul device")
|
||||
|
||||
else:
|
||||
raise RuntimeError("invalid max_matmul device")
|
||||
|
||||
Reference in New Issue
Block a user