From 1e5d9ad8f75beaa321659623cdcb67720b322086 Mon Sep 17 00:00:00 2001 From: Francis Lam Date: Wed, 19 Mar 2025 00:04:57 -0700 Subject: [PATCH] extra/gemm/max_matmul: start of custom kernels for GEMM (#6926) * extra/gemm/max_matmul: start of custom kernels for GEMM * add an unoptimized FP16/FP16 MMA example * add slow 3-stage fp16 acc example * add correct 3-stage pipeline with unswizzled/flat smem input (slow) * add acc fp16 example with 3 stages and swizzle (no bank conflicts) * add max version of NV fp16_fp16_fp16 * fix up comments and removed unused code in max variations * add start of no_xor example * fix to account for UOps to Ops --- .../max_kernels/nv.fp16_fp16_fp16.2_stage.cu | 508 +++++++++++++++++ .../max_kernels/nv.fp16_fp16_fp16.3_stage.cu | 465 ++++++++++++++++ .../nv.fp16_fp16_fp16.3_stage_swizzled.cu | 517 ++++++++++++++++++ .../gemm/max_kernels/nv.fp16_fp16_fp16.max.cu | 482 ++++++++++++++++ .../max_kernels/nv.fp16_fp16_fp16.no_xor.cu | 486 ++++++++++++++++ .../max_kernels/nv.fp16_fp32_fp16.hcopt.cu | 157 ++++++ ...6_fp32_fp32.2_stage_swizzled_smem_input.cu | 398 ++++++++++++++ .../nv.fp16_fp32_fp32.flat_smem_input.cu | 363 ++++++++++++ .../gemm/max_kernels/nv.fp16_fp32_fp32.max.cu | 439 +++++++++++++++ .../nv.fp16_fp32_fp32.swizzled_smem_input.cu | 371 +++++++++++++ extra/gemm/max_matmul.py | 232 ++++++++ 11 files changed, 4418 insertions(+) create mode 100644 extra/gemm/max_kernels/nv.fp16_fp16_fp16.2_stage.cu create mode 100644 extra/gemm/max_kernels/nv.fp16_fp16_fp16.3_stage.cu create mode 100644 extra/gemm/max_kernels/nv.fp16_fp16_fp16.3_stage_swizzled.cu create mode 100644 extra/gemm/max_kernels/nv.fp16_fp16_fp16.max.cu create mode 100644 extra/gemm/max_kernels/nv.fp16_fp16_fp16.no_xor.cu create mode 100644 extra/gemm/max_kernels/nv.fp16_fp32_fp16.hcopt.cu create mode 100644 extra/gemm/max_kernels/nv.fp16_fp32_fp32.2_stage_swizzled_smem_input.cu create mode 100644 extra/gemm/max_kernels/nv.fp16_fp32_fp32.flat_smem_input.cu create mode 100644 extra/gemm/max_kernels/nv.fp16_fp32_fp32.max.cu create mode 100644 extra/gemm/max_kernels/nv.fp16_fp32_fp32.swizzled_smem_input.cu create mode 100644 extra/gemm/max_matmul.py diff --git a/extra/gemm/max_kernels/nv.fp16_fp16_fp16.2_stage.cu b/extra/gemm/max_kernels/nv.fp16_fp16_fp16.2_stage.cu new file mode 100644 index 0000000000..32879b1e64 --- /dev/null +++ b/extra/gemm/max_kernels/nv.fp16_fp16_fp16.2_stage.cu @@ -0,0 +1,508 @@ +#define INFINITY (__int_as_float(0x7f800000)) +#define NAN (__int_as_float(0x7fffffff)) +#include +#include +#define N_PAD 132 + +struct __align__(8) half4 { half x, y, z, w; }; +__device__ half4 make_half4(half x, half y, half z, half w) { half4 r={x, y, z, w}; return r; } + +struct __align__(16) half8 { half x, y, z, w, a, b, c, d; }; +__device__ half8 make_half8(half x, half y, half z, half w, half a, half b, half c, half d) { half8 r={x, y, z, w, a, b, c, d}; return r; } + +__device__ void __ldmatrix_a_elems(half8 *regs, half *smem) { + uint32_t reg0, reg1, reg2, reg3; + asm volatile( + "ldmatrix.sync.aligned.m8n8.x4.shared.b16 {%0, %1, %2, %3}, [%4];" + : "=r"(reg0), "=r"(reg1), "=r"(reg2), "=r"(reg3) + : "l"(__cvta_generic_to_shared(smem)) + ); + uint32_t *addr = reinterpret_cast(regs); + addr[0] = reg0; + addr[1] = reg1; + addr[2] = reg2; + addr[3] = reg3; +} + +__device__ void __ldmatrix_b_elems(half4 *regs_lo, half4 *regs_hi, half *smem) { + uint32_t reg0, reg1, reg2, reg3; + asm volatile( + "ldmatrix.sync.aligned.m8n8.x4.trans.shared.b16 {%0, %1, %2, %3}, [%4];" + : "=r"(reg0), "=r"(reg1), "=r"(reg2), "=r"(reg3) + : "l"(__cvta_generic_to_shared(smem)) + ); + uint32_t *addr_lo = reinterpret_cast(regs_lo); + uint32_t *addr_hi = reinterpret_cast(regs_hi); + addr_lo[0] = reg0; + addr_lo[1] = reg1; + addr_hi[0] = reg2; + addr_hi[1] = reg3; +} + +__device__ half4 __WMMA_8_16_16_half_half(half8 a, half4 b, half4 c) { + int *a_pk = (int *) (&a), *b_pk = (int *) (&b), *c_pk = (int *) (&c); + asm( "mma.sync.aligned.m16n8k16.row.col.f16.f16.f16.f16 { %0, %1 }, { %2, %3, %4, %5 }, { %6, %7 }, { %0, %1 };" + : "+r"(c_pk[0]), "+r"(c_pk[1]): "r"(a_pk[0]), "r"(a_pk[1]), "r"(a_pk[2]), "r"(a_pk[3]), "r"(b_pk[0]), "r"(b_pk[1]) ); + return c; +} + +extern "C" __global__ void __launch_bounds__(128) wmma_example(half* data0, const half* data1, const half* data2, int N, int K) { + int grid_m = blockIdx.x; /* M//64 */ + int grid_n = blockIdx.y; /* N//128 */ + int threads = threadIdx.x; /* 128 */ + int wg_m = (threads/64); // 0 or 1 for 1st and 3rd blocks of b_m=16xb_k=16 vs 2nd and 4th blocks + int wg_n = (threads/32)%2; // 0 or 1 for 1st, 3rd, 5th, 7th blocks of b_n=16xb_k=16 vs 2nd, 4th, 6th, 8th blocks - differs from triton + int wg_threads = threads%32; + int num_k_blocks = K / 64; + + // load indexes + size_t global_a_off = ((grid_m * 64) * K) + ((threads % 8) * 8) + ((threads / 8) * K); + size_t global_b_off = (grid_n * 128) + ((threads % 16) * 8) + ((threads / 16) * N); + + // swizzled smem store offsets - columns of smem are swizzled + // here's a link to a description of the triton: https://github.com/triton-lang/triton/discussions/2026#discussioncomment-6746579 + // see also the thunderkittens impl: https://github.com/HazyResearch/ThunderKittens/blob/main/include/types/shared/st.cuh + size_t store_smem_a_off = ((threads / 8) * 64) + (((threads * 8) ^ threads) & 56); // r15 + size_t store_smem_b_off = ((threads / 16) * 128) + (((threads / 16) * 8) ^ ((threads % 16) * 8)); // r19\ + + // ldmatrix indices + // threads 0-7 are row starts for A, 8-15 for B, 16-23 for C, 24-31 for D + // [ A | C ] + // [ - + - ] + // [ B | D ] + + // swizzled ldmatrix + size_t load_smem_a_row = ((wg_m * 16) + (threads % 16)) * 64; // r293 + size_t load_smem_a_phase = (threads / 16) % 2; // r4 + + size_t load_smem_b_row = (threads % 16) * 128; // r299 + size_t load_smem_b_phase = (wg_n * 2) + (((threads / 16) % 2)); // r297 -- this differs from the generated triton kernel (swapped order) + + size_t load_smem_a_0_k_0 = load_smem_a_row + (((load_smem_a_phase + 0) ^ (threads % 8)) * 8); // r38 + size_t load_smem_a_1_k_0 = load_smem_a_0_k_0 + (32 * 64); + size_t load_smem_b_0_k_0 = load_smem_b_row + (((load_smem_b_phase + 0) ^ (threads % 8)) * 8); + size_t load_smem_b_1_k_0 = load_smem_b_row + (((load_smem_b_phase + 4) ^ (threads % 8)) * 8); + size_t load_smem_b_2_k_0 = load_smem_b_row + (((load_smem_b_phase + 8) ^ (threads % 8)) * 8); + size_t load_smem_b_3_k_0 = load_smem_b_row + (((load_smem_b_phase + 12) ^ (threads % 8)) * 8); + + size_t load_smem_a_0_k_1 = load_smem_a_row + (((load_smem_a_phase + 2) ^ (threads % 8)) * 8); // r58 = r293 + r316; + size_t load_smem_a_1_k_1 = load_smem_a_0_k_1 + (32 * 64); + size_t load_smem_b_0_k_1 = load_smem_b_0_k_0 + (16 * 128); + size_t load_smem_b_1_k_1 = load_smem_b_1_k_0 + (16 * 128); + size_t load_smem_b_2_k_1 = load_smem_b_2_k_0 + (16 * 128); + size_t load_smem_b_3_k_1 = load_smem_b_3_k_0 + (16 * 128); + + size_t load_smem_a_0_k_2 = load_smem_a_row + (((load_smem_a_phase + 4) ^ (threads % 8)) * 8); // r59 = r293 + r319; + size_t load_smem_a_1_k_2 = load_smem_a_0_k_2 + (32 * 64); + size_t load_smem_b_0_k_2 = load_smem_b_0_k_0 + (32 * 128); + size_t load_smem_b_1_k_2 = load_smem_b_1_k_0 + (32 * 128); + size_t load_smem_b_2_k_2 = load_smem_b_2_k_0 + (32 * 128); + size_t load_smem_b_3_k_2 = load_smem_b_3_k_0 + (32 * 128); + + size_t load_smem_a_0_k_3 = load_smem_a_row + (((load_smem_a_phase + 6) ^ (threads % 8)) * 8); // r60 = r293 + r322; + size_t load_smem_a_1_k_3 = load_smem_a_0_k_3 + (32 * 64); + size_t load_smem_b_0_k_3 = load_smem_b_0_k_0 + (48 * 128); + size_t load_smem_b_1_k_3 = load_smem_b_1_k_0 + (48 * 128); + size_t load_smem_b_2_k_3 = load_smem_b_2_k_0 + (48 * 128); + size_t load_smem_b_3_k_3 = load_smem_b_3_k_0 + (48 * 128); + + // create shared mem (A_1 8192 bytes, A_2 8192 bytes, B_1 16384 bytes, B2_16384 bytes) + __shared__ alignas(16) char smem[49152]; + + // create accs (16 WMMAs and 4 output elements each) and zero + half4 acc_frag_0_0 = make_half4(0.0f,0.0f,0.0f,0.0f); + half4 acc_frag_0_1 = make_half4(0.0f,0.0f,0.0f,0.0f); + half4 acc_frag_0_2 = make_half4(0.0f,0.0f,0.0f,0.0f); + half4 acc_frag_0_3 = make_half4(0.0f,0.0f,0.0f,0.0f); + half4 acc_frag_0_4 = make_half4(0.0f,0.0f,0.0f,0.0f); + half4 acc_frag_0_5 = make_half4(0.0f,0.0f,0.0f,0.0f); + half4 acc_frag_0_6 = make_half4(0.0f,0.0f,0.0f,0.0f); + half4 acc_frag_0_7 = make_half4(0.0f,0.0f,0.0f,0.0f); + half4 acc_frag_1_0 = make_half4(0.0f,0.0f,0.0f,0.0f); + half4 acc_frag_1_1 = make_half4(0.0f,0.0f,0.0f,0.0f); + half4 acc_frag_1_2 = make_half4(0.0f,0.0f,0.0f,0.0f); + half4 acc_frag_1_3 = make_half4(0.0f,0.0f,0.0f,0.0f); + half4 acc_frag_1_4 = make_half4(0.0f,0.0f,0.0f,0.0f); + half4 acc_frag_1_5 = make_half4(0.0f,0.0f,0.0f,0.0f); + half4 acc_frag_1_6 = make_half4(0.0f,0.0f,0.0f,0.0f); + half4 acc_frag_1_7 = make_half4(0.0f,0.0f,0.0f,0.0f); + + // create registers for block A elements (2) + half8 a_frag_0; + half8 a_frag_1; + + // create register for block B elements (8) + half4 b_frag_0; + half4 b_frag_1; + half4 b_frag_2; + half4 b_frag_3; + half4 b_frag_4; + half4 b_frag_5; + half4 b_frag_6; + half4 b_frag_7; + + half *smem_a_even = (half *)(smem); + half *smem_a_odd = (half *)(smem + 8192); + half *smem_b_even = (half *)(smem + 16384); + half *smem_b_odd = (half *)(smem + 32768); + + // https://developer.nvidia.com/blog/controlling-data-movement-to-boost-performance-on-ampere-architecture/ + // https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#asynchronous-data-copies + + // start first pre-fetch load A + __pipeline_memcpy_async(&smem_a_even[store_smem_a_off + ( 0)], &data1[global_a_off + ( 0)], 16); + __pipeline_memcpy_async(&smem_a_even[store_smem_a_off + (16*64)], &data1[global_a_off + (16*K)], 16); + __pipeline_memcpy_async(&smem_a_even[store_smem_a_off + (32*64)], &data1[global_a_off + (32*K)], 16); + __pipeline_memcpy_async(&smem_a_even[store_smem_a_off + (48*64)], &data1[global_a_off + (48*K)], 16); + + // start first pre-fetch load B + __pipeline_memcpy_async(&smem_b_even[store_smem_b_off + ( 0)], &data2[global_b_off + ( 0)], 16); + __pipeline_memcpy_async(&smem_b_even[store_smem_b_off + ( 8*128)], &data2[global_b_off + ( 8*N)], 16); + __pipeline_memcpy_async(&smem_b_even[store_smem_b_off + (16*128)], &data2[global_b_off + (16*N)], 16); + __pipeline_memcpy_async(&smem_b_even[store_smem_b_off + (24*128)], &data2[global_b_off + (24*N)], 16); + __pipeline_memcpy_async(&smem_b_even[store_smem_b_off + (32*128)], &data2[global_b_off + (32*N)], 16); + __pipeline_memcpy_async(&smem_b_even[store_smem_b_off + (40*128)], &data2[global_b_off + (40*N)], 16); + __pipeline_memcpy_async(&smem_b_even[store_smem_b_off + (48*128)], &data2[global_b_off + (48*N)], 16); + __pipeline_memcpy_async(&smem_b_even[store_smem_b_off + (56*128)], &data2[global_b_off + (56*N)], 16); + __pipeline_commit(); + + global_a_off += 64; + global_b_off += 64 * N; + __syncthreads(); + + // start second pre-fetch load A + __pipeline_memcpy_async(&smem_a_odd[store_smem_a_off + ( 0)], &data1[global_a_off + ( 0)], 16); + __pipeline_memcpy_async(&smem_a_odd[store_smem_a_off + (16*64)], &data1[global_a_off + (16*K)], 16); + __pipeline_memcpy_async(&smem_a_odd[store_smem_a_off + (32*64)], &data1[global_a_off + (32*K)], 16); + __pipeline_memcpy_async(&smem_a_odd[store_smem_a_off + (48*64)], &data1[global_a_off + (48*K)], 16); + + // start second pre-fetch load B + __pipeline_memcpy_async(&smem_b_odd[store_smem_b_off + ( 0)], &data2[global_b_off + ( 0)], 16); + __pipeline_memcpy_async(&smem_b_odd[store_smem_b_off + ( 8*128)], &data2[global_b_off + ( 8*N)], 16); + __pipeline_memcpy_async(&smem_b_odd[store_smem_b_off + (16*128)], &data2[global_b_off + (16*N)], 16); + __pipeline_memcpy_async(&smem_b_odd[store_smem_b_off + (24*128)], &data2[global_b_off + (24*N)], 16); + __pipeline_memcpy_async(&smem_b_odd[store_smem_b_off + (32*128)], &data2[global_b_off + (32*N)], 16); + __pipeline_memcpy_async(&smem_b_odd[store_smem_b_off + (40*128)], &data2[global_b_off + (40*N)], 16); + __pipeline_memcpy_async(&smem_b_odd[store_smem_b_off + (48*128)], &data2[global_b_off + (48*N)], 16); + __pipeline_memcpy_async(&smem_b_odd[store_smem_b_off + (56*128)], &data2[global_b_off + (56*N)], 16); + __pipeline_commit(); + + global_a_off += 64; + global_b_off += 64 * N; + + // wait on needed prefetch value + __pipeline_wait_prior(0); // TODO: this enables fast iterations, but incorrect results with 1 (it shouldn't) + __syncthreads(); + + for (int block_k = 0; block_k < num_k_blocks; block_k++) { + // BLOCK_K==4: unroll 4 iterations of ldmatrix/wmma + half *smem_a_curr = (block_k % 2) ? smem_a_even : smem_a_odd; + half *smem_b_curr = (block_k % 2) ? smem_b_even : smem_b_odd; + + // first load 16 K elements and 16 WMMAs: BLOCK_M==2 * BLOCK_N==8 + __ldmatrix_a_elems(&a_frag_0, &smem_a_curr[load_smem_a_0_k_0]); + __ldmatrix_a_elems(&a_frag_1, &smem_a_curr[load_smem_a_1_k_0]); + __ldmatrix_b_elems(&b_frag_0, &b_frag_1, &smem_b_curr[load_smem_b_0_k_0]); + __ldmatrix_b_elems(&b_frag_2, &b_frag_3, &smem_b_curr[load_smem_b_1_k_0]); + __ldmatrix_b_elems(&b_frag_4, &b_frag_5, &smem_b_curr[load_smem_b_2_k_0]); + __ldmatrix_b_elems(&b_frag_6, &b_frag_7, &smem_b_curr[load_smem_b_3_k_0]); + acc_frag_0_0 = __WMMA_8_16_16_half_half(a_frag_0, b_frag_0, acc_frag_0_0); + acc_frag_0_1 = __WMMA_8_16_16_half_half(a_frag_0, b_frag_1, acc_frag_0_1); + acc_frag_0_2 = __WMMA_8_16_16_half_half(a_frag_0, b_frag_2, acc_frag_0_2); + acc_frag_0_3 = __WMMA_8_16_16_half_half(a_frag_0, b_frag_3, acc_frag_0_3); + acc_frag_0_4 = __WMMA_8_16_16_half_half(a_frag_0, b_frag_4, acc_frag_0_4); + acc_frag_0_5 = __WMMA_8_16_16_half_half(a_frag_0, b_frag_5, acc_frag_0_5); + acc_frag_0_6 = __WMMA_8_16_16_half_half(a_frag_0, b_frag_6, acc_frag_0_6); + acc_frag_0_7 = __WMMA_8_16_16_half_half(a_frag_0, b_frag_7, acc_frag_0_7); + acc_frag_1_0 = __WMMA_8_16_16_half_half(a_frag_1, b_frag_0, acc_frag_1_0); + acc_frag_1_1 = __WMMA_8_16_16_half_half(a_frag_1, b_frag_1, acc_frag_1_1); + acc_frag_1_2 = __WMMA_8_16_16_half_half(a_frag_1, b_frag_2, acc_frag_1_2); + acc_frag_1_3 = __WMMA_8_16_16_half_half(a_frag_1, b_frag_3, acc_frag_1_3); + acc_frag_1_4 = __WMMA_8_16_16_half_half(a_frag_1, b_frag_4, acc_frag_1_4); + acc_frag_1_5 = __WMMA_8_16_16_half_half(a_frag_1, b_frag_5, acc_frag_1_5); + acc_frag_1_6 = __WMMA_8_16_16_half_half(a_frag_1, b_frag_6, acc_frag_1_6); + acc_frag_1_7 = __WMMA_8_16_16_half_half(a_frag_1, b_frag_7, acc_frag_1_7); + + // next 16 K elements + __ldmatrix_a_elems(&a_frag_0, &smem_a_curr[load_smem_a_0_k_1]); + __ldmatrix_a_elems(&a_frag_1, &smem_a_curr[load_smem_a_1_k_1]); + __ldmatrix_b_elems(&b_frag_0, &b_frag_1, &smem_b_curr[load_smem_b_0_k_1]); + __ldmatrix_b_elems(&b_frag_2, &b_frag_3, &smem_b_curr[load_smem_b_1_k_1]); + __ldmatrix_b_elems(&b_frag_4, &b_frag_5, &smem_b_curr[load_smem_b_2_k_1]); + __ldmatrix_b_elems(&b_frag_6, &b_frag_7, &smem_b_curr[load_smem_b_3_k_1]); + acc_frag_0_0 = __WMMA_8_16_16_half_half(a_frag_0, b_frag_0, acc_frag_0_0); + acc_frag_0_1 = __WMMA_8_16_16_half_half(a_frag_0, b_frag_1, acc_frag_0_1); + acc_frag_0_2 = __WMMA_8_16_16_half_half(a_frag_0, b_frag_2, acc_frag_0_2); + acc_frag_0_3 = __WMMA_8_16_16_half_half(a_frag_0, b_frag_3, acc_frag_0_3); + acc_frag_0_4 = __WMMA_8_16_16_half_half(a_frag_0, b_frag_4, acc_frag_0_4); + acc_frag_0_5 = __WMMA_8_16_16_half_half(a_frag_0, b_frag_5, acc_frag_0_5); + acc_frag_0_6 = __WMMA_8_16_16_half_half(a_frag_0, b_frag_6, acc_frag_0_6); + acc_frag_0_7 = __WMMA_8_16_16_half_half(a_frag_0, b_frag_7, acc_frag_0_7); + acc_frag_1_0 = __WMMA_8_16_16_half_half(a_frag_1, b_frag_0, acc_frag_1_0); + acc_frag_1_1 = __WMMA_8_16_16_half_half(a_frag_1, b_frag_1, acc_frag_1_1); + acc_frag_1_2 = __WMMA_8_16_16_half_half(a_frag_1, b_frag_2, acc_frag_1_2); + acc_frag_1_3 = __WMMA_8_16_16_half_half(a_frag_1, b_frag_3, acc_frag_1_3); + acc_frag_1_4 = __WMMA_8_16_16_half_half(a_frag_1, b_frag_4, acc_frag_1_4); + acc_frag_1_5 = __WMMA_8_16_16_half_half(a_frag_1, b_frag_5, acc_frag_1_5); + acc_frag_1_6 = __WMMA_8_16_16_half_half(a_frag_1, b_frag_6, acc_frag_1_6); + acc_frag_1_7 = __WMMA_8_16_16_half_half(a_frag_1, b_frag_7, acc_frag_1_7); + + // next 16 K elements + __ldmatrix_a_elems(&a_frag_0, &smem_a_curr[load_smem_a_0_k_2]); + __ldmatrix_a_elems(&a_frag_1, &smem_a_curr[load_smem_a_1_k_2]); + __ldmatrix_b_elems(&b_frag_0, &b_frag_1, &smem_b_curr[load_smem_b_0_k_2]); + __ldmatrix_b_elems(&b_frag_2, &b_frag_3, &smem_b_curr[load_smem_b_1_k_2]); + __ldmatrix_b_elems(&b_frag_4, &b_frag_5, &smem_b_curr[load_smem_b_2_k_2]); + __ldmatrix_b_elems(&b_frag_6, &b_frag_7, &smem_b_curr[load_smem_b_3_k_2]); + acc_frag_0_0 = __WMMA_8_16_16_half_half(a_frag_0, b_frag_0, acc_frag_0_0); + acc_frag_0_1 = __WMMA_8_16_16_half_half(a_frag_0, b_frag_1, acc_frag_0_1); + acc_frag_0_2 = __WMMA_8_16_16_half_half(a_frag_0, b_frag_2, acc_frag_0_2); + acc_frag_0_3 = __WMMA_8_16_16_half_half(a_frag_0, b_frag_3, acc_frag_0_3); + acc_frag_0_4 = __WMMA_8_16_16_half_half(a_frag_0, b_frag_4, acc_frag_0_4); + acc_frag_0_5 = __WMMA_8_16_16_half_half(a_frag_0, b_frag_5, acc_frag_0_5); + acc_frag_0_6 = __WMMA_8_16_16_half_half(a_frag_0, b_frag_6, acc_frag_0_6); + acc_frag_0_7 = __WMMA_8_16_16_half_half(a_frag_0, b_frag_7, acc_frag_0_7); + acc_frag_1_0 = __WMMA_8_16_16_half_half(a_frag_1, b_frag_0, acc_frag_1_0); + acc_frag_1_1 = __WMMA_8_16_16_half_half(a_frag_1, b_frag_1, acc_frag_1_1); + acc_frag_1_2 = __WMMA_8_16_16_half_half(a_frag_1, b_frag_2, acc_frag_1_2); + acc_frag_1_3 = __WMMA_8_16_16_half_half(a_frag_1, b_frag_3, acc_frag_1_3); + acc_frag_1_4 = __WMMA_8_16_16_half_half(a_frag_1, b_frag_4, acc_frag_1_4); + acc_frag_1_5 = __WMMA_8_16_16_half_half(a_frag_1, b_frag_5, acc_frag_1_5); + acc_frag_1_6 = __WMMA_8_16_16_half_half(a_frag_1, b_frag_6, acc_frag_1_6); + acc_frag_1_7 = __WMMA_8_16_16_half_half(a_frag_1, b_frag_7, acc_frag_1_7); + + // last 16 K elements + __ldmatrix_a_elems(&a_frag_0, &smem_a_curr[load_smem_a_0_k_3]); + __ldmatrix_a_elems(&a_frag_1, &smem_a_curr[load_smem_a_1_k_3]); + __ldmatrix_b_elems(&b_frag_0, &b_frag_1, &smem_b_curr[load_smem_b_0_k_3]); + __ldmatrix_b_elems(&b_frag_2, &b_frag_3, &smem_b_curr[load_smem_b_1_k_3]); + __ldmatrix_b_elems(&b_frag_4, &b_frag_5, &smem_b_curr[load_smem_b_2_k_3]); + __ldmatrix_b_elems(&b_frag_6, &b_frag_7, &smem_b_curr[load_smem_b_3_k_3]); + acc_frag_0_0 = __WMMA_8_16_16_half_half(a_frag_0, b_frag_0, acc_frag_0_0); + acc_frag_0_1 = __WMMA_8_16_16_half_half(a_frag_0, b_frag_1, acc_frag_0_1); + acc_frag_0_2 = __WMMA_8_16_16_half_half(a_frag_0, b_frag_2, acc_frag_0_2); + acc_frag_0_3 = __WMMA_8_16_16_half_half(a_frag_0, b_frag_3, acc_frag_0_3); + acc_frag_0_4 = __WMMA_8_16_16_half_half(a_frag_0, b_frag_4, acc_frag_0_4); + acc_frag_0_5 = __WMMA_8_16_16_half_half(a_frag_0, b_frag_5, acc_frag_0_5); + acc_frag_0_6 = __WMMA_8_16_16_half_half(a_frag_0, b_frag_6, acc_frag_0_6); + acc_frag_0_7 = __WMMA_8_16_16_half_half(a_frag_0, b_frag_7, acc_frag_0_7); + acc_frag_1_0 = __WMMA_8_16_16_half_half(a_frag_1, b_frag_0, acc_frag_1_0); + acc_frag_1_1 = __WMMA_8_16_16_half_half(a_frag_1, b_frag_1, acc_frag_1_1); + acc_frag_1_2 = __WMMA_8_16_16_half_half(a_frag_1, b_frag_2, acc_frag_1_2); + acc_frag_1_3 = __WMMA_8_16_16_half_half(a_frag_1, b_frag_3, acc_frag_1_3); + acc_frag_1_4 = __WMMA_8_16_16_half_half(a_frag_1, b_frag_4, acc_frag_1_4); + acc_frag_1_5 = __WMMA_8_16_16_half_half(a_frag_1, b_frag_5, acc_frag_1_5); + acc_frag_1_6 = __WMMA_8_16_16_half_half(a_frag_1, b_frag_6, acc_frag_1_6); + acc_frag_1_7 = __WMMA_8_16_16_half_half(a_frag_1, b_frag_7, acc_frag_1_7); + + // prefetch next iteration if needed + __syncthreads(); + if (block_k < (num_k_blocks-2)) { + __pipeline_memcpy_async(&smem_a_curr[store_smem_a_off + ( 0)], &data1[global_a_off + ( 0)], 16); + __pipeline_memcpy_async(&smem_a_curr[store_smem_a_off + (16*64)], &data1[global_a_off + (16*K)], 16); + __pipeline_memcpy_async(&smem_a_curr[store_smem_a_off + (32*64)], &data1[global_a_off + (32*K)], 16); + __pipeline_memcpy_async(&smem_a_curr[store_smem_a_off + (48*64)], &data1[global_a_off + (48*K)], 16); + + __pipeline_memcpy_async(&smem_b_curr[store_smem_b_off + ( 0)], &data2[global_b_off + ( 0)], 16); + __pipeline_memcpy_async(&smem_b_curr[store_smem_b_off + ( 8*128)], &data2[global_b_off + ( 8*N)], 16); + __pipeline_memcpy_async(&smem_b_curr[store_smem_b_off + (16*128)], &data2[global_b_off + (16*N)], 16); + __pipeline_memcpy_async(&smem_b_curr[store_smem_b_off + (24*128)], &data2[global_b_off + (24*N)], 16); + __pipeline_memcpy_async(&smem_b_curr[store_smem_b_off + (32*128)], &data2[global_b_off + (32*N)], 16); + __pipeline_memcpy_async(&smem_b_curr[store_smem_b_off + (40*128)], &data2[global_b_off + (40*N)], 16); + __pipeline_memcpy_async(&smem_b_curr[store_smem_b_off + (48*128)], &data2[global_b_off + (48*N)], 16); + __pipeline_memcpy_async(&smem_b_curr[store_smem_b_off + (56*128)], &data2[global_b_off + (56*N)], 16); + + global_a_off += 64; + global_b_off += 64 * N; + } + __pipeline_commit(); + + if (block_k < num_k_blocks-1) { + __pipeline_wait_prior(1); + __syncthreads(); + } + } + + // write accumulators to output + __pipeline_wait_prior(0); + __syncthreads(); + + // // store registers to smem first, then read back to do float4 writes to global + // float *smem_d = (float *)(smem); + // size_t smem_d_off = (wg_m * 16 * N_PAD) + (wg_n * 16) + ((wg_threads % 4) * 2) + (((wg_threads / 4) % 8) * N_PAD); + // smem_d[smem_d_off + 0 + ( 0*8) ] = acc_frag_0_0.x; + // smem_d[smem_d_off + 1 + ( 0*8) ] = acc_frag_0_0.y; + // smem_d[smem_d_off + 0 + ( 0*8) + (8*N_PAD)] = acc_frag_0_0.z; + // smem_d[smem_d_off + 1 + ( 0*8) + (8*N_PAD)] = acc_frag_0_0.w; + // smem_d[smem_d_off + 0 + ( 1*8) ] = acc_frag_0_1.x; + // smem_d[smem_d_off + 1 + ( 1*8) ] = acc_frag_0_1.y; + // smem_d[smem_d_off + 0 + ( 1*8) + (8*N_PAD)] = acc_frag_0_1.z; + // smem_d[smem_d_off + 1 + ( 1*8) + (8*N_PAD)] = acc_frag_0_1.w; + // smem_d[smem_d_off + 0 + ( 4*8) ] = acc_frag_0_2.x; + // smem_d[smem_d_off + 1 + ( 4*8) ] = acc_frag_0_2.y; + // smem_d[smem_d_off + 0 + ( 4*8) + (8*N_PAD)] = acc_frag_0_2.z; + // smem_d[smem_d_off + 1 + ( 4*8) + (8*N_PAD)] = acc_frag_0_2.w; + // smem_d[smem_d_off + 0 + ( 5*8) ] = acc_frag_0_3.x; + // smem_d[smem_d_off + 1 + ( 5*8) ] = acc_frag_0_3.y; + // smem_d[smem_d_off + 0 + ( 5*8) + (8*N_PAD)] = acc_frag_0_3.z; + // smem_d[smem_d_off + 1 + ( 5*8) + (8*N_PAD)] = acc_frag_0_3.w; + // smem_d[smem_d_off + 0 + ( 8*8) ] = acc_frag_0_4.x; + // smem_d[smem_d_off + 1 + ( 8*8) ] = acc_frag_0_4.y; + // smem_d[smem_d_off + 0 + ( 8*8) + (8*N_PAD)] = acc_frag_0_4.z; + // smem_d[smem_d_off + 1 + ( 8*8) + (8*N_PAD)] = acc_frag_0_4.w; + // smem_d[smem_d_off + 0 + ( 9*8) ] = acc_frag_0_5.x; + // smem_d[smem_d_off + 1 + ( 9*8) ] = acc_frag_0_5.y; + // smem_d[smem_d_off + 0 + ( 9*8) + (8*N_PAD)] = acc_frag_0_5.z; + // smem_d[smem_d_off + 1 + ( 9*8) + (8*N_PAD)] = acc_frag_0_5.w; + // smem_d[smem_d_off + 0 + (12*8) ] = acc_frag_0_6.x; + // smem_d[smem_d_off + 1 + (12*8) ] = acc_frag_0_6.y; + // smem_d[smem_d_off + 0 + (12*8) + (8*N_PAD)] = acc_frag_0_6.z; + // smem_d[smem_d_off + 1 + (12*8) + (8*N_PAD)] = acc_frag_0_6.w; + // smem_d[smem_d_off + 0 + (13*8) ] = acc_frag_0_7.x; + // smem_d[smem_d_off + 1 + (13*8) ] = acc_frag_0_7.y; + // smem_d[smem_d_off + 0 + (13*8) + (8*N_PAD)] = acc_frag_0_7.z; + // smem_d[smem_d_off + 1 + (13*8) + (8*N_PAD)] = acc_frag_0_7.w; + + // __syncthreads(); + // size_t load_smem_d_off = ((threads % 32) * 4) + ((threads / 32) * N_PAD); + // float4 d_0_0 = *((float4 *)(smem_d + load_smem_d_off + ( 0 * N_PAD))); + // float4 d_0_1 = *((float4 *)(smem_d + load_smem_d_off + ( 4 * N_PAD))); + // float4 d_0_2 = *((float4 *)(smem_d + load_smem_d_off + ( 8 * N_PAD))); + // float4 d_0_3 = *((float4 *)(smem_d + load_smem_d_off + (12 * N_PAD))); + // float4 d_0_4 = *((float4 *)(smem_d + load_smem_d_off + (16 * N_PAD))); + // float4 d_0_5 = *((float4 *)(smem_d + load_smem_d_off + (20 * N_PAD))); + // float4 d_0_6 = *((float4 *)(smem_d + load_smem_d_off + (24 * N_PAD))); + // float4 d_0_7 = *((float4 *)(smem_d + load_smem_d_off + (28 * N_PAD))); + + // __syncthreads(); + // smem_d[smem_d_off + 0 + ( 0*8) ] = acc_frag_1_0.x; + // smem_d[smem_d_off + 1 + ( 0*8) ] = acc_frag_1_0.y; + // smem_d[smem_d_off + 0 + ( 0*8) + (8*N_PAD)] = acc_frag_1_0.z; + // smem_d[smem_d_off + 1 + ( 0*8) + (8*N_PAD)] = acc_frag_1_0.w; + // smem_d[smem_d_off + 0 + ( 1*8) ] = acc_frag_1_1.x; + // smem_d[smem_d_off + 1 + ( 1*8) ] = acc_frag_1_1.y; + // smem_d[smem_d_off + 0 + ( 1*8) + (8*N_PAD)] = acc_frag_1_1.z; + // smem_d[smem_d_off + 1 + ( 1*8) + (8*N_PAD)] = acc_frag_1_1.w; + // smem_d[smem_d_off + 0 + ( 4*8) ] = acc_frag_1_2.x; + // smem_d[smem_d_off + 1 + ( 4*8) ] = acc_frag_1_2.y; + // smem_d[smem_d_off + 0 + ( 4*8) + (8*N_PAD)] = acc_frag_1_2.z; + // smem_d[smem_d_off + 1 + ( 4*8) + (8*N_PAD)] = acc_frag_1_2.w; + // smem_d[smem_d_off + 0 + ( 5*8) ] = acc_frag_1_3.x; + // smem_d[smem_d_off + 1 + ( 5*8) ] = acc_frag_1_3.y; + // smem_d[smem_d_off + 0 + ( 5*8) + (8*N_PAD)] = acc_frag_1_3.z; + // smem_d[smem_d_off + 1 + ( 5*8) + (8*N_PAD)] = acc_frag_1_3.w; + // smem_d[smem_d_off + 0 + ( 8*8) ] = acc_frag_1_4.x; + // smem_d[smem_d_off + 1 + ( 8*8) ] = acc_frag_1_4.y; + // smem_d[smem_d_off + 0 + ( 8*8) + (8*N_PAD)] = acc_frag_1_4.z; + // smem_d[smem_d_off + 1 + ( 8*8) + (8*N_PAD)] = acc_frag_1_4.w; + // smem_d[smem_d_off + 0 + ( 9*8) ] = acc_frag_1_5.x; + // smem_d[smem_d_off + 1 + ( 9*8) ] = acc_frag_1_5.y; + // smem_d[smem_d_off + 0 + ( 9*8) + (8*N_PAD)] = acc_frag_1_5.z; + // smem_d[smem_d_off + 1 + ( 9*8) + (8*N_PAD)] = acc_frag_1_5.w; + // smem_d[smem_d_off + 0 + (12*8) ] = acc_frag_1_6.x; + // smem_d[smem_d_off + 1 + (12*8) ] = acc_frag_1_6.y; + // smem_d[smem_d_off + 0 + (12*8) + (8*N_PAD)] = acc_frag_1_6.z; + // smem_d[smem_d_off + 1 + (12*8) + (8*N_PAD)] = acc_frag_1_6.w; + // smem_d[smem_d_off + 0 + (13*8) ] = acc_frag_1_7.x; + // smem_d[smem_d_off + 1 + (13*8) ] = acc_frag_1_7.y; + // smem_d[smem_d_off + 0 + (13*8) + (8*N_PAD)] = acc_frag_1_7.z; + // smem_d[smem_d_off + 1 + (13*8) + (8*N_PAD)] = acc_frag_1_7.w; + + // __syncthreads(); + // float4 d_1_0 = *((float4 *)(smem_d + load_smem_d_off + ( 0 * N_PAD))); + // float4 d_1_1 = *((float4 *)(smem_d + load_smem_d_off + ( 4 * N_PAD))); + // float4 d_1_2 = *((float4 *)(smem_d + load_smem_d_off + ( 8 * N_PAD))); + // float4 d_1_3 = *((float4 *)(smem_d + load_smem_d_off + (12 * N_PAD))); + // float4 d_1_4 = *((float4 *)(smem_d + load_smem_d_off + (16 * N_PAD))); + // float4 d_1_5 = *((float4 *)(smem_d + load_smem_d_off + (20 * N_PAD))); + // float4 d_1_6 = *((float4 *)(smem_d + load_smem_d_off + (24 * N_PAD))); + // float4 d_1_7 = *((float4 *)(smem_d + load_smem_d_off + (28 * N_PAD))); + + // __syncthreads(); + // float *global_d = &data0[((grid_m * 64) * N) + (grid_n * 128) + ((threads % 32) * 4) + ((threads / 32) * N)]; + // *((float4 *)(global_d + 0*N)) = d_0_0; + // *((float4 *)(global_d + 4*N)) = d_0_1; + // *((float4 *)(global_d + 8*N)) = d_0_2; + // *((float4 *)(global_d + 12*N)) = d_0_3; + // *((float4 *)(global_d + 16*N)) = d_0_4; + // *((float4 *)(global_d + 20*N)) = d_0_5; + // *((float4 *)(global_d + 24*N)) = d_0_6; + // *((float4 *)(global_d + 28*N)) = d_0_7; + // *((float4 *)(global_d + 32*N)) = d_1_0; + // *((float4 *)(global_d + 36*N)) = d_1_1; + // *((float4 *)(global_d + 40*N)) = d_1_2; + // *((float4 *)(global_d + 44*N)) = d_1_3; + // *((float4 *)(global_d + 48*N)) = d_1_4; + // *((float4 *)(global_d + 52*N)) = d_1_5; + // *((float4 *)(global_d + 56*N)) = d_1_6; + // *((float4 *)(global_d + 60*N)) = d_1_7; + + // slower way: write floats one by one to data0 + size_t wg_c_off = ((grid_m * 64) * N) + (grid_n * 128) + (wg_m * 16 * N) + (wg_n * 16); + size_t thread_c_off = ((wg_threads % 4) * 2) + (((wg_threads / 4) % 8) * N); + data0[wg_c_off + thread_c_off + 0 + ( 0*8)] = acc_frag_0_0.x; + data0[wg_c_off + thread_c_off + 1 + ( 0*8)] = acc_frag_0_0.y; + data0[wg_c_off + thread_c_off + (8 * N) + 0 + ( 0*8)] = acc_frag_0_0.z; + data0[wg_c_off + thread_c_off + (8 * N) + 1 + ( 0*8)] = acc_frag_0_0.w; + data0[wg_c_off + thread_c_off + 0 + ( 1*8)] = acc_frag_0_1.x; + data0[wg_c_off + thread_c_off + 1 + ( 1*8)] = acc_frag_0_1.y; + data0[wg_c_off + thread_c_off + (8 * N) + 0 + ( 1*8)] = acc_frag_0_1.z; + data0[wg_c_off + thread_c_off + (8 * N) + 1 + ( 1*8)] = acc_frag_0_1.w; + data0[wg_c_off + thread_c_off + 0 + ( 4*8)] = acc_frag_0_2.x; + data0[wg_c_off + thread_c_off + 1 + ( 4*8)] = acc_frag_0_2.y; + data0[wg_c_off + thread_c_off + (8 * N) + 0 + ( 4*8)] = acc_frag_0_2.z; + data0[wg_c_off + thread_c_off + (8 * N) + 1 + ( 4*8)] = acc_frag_0_2.w; + data0[wg_c_off + thread_c_off + 0 + ( 5*8)] = acc_frag_0_3.x; + data0[wg_c_off + thread_c_off + 1 + ( 5*8)] = acc_frag_0_3.y; + data0[wg_c_off + thread_c_off + (8 * N) + 0 + ( 5*8)] = acc_frag_0_3.z; + data0[wg_c_off + thread_c_off + (8 * N) + 1 + ( 5*8)] = acc_frag_0_3.w; + data0[wg_c_off + thread_c_off + 0 + ( 8*8)] = acc_frag_0_4.x; + data0[wg_c_off + thread_c_off + 1 + ( 8*8)] = acc_frag_0_4.y; + data0[wg_c_off + thread_c_off + (8 * N) + 0 + ( 8*8)] = acc_frag_0_4.z; + data0[wg_c_off + thread_c_off + (8 * N) + 1 + ( 8*8)] = acc_frag_0_4.w; + data0[wg_c_off + thread_c_off + 0 + ( 9*8)] = acc_frag_0_5.x; + data0[wg_c_off + thread_c_off + 1 + ( 9*8)] = acc_frag_0_5.y; + data0[wg_c_off + thread_c_off + (8 * N) + 0 + ( 9*8)] = acc_frag_0_5.z; + data0[wg_c_off + thread_c_off + (8 * N) + 1 + ( 9*8)] = acc_frag_0_5.w; + data0[wg_c_off + thread_c_off + 0 + (12*8)] = acc_frag_0_6.x; + data0[wg_c_off + thread_c_off + 1 + (12*8)] = acc_frag_0_6.y; + data0[wg_c_off + thread_c_off + (8 * N) + 0 + (12*8)] = acc_frag_0_6.z; + data0[wg_c_off + thread_c_off + (8 * N) + 1 + (12*8)] = acc_frag_0_6.w; + data0[wg_c_off + thread_c_off + 0 + (13*8)] = acc_frag_0_7.x; + data0[wg_c_off + thread_c_off + 1 + (13*8)] = acc_frag_0_7.y; + data0[wg_c_off + thread_c_off + (8 * N) + 0 + (13*8)] = acc_frag_0_7.z; + data0[wg_c_off + thread_c_off + (8 * N) + 1 + (13*8)] = acc_frag_0_7.w; + wg_c_off += 32*N; + data0[wg_c_off + thread_c_off + 0 + ( 0*8)] = acc_frag_1_0.x; + data0[wg_c_off + thread_c_off + 1 + ( 0*8)] = acc_frag_1_0.y; + data0[wg_c_off + thread_c_off + (8 * N) + 0 + ( 0*8)] = acc_frag_1_0.z; + data0[wg_c_off + thread_c_off + (8 * N) + 1 + ( 0*8)] = acc_frag_1_0.w; + data0[wg_c_off + thread_c_off + 0 + ( 1*8)] = acc_frag_1_1.x; + data0[wg_c_off + thread_c_off + 1 + ( 1*8)] = acc_frag_1_1.y; + data0[wg_c_off + thread_c_off + (8 * N) + 0 + ( 1*8)] = acc_frag_1_1.z; + data0[wg_c_off + thread_c_off + (8 * N) + 1 + ( 1*8)] = acc_frag_1_1.w; + data0[wg_c_off + thread_c_off + 0 + ( 4*8)] = acc_frag_1_2.x; + data0[wg_c_off + thread_c_off + 1 + ( 4*8)] = acc_frag_1_2.y; + data0[wg_c_off + thread_c_off + (8 * N) + 0 + ( 4*8)] = acc_frag_1_2.z; + data0[wg_c_off + thread_c_off + (8 * N) + 1 + ( 4*8)] = acc_frag_1_2.w; + data0[wg_c_off + thread_c_off + 0 + ( 5*8)] = acc_frag_1_3.x; + data0[wg_c_off + thread_c_off + 1 + ( 5*8)] = acc_frag_1_3.y; + data0[wg_c_off + thread_c_off + (8 * N) + 0 + ( 5*8)] = acc_frag_1_3.z; + data0[wg_c_off + thread_c_off + (8 * N) + 1 + ( 5*8)] = acc_frag_1_3.w; + data0[wg_c_off + thread_c_off + 0 + ( 8*8)] = acc_frag_1_4.x; + data0[wg_c_off + thread_c_off + 1 + ( 8*8)] = acc_frag_1_4.y; + data0[wg_c_off + thread_c_off + (8 * N) + 0 + ( 8*8)] = acc_frag_1_4.z; + data0[wg_c_off + thread_c_off + (8 * N) + 1 + ( 8*8)] = acc_frag_1_4.w; + data0[wg_c_off + thread_c_off + 0 + ( 9*8)] = acc_frag_1_5.x; + data0[wg_c_off + thread_c_off + 1 + ( 9*8)] = acc_frag_1_5.y; + data0[wg_c_off + thread_c_off + (8 * N) + 0 + ( 9*8)] = acc_frag_1_5.z; + data0[wg_c_off + thread_c_off + (8 * N) + 1 + ( 9*8)] = acc_frag_1_5.w; + data0[wg_c_off + thread_c_off + 0 + (12*8)] = acc_frag_1_6.x; + data0[wg_c_off + thread_c_off + 1 + (12*8)] = acc_frag_1_6.y; + data0[wg_c_off + thread_c_off + (8 * N) + 0 + (12*8)] = acc_frag_1_6.z; + data0[wg_c_off + thread_c_off + (8 * N) + 1 + (12*8)] = acc_frag_1_6.w; + data0[wg_c_off + thread_c_off + 0 + (13*8)] = acc_frag_1_7.x; + data0[wg_c_off + thread_c_off + 1 + (13*8)] = acc_frag_1_7.y; + data0[wg_c_off + thread_c_off + (8 * N) + 0 + (13*8)] = acc_frag_1_7.z; + data0[wg_c_off + thread_c_off + (8 * N) + 1 + (13*8)] = acc_frag_1_7.w; +} diff --git a/extra/gemm/max_kernels/nv.fp16_fp16_fp16.3_stage.cu b/extra/gemm/max_kernels/nv.fp16_fp16_fp16.3_stage.cu new file mode 100644 index 0000000000..206fd9e32d --- /dev/null +++ b/extra/gemm/max_kernels/nv.fp16_fp16_fp16.3_stage.cu @@ -0,0 +1,465 @@ +#define INFINITY (__int_as_float(0x7f800000)) +#define NAN (__int_as_float(0x7fffffff)) +#include +#include +#define N_PAD 132 + +struct __align__(8) half4 { half x, y, z, w; }; +__device__ half4 make_half4(half x, half y, half z, half w) { half4 r={x, y, z, w}; return r; } + +struct __align__(16) half8 { half x, y, z, w, a, b, c, d; }; +__device__ half8 make_half8(half x, half y, half z, half w, half a, half b, half c, half d) { half8 r={x, y, z, w, a, b, c, d}; return r; } + +__device__ void __ldmatrix_a_elems(half8 *regs, half *smem) { + uint32_t reg0, reg1, reg2, reg3; + asm volatile( + "ldmatrix.sync.aligned.m8n8.x4.shared.b16 {%0, %1, %2, %3}, [%4];" + : "=r"(reg0), "=r"(reg1), "=r"(reg2), "=r"(reg3) + : "l"(__cvta_generic_to_shared(smem)) + ); + uint32_t *addr = reinterpret_cast(regs); + addr[0] = reg0; + addr[1] = reg1; + addr[2] = reg2; + addr[3] = reg3; +} + +__device__ void __ldmatrix_b_elems(half4 *regs_lo, half4 *regs_hi, half *smem) { + uint32_t reg0, reg1, reg2, reg3; + asm volatile( + "ldmatrix.sync.aligned.m8n8.x4.trans.shared.b16 {%0, %1, %2, %3}, [%4];" + : "=r"(reg0), "=r"(reg1), "=r"(reg2), "=r"(reg3) + : "l"(__cvta_generic_to_shared(smem)) + ); + uint32_t *addr_lo = reinterpret_cast(regs_lo); + uint32_t *addr_hi = reinterpret_cast(regs_hi); + addr_lo[0] = reg0; + addr_lo[1] = reg1; + addr_hi[0] = reg2; + addr_hi[1] = reg3; +} + +__device__ half4 __WMMA_8_16_16_half_half(half8 a, half4 b, half4 c) { + int *a_pk = (int *) (&a), *b_pk = (int *) (&b), *c_pk = (int *) (&c); + asm( "mma.sync.aligned.m16n8k16.row.col.f16.f16.f16.f16 { %0, %1 }, { %2, %3, %4, %5 }, { %6, %7 }, { %0, %1 };" + : "+r"(c_pk[0]), "+r"(c_pk[1]): "r"(a_pk[0]), "r"(a_pk[1]), "r"(a_pk[2]), "r"(a_pk[3]), "r"(b_pk[0]), "r"(b_pk[1]) ); + return c; +} + +extern "C" __global__ void __launch_bounds__(256) wmma_example(half* data0, const half* data1, const half* data2, int N, int K) { + extern __shared__ char smem[]; + half *smem_a_0 = (half *)(smem); + half *smem_a_1 = (half *)(smem + 16384); + half *smem_a_2 = (half *)(smem + 32768); + half *smem_b_0 = (half *)(smem + 49152); + half *smem_b_1 = (half *)(smem + 57344); + half *smem_b_2 = (half *)(smem + 65536); + + int grid_m = blockIdx.x; /* M//256 */ + int grid_n = blockIdx.y; /* N//128 */ + int wg_threads = threadIdx.x; // 32 + int wg_m = threadIdx.y; // 4 + int wg_n = threadIdx.z; // 2 + int threads = threadIdx.x + (threadIdx.y * 32) + (threadIdx.z * 128); /* 256 */ + int num_k_blocks = K / 32; + + // load indexes + size_t global_a_off = ((grid_m * 256) * K) + ((threads % 4) * 8) + ((threads / 4) * K); + size_t global_b_off = (grid_n * 128) + ((threads % 16) * 8) + ((threads / 16) * N); + + // unswizzed smem store + size_t store_smem_a_off = ((threads % 4) * 8) + ((threads / 4) * 32); // 64 rows / 32 cols per copy + size_t store_smem_b_off = ((threads % 16) * 8) + ((threads / 16) * 128); // 16 rows / 128 cols per copy + + // ldmatrix indices + // threads 0-7 are row starts for A, 8-15 for B, 16-23 for C, 24-31 for D + // [ A | C ] + // [ - + - ] + // [ B | D ] + + // unswizzed ldmatrix + size_t load_smem_a_0_k_0 = (wg_m * 16 * 32) + ((wg_threads % 16) * 32) + ((wg_threads / 16) * 8); + size_t load_smem_a_1_k_0 = load_smem_a_0_k_0 + ( 64 * 32); + size_t load_smem_a_2_k_0 = load_smem_a_0_k_0 + (128 * 32); + size_t load_smem_a_3_k_0 = load_smem_a_0_k_0 + (192 * 32); + + size_t load_smem_a_0_k_1 = load_smem_a_0_k_0 + 16; + size_t load_smem_a_1_k_1 = load_smem_a_0_k_1 + ( 64 * 32); + size_t load_smem_a_2_k_1 = load_smem_a_0_k_1 + (128 * 32); + size_t load_smem_a_3_k_1 = load_smem_a_0_k_1 + (192 * 32); + + size_t load_smem_b_0_k_0 = (wg_n * 16) + ((wg_threads % 16) * 128) + ((wg_threads / 16) * 8); + size_t load_smem_b_1_k_0 = load_smem_b_0_k_0 + 32; + size_t load_smem_b_2_k_0 = load_smem_b_0_k_0 + 64; + size_t load_smem_b_3_k_0 = load_smem_b_0_k_0 + 96; + + size_t load_smem_b_0_k_1 = load_smem_b_0_k_0 + (16 * 128); + size_t load_smem_b_1_k_1 = load_smem_b_0_k_1 + 32; + size_t load_smem_b_2_k_1 = load_smem_b_0_k_1 + 64; + size_t load_smem_b_3_k_1 = load_smem_b_0_k_1 + 96; + + // create accs (M=4, N=8) + half4 acc_frag_0_0 = make_half4(0.0f,0.0f,0.0f,0.0f); + half4 acc_frag_0_1 = make_half4(0.0f,0.0f,0.0f,0.0f); + half4 acc_frag_0_2 = make_half4(0.0f,0.0f,0.0f,0.0f); + half4 acc_frag_0_3 = make_half4(0.0f,0.0f,0.0f,0.0f); + half4 acc_frag_0_4 = make_half4(0.0f,0.0f,0.0f,0.0f); + half4 acc_frag_0_5 = make_half4(0.0f,0.0f,0.0f,0.0f); + half4 acc_frag_0_6 = make_half4(0.0f,0.0f,0.0f,0.0f); + half4 acc_frag_0_7 = make_half4(0.0f,0.0f,0.0f,0.0f); + half4 acc_frag_1_0 = make_half4(0.0f,0.0f,0.0f,0.0f); + half4 acc_frag_1_1 = make_half4(0.0f,0.0f,0.0f,0.0f); + half4 acc_frag_1_2 = make_half4(0.0f,0.0f,0.0f,0.0f); + half4 acc_frag_1_3 = make_half4(0.0f,0.0f,0.0f,0.0f); + half4 acc_frag_1_4 = make_half4(0.0f,0.0f,0.0f,0.0f); + half4 acc_frag_1_5 = make_half4(0.0f,0.0f,0.0f,0.0f); + half4 acc_frag_1_6 = make_half4(0.0f,0.0f,0.0f,0.0f); + half4 acc_frag_1_7 = make_half4(0.0f,0.0f,0.0f,0.0f); + half4 acc_frag_2_0 = make_half4(0.0f,0.0f,0.0f,0.0f); + half4 acc_frag_2_1 = make_half4(0.0f,0.0f,0.0f,0.0f); + half4 acc_frag_2_2 = make_half4(0.0f,0.0f,0.0f,0.0f); + half4 acc_frag_2_3 = make_half4(0.0f,0.0f,0.0f,0.0f); + half4 acc_frag_2_4 = make_half4(0.0f,0.0f,0.0f,0.0f); + half4 acc_frag_2_5 = make_half4(0.0f,0.0f,0.0f,0.0f); + half4 acc_frag_2_6 = make_half4(0.0f,0.0f,0.0f,0.0f); + half4 acc_frag_2_7 = make_half4(0.0f,0.0f,0.0f,0.0f); + half4 acc_frag_3_0 = make_half4(0.0f,0.0f,0.0f,0.0f); + half4 acc_frag_3_1 = make_half4(0.0f,0.0f,0.0f,0.0f); + half4 acc_frag_3_2 = make_half4(0.0f,0.0f,0.0f,0.0f); + half4 acc_frag_3_3 = make_half4(0.0f,0.0f,0.0f,0.0f); + half4 acc_frag_3_4 = make_half4(0.0f,0.0f,0.0f,0.0f); + half4 acc_frag_3_5 = make_half4(0.0f,0.0f,0.0f,0.0f); + half4 acc_frag_3_6 = make_half4(0.0f,0.0f,0.0f,0.0f); + half4 acc_frag_3_7 = make_half4(0.0f,0.0f,0.0f,0.0f); + + // create registers for block A elements + half8 a_frag_0_k_0; + half8 a_frag_1_k_0; + half8 a_frag_2_k_0; + half8 a_frag_3_k_0; + half8 a_frag_0_k_1; + half8 a_frag_1_k_1; + half8 a_frag_2_k_1; + half8 a_frag_3_k_1; + + // create register for block B elements + half4 b_frag_0_k_0; + half4 b_frag_1_k_0; + half4 b_frag_2_k_0; + half4 b_frag_3_k_0; + half4 b_frag_4_k_0; + half4 b_frag_5_k_0; + half4 b_frag_6_k_0; + half4 b_frag_7_k_0; + half4 b_frag_0_k_1; + half4 b_frag_1_k_1; + half4 b_frag_2_k_1; + half4 b_frag_3_k_1; + half4 b_frag_4_k_1; + half4 b_frag_5_k_1; + half4 b_frag_6_k_1; + half4 b_frag_7_k_1; + + __syncthreads(); + + // load first tile + __pipeline_memcpy_async(&smem_a_0[store_smem_a_off + ( 0)], &data1[global_a_off + ( 0)], 16); + __pipeline_memcpy_async(&smem_a_0[store_smem_a_off + ( 64*32)], &data1[global_a_off + ( 64*K)], 16); + __pipeline_memcpy_async(&smem_a_0[store_smem_a_off + (128*32)], &data1[global_a_off + (128*K)], 16); + __pipeline_memcpy_async(&smem_a_0[store_smem_a_off + (192*32)], &data1[global_a_off + (192*K)], 16); + __pipeline_memcpy_async(&smem_b_0[store_smem_b_off + ( 0)], &data2[global_b_off + ( 0)], 16); + __pipeline_memcpy_async(&smem_b_0[store_smem_b_off + (16*128)], &data2[global_b_off + ( 16*N)], 16); + __pipeline_commit(); + + global_a_off += 32; + global_b_off += 32 * N; + + // load second tile + __pipeline_memcpy_async(&smem_a_1[store_smem_a_off + ( 0)], &data1[global_a_off + ( 0)], 16); + __pipeline_memcpy_async(&smem_a_1[store_smem_a_off + ( 64*32)], &data1[global_a_off + ( 64*K)], 16); + __pipeline_memcpy_async(&smem_a_1[store_smem_a_off + (128*32)], &data1[global_a_off + (128*K)], 16); + __pipeline_memcpy_async(&smem_a_1[store_smem_a_off + (192*32)], &data1[global_a_off + (192*K)], 16); + __pipeline_memcpy_async(&smem_b_1[store_smem_b_off + ( 0)], &data2[global_b_off + ( 0)], 16); + __pipeline_memcpy_async(&smem_b_1[store_smem_b_off + (16*128)], &data2[global_b_off + ( 16*N)], 16); + __pipeline_commit(); + + global_a_off += 32; + global_b_off += 32 * N; + + // wait on first pre-fetch load + __pipeline_wait_prior(1); + __syncthreads(); + + // load K=0 for the first tile + __ldmatrix_a_elems(&a_frag_0_k_0, &smem_a_0[load_smem_a_0_k_0]); + __ldmatrix_a_elems(&a_frag_1_k_0, &smem_a_0[load_smem_a_1_k_0]); + __ldmatrix_a_elems(&a_frag_2_k_0, &smem_a_0[load_smem_a_2_k_0]); + __ldmatrix_a_elems(&a_frag_3_k_0, &smem_a_0[load_smem_a_3_k_0]); + __ldmatrix_b_elems(&b_frag_0_k_0, &b_frag_1_k_0, &smem_b_0[load_smem_b_0_k_0]); + __ldmatrix_b_elems(&b_frag_2_k_0, &b_frag_3_k_0, &smem_b_0[load_smem_b_1_k_0]); + __ldmatrix_b_elems(&b_frag_4_k_0, &b_frag_5_k_0, &smem_b_0[load_smem_b_2_k_0]); + __ldmatrix_b_elems(&b_frag_6_k_0, &b_frag_7_k_0, &smem_b_0[load_smem_b_3_k_0]); + + for (int block_k = 0; block_k < num_k_blocks; block_k++) { + int phase_k = block_k % 3; + half *smem_a_curr = (phase_k == 0) ? smem_a_0 : ((phase_k == 1) ? smem_a_1 : smem_a_2); + half *smem_b_curr = (phase_k == 0) ? smem_b_0 : ((phase_k == 1) ? smem_b_1 : smem_b_2); + + int next_phase_k = (block_k+1) % 3; + half *smem_a_next = (next_phase_k == 0) ? smem_a_0 : ((next_phase_k == 1) ? smem_a_1 : smem_a_2); + half *smem_b_next = (next_phase_k == 0) ? smem_b_0 : ((next_phase_k == 1) ? smem_b_1 : smem_b_2); + + int store_phase_k = (block_k+2) % 3; + half *smem_a_store = (store_phase_k == 0) ? smem_a_0 : ((store_phase_k == 1) ? smem_a_1 : smem_a_2); + half *smem_b_store = (store_phase_k == 0) ? smem_b_0 : ((store_phase_k == 1) ? smem_b_1 : smem_b_2); + + // load K=1 elements for the current tile + __ldmatrix_a_elems(&a_frag_0_k_1, &smem_a_curr[load_smem_a_0_k_1]); + __ldmatrix_a_elems(&a_frag_1_k_1, &smem_a_curr[load_smem_a_1_k_1]); + __ldmatrix_a_elems(&a_frag_2_k_1, &smem_a_curr[load_smem_a_2_k_1]); + __ldmatrix_a_elems(&a_frag_3_k_1, &smem_a_curr[load_smem_a_3_k_1]); + __ldmatrix_b_elems(&b_frag_0_k_1, &b_frag_1_k_1, &smem_b_curr[load_smem_b_0_k_1]); + __ldmatrix_b_elems(&b_frag_2_k_1, &b_frag_3_k_1, &smem_b_curr[load_smem_b_1_k_1]); + __ldmatrix_b_elems(&b_frag_4_k_1, &b_frag_5_k_1, &smem_b_curr[load_smem_b_2_k_1]); + __ldmatrix_b_elems(&b_frag_6_k_1, &b_frag_7_k_1, &smem_b_curr[load_smem_b_3_k_1]); + + // MMA K=0, (M=4 x N=8) + acc_frag_0_0 = __WMMA_8_16_16_half_half(a_frag_0_k_0, b_frag_0_k_0, acc_frag_0_0); + acc_frag_0_1 = __WMMA_8_16_16_half_half(a_frag_0_k_0, b_frag_1_k_0, acc_frag_0_1); + acc_frag_0_2 = __WMMA_8_16_16_half_half(a_frag_0_k_0, b_frag_2_k_0, acc_frag_0_2); + acc_frag_0_3 = __WMMA_8_16_16_half_half(a_frag_0_k_0, b_frag_3_k_0, acc_frag_0_3); + acc_frag_0_4 = __WMMA_8_16_16_half_half(a_frag_0_k_0, b_frag_4_k_0, acc_frag_0_4); + acc_frag_0_5 = __WMMA_8_16_16_half_half(a_frag_0_k_0, b_frag_5_k_0, acc_frag_0_5); + acc_frag_0_6 = __WMMA_8_16_16_half_half(a_frag_0_k_0, b_frag_6_k_0, acc_frag_0_6); + acc_frag_0_7 = __WMMA_8_16_16_half_half(a_frag_0_k_0, b_frag_7_k_0, acc_frag_0_7); + acc_frag_1_0 = __WMMA_8_16_16_half_half(a_frag_1_k_0, b_frag_0_k_0, acc_frag_1_0); + acc_frag_1_1 = __WMMA_8_16_16_half_half(a_frag_1_k_0, b_frag_1_k_0, acc_frag_1_1); + acc_frag_1_2 = __WMMA_8_16_16_half_half(a_frag_1_k_0, b_frag_2_k_0, acc_frag_1_2); + acc_frag_1_3 = __WMMA_8_16_16_half_half(a_frag_1_k_0, b_frag_3_k_0, acc_frag_1_3); + acc_frag_1_4 = __WMMA_8_16_16_half_half(a_frag_1_k_0, b_frag_4_k_0, acc_frag_1_4); + acc_frag_1_5 = __WMMA_8_16_16_half_half(a_frag_1_k_0, b_frag_5_k_0, acc_frag_1_5); + acc_frag_1_6 = __WMMA_8_16_16_half_half(a_frag_1_k_0, b_frag_6_k_0, acc_frag_1_6); + acc_frag_1_7 = __WMMA_8_16_16_half_half(a_frag_1_k_0, b_frag_7_k_0, acc_frag_1_7); + acc_frag_2_0 = __WMMA_8_16_16_half_half(a_frag_2_k_0, b_frag_0_k_0, acc_frag_2_0); + acc_frag_2_1 = __WMMA_8_16_16_half_half(a_frag_2_k_0, b_frag_1_k_0, acc_frag_2_1); + acc_frag_2_2 = __WMMA_8_16_16_half_half(a_frag_2_k_0, b_frag_2_k_0, acc_frag_2_2); + acc_frag_2_3 = __WMMA_8_16_16_half_half(a_frag_2_k_0, b_frag_3_k_0, acc_frag_2_3); + acc_frag_2_4 = __WMMA_8_16_16_half_half(a_frag_2_k_0, b_frag_4_k_0, acc_frag_2_4); + acc_frag_2_5 = __WMMA_8_16_16_half_half(a_frag_2_k_0, b_frag_5_k_0, acc_frag_2_5); + acc_frag_2_6 = __WMMA_8_16_16_half_half(a_frag_2_k_0, b_frag_6_k_0, acc_frag_2_6); + acc_frag_2_7 = __WMMA_8_16_16_half_half(a_frag_2_k_0, b_frag_7_k_0, acc_frag_2_7); + acc_frag_3_0 = __WMMA_8_16_16_half_half(a_frag_3_k_0, b_frag_0_k_0, acc_frag_3_0); + acc_frag_3_1 = __WMMA_8_16_16_half_half(a_frag_3_k_0, b_frag_1_k_0, acc_frag_3_1); + acc_frag_3_2 = __WMMA_8_16_16_half_half(a_frag_3_k_0, b_frag_2_k_0, acc_frag_3_2); + acc_frag_3_3 = __WMMA_8_16_16_half_half(a_frag_3_k_0, b_frag_3_k_0, acc_frag_3_3); + acc_frag_3_4 = __WMMA_8_16_16_half_half(a_frag_3_k_0, b_frag_4_k_0, acc_frag_3_4); + acc_frag_3_5 = __WMMA_8_16_16_half_half(a_frag_3_k_0, b_frag_5_k_0, acc_frag_3_5); + acc_frag_3_6 = __WMMA_8_16_16_half_half(a_frag_3_k_0, b_frag_6_k_0, acc_frag_3_6); + acc_frag_3_7 = __WMMA_8_16_16_half_half(a_frag_3_k_0, b_frag_7_k_0, acc_frag_3_7); + + // load next tile + if (block_k < (num_k_blocks-2)) { + __pipeline_memcpy_async(&smem_a_store[store_smem_a_off + ( 0)], &data1[global_a_off + ( 0)], 16); + __pipeline_memcpy_async(&smem_a_store[store_smem_a_off + ( 64*32)], &data1[global_a_off + ( 64*K)], 16); + __pipeline_memcpy_async(&smem_a_store[store_smem_a_off + (128*32)], &data1[global_a_off + (128*K)], 16); + __pipeline_memcpy_async(&smem_a_store[store_smem_a_off + (192*32)], &data1[global_a_off + (192*K)], 16); + __pipeline_memcpy_async(&smem_b_store[store_smem_b_off + ( 0)], &data2[global_b_off + ( 0)], 16); + __pipeline_memcpy_async(&smem_b_store[store_smem_b_off + (16*128)], &data2[global_b_off + ( 16*N)], 16); + global_a_off += 32; + global_b_off += 32 * N; + } + __pipeline_commit(); + + // wait next tile + __pipeline_wait_prior(1); + __syncthreads(); + + // load K=0 for the next tile + __ldmatrix_a_elems(&a_frag_0_k_0, &smem_a_next[load_smem_a_0_k_0]); + __ldmatrix_a_elems(&a_frag_1_k_0, &smem_a_next[load_smem_a_1_k_0]); + __ldmatrix_a_elems(&a_frag_2_k_0, &smem_a_next[load_smem_a_2_k_0]); + __ldmatrix_a_elems(&a_frag_3_k_0, &smem_a_next[load_smem_a_3_k_0]); + __ldmatrix_b_elems(&b_frag_0_k_0, &b_frag_1_k_0, &smem_b_next[load_smem_b_0_k_0]); + __ldmatrix_b_elems(&b_frag_2_k_0, &b_frag_3_k_0, &smem_b_next[load_smem_b_1_k_0]); + __ldmatrix_b_elems(&b_frag_4_k_0, &b_frag_5_k_0, &smem_b_next[load_smem_b_2_k_0]); + __ldmatrix_b_elems(&b_frag_6_k_0, &b_frag_7_k_0, &smem_b_next[load_smem_b_3_k_0]); + + // MMA K=1, (M=4 x N=8) + acc_frag_0_0 = __WMMA_8_16_16_half_half(a_frag_0_k_1, b_frag_0_k_1, acc_frag_0_0); + acc_frag_0_1 = __WMMA_8_16_16_half_half(a_frag_0_k_1, b_frag_1_k_1, acc_frag_0_1); + acc_frag_0_2 = __WMMA_8_16_16_half_half(a_frag_0_k_1, b_frag_2_k_1, acc_frag_0_2); + acc_frag_0_3 = __WMMA_8_16_16_half_half(a_frag_0_k_1, b_frag_3_k_1, acc_frag_0_3); + acc_frag_0_4 = __WMMA_8_16_16_half_half(a_frag_0_k_1, b_frag_4_k_1, acc_frag_0_4); + acc_frag_0_5 = __WMMA_8_16_16_half_half(a_frag_0_k_1, b_frag_5_k_1, acc_frag_0_5); + acc_frag_0_6 = __WMMA_8_16_16_half_half(a_frag_0_k_1, b_frag_6_k_1, acc_frag_0_6); + acc_frag_0_7 = __WMMA_8_16_16_half_half(a_frag_0_k_1, b_frag_7_k_1, acc_frag_0_7); + acc_frag_1_0 = __WMMA_8_16_16_half_half(a_frag_1_k_1, b_frag_0_k_1, acc_frag_1_0); + acc_frag_1_1 = __WMMA_8_16_16_half_half(a_frag_1_k_1, b_frag_1_k_1, acc_frag_1_1); + acc_frag_1_2 = __WMMA_8_16_16_half_half(a_frag_1_k_1, b_frag_2_k_1, acc_frag_1_2); + acc_frag_1_3 = __WMMA_8_16_16_half_half(a_frag_1_k_1, b_frag_3_k_1, acc_frag_1_3); + acc_frag_1_4 = __WMMA_8_16_16_half_half(a_frag_1_k_1, b_frag_4_k_1, acc_frag_1_4); + acc_frag_1_5 = __WMMA_8_16_16_half_half(a_frag_1_k_1, b_frag_5_k_1, acc_frag_1_5); + acc_frag_1_6 = __WMMA_8_16_16_half_half(a_frag_1_k_1, b_frag_6_k_1, acc_frag_1_6); + acc_frag_1_7 = __WMMA_8_16_16_half_half(a_frag_1_k_1, b_frag_7_k_1, acc_frag_1_7); + acc_frag_2_0 = __WMMA_8_16_16_half_half(a_frag_2_k_1, b_frag_0_k_1, acc_frag_2_0); + acc_frag_2_1 = __WMMA_8_16_16_half_half(a_frag_2_k_1, b_frag_1_k_1, acc_frag_2_1); + acc_frag_2_2 = __WMMA_8_16_16_half_half(a_frag_2_k_1, b_frag_2_k_1, acc_frag_2_2); + acc_frag_2_3 = __WMMA_8_16_16_half_half(a_frag_2_k_1, b_frag_3_k_1, acc_frag_2_3); + acc_frag_2_4 = __WMMA_8_16_16_half_half(a_frag_2_k_1, b_frag_4_k_1, acc_frag_2_4); + acc_frag_2_5 = __WMMA_8_16_16_half_half(a_frag_2_k_1, b_frag_5_k_1, acc_frag_2_5); + acc_frag_2_6 = __WMMA_8_16_16_half_half(a_frag_2_k_1, b_frag_6_k_1, acc_frag_2_6); + acc_frag_2_7 = __WMMA_8_16_16_half_half(a_frag_2_k_1, b_frag_7_k_1, acc_frag_2_7); + acc_frag_3_0 = __WMMA_8_16_16_half_half(a_frag_3_k_1, b_frag_0_k_1, acc_frag_3_0); + acc_frag_3_1 = __WMMA_8_16_16_half_half(a_frag_3_k_1, b_frag_1_k_1, acc_frag_3_1); + acc_frag_3_2 = __WMMA_8_16_16_half_half(a_frag_3_k_1, b_frag_2_k_1, acc_frag_3_2); + acc_frag_3_3 = __WMMA_8_16_16_half_half(a_frag_3_k_1, b_frag_3_k_1, acc_frag_3_3); + acc_frag_3_4 = __WMMA_8_16_16_half_half(a_frag_3_k_1, b_frag_4_k_1, acc_frag_3_4); + acc_frag_3_5 = __WMMA_8_16_16_half_half(a_frag_3_k_1, b_frag_5_k_1, acc_frag_3_5); + acc_frag_3_6 = __WMMA_8_16_16_half_half(a_frag_3_k_1, b_frag_6_k_1, acc_frag_3_6); + acc_frag_3_7 = __WMMA_8_16_16_half_half(a_frag_3_k_1, b_frag_7_k_1, acc_frag_3_7); + } + + // write accumulators to output + __pipeline_wait_prior(0); + __syncthreads(); + + // slower way: write accs one by one to data0 + + size_t wg_c_off = ((grid_m * 256) * N) + (grid_n * 128) + (wg_m * 16 * N) + (wg_n * 16); + size_t thread_c_off = ((wg_threads % 4) * 2) + (((wg_threads / 4) % 8) * N); + data0[wg_c_off + thread_c_off + 0 + ( 0*8)] = acc_frag_0_0.x; + data0[wg_c_off + thread_c_off + 1 + ( 0*8)] = acc_frag_0_0.y; + data0[wg_c_off + thread_c_off + (8 * N) + 0 + ( 0*8)] = acc_frag_0_0.z; + data0[wg_c_off + thread_c_off + (8 * N) + 1 + ( 0*8)] = acc_frag_0_0.w; + data0[wg_c_off + thread_c_off + 0 + ( 1*8)] = acc_frag_0_1.x; + data0[wg_c_off + thread_c_off + 1 + ( 1*8)] = acc_frag_0_1.y; + data0[wg_c_off + thread_c_off + (8 * N) + 0 + ( 1*8)] = acc_frag_0_1.z; + data0[wg_c_off + thread_c_off + (8 * N) + 1 + ( 1*8)] = acc_frag_0_1.w; + data0[wg_c_off + thread_c_off + 0 + ( 4*8)] = acc_frag_0_2.x; + data0[wg_c_off + thread_c_off + 1 + ( 4*8)] = acc_frag_0_2.y; + data0[wg_c_off + thread_c_off + (8 * N) + 0 + ( 4*8)] = acc_frag_0_2.z; + data0[wg_c_off + thread_c_off + (8 * N) + 1 + ( 4*8)] = acc_frag_0_2.w; + data0[wg_c_off + thread_c_off + 0 + ( 5*8)] = acc_frag_0_3.x; + data0[wg_c_off + thread_c_off + 1 + ( 5*8)] = acc_frag_0_3.y; + data0[wg_c_off + thread_c_off + (8 * N) + 0 + ( 5*8)] = acc_frag_0_3.z; + data0[wg_c_off + thread_c_off + (8 * N) + 1 + ( 5*8)] = acc_frag_0_3.w; + data0[wg_c_off + thread_c_off + 0 + ( 8*8)] = acc_frag_0_4.x; + data0[wg_c_off + thread_c_off + 1 + ( 8*8)] = acc_frag_0_4.y; + data0[wg_c_off + thread_c_off + (8 * N) + 0 + ( 8*8)] = acc_frag_0_4.z; + data0[wg_c_off + thread_c_off + (8 * N) + 1 + ( 8*8)] = acc_frag_0_4.w; + data0[wg_c_off + thread_c_off + 0 + ( 9*8)] = acc_frag_0_5.x; + data0[wg_c_off + thread_c_off + 1 + ( 9*8)] = acc_frag_0_5.y; + data0[wg_c_off + thread_c_off + (8 * N) + 0 + ( 9*8)] = acc_frag_0_5.z; + data0[wg_c_off + thread_c_off + (8 * N) + 1 + ( 9*8)] = acc_frag_0_5.w; + data0[wg_c_off + thread_c_off + 0 + (12*8)] = acc_frag_0_6.x; + data0[wg_c_off + thread_c_off + 1 + (12*8)] = acc_frag_0_6.y; + data0[wg_c_off + thread_c_off + (8 * N) + 0 + (12*8)] = acc_frag_0_6.z; + data0[wg_c_off + thread_c_off + (8 * N) + 1 + (12*8)] = acc_frag_0_6.w; + data0[wg_c_off + thread_c_off + 0 + (13*8)] = acc_frag_0_7.x; + data0[wg_c_off + thread_c_off + 1 + (13*8)] = acc_frag_0_7.y; + data0[wg_c_off + thread_c_off + (8 * N) + 0 + (13*8)] = acc_frag_0_7.z; + data0[wg_c_off + thread_c_off + (8 * N) + 1 + (13*8)] = acc_frag_0_7.w; + + wg_c_off += 64*N; + data0[wg_c_off + thread_c_off + 0 + ( 0*8)] = acc_frag_1_0.x; + data0[wg_c_off + thread_c_off + 1 + ( 0*8)] = acc_frag_1_0.y; + data0[wg_c_off + thread_c_off + (8 * N) + 0 + ( 0*8)] = acc_frag_1_0.z; + data0[wg_c_off + thread_c_off + (8 * N) + 1 + ( 0*8)] = acc_frag_1_0.w; + data0[wg_c_off + thread_c_off + 0 + ( 1*8)] = acc_frag_1_1.x; + data0[wg_c_off + thread_c_off + 1 + ( 1*8)] = acc_frag_1_1.y; + data0[wg_c_off + thread_c_off + (8 * N) + 0 + ( 1*8)] = acc_frag_1_1.z; + data0[wg_c_off + thread_c_off + (8 * N) + 1 + ( 1*8)] = acc_frag_1_1.w; + data0[wg_c_off + thread_c_off + 0 + ( 4*8)] = acc_frag_1_2.x; + data0[wg_c_off + thread_c_off + 1 + ( 4*8)] = acc_frag_1_2.y; + data0[wg_c_off + thread_c_off + (8 * N) + 0 + ( 4*8)] = acc_frag_1_2.z; + data0[wg_c_off + thread_c_off + (8 * N) + 1 + ( 4*8)] = acc_frag_1_2.w; + data0[wg_c_off + thread_c_off + 0 + ( 5*8)] = acc_frag_1_3.x; + data0[wg_c_off + thread_c_off + 1 + ( 5*8)] = acc_frag_1_3.y; + data0[wg_c_off + thread_c_off + (8 * N) + 0 + ( 5*8)] = acc_frag_1_3.z; + data0[wg_c_off + thread_c_off + (8 * N) + 1 + ( 5*8)] = acc_frag_1_3.w; + data0[wg_c_off + thread_c_off + 0 + ( 8*8)] = acc_frag_1_4.x; + data0[wg_c_off + thread_c_off + 1 + ( 8*8)] = acc_frag_1_4.y; + data0[wg_c_off + thread_c_off + (8 * N) + 0 + ( 8*8)] = acc_frag_1_4.z; + data0[wg_c_off + thread_c_off + (8 * N) + 1 + ( 8*8)] = acc_frag_1_4.w; + data0[wg_c_off + thread_c_off + 0 + ( 9*8)] = acc_frag_1_5.x; + data0[wg_c_off + thread_c_off + 1 + ( 9*8)] = acc_frag_1_5.y; + data0[wg_c_off + thread_c_off + (8 * N) + 0 + ( 9*8)] = acc_frag_1_5.z; + data0[wg_c_off + thread_c_off + (8 * N) + 1 + ( 9*8)] = acc_frag_1_5.w; + data0[wg_c_off + thread_c_off + 0 + (12*8)] = acc_frag_1_6.x; + data0[wg_c_off + thread_c_off + 1 + (12*8)] = acc_frag_1_6.y; + data0[wg_c_off + thread_c_off + (8 * N) + 0 + (12*8)] = acc_frag_1_6.z; + data0[wg_c_off + thread_c_off + (8 * N) + 1 + (12*8)] = acc_frag_1_6.w; + data0[wg_c_off + thread_c_off + 0 + (13*8)] = acc_frag_1_7.x; + data0[wg_c_off + thread_c_off + 1 + (13*8)] = acc_frag_1_7.y; + data0[wg_c_off + thread_c_off + (8 * N) + 0 + (13*8)] = acc_frag_1_7.z; + data0[wg_c_off + thread_c_off + (8 * N) + 1 + (13*8)] = acc_frag_1_7.w; + + wg_c_off += 64*N; + data0[wg_c_off + thread_c_off + 0 + ( 0*8)] = acc_frag_2_0.x; + data0[wg_c_off + thread_c_off + 1 + ( 0*8)] = acc_frag_2_0.y; + data0[wg_c_off + thread_c_off + (8 * N) + 0 + ( 0*8)] = acc_frag_2_0.z; + data0[wg_c_off + thread_c_off + (8 * N) + 1 + ( 0*8)] = acc_frag_2_0.w; + data0[wg_c_off + thread_c_off + 0 + ( 1*8)] = acc_frag_2_1.x; + data0[wg_c_off + thread_c_off + 1 + ( 1*8)] = acc_frag_2_1.y; + data0[wg_c_off + thread_c_off + (8 * N) + 0 + ( 1*8)] = acc_frag_2_1.z; + data0[wg_c_off + thread_c_off + (8 * N) + 1 + ( 1*8)] = acc_frag_2_1.w; + data0[wg_c_off + thread_c_off + 0 + ( 4*8)] = acc_frag_2_2.x; + data0[wg_c_off + thread_c_off + 1 + ( 4*8)] = acc_frag_2_2.y; + data0[wg_c_off + thread_c_off + (8 * N) + 0 + ( 4*8)] = acc_frag_2_2.z; + data0[wg_c_off + thread_c_off + (8 * N) + 1 + ( 4*8)] = acc_frag_2_2.w; + data0[wg_c_off + thread_c_off + 0 + ( 5*8)] = acc_frag_2_3.x; + data0[wg_c_off + thread_c_off + 1 + ( 5*8)] = acc_frag_2_3.y; + data0[wg_c_off + thread_c_off + (8 * N) + 0 + ( 5*8)] = acc_frag_2_3.z; + data0[wg_c_off + thread_c_off + (8 * N) + 1 + ( 5*8)] = acc_frag_2_3.w; + data0[wg_c_off + thread_c_off + 0 + ( 8*8)] = acc_frag_2_4.x; + data0[wg_c_off + thread_c_off + 1 + ( 8*8)] = acc_frag_2_4.y; + data0[wg_c_off + thread_c_off + (8 * N) + 0 + ( 8*8)] = acc_frag_2_4.z; + data0[wg_c_off + thread_c_off + (8 * N) + 1 + ( 8*8)] = acc_frag_2_4.w; + data0[wg_c_off + thread_c_off + 0 + ( 9*8)] = acc_frag_2_5.x; + data0[wg_c_off + thread_c_off + 1 + ( 9*8)] = acc_frag_2_5.y; + data0[wg_c_off + thread_c_off + (8 * N) + 0 + ( 9*8)] = acc_frag_2_5.z; + data0[wg_c_off + thread_c_off + (8 * N) + 1 + ( 9*8)] = acc_frag_2_5.w; + data0[wg_c_off + thread_c_off + 0 + (12*8)] = acc_frag_2_6.x; + data0[wg_c_off + thread_c_off + 1 + (12*8)] = acc_frag_2_6.y; + data0[wg_c_off + thread_c_off + (8 * N) + 0 + (12*8)] = acc_frag_2_6.z; + data0[wg_c_off + thread_c_off + (8 * N) + 1 + (12*8)] = acc_frag_2_6.w; + data0[wg_c_off + thread_c_off + 0 + (13*8)] = acc_frag_2_7.x; + data0[wg_c_off + thread_c_off + 1 + (13*8)] = acc_frag_2_7.y; + data0[wg_c_off + thread_c_off + (8 * N) + 0 + (13*8)] = acc_frag_2_7.z; + data0[wg_c_off + thread_c_off + (8 * N) + 1 + (13*8)] = acc_frag_2_7.w; + + wg_c_off += 64*N; + data0[wg_c_off + thread_c_off + 0 + ( 0*8)] = acc_frag_3_0.x; + data0[wg_c_off + thread_c_off + 1 + ( 0*8)] = acc_frag_3_0.y; + data0[wg_c_off + thread_c_off + (8 * N) + 0 + ( 0*8)] = acc_frag_3_0.z; + data0[wg_c_off + thread_c_off + (8 * N) + 1 + ( 0*8)] = acc_frag_3_0.w; + data0[wg_c_off + thread_c_off + 0 + ( 1*8)] = acc_frag_3_1.x; + data0[wg_c_off + thread_c_off + 1 + ( 1*8)] = acc_frag_3_1.y; + data0[wg_c_off + thread_c_off + (8 * N) + 0 + ( 1*8)] = acc_frag_3_1.z; + data0[wg_c_off + thread_c_off + (8 * N) + 1 + ( 1*8)] = acc_frag_3_1.w; + data0[wg_c_off + thread_c_off + 0 + ( 4*8)] = acc_frag_3_2.x; + data0[wg_c_off + thread_c_off + 1 + ( 4*8)] = acc_frag_3_2.y; + data0[wg_c_off + thread_c_off + (8 * N) + 0 + ( 4*8)] = acc_frag_3_2.z; + data0[wg_c_off + thread_c_off + (8 * N) + 1 + ( 4*8)] = acc_frag_3_2.w; + data0[wg_c_off + thread_c_off + 0 + ( 5*8)] = acc_frag_3_3.x; + data0[wg_c_off + thread_c_off + 1 + ( 5*8)] = acc_frag_3_3.y; + data0[wg_c_off + thread_c_off + (8 * N) + 0 + ( 5*8)] = acc_frag_3_3.z; + data0[wg_c_off + thread_c_off + (8 * N) + 1 + ( 5*8)] = acc_frag_3_3.w; + data0[wg_c_off + thread_c_off + 0 + ( 8*8)] = acc_frag_3_4.x; + data0[wg_c_off + thread_c_off + 1 + ( 8*8)] = acc_frag_3_4.y; + data0[wg_c_off + thread_c_off + (8 * N) + 0 + ( 8*8)] = acc_frag_3_4.z; + data0[wg_c_off + thread_c_off + (8 * N) + 1 + ( 8*8)] = acc_frag_3_4.w; + data0[wg_c_off + thread_c_off + 0 + ( 9*8)] = acc_frag_3_5.x; + data0[wg_c_off + thread_c_off + 1 + ( 9*8)] = acc_frag_3_5.y; + data0[wg_c_off + thread_c_off + (8 * N) + 0 + ( 9*8)] = acc_frag_3_5.z; + data0[wg_c_off + thread_c_off + (8 * N) + 1 + ( 9*8)] = acc_frag_3_5.w; + data0[wg_c_off + thread_c_off + 0 + (12*8)] = acc_frag_3_6.x; + data0[wg_c_off + thread_c_off + 1 + (12*8)] = acc_frag_3_6.y; + data0[wg_c_off + thread_c_off + (8 * N) + 0 + (12*8)] = acc_frag_3_6.z; + data0[wg_c_off + thread_c_off + (8 * N) + 1 + (12*8)] = acc_frag_3_6.w; + data0[wg_c_off + thread_c_off + 0 + (13*8)] = acc_frag_3_7.x; + data0[wg_c_off + thread_c_off + 1 + (13*8)] = acc_frag_3_7.y; + data0[wg_c_off + thread_c_off + (8 * N) + 0 + (13*8)] = acc_frag_3_7.z; + data0[wg_c_off + thread_c_off + (8 * N) + 1 + (13*8)] = acc_frag_3_7.w; + +} diff --git a/extra/gemm/max_kernels/nv.fp16_fp16_fp16.3_stage_swizzled.cu b/extra/gemm/max_kernels/nv.fp16_fp16_fp16.3_stage_swizzled.cu new file mode 100644 index 0000000000..0352666b23 --- /dev/null +++ b/extra/gemm/max_kernels/nv.fp16_fp16_fp16.3_stage_swizzled.cu @@ -0,0 +1,517 @@ +#define INFINITY (__int_as_float(0x7f800000)) +#define NAN (__int_as_float(0x7fffffff)) +#include +#include +#define N_PAD 132 + +struct __align__(8) half4 { half x, y, z, w; }; +__device__ half4 make_half4(half x, half y, half z, half w) { half4 r={x, y, z, w}; return r; } + +struct __align__(16) half8 { half x, y, z, w, a, b, c, d; }; +__device__ half8 make_half8(half x, half y, half z, half w, half a, half b, half c, half d) { half8 r={x, y, z, w, a, b, c, d}; return r; } + +__device__ void __ldmatrix_a_elems(half8 *regs, half *smem) { + uint32_t reg0, reg1, reg2, reg3; + asm volatile( + "ldmatrix.sync.aligned.m8n8.x4.shared.b16 {%0, %1, %2, %3}, [%4];" + : "=r"(reg0), "=r"(reg1), "=r"(reg2), "=r"(reg3) + : "l"(__cvta_generic_to_shared(smem)) + ); + uint32_t *addr = reinterpret_cast(regs); + addr[0] = reg0; + addr[1] = reg1; + addr[2] = reg2; + addr[3] = reg3; +} + +__device__ void __ldmatrix_b_elems(half4 *regs_lo, half4 *regs_hi, half *smem) { + uint32_t reg0, reg1, reg2, reg3; + asm volatile( + "ldmatrix.sync.aligned.m8n8.x4.trans.shared.b16 {%0, %1, %2, %3}, [%4];" + : "=r"(reg0), "=r"(reg1), "=r"(reg2), "=r"(reg3) + : "l"(__cvta_generic_to_shared(smem)) + ); + uint32_t *addr_lo = reinterpret_cast(regs_lo); + uint32_t *addr_hi = reinterpret_cast(regs_hi); + addr_lo[0] = reg0; + addr_lo[1] = reg1; + addr_hi[0] = reg2; + addr_hi[1] = reg3; +} + +__device__ half4 __WMMA_8_16_16_half_half(half8 a, half4 b, half4 c) { + int *a_pk = (int *) (&a), *b_pk = (int *) (&b), *c_pk = (int *) (&c); + asm( "mma.sync.aligned.m16n8k16.row.col.f16.f16.f16.f16 { %0, %1 }, { %2, %3, %4, %5 }, { %6, %7 }, { %0, %1 };" + : "+r"(c_pk[0]), "+r"(c_pk[1]): "r"(a_pk[0]), "r"(a_pk[1]), "r"(a_pk[2]), "r"(a_pk[3]), "r"(b_pk[0]), "r"(b_pk[1]) ); + return c; +} + +extern "C" __global__ void __launch_bounds__(256) wmma_example(half* data0, const half* data1, const half* data2, int N, int K) { + extern __shared__ char smem[]; + half *smem_a_0 = (half *)(smem); + half *smem_a_1 = (half *)(smem + 16384); + half *smem_a_2 = (half *)(smem + 32768); + half *smem_b_0 = (half *)(smem + 49152); + half *smem_b_1 = (half *)(smem + 57344); + half *smem_b_2 = (half *)(smem + 65536); + + int grid_m = blockIdx.x; /* M//256 */ + int grid_n = blockIdx.y; /* N//128 */ + int wg_threads = threadIdx.x; // 32 + int wg_m = threadIdx.y; // 4 + int wg_n = threadIdx.z; // 2 + int threads = threadIdx.x + (threadIdx.y * 32) + (threadIdx.z * 128); /* 256 */ + int num_k_blocks = K / 32; + + // ldmatrix indices + // threads 0-7 are row starts for A, 8-15 for B, 16-23 for C, 24-31 for D + // [ A | C ] + // [ - + - ] + // [ B | D ] + + // unswizzled A - SMEM_A is 256 rows x 32 cols + // size_t global_a_off = ((grid_m * 256) * K) + ((threads % 4) * 8) + ((threads / 4) * K); + // size_t store_smem_a_off = ((threads % 4) * 8) + ((threads / 4) * 32); // 64 rows / 32 cols per copy + // size_t load_smem_a_0_k_0 = (wg_m * 16 * 32) + ((wg_threads % 16) * 32) + ((wg_threads / 16) * 8); + // size_t load_smem_a_1_k_0 = load_smem_a_0_k_0 + ( 64 * 32); + // size_t load_smem_a_2_k_0 = load_smem_a_0_k_0 + (128 * 32); + // size_t load_smem_a_3_k_0 = load_smem_a_0_k_0 + (192 * 32); + // size_t load_smem_a_0_k_1 = load_smem_a_0_k_0 + 16; + // size_t load_smem_a_1_k_1 = load_smem_a_0_k_1 + ( 64 * 32); + // size_t load_smem_a_2_k_1 = load_smem_a_0_k_1 + (128 * 32); + // size_t load_smem_a_3_k_1 = load_smem_a_0_k_1 + (192 * 32); + + // unswizzled reshaped A - SMEM_A is 128 rows x 64 cols, [ (M=0, K=0), (M=0, K=1), (M=8, K=0), (M=8, K=1) ], etc. + // size_t global_a_off = ((grid_m * 256) * K) + ((threads % 4) * 8) + (((threads / 4) % 2) * 8 * 16 * K) + ((threads / 8) * K); + // size_t store_smem_a_off = ((threads % 8) * 8) + ((threads / 8) * 64); // 32 rows / 64 cols per copy + // size_t load_smem_a_0_k_0 = (wg_m * 16 * 64) + ((wg_threads % 16) * 64) + ((wg_threads / 16) * 8); + // size_t load_smem_a_1_k_0 = load_smem_a_0_k_0 + (64 * 64); + // size_t load_smem_a_2_k_0 = load_smem_a_0_k_0 + + 32; + // size_t load_smem_a_3_k_0 = load_smem_a_0_k_0 + (64 * 64) + 32; + // size_t load_smem_a_0_k_1 = load_smem_a_0_k_0 + 16; + // size_t load_smem_a_1_k_1 = load_smem_a_1_k_0 + 16; + // size_t load_smem_a_2_k_1 = load_smem_a_2_k_0 + 16; + // size_t load_smem_a_3_k_1 = load_smem_a_3_k_0 + 16; + + // swizzled A + size_t global_a_off = ((grid_m * 256) * K) + ((threads % 4) * 8) + (((threads / 4) % 2) * 8 * 16 * K) + ((threads / 8) * K); + size_t store_smem_a_off = ((threads / 8) * 64) + (((threads * 8) ^ threads) & 56); // 32 rows / 64 cols per copy + size_t load_smem_a_row = ((wg_m * 16) + (threads % 16)) * 64; + size_t load_smem_a_phase = (threads / 16) % 2; + size_t load_smem_a_0_k_0 = load_smem_a_row + ( 0 * 64) + (((load_smem_a_phase + 0) ^ (threads % 8)) * 8); + size_t load_smem_a_1_k_0 = load_smem_a_row + (64 * 64) + (((load_smem_a_phase + 0) ^ (threads % 8)) * 8); + size_t load_smem_a_2_k_0 = load_smem_a_row + ( 0 * 64) + (((load_smem_a_phase + 4) ^ (threads % 8)) * 8); + size_t load_smem_a_3_k_0 = load_smem_a_row + (64 * 64) + (((load_smem_a_phase + 4) ^ (threads % 8)) * 8); + size_t load_smem_a_0_k_1 = load_smem_a_row + ( 0 * 64) + (((load_smem_a_phase + 2) ^ (threads % 8)) * 8); + size_t load_smem_a_1_k_1 = load_smem_a_row + (64 * 64) + (((load_smem_a_phase + 2) ^ (threads % 8)) * 8); + size_t load_smem_a_2_k_1 = load_smem_a_row + ( 0 * 64) + (((load_smem_a_phase + 6) ^ (threads % 8)) * 8); + size_t load_smem_a_3_k_1 = load_smem_a_row + (64 * 64) + (((load_smem_a_phase + 6) ^ (threads % 8)) * 8); + + // unswizzed B + // size_t global_b_off = (grid_n * 128) + ((threads % 16) * 8) + ((threads / 16) * N); + // size_t store_smem_b_off = ((threads % 16) * 8) + ((threads / 16) * 128); // 16 rows / 128 cols per copy + // size_t load_smem_b_0_k_0 = (wg_n * 16) + ((wg_threads % 16) * 128) + ((wg_threads / 16) * 8); + // size_t load_smem_b_1_k_0 = load_smem_b_0_k_0 + 32; + // size_t load_smem_b_2_k_0 = load_smem_b_0_k_0 + 64; + // size_t load_smem_b_3_k_0 = load_smem_b_0_k_0 + 96; + // size_t load_smem_b_0_k_1 = load_smem_b_0_k_0 + (16 * 128); + // size_t load_smem_b_1_k_1 = load_smem_b_0_k_1 + 32; + // size_t load_smem_b_2_k_1 = load_smem_b_0_k_1 + 64; + // size_t load_smem_b_3_k_1 = load_smem_b_0_k_1 + 96; + + // swizzled B + size_t global_b_off = (grid_n * 128) + ((threads % 16) * 8) + ((threads / 16) * N); + size_t store_smem_b_off = ((threads / 16) * 128) + ((((threads / 16) % 8) * 8) ^ ((threads % 16) * 8)); // 16 rows / 128 cols per copy + size_t load_smem_b_row = (threads % 16) * 128; + size_t load_smem_b_phase = (wg_n * 2) + (wg_threads / 16); + size_t load_smem_b_0_k_0 = load_smem_b_row + (((load_smem_b_phase + 0) ^ (threads % 8)) * 8); + size_t load_smem_b_1_k_0 = load_smem_b_row + (((load_smem_b_phase + 4) ^ (threads % 8)) * 8); + size_t load_smem_b_2_k_0 = load_smem_b_row + (((load_smem_b_phase + 8) ^ (threads % 8)) * 8); + size_t load_smem_b_3_k_0 = load_smem_b_row + (((load_smem_b_phase + 12) ^ (threads % 8)) * 8); + size_t load_smem_b_0_k_1 = load_smem_b_0_k_0 + (16 * 128); + size_t load_smem_b_1_k_1 = load_smem_b_1_k_0 + (16 * 128); + size_t load_smem_b_2_k_1 = load_smem_b_2_k_0 + (16 * 128); + size_t load_smem_b_3_k_1 = load_smem_b_3_k_0 + (16 * 128); + + // create accs (M=4, N=8) + half4 acc_frag_0_0 = make_half4(0.0f,0.0f,0.0f,0.0f); + half4 acc_frag_0_1 = make_half4(0.0f,0.0f,0.0f,0.0f); + half4 acc_frag_0_2 = make_half4(0.0f,0.0f,0.0f,0.0f); + half4 acc_frag_0_3 = make_half4(0.0f,0.0f,0.0f,0.0f); + half4 acc_frag_0_4 = make_half4(0.0f,0.0f,0.0f,0.0f); + half4 acc_frag_0_5 = make_half4(0.0f,0.0f,0.0f,0.0f); + half4 acc_frag_0_6 = make_half4(0.0f,0.0f,0.0f,0.0f); + half4 acc_frag_0_7 = make_half4(0.0f,0.0f,0.0f,0.0f); + half4 acc_frag_1_0 = make_half4(0.0f,0.0f,0.0f,0.0f); + half4 acc_frag_1_1 = make_half4(0.0f,0.0f,0.0f,0.0f); + half4 acc_frag_1_2 = make_half4(0.0f,0.0f,0.0f,0.0f); + half4 acc_frag_1_3 = make_half4(0.0f,0.0f,0.0f,0.0f); + half4 acc_frag_1_4 = make_half4(0.0f,0.0f,0.0f,0.0f); + half4 acc_frag_1_5 = make_half4(0.0f,0.0f,0.0f,0.0f); + half4 acc_frag_1_6 = make_half4(0.0f,0.0f,0.0f,0.0f); + half4 acc_frag_1_7 = make_half4(0.0f,0.0f,0.0f,0.0f); + half4 acc_frag_2_0 = make_half4(0.0f,0.0f,0.0f,0.0f); + half4 acc_frag_2_1 = make_half4(0.0f,0.0f,0.0f,0.0f); + half4 acc_frag_2_2 = make_half4(0.0f,0.0f,0.0f,0.0f); + half4 acc_frag_2_3 = make_half4(0.0f,0.0f,0.0f,0.0f); + half4 acc_frag_2_4 = make_half4(0.0f,0.0f,0.0f,0.0f); + half4 acc_frag_2_5 = make_half4(0.0f,0.0f,0.0f,0.0f); + half4 acc_frag_2_6 = make_half4(0.0f,0.0f,0.0f,0.0f); + half4 acc_frag_2_7 = make_half4(0.0f,0.0f,0.0f,0.0f); + half4 acc_frag_3_0 = make_half4(0.0f,0.0f,0.0f,0.0f); + half4 acc_frag_3_1 = make_half4(0.0f,0.0f,0.0f,0.0f); + half4 acc_frag_3_2 = make_half4(0.0f,0.0f,0.0f,0.0f); + half4 acc_frag_3_3 = make_half4(0.0f,0.0f,0.0f,0.0f); + half4 acc_frag_3_4 = make_half4(0.0f,0.0f,0.0f,0.0f); + half4 acc_frag_3_5 = make_half4(0.0f,0.0f,0.0f,0.0f); + half4 acc_frag_3_6 = make_half4(0.0f,0.0f,0.0f,0.0f); + half4 acc_frag_3_7 = make_half4(0.0f,0.0f,0.0f,0.0f); + + // create registers for block A elements + half8 a_frag_0_k_0; + half8 a_frag_1_k_0; + half8 a_frag_2_k_0; + half8 a_frag_3_k_0; + half8 a_frag_0_k_1; + half8 a_frag_1_k_1; + half8 a_frag_2_k_1; + half8 a_frag_3_k_1; + + // create register for block B elements + half4 b_frag_0_k_0; + half4 b_frag_1_k_0; + half4 b_frag_2_k_0; + half4 b_frag_3_k_0; + half4 b_frag_4_k_0; + half4 b_frag_5_k_0; + half4 b_frag_6_k_0; + half4 b_frag_7_k_0; + half4 b_frag_0_k_1; + half4 b_frag_1_k_1; + half4 b_frag_2_k_1; + half4 b_frag_3_k_1; + half4 b_frag_4_k_1; + half4 b_frag_5_k_1; + half4 b_frag_6_k_1; + half4 b_frag_7_k_1; + + __syncthreads(); + + // load first tile + // unswizzled 256 x 32 + // __pipeline_memcpy_async(&smem_a_0[store_smem_a_off + ( 0)], &data1[global_a_off + ( 0)], 16); + // __pipeline_memcpy_async(&smem_a_0[store_smem_a_off + ( 64*32)], &data1[global_a_off + ( 64*K)], 16); + // __pipeline_memcpy_async(&smem_a_0[store_smem_a_off + (128*32)], &data1[global_a_off + (128*K)], 16); + // __pipeline_memcpy_async(&smem_a_0[store_smem_a_off + (192*32)], &data1[global_a_off + (192*K)], 16); + // unswizzled 128 x 64 + __pipeline_memcpy_async(&smem_a_0[store_smem_a_off + ( 0)], &data1[global_a_off + ( 0)], 16); + __pipeline_memcpy_async(&smem_a_0[store_smem_a_off + ( 32*64)], &data1[global_a_off + ( 32*K)], 16); + __pipeline_memcpy_async(&smem_a_0[store_smem_a_off + ( 64*64)], &data1[global_a_off + ( 64*K)], 16); + __pipeline_memcpy_async(&smem_a_0[store_smem_a_off + ( 96*64)], &data1[global_a_off + ( 96*K)], 16); + __pipeline_memcpy_async(&smem_b_0[store_smem_b_off + ( 0)], &data2[global_b_off + ( 0)], 16); + __pipeline_memcpy_async(&smem_b_0[store_smem_b_off + (16*128)], &data2[global_b_off + ( 16*N)], 16); + __pipeline_commit(); + global_a_off += 32; + global_b_off += 32 * N; + + // load second tile + // unswizzled 256 x 32 + // __pipeline_memcpy_async(&smem_a_1[store_smem_a_off + ( 0)], &data1[global_a_off + ( 0)], 16); + // __pipeline_memcpy_async(&smem_a_1[store_smem_a_off + ( 64*32)], &data1[global_a_off + ( 64*K)], 16); + // __pipeline_memcpy_async(&smem_a_1[store_smem_a_off + (128*32)], &data1[global_a_off + (128*K)], 16); + // __pipeline_memcpy_async(&smem_a_1[store_smem_a_off + (192*32)], &data1[global_a_off + (192*K)], 16); + // unswizzled 128 x 64 + __pipeline_memcpy_async(&smem_a_1[store_smem_a_off + ( 0)], &data1[global_a_off + ( 0)], 16); + __pipeline_memcpy_async(&smem_a_1[store_smem_a_off + ( 32*64)], &data1[global_a_off + ( 32*K)], 16); + __pipeline_memcpy_async(&smem_a_1[store_smem_a_off + ( 64*64)], &data1[global_a_off + ( 64*K)], 16); + __pipeline_memcpy_async(&smem_a_1[store_smem_a_off + ( 96*64)], &data1[global_a_off + ( 96*K)], 16); + __pipeline_memcpy_async(&smem_b_1[store_smem_b_off + ( 0)], &data2[global_b_off + ( 0)], 16); + __pipeline_memcpy_async(&smem_b_1[store_smem_b_off + (16*128)], &data2[global_b_off + ( 16*N)], 16); + __pipeline_commit(); + + global_a_off += 32; + global_b_off += 32 * N; + + // wait on first pre-fetch load + __pipeline_wait_prior(1); + __syncthreads(); + + // load K=0 for the first tile + __ldmatrix_a_elems(&a_frag_0_k_0, &smem_a_0[load_smem_a_0_k_0]); + __ldmatrix_a_elems(&a_frag_1_k_0, &smem_a_0[load_smem_a_1_k_0]); + __ldmatrix_a_elems(&a_frag_2_k_0, &smem_a_0[load_smem_a_2_k_0]); + __ldmatrix_a_elems(&a_frag_3_k_0, &smem_a_0[load_smem_a_3_k_0]); + __ldmatrix_b_elems(&b_frag_0_k_0, &b_frag_1_k_0, &smem_b_0[load_smem_b_0_k_0]); + __ldmatrix_b_elems(&b_frag_2_k_0, &b_frag_3_k_0, &smem_b_0[load_smem_b_1_k_0]); + __ldmatrix_b_elems(&b_frag_4_k_0, &b_frag_5_k_0, &smem_b_0[load_smem_b_2_k_0]); + __ldmatrix_b_elems(&b_frag_6_k_0, &b_frag_7_k_0, &smem_b_0[load_smem_b_3_k_0]); + + for (int block_k = 0; block_k < num_k_blocks; block_k++) { + int phase_k = block_k % 3; + half *smem_a_curr = (phase_k == 0) ? smem_a_0 : ((phase_k == 1) ? smem_a_1 : smem_a_2); + half *smem_b_curr = (phase_k == 0) ? smem_b_0 : ((phase_k == 1) ? smem_b_1 : smem_b_2); + + int next_phase_k = (block_k+1) % 3; + half *smem_a_next = (next_phase_k == 0) ? smem_a_0 : ((next_phase_k == 1) ? smem_a_1 : smem_a_2); + half *smem_b_next = (next_phase_k == 0) ? smem_b_0 : ((next_phase_k == 1) ? smem_b_1 : smem_b_2); + + int store_phase_k = (block_k+2) % 3; + half *smem_a_store = (store_phase_k == 0) ? smem_a_0 : ((store_phase_k == 1) ? smem_a_1 : smem_a_2); + half *smem_b_store = (store_phase_k == 0) ? smem_b_0 : ((store_phase_k == 1) ? smem_b_1 : smem_b_2); + + // load K=1 elements for the current tile + __ldmatrix_a_elems(&a_frag_0_k_1, &smem_a_curr[load_smem_a_0_k_1]); + __ldmatrix_a_elems(&a_frag_1_k_1, &smem_a_curr[load_smem_a_1_k_1]); + __ldmatrix_a_elems(&a_frag_2_k_1, &smem_a_curr[load_smem_a_2_k_1]); + __ldmatrix_a_elems(&a_frag_3_k_1, &smem_a_curr[load_smem_a_3_k_1]); + __ldmatrix_b_elems(&b_frag_0_k_1, &b_frag_1_k_1, &smem_b_curr[load_smem_b_0_k_1]); + __ldmatrix_b_elems(&b_frag_2_k_1, &b_frag_3_k_1, &smem_b_curr[load_smem_b_1_k_1]); + __ldmatrix_b_elems(&b_frag_4_k_1, &b_frag_5_k_1, &smem_b_curr[load_smem_b_2_k_1]); + __ldmatrix_b_elems(&b_frag_6_k_1, &b_frag_7_k_1, &smem_b_curr[load_smem_b_3_k_1]); + + // MMA K=0, (M=4 x N=8) + acc_frag_0_0 = __WMMA_8_16_16_half_half(a_frag_0_k_0, b_frag_0_k_0, acc_frag_0_0); + acc_frag_0_1 = __WMMA_8_16_16_half_half(a_frag_0_k_0, b_frag_1_k_0, acc_frag_0_1); + acc_frag_0_2 = __WMMA_8_16_16_half_half(a_frag_0_k_0, b_frag_2_k_0, acc_frag_0_2); + acc_frag_0_3 = __WMMA_8_16_16_half_half(a_frag_0_k_0, b_frag_3_k_0, acc_frag_0_3); + acc_frag_0_4 = __WMMA_8_16_16_half_half(a_frag_0_k_0, b_frag_4_k_0, acc_frag_0_4); + acc_frag_0_5 = __WMMA_8_16_16_half_half(a_frag_0_k_0, b_frag_5_k_0, acc_frag_0_5); + acc_frag_0_6 = __WMMA_8_16_16_half_half(a_frag_0_k_0, b_frag_6_k_0, acc_frag_0_6); + acc_frag_0_7 = __WMMA_8_16_16_half_half(a_frag_0_k_0, b_frag_7_k_0, acc_frag_0_7); + acc_frag_1_0 = __WMMA_8_16_16_half_half(a_frag_1_k_0, b_frag_0_k_0, acc_frag_1_0); + acc_frag_1_1 = __WMMA_8_16_16_half_half(a_frag_1_k_0, b_frag_1_k_0, acc_frag_1_1); + acc_frag_1_2 = __WMMA_8_16_16_half_half(a_frag_1_k_0, b_frag_2_k_0, acc_frag_1_2); + acc_frag_1_3 = __WMMA_8_16_16_half_half(a_frag_1_k_0, b_frag_3_k_0, acc_frag_1_3); + acc_frag_1_4 = __WMMA_8_16_16_half_half(a_frag_1_k_0, b_frag_4_k_0, acc_frag_1_4); + acc_frag_1_5 = __WMMA_8_16_16_half_half(a_frag_1_k_0, b_frag_5_k_0, acc_frag_1_5); + acc_frag_1_6 = __WMMA_8_16_16_half_half(a_frag_1_k_0, b_frag_6_k_0, acc_frag_1_6); + acc_frag_1_7 = __WMMA_8_16_16_half_half(a_frag_1_k_0, b_frag_7_k_0, acc_frag_1_7); + acc_frag_2_0 = __WMMA_8_16_16_half_half(a_frag_2_k_0, b_frag_0_k_0, acc_frag_2_0); + acc_frag_2_1 = __WMMA_8_16_16_half_half(a_frag_2_k_0, b_frag_1_k_0, acc_frag_2_1); + acc_frag_2_2 = __WMMA_8_16_16_half_half(a_frag_2_k_0, b_frag_2_k_0, acc_frag_2_2); + acc_frag_2_3 = __WMMA_8_16_16_half_half(a_frag_2_k_0, b_frag_3_k_0, acc_frag_2_3); + acc_frag_2_4 = __WMMA_8_16_16_half_half(a_frag_2_k_0, b_frag_4_k_0, acc_frag_2_4); + acc_frag_2_5 = __WMMA_8_16_16_half_half(a_frag_2_k_0, b_frag_5_k_0, acc_frag_2_5); + acc_frag_2_6 = __WMMA_8_16_16_half_half(a_frag_2_k_0, b_frag_6_k_0, acc_frag_2_6); + acc_frag_2_7 = __WMMA_8_16_16_half_half(a_frag_2_k_0, b_frag_7_k_0, acc_frag_2_7); + acc_frag_3_0 = __WMMA_8_16_16_half_half(a_frag_3_k_0, b_frag_0_k_0, acc_frag_3_0); + acc_frag_3_1 = __WMMA_8_16_16_half_half(a_frag_3_k_0, b_frag_1_k_0, acc_frag_3_1); + acc_frag_3_2 = __WMMA_8_16_16_half_half(a_frag_3_k_0, b_frag_2_k_0, acc_frag_3_2); + acc_frag_3_3 = __WMMA_8_16_16_half_half(a_frag_3_k_0, b_frag_3_k_0, acc_frag_3_3); + acc_frag_3_4 = __WMMA_8_16_16_half_half(a_frag_3_k_0, b_frag_4_k_0, acc_frag_3_4); + acc_frag_3_5 = __WMMA_8_16_16_half_half(a_frag_3_k_0, b_frag_5_k_0, acc_frag_3_5); + acc_frag_3_6 = __WMMA_8_16_16_half_half(a_frag_3_k_0, b_frag_6_k_0, acc_frag_3_6); + acc_frag_3_7 = __WMMA_8_16_16_half_half(a_frag_3_k_0, b_frag_7_k_0, acc_frag_3_7); + + // load next tile + if (block_k < (num_k_blocks-2)) { + // unswizzled 256 x 32 + // __pipeline_memcpy_async(&smem_a_store[store_smem_a_off + ( 0)], &data1[global_a_off + ( 0)], 16); + // __pipeline_memcpy_async(&smem_a_store[store_smem_a_off + ( 64*32)], &data1[global_a_off + ( 64*K)], 16); + // __pipeline_memcpy_async(&smem_a_store[store_smem_a_off + (128*32)], &data1[global_a_off + (128*K)], 16); + // __pipeline_memcpy_async(&smem_a_store[store_smem_a_off + (192*32)], &data1[global_a_off + (192*K)], 16); + // unswizzled 128 x 64 + __pipeline_memcpy_async(&smem_a_store[store_smem_a_off + ( 0)], &data1[global_a_off + ( 0)], 16); + __pipeline_memcpy_async(&smem_a_store[store_smem_a_off + ( 32*64)], &data1[global_a_off + ( 32*K)], 16); + __pipeline_memcpy_async(&smem_a_store[store_smem_a_off + ( 64*64)], &data1[global_a_off + ( 64*K)], 16); + __pipeline_memcpy_async(&smem_a_store[store_smem_a_off + ( 96*64)], &data1[global_a_off + ( 96*K)], 16); + __pipeline_memcpy_async(&smem_b_store[store_smem_b_off + ( 0)], &data2[global_b_off + ( 0)], 16); + __pipeline_memcpy_async(&smem_b_store[store_smem_b_off + (16*128)], &data2[global_b_off + ( 16*N)], 16); + global_a_off += 32; + global_b_off += 32 * N; + } + __pipeline_commit(); + + // wait next tile + __pipeline_wait_prior(1); + __syncthreads(); + + // load K=0 for the next tile + __ldmatrix_a_elems(&a_frag_0_k_0, &smem_a_next[load_smem_a_0_k_0]); + __ldmatrix_a_elems(&a_frag_1_k_0, &smem_a_next[load_smem_a_1_k_0]); + __ldmatrix_a_elems(&a_frag_2_k_0, &smem_a_next[load_smem_a_2_k_0]); + __ldmatrix_a_elems(&a_frag_3_k_0, &smem_a_next[load_smem_a_3_k_0]); + __ldmatrix_b_elems(&b_frag_0_k_0, &b_frag_1_k_0, &smem_b_next[load_smem_b_0_k_0]); + __ldmatrix_b_elems(&b_frag_2_k_0, &b_frag_3_k_0, &smem_b_next[load_smem_b_1_k_0]); + __ldmatrix_b_elems(&b_frag_4_k_0, &b_frag_5_k_0, &smem_b_next[load_smem_b_2_k_0]); + __ldmatrix_b_elems(&b_frag_6_k_0, &b_frag_7_k_0, &smem_b_next[load_smem_b_3_k_0]); + + // MMA K=1, (M=4 x N=8) + acc_frag_0_0 = __WMMA_8_16_16_half_half(a_frag_0_k_1, b_frag_0_k_1, acc_frag_0_0); + acc_frag_0_1 = __WMMA_8_16_16_half_half(a_frag_0_k_1, b_frag_1_k_1, acc_frag_0_1); + acc_frag_0_2 = __WMMA_8_16_16_half_half(a_frag_0_k_1, b_frag_2_k_1, acc_frag_0_2); + acc_frag_0_3 = __WMMA_8_16_16_half_half(a_frag_0_k_1, b_frag_3_k_1, acc_frag_0_3); + acc_frag_0_4 = __WMMA_8_16_16_half_half(a_frag_0_k_1, b_frag_4_k_1, acc_frag_0_4); + acc_frag_0_5 = __WMMA_8_16_16_half_half(a_frag_0_k_1, b_frag_5_k_1, acc_frag_0_5); + acc_frag_0_6 = __WMMA_8_16_16_half_half(a_frag_0_k_1, b_frag_6_k_1, acc_frag_0_6); + acc_frag_0_7 = __WMMA_8_16_16_half_half(a_frag_0_k_1, b_frag_7_k_1, acc_frag_0_7); + acc_frag_1_0 = __WMMA_8_16_16_half_half(a_frag_1_k_1, b_frag_0_k_1, acc_frag_1_0); + acc_frag_1_1 = __WMMA_8_16_16_half_half(a_frag_1_k_1, b_frag_1_k_1, acc_frag_1_1); + acc_frag_1_2 = __WMMA_8_16_16_half_half(a_frag_1_k_1, b_frag_2_k_1, acc_frag_1_2); + acc_frag_1_3 = __WMMA_8_16_16_half_half(a_frag_1_k_1, b_frag_3_k_1, acc_frag_1_3); + acc_frag_1_4 = __WMMA_8_16_16_half_half(a_frag_1_k_1, b_frag_4_k_1, acc_frag_1_4); + acc_frag_1_5 = __WMMA_8_16_16_half_half(a_frag_1_k_1, b_frag_5_k_1, acc_frag_1_5); + acc_frag_1_6 = __WMMA_8_16_16_half_half(a_frag_1_k_1, b_frag_6_k_1, acc_frag_1_6); + acc_frag_1_7 = __WMMA_8_16_16_half_half(a_frag_1_k_1, b_frag_7_k_1, acc_frag_1_7); + acc_frag_2_0 = __WMMA_8_16_16_half_half(a_frag_2_k_1, b_frag_0_k_1, acc_frag_2_0); + acc_frag_2_1 = __WMMA_8_16_16_half_half(a_frag_2_k_1, b_frag_1_k_1, acc_frag_2_1); + acc_frag_2_2 = __WMMA_8_16_16_half_half(a_frag_2_k_1, b_frag_2_k_1, acc_frag_2_2); + acc_frag_2_3 = __WMMA_8_16_16_half_half(a_frag_2_k_1, b_frag_3_k_1, acc_frag_2_3); + acc_frag_2_4 = __WMMA_8_16_16_half_half(a_frag_2_k_1, b_frag_4_k_1, acc_frag_2_4); + acc_frag_2_5 = __WMMA_8_16_16_half_half(a_frag_2_k_1, b_frag_5_k_1, acc_frag_2_5); + acc_frag_2_6 = __WMMA_8_16_16_half_half(a_frag_2_k_1, b_frag_6_k_1, acc_frag_2_6); + acc_frag_2_7 = __WMMA_8_16_16_half_half(a_frag_2_k_1, b_frag_7_k_1, acc_frag_2_7); + acc_frag_3_0 = __WMMA_8_16_16_half_half(a_frag_3_k_1, b_frag_0_k_1, acc_frag_3_0); + acc_frag_3_1 = __WMMA_8_16_16_half_half(a_frag_3_k_1, b_frag_1_k_1, acc_frag_3_1); + acc_frag_3_2 = __WMMA_8_16_16_half_half(a_frag_3_k_1, b_frag_2_k_1, acc_frag_3_2); + acc_frag_3_3 = __WMMA_8_16_16_half_half(a_frag_3_k_1, b_frag_3_k_1, acc_frag_3_3); + acc_frag_3_4 = __WMMA_8_16_16_half_half(a_frag_3_k_1, b_frag_4_k_1, acc_frag_3_4); + acc_frag_3_5 = __WMMA_8_16_16_half_half(a_frag_3_k_1, b_frag_5_k_1, acc_frag_3_5); + acc_frag_3_6 = __WMMA_8_16_16_half_half(a_frag_3_k_1, b_frag_6_k_1, acc_frag_3_6); + acc_frag_3_7 = __WMMA_8_16_16_half_half(a_frag_3_k_1, b_frag_7_k_1, acc_frag_3_7); + } + + // write accumulators to output + __pipeline_wait_prior(0); + __syncthreads(); + + // slower way: write accs one by one to data0 + + size_t wg_c_off = ((grid_m * 256) * N) + (grid_n * 128) + (wg_m * 16 * N) + (wg_n * 16); + size_t thread_c_off = ((wg_threads % 4) * 2) + (((wg_threads / 4) % 8) * N); + data0[wg_c_off + thread_c_off + 0 + ( 0*8)] = acc_frag_0_0.x; + data0[wg_c_off + thread_c_off + 1 + ( 0*8)] = acc_frag_0_0.y; + data0[wg_c_off + thread_c_off + (8 * N) + 0 + ( 0*8)] = acc_frag_0_0.z; + data0[wg_c_off + thread_c_off + (8 * N) + 1 + ( 0*8)] = acc_frag_0_0.w; + data0[wg_c_off + thread_c_off + 0 + ( 1*8)] = acc_frag_0_1.x; + data0[wg_c_off + thread_c_off + 1 + ( 1*8)] = acc_frag_0_1.y; + data0[wg_c_off + thread_c_off + (8 * N) + 0 + ( 1*8)] = acc_frag_0_1.z; + data0[wg_c_off + thread_c_off + (8 * N) + 1 + ( 1*8)] = acc_frag_0_1.w; + data0[wg_c_off + thread_c_off + 0 + ( 4*8)] = acc_frag_0_2.x; + data0[wg_c_off + thread_c_off + 1 + ( 4*8)] = acc_frag_0_2.y; + data0[wg_c_off + thread_c_off + (8 * N) + 0 + ( 4*8)] = acc_frag_0_2.z; + data0[wg_c_off + thread_c_off + (8 * N) + 1 + ( 4*8)] = acc_frag_0_2.w; + data0[wg_c_off + thread_c_off + 0 + ( 5*8)] = acc_frag_0_3.x; + data0[wg_c_off + thread_c_off + 1 + ( 5*8)] = acc_frag_0_3.y; + data0[wg_c_off + thread_c_off + (8 * N) + 0 + ( 5*8)] = acc_frag_0_3.z; + data0[wg_c_off + thread_c_off + (8 * N) + 1 + ( 5*8)] = acc_frag_0_3.w; + data0[wg_c_off + thread_c_off + 0 + ( 8*8)] = acc_frag_0_4.x; + data0[wg_c_off + thread_c_off + 1 + ( 8*8)] = acc_frag_0_4.y; + data0[wg_c_off + thread_c_off + (8 * N) + 0 + ( 8*8)] = acc_frag_0_4.z; + data0[wg_c_off + thread_c_off + (8 * N) + 1 + ( 8*8)] = acc_frag_0_4.w; + data0[wg_c_off + thread_c_off + 0 + ( 9*8)] = acc_frag_0_5.x; + data0[wg_c_off + thread_c_off + 1 + ( 9*8)] = acc_frag_0_5.y; + data0[wg_c_off + thread_c_off + (8 * N) + 0 + ( 9*8)] = acc_frag_0_5.z; + data0[wg_c_off + thread_c_off + (8 * N) + 1 + ( 9*8)] = acc_frag_0_5.w; + data0[wg_c_off + thread_c_off + 0 + (12*8)] = acc_frag_0_6.x; + data0[wg_c_off + thread_c_off + 1 + (12*8)] = acc_frag_0_6.y; + data0[wg_c_off + thread_c_off + (8 * N) + 0 + (12*8)] = acc_frag_0_6.z; + data0[wg_c_off + thread_c_off + (8 * N) + 1 + (12*8)] = acc_frag_0_6.w; + data0[wg_c_off + thread_c_off + 0 + (13*8)] = acc_frag_0_7.x; + data0[wg_c_off + thread_c_off + 1 + (13*8)] = acc_frag_0_7.y; + data0[wg_c_off + thread_c_off + (8 * N) + 0 + (13*8)] = acc_frag_0_7.z; + data0[wg_c_off + thread_c_off + (8 * N) + 1 + (13*8)] = acc_frag_0_7.w; + + wg_c_off += 64*N; + data0[wg_c_off + thread_c_off + 0 + ( 0*8)] = acc_frag_1_0.x; + data0[wg_c_off + thread_c_off + 1 + ( 0*8)] = acc_frag_1_0.y; + data0[wg_c_off + thread_c_off + (8 * N) + 0 + ( 0*8)] = acc_frag_1_0.z; + data0[wg_c_off + thread_c_off + (8 * N) + 1 + ( 0*8)] = acc_frag_1_0.w; + data0[wg_c_off + thread_c_off + 0 + ( 1*8)] = acc_frag_1_1.x; + data0[wg_c_off + thread_c_off + 1 + ( 1*8)] = acc_frag_1_1.y; + data0[wg_c_off + thread_c_off + (8 * N) + 0 + ( 1*8)] = acc_frag_1_1.z; + data0[wg_c_off + thread_c_off + (8 * N) + 1 + ( 1*8)] = acc_frag_1_1.w; + data0[wg_c_off + thread_c_off + 0 + ( 4*8)] = acc_frag_1_2.x; + data0[wg_c_off + thread_c_off + 1 + ( 4*8)] = acc_frag_1_2.y; + data0[wg_c_off + thread_c_off + (8 * N) + 0 + ( 4*8)] = acc_frag_1_2.z; + data0[wg_c_off + thread_c_off + (8 * N) + 1 + ( 4*8)] = acc_frag_1_2.w; + data0[wg_c_off + thread_c_off + 0 + ( 5*8)] = acc_frag_1_3.x; + data0[wg_c_off + thread_c_off + 1 + ( 5*8)] = acc_frag_1_3.y; + data0[wg_c_off + thread_c_off + (8 * N) + 0 + ( 5*8)] = acc_frag_1_3.z; + data0[wg_c_off + thread_c_off + (8 * N) + 1 + ( 5*8)] = acc_frag_1_3.w; + data0[wg_c_off + thread_c_off + 0 + ( 8*8)] = acc_frag_1_4.x; + data0[wg_c_off + thread_c_off + 1 + ( 8*8)] = acc_frag_1_4.y; + data0[wg_c_off + thread_c_off + (8 * N) + 0 + ( 8*8)] = acc_frag_1_4.z; + data0[wg_c_off + thread_c_off + (8 * N) + 1 + ( 8*8)] = acc_frag_1_4.w; + data0[wg_c_off + thread_c_off + 0 + ( 9*8)] = acc_frag_1_5.x; + data0[wg_c_off + thread_c_off + 1 + ( 9*8)] = acc_frag_1_5.y; + data0[wg_c_off + thread_c_off + (8 * N) + 0 + ( 9*8)] = acc_frag_1_5.z; + data0[wg_c_off + thread_c_off + (8 * N) + 1 + ( 9*8)] = acc_frag_1_5.w; + data0[wg_c_off + thread_c_off + 0 + (12*8)] = acc_frag_1_6.x; + data0[wg_c_off + thread_c_off + 1 + (12*8)] = acc_frag_1_6.y; + data0[wg_c_off + thread_c_off + (8 * N) + 0 + (12*8)] = acc_frag_1_6.z; + data0[wg_c_off + thread_c_off + (8 * N) + 1 + (12*8)] = acc_frag_1_6.w; + data0[wg_c_off + thread_c_off + 0 + (13*8)] = acc_frag_1_7.x; + data0[wg_c_off + thread_c_off + 1 + (13*8)] = acc_frag_1_7.y; + data0[wg_c_off + thread_c_off + (8 * N) + 0 + (13*8)] = acc_frag_1_7.z; + data0[wg_c_off + thread_c_off + (8 * N) + 1 + (13*8)] = acc_frag_1_7.w; + + wg_c_off += 64*N; + data0[wg_c_off + thread_c_off + 0 + ( 0*8)] = acc_frag_2_0.x; + data0[wg_c_off + thread_c_off + 1 + ( 0*8)] = acc_frag_2_0.y; + data0[wg_c_off + thread_c_off + (8 * N) + 0 + ( 0*8)] = acc_frag_2_0.z; + data0[wg_c_off + thread_c_off + (8 * N) + 1 + ( 0*8)] = acc_frag_2_0.w; + data0[wg_c_off + thread_c_off + 0 + ( 1*8)] = acc_frag_2_1.x; + data0[wg_c_off + thread_c_off + 1 + ( 1*8)] = acc_frag_2_1.y; + data0[wg_c_off + thread_c_off + (8 * N) + 0 + ( 1*8)] = acc_frag_2_1.z; + data0[wg_c_off + thread_c_off + (8 * N) + 1 + ( 1*8)] = acc_frag_2_1.w; + data0[wg_c_off + thread_c_off + 0 + ( 4*8)] = acc_frag_2_2.x; + data0[wg_c_off + thread_c_off + 1 + ( 4*8)] = acc_frag_2_2.y; + data0[wg_c_off + thread_c_off + (8 * N) + 0 + ( 4*8)] = acc_frag_2_2.z; + data0[wg_c_off + thread_c_off + (8 * N) + 1 + ( 4*8)] = acc_frag_2_2.w; + data0[wg_c_off + thread_c_off + 0 + ( 5*8)] = acc_frag_2_3.x; + data0[wg_c_off + thread_c_off + 1 + ( 5*8)] = acc_frag_2_3.y; + data0[wg_c_off + thread_c_off + (8 * N) + 0 + ( 5*8)] = acc_frag_2_3.z; + data0[wg_c_off + thread_c_off + (8 * N) + 1 + ( 5*8)] = acc_frag_2_3.w; + data0[wg_c_off + thread_c_off + 0 + ( 8*8)] = acc_frag_2_4.x; + data0[wg_c_off + thread_c_off + 1 + ( 8*8)] = acc_frag_2_4.y; + data0[wg_c_off + thread_c_off + (8 * N) + 0 + ( 8*8)] = acc_frag_2_4.z; + data0[wg_c_off + thread_c_off + (8 * N) + 1 + ( 8*8)] = acc_frag_2_4.w; + data0[wg_c_off + thread_c_off + 0 + ( 9*8)] = acc_frag_2_5.x; + data0[wg_c_off + thread_c_off + 1 + ( 9*8)] = acc_frag_2_5.y; + data0[wg_c_off + thread_c_off + (8 * N) + 0 + ( 9*8)] = acc_frag_2_5.z; + data0[wg_c_off + thread_c_off + (8 * N) + 1 + ( 9*8)] = acc_frag_2_5.w; + data0[wg_c_off + thread_c_off + 0 + (12*8)] = acc_frag_2_6.x; + data0[wg_c_off + thread_c_off + 1 + (12*8)] = acc_frag_2_6.y; + data0[wg_c_off + thread_c_off + (8 * N) + 0 + (12*8)] = acc_frag_2_6.z; + data0[wg_c_off + thread_c_off + (8 * N) + 1 + (12*8)] = acc_frag_2_6.w; + data0[wg_c_off + thread_c_off + 0 + (13*8)] = acc_frag_2_7.x; + data0[wg_c_off + thread_c_off + 1 + (13*8)] = acc_frag_2_7.y; + data0[wg_c_off + thread_c_off + (8 * N) + 0 + (13*8)] = acc_frag_2_7.z; + data0[wg_c_off + thread_c_off + (8 * N) + 1 + (13*8)] = acc_frag_2_7.w; + + wg_c_off += 64*N; + data0[wg_c_off + thread_c_off + 0 + ( 0*8)] = acc_frag_3_0.x; + data0[wg_c_off + thread_c_off + 1 + ( 0*8)] = acc_frag_3_0.y; + data0[wg_c_off + thread_c_off + (8 * N) + 0 + ( 0*8)] = acc_frag_3_0.z; + data0[wg_c_off + thread_c_off + (8 * N) + 1 + ( 0*8)] = acc_frag_3_0.w; + data0[wg_c_off + thread_c_off + 0 + ( 1*8)] = acc_frag_3_1.x; + data0[wg_c_off + thread_c_off + 1 + ( 1*8)] = acc_frag_3_1.y; + data0[wg_c_off + thread_c_off + (8 * N) + 0 + ( 1*8)] = acc_frag_3_1.z; + data0[wg_c_off + thread_c_off + (8 * N) + 1 + ( 1*8)] = acc_frag_3_1.w; + data0[wg_c_off + thread_c_off + 0 + ( 4*8)] = acc_frag_3_2.x; + data0[wg_c_off + thread_c_off + 1 + ( 4*8)] = acc_frag_3_2.y; + data0[wg_c_off + thread_c_off + (8 * N) + 0 + ( 4*8)] = acc_frag_3_2.z; + data0[wg_c_off + thread_c_off + (8 * N) + 1 + ( 4*8)] = acc_frag_3_2.w; + data0[wg_c_off + thread_c_off + 0 + ( 5*8)] = acc_frag_3_3.x; + data0[wg_c_off + thread_c_off + 1 + ( 5*8)] = acc_frag_3_3.y; + data0[wg_c_off + thread_c_off + (8 * N) + 0 + ( 5*8)] = acc_frag_3_3.z; + data0[wg_c_off + thread_c_off + (8 * N) + 1 + ( 5*8)] = acc_frag_3_3.w; + data0[wg_c_off + thread_c_off + 0 + ( 8*8)] = acc_frag_3_4.x; + data0[wg_c_off + thread_c_off + 1 + ( 8*8)] = acc_frag_3_4.y; + data0[wg_c_off + thread_c_off + (8 * N) + 0 + ( 8*8)] = acc_frag_3_4.z; + data0[wg_c_off + thread_c_off + (8 * N) + 1 + ( 8*8)] = acc_frag_3_4.w; + data0[wg_c_off + thread_c_off + 0 + ( 9*8)] = acc_frag_3_5.x; + data0[wg_c_off + thread_c_off + 1 + ( 9*8)] = acc_frag_3_5.y; + data0[wg_c_off + thread_c_off + (8 * N) + 0 + ( 9*8)] = acc_frag_3_5.z; + data0[wg_c_off + thread_c_off + (8 * N) + 1 + ( 9*8)] = acc_frag_3_5.w; + data0[wg_c_off + thread_c_off + 0 + (12*8)] = acc_frag_3_6.x; + data0[wg_c_off + thread_c_off + 1 + (12*8)] = acc_frag_3_6.y; + data0[wg_c_off + thread_c_off + (8 * N) + 0 + (12*8)] = acc_frag_3_6.z; + data0[wg_c_off + thread_c_off + (8 * N) + 1 + (12*8)] = acc_frag_3_6.w; + data0[wg_c_off + thread_c_off + 0 + (13*8)] = acc_frag_3_7.x; + data0[wg_c_off + thread_c_off + 1 + (13*8)] = acc_frag_3_7.y; + data0[wg_c_off + thread_c_off + (8 * N) + 0 + (13*8)] = acc_frag_3_7.z; + data0[wg_c_off + thread_c_off + (8 * N) + 1 + (13*8)] = acc_frag_3_7.w; + +} diff --git a/extra/gemm/max_kernels/nv.fp16_fp16_fp16.max.cu b/extra/gemm/max_kernels/nv.fp16_fp16_fp16.max.cu new file mode 100644 index 0000000000..588f6b3c08 --- /dev/null +++ b/extra/gemm/max_kernels/nv.fp16_fp16_fp16.max.cu @@ -0,0 +1,482 @@ +#define INFINITY (__int_as_float(0x7f800000)) +#define NAN (__int_as_float(0x7fffffff)) +#include +#include +#define SMEM_N_WIDTH 136 + +struct __align__(8) half4 { half x, y, z, w; }; +__device__ half4 make_half4(half x, half y, half z, half w) { half4 r={x, y, z, w}; return r; } + +struct __align__(16) half8 { half x, y, z, w, a, b, c, d; }; +__device__ half8 make_half8(half x, half y, half z, half w, half a, half b, half c, half d) { half8 r={x, y, z, w, a, b, c, d}; return r; } + +__device__ void __ldmatrix_a_elems(half8 *regs, half *smem) { + uint32_t reg0, reg1, reg2, reg3; + asm volatile( + "ldmatrix.sync.aligned.m8n8.x4.shared.b16 {%0, %1, %2, %3}, [%4];" + : "=r"(reg0), "=r"(reg1), "=r"(reg2), "=r"(reg3) + : "l"(__cvta_generic_to_shared(smem)) + ); + uint32_t *addr = reinterpret_cast(regs); + addr[0] = reg0; + addr[1] = reg1; + addr[2] = reg2; + addr[3] = reg3; +} + +__device__ void __ldmatrix_b_elems(half4 *regs_lo, half4 *regs_hi, half *smem) { + uint32_t reg0, reg1, reg2, reg3; + asm volatile( + "ldmatrix.sync.aligned.m8n8.x4.trans.shared.b16 {%0, %1, %2, %3}, [%4];" + : "=r"(reg0), "=r"(reg1), "=r"(reg2), "=r"(reg3) + : "l"(__cvta_generic_to_shared(smem)) + ); + uint32_t *addr_lo = reinterpret_cast(regs_lo); + uint32_t *addr_hi = reinterpret_cast(regs_hi); + addr_lo[0] = reg0; + addr_lo[1] = reg1; + addr_hi[0] = reg2; + addr_hi[1] = reg3; +} + +__device__ half4 __WMMA_8_16_16_half_half(half8 a, half4 b, half4 c) { + int *a_pk = (int *) (&a), *b_pk = (int *) (&b), *c_pk = (int *) (&c); + asm( "mma.sync.aligned.m16n8k16.row.col.f16.f16.f16.f16 { %0, %1 }, { %2, %3, %4, %5 }, { %6, %7 }, { %0, %1 };" + : "+r"(c_pk[0]), "+r"(c_pk[1]): "r"(a_pk[0]), "r"(a_pk[1]), "r"(a_pk[2]), "r"(a_pk[3]), "r"(b_pk[0]), "r"(b_pk[1]) ); + return c; +} + +extern "C" __global__ void __launch_bounds__(256) wmma_example(half* data0, const half* data1, const half* data2, int N, int K) { + extern __shared__ char smem[]; + half *smem_a_0 = (half *)(smem); + half *smem_a_1 = (half *)(smem + 16384); + half *smem_a_2 = (half *)(smem + 32768); + half *smem_b_0 = (half *)(smem + 49152); + half *smem_b_1 = (half *)(smem + 57344); + half *smem_b_2 = (half *)(smem + 65536); + + int grid_m = blockIdx.x; /* M//256 */ + int grid_n = blockIdx.y; /* N//128 */ + int wg_threads = threadIdx.x; // 32 + int wg_m = threadIdx.y; // 4 + int wg_n = threadIdx.z; // 2 + int threads = threadIdx.x + (threadIdx.y * 32) + (threadIdx.z * 128); /* 256 */ + int num_k_blocks = K / 32; + + // ldmatrix indices - 4x loads of 8x8 matrices by 32 threads + // threads 0-7 are row starts for A, 8-15 for B, 16-23 for C, 24-31 for D + // [ A | C ] + // [ - + - ] + // [ B | D ] + + // swizzled A - SMEM_A is 128 rows x 64 cols + size_t global_a_off = ((grid_m * 256) * K) + ((threads % 4) * 8) + (((threads / 4) % 2) * 8 * 16 * K) + ((threads / 8) * K); + size_t store_smem_a_off = ((threads / 8) * 64) + (((threads * 8) ^ threads) & 56); // 32 rows / 64 cols per copy + size_t load_smem_a_row = ((wg_m * 16) + (threads % 16)) * 64; + size_t load_smem_a_phase = (threads / 16) % 2; + size_t load_smem_a_0_k_0 = load_smem_a_row + ( 0 * 64) + (((load_smem_a_phase + 0) ^ (threads % 8)) * 8); + size_t load_smem_a_1_k_0 = load_smem_a_row + (64 * 64) + (((load_smem_a_phase + 0) ^ (threads % 8)) * 8); + size_t load_smem_a_2_k_0 = load_smem_a_row + ( 0 * 64) + (((load_smem_a_phase + 4) ^ (threads % 8)) * 8); + size_t load_smem_a_3_k_0 = load_smem_a_row + (64 * 64) + (((load_smem_a_phase + 4) ^ (threads % 8)) * 8); + size_t load_smem_a_0_k_1 = load_smem_a_row + ( 0 * 64) + (((load_smem_a_phase + 2) ^ (threads % 8)) * 8); + size_t load_smem_a_1_k_1 = load_smem_a_row + (64 * 64) + (((load_smem_a_phase + 2) ^ (threads % 8)) * 8); + size_t load_smem_a_2_k_1 = load_smem_a_row + ( 0 * 64) + (((load_smem_a_phase + 6) ^ (threads % 8)) * 8); + size_t load_smem_a_3_k_1 = load_smem_a_row + (64 * 64) + (((load_smem_a_phase + 6) ^ (threads % 8)) * 8); + + // swizzled B - SMEM_B is 32 rows x 128 cols + size_t global_b_off = (grid_n * 128) + ((threads % 16) * 8) + ((threads / 16) * N); + size_t store_smem_b_off = ((threads / 16) * 128) + ((((threads / 16) % 8) * 8) ^ ((threads % 16) * 8)); // 16 rows / 128 cols per copy + size_t load_smem_b_row = (threads % 16) * 128; + size_t load_smem_b_phase = (wg_n * 2) + (wg_threads / 16); + size_t load_smem_b_0_k_0 = load_smem_b_row + (((load_smem_b_phase + 0) ^ (threads % 8)) * 8); + size_t load_smem_b_1_k_0 = load_smem_b_row + (((load_smem_b_phase + 4) ^ (threads % 8)) * 8); + size_t load_smem_b_2_k_0 = load_smem_b_row + (((load_smem_b_phase + 8) ^ (threads % 8)) * 8); + size_t load_smem_b_3_k_0 = load_smem_b_row + (((load_smem_b_phase + 12) ^ (threads % 8)) * 8); + size_t load_smem_b_0_k_1 = load_smem_b_0_k_0 + (16 * 128); + size_t load_smem_b_1_k_1 = load_smem_b_1_k_0 + (16 * 128); + size_t load_smem_b_2_k_1 = load_smem_b_2_k_0 + (16 * 128); + size_t load_smem_b_3_k_1 = load_smem_b_3_k_0 + (16 * 128); + + // create accs (M=4, N=8) + half4 acc_frag_0_0 = make_half4(0.0f,0.0f,0.0f,0.0f); + half4 acc_frag_0_1 = make_half4(0.0f,0.0f,0.0f,0.0f); + half4 acc_frag_0_2 = make_half4(0.0f,0.0f,0.0f,0.0f); + half4 acc_frag_0_3 = make_half4(0.0f,0.0f,0.0f,0.0f); + half4 acc_frag_0_4 = make_half4(0.0f,0.0f,0.0f,0.0f); + half4 acc_frag_0_5 = make_half4(0.0f,0.0f,0.0f,0.0f); + half4 acc_frag_0_6 = make_half4(0.0f,0.0f,0.0f,0.0f); + half4 acc_frag_0_7 = make_half4(0.0f,0.0f,0.0f,0.0f); + half4 acc_frag_1_0 = make_half4(0.0f,0.0f,0.0f,0.0f); + half4 acc_frag_1_1 = make_half4(0.0f,0.0f,0.0f,0.0f); + half4 acc_frag_1_2 = make_half4(0.0f,0.0f,0.0f,0.0f); + half4 acc_frag_1_3 = make_half4(0.0f,0.0f,0.0f,0.0f); + half4 acc_frag_1_4 = make_half4(0.0f,0.0f,0.0f,0.0f); + half4 acc_frag_1_5 = make_half4(0.0f,0.0f,0.0f,0.0f); + half4 acc_frag_1_6 = make_half4(0.0f,0.0f,0.0f,0.0f); + half4 acc_frag_1_7 = make_half4(0.0f,0.0f,0.0f,0.0f); + half4 acc_frag_2_0 = make_half4(0.0f,0.0f,0.0f,0.0f); + half4 acc_frag_2_1 = make_half4(0.0f,0.0f,0.0f,0.0f); + half4 acc_frag_2_2 = make_half4(0.0f,0.0f,0.0f,0.0f); + half4 acc_frag_2_3 = make_half4(0.0f,0.0f,0.0f,0.0f); + half4 acc_frag_2_4 = make_half4(0.0f,0.0f,0.0f,0.0f); + half4 acc_frag_2_5 = make_half4(0.0f,0.0f,0.0f,0.0f); + half4 acc_frag_2_6 = make_half4(0.0f,0.0f,0.0f,0.0f); + half4 acc_frag_2_7 = make_half4(0.0f,0.0f,0.0f,0.0f); + half4 acc_frag_3_0 = make_half4(0.0f,0.0f,0.0f,0.0f); + half4 acc_frag_3_1 = make_half4(0.0f,0.0f,0.0f,0.0f); + half4 acc_frag_3_2 = make_half4(0.0f,0.0f,0.0f,0.0f); + half4 acc_frag_3_3 = make_half4(0.0f,0.0f,0.0f,0.0f); + half4 acc_frag_3_4 = make_half4(0.0f,0.0f,0.0f,0.0f); + half4 acc_frag_3_5 = make_half4(0.0f,0.0f,0.0f,0.0f); + half4 acc_frag_3_6 = make_half4(0.0f,0.0f,0.0f,0.0f); + half4 acc_frag_3_7 = make_half4(0.0f,0.0f,0.0f,0.0f); + + // create registers for block A elements + half8 a_frag_0_k_0; + half8 a_frag_1_k_0; + half8 a_frag_2_k_0; + half8 a_frag_3_k_0; + half8 a_frag_0_k_1; + half8 a_frag_1_k_1; + half8 a_frag_2_k_1; + half8 a_frag_3_k_1; + + // create register for block B elements + half4 b_frag_0_k_0; + half4 b_frag_1_k_0; + half4 b_frag_2_k_0; + half4 b_frag_3_k_0; + half4 b_frag_4_k_0; + half4 b_frag_5_k_0; + half4 b_frag_6_k_0; + half4 b_frag_7_k_0; + half4 b_frag_0_k_1; + half4 b_frag_1_k_1; + half4 b_frag_2_k_1; + half4 b_frag_3_k_1; + half4 b_frag_4_k_1; + half4 b_frag_5_k_1; + half4 b_frag_6_k_1; + half4 b_frag_7_k_1; + + __syncthreads(); + + // load first tile + __pipeline_memcpy_async(&smem_a_0[store_smem_a_off + ( 0)], &data1[global_a_off + ( 0)], 16); + __pipeline_memcpy_async(&smem_a_0[store_smem_a_off + ( 32*64)], &data1[global_a_off + ( 32*K)], 16); + __pipeline_memcpy_async(&smem_a_0[store_smem_a_off + ( 64*64)], &data1[global_a_off + ( 64*K)], 16); + __pipeline_memcpy_async(&smem_a_0[store_smem_a_off + ( 96*64)], &data1[global_a_off + ( 96*K)], 16); + __pipeline_memcpy_async(&smem_b_0[store_smem_b_off + ( 0)], &data2[global_b_off + ( 0)], 16); + __pipeline_memcpy_async(&smem_b_0[store_smem_b_off + (16*128)], &data2[global_b_off + ( 16*N)], 16); + __pipeline_commit(); + global_a_off += 32; + global_b_off += 32 * N; + + // load second tile + __pipeline_memcpy_async(&smem_a_1[store_smem_a_off + ( 0)], &data1[global_a_off + ( 0)], 16); + __pipeline_memcpy_async(&smem_a_1[store_smem_a_off + ( 32*64)], &data1[global_a_off + ( 32*K)], 16); + __pipeline_memcpy_async(&smem_a_1[store_smem_a_off + ( 64*64)], &data1[global_a_off + ( 64*K)], 16); + __pipeline_memcpy_async(&smem_a_1[store_smem_a_off + ( 96*64)], &data1[global_a_off + ( 96*K)], 16); + __pipeline_memcpy_async(&smem_b_1[store_smem_b_off + ( 0)], &data2[global_b_off + ( 0)], 16); + __pipeline_memcpy_async(&smem_b_1[store_smem_b_off + (16*128)], &data2[global_b_off + ( 16*N)], 16); + __pipeline_commit(); + + global_a_off += 32; + global_b_off += 32 * N; + + // wait on first pre-fetch load + __pipeline_wait_prior(1); + __syncthreads(); + + // load K=0 elements for the first tile + __ldmatrix_a_elems(&a_frag_0_k_0, &smem_a_0[load_smem_a_0_k_0]); + __ldmatrix_a_elems(&a_frag_1_k_0, &smem_a_0[load_smem_a_1_k_0]); + __ldmatrix_a_elems(&a_frag_2_k_0, &smem_a_0[load_smem_a_2_k_0]); + __ldmatrix_a_elems(&a_frag_3_k_0, &smem_a_0[load_smem_a_3_k_0]); + __ldmatrix_b_elems(&b_frag_0_k_0, &b_frag_1_k_0, &smem_b_0[load_smem_b_0_k_0]); + __ldmatrix_b_elems(&b_frag_2_k_0, &b_frag_3_k_0, &smem_b_0[load_smem_b_1_k_0]); + __ldmatrix_b_elems(&b_frag_4_k_0, &b_frag_5_k_0, &smem_b_0[load_smem_b_2_k_0]); + __ldmatrix_b_elems(&b_frag_6_k_0, &b_frag_7_k_0, &smem_b_0[load_smem_b_3_k_0]); + + for (int block_k = 0; block_k < num_k_blocks; block_k++) { + int phase_k = block_k % 3; + half *smem_a_curr = (phase_k == 0) ? smem_a_0 : ((phase_k == 1) ? smem_a_1 : smem_a_2); + half *smem_b_curr = (phase_k == 0) ? smem_b_0 : ((phase_k == 1) ? smem_b_1 : smem_b_2); + + int next_phase_k = (block_k+1) % 3; + half *smem_a_next = (next_phase_k == 0) ? smem_a_0 : ((next_phase_k == 1) ? smem_a_1 : smem_a_2); + half *smem_b_next = (next_phase_k == 0) ? smem_b_0 : ((next_phase_k == 1) ? smem_b_1 : smem_b_2); + + int store_phase_k = (block_k+2) % 3; + half *smem_a_store = (store_phase_k == 0) ? smem_a_0 : ((store_phase_k == 1) ? smem_a_1 : smem_a_2); + half *smem_b_store = (store_phase_k == 0) ? smem_b_0 : ((store_phase_k == 1) ? smem_b_1 : smem_b_2); + + // load K=1 elements for the current tile + __ldmatrix_a_elems(&a_frag_0_k_1, &smem_a_curr[load_smem_a_0_k_1]); + __ldmatrix_a_elems(&a_frag_1_k_1, &smem_a_curr[load_smem_a_1_k_1]); + __ldmatrix_a_elems(&a_frag_2_k_1, &smem_a_curr[load_smem_a_2_k_1]); + __ldmatrix_a_elems(&a_frag_3_k_1, &smem_a_curr[load_smem_a_3_k_1]); + __ldmatrix_b_elems(&b_frag_0_k_1, &b_frag_1_k_1, &smem_b_curr[load_smem_b_0_k_1]); + __ldmatrix_b_elems(&b_frag_2_k_1, &b_frag_3_k_1, &smem_b_curr[load_smem_b_1_k_1]); + __ldmatrix_b_elems(&b_frag_4_k_1, &b_frag_5_k_1, &smem_b_curr[load_smem_b_2_k_1]); + __ldmatrix_b_elems(&b_frag_6_k_1, &b_frag_7_k_1, &smem_b_curr[load_smem_b_3_k_1]); + + // MMA K=0, (M=4 x N=8) + acc_frag_0_0 = __WMMA_8_16_16_half_half(a_frag_0_k_0, b_frag_0_k_0, acc_frag_0_0); + acc_frag_0_1 = __WMMA_8_16_16_half_half(a_frag_0_k_0, b_frag_1_k_0, acc_frag_0_1); + acc_frag_0_2 = __WMMA_8_16_16_half_half(a_frag_0_k_0, b_frag_2_k_0, acc_frag_0_2); + acc_frag_0_3 = __WMMA_8_16_16_half_half(a_frag_0_k_0, b_frag_3_k_0, acc_frag_0_3); + acc_frag_0_4 = __WMMA_8_16_16_half_half(a_frag_0_k_0, b_frag_4_k_0, acc_frag_0_4); + acc_frag_0_5 = __WMMA_8_16_16_half_half(a_frag_0_k_0, b_frag_5_k_0, acc_frag_0_5); + acc_frag_0_6 = __WMMA_8_16_16_half_half(a_frag_0_k_0, b_frag_6_k_0, acc_frag_0_6); + acc_frag_0_7 = __WMMA_8_16_16_half_half(a_frag_0_k_0, b_frag_7_k_0, acc_frag_0_7); + acc_frag_1_0 = __WMMA_8_16_16_half_half(a_frag_1_k_0, b_frag_0_k_0, acc_frag_1_0); + acc_frag_1_1 = __WMMA_8_16_16_half_half(a_frag_1_k_0, b_frag_1_k_0, acc_frag_1_1); + acc_frag_1_2 = __WMMA_8_16_16_half_half(a_frag_1_k_0, b_frag_2_k_0, acc_frag_1_2); + acc_frag_1_3 = __WMMA_8_16_16_half_half(a_frag_1_k_0, b_frag_3_k_0, acc_frag_1_3); + acc_frag_1_4 = __WMMA_8_16_16_half_half(a_frag_1_k_0, b_frag_4_k_0, acc_frag_1_4); + acc_frag_1_5 = __WMMA_8_16_16_half_half(a_frag_1_k_0, b_frag_5_k_0, acc_frag_1_5); + acc_frag_1_6 = __WMMA_8_16_16_half_half(a_frag_1_k_0, b_frag_6_k_0, acc_frag_1_6); + acc_frag_1_7 = __WMMA_8_16_16_half_half(a_frag_1_k_0, b_frag_7_k_0, acc_frag_1_7); + acc_frag_2_0 = __WMMA_8_16_16_half_half(a_frag_2_k_0, b_frag_0_k_0, acc_frag_2_0); + acc_frag_2_1 = __WMMA_8_16_16_half_half(a_frag_2_k_0, b_frag_1_k_0, acc_frag_2_1); + acc_frag_2_2 = __WMMA_8_16_16_half_half(a_frag_2_k_0, b_frag_2_k_0, acc_frag_2_2); + acc_frag_2_3 = __WMMA_8_16_16_half_half(a_frag_2_k_0, b_frag_3_k_0, acc_frag_2_3); + acc_frag_2_4 = __WMMA_8_16_16_half_half(a_frag_2_k_0, b_frag_4_k_0, acc_frag_2_4); + acc_frag_2_5 = __WMMA_8_16_16_half_half(a_frag_2_k_0, b_frag_5_k_0, acc_frag_2_5); + acc_frag_2_6 = __WMMA_8_16_16_half_half(a_frag_2_k_0, b_frag_6_k_0, acc_frag_2_6); + acc_frag_2_7 = __WMMA_8_16_16_half_half(a_frag_2_k_0, b_frag_7_k_0, acc_frag_2_7); + acc_frag_3_0 = __WMMA_8_16_16_half_half(a_frag_3_k_0, b_frag_0_k_0, acc_frag_3_0); + acc_frag_3_1 = __WMMA_8_16_16_half_half(a_frag_3_k_0, b_frag_1_k_0, acc_frag_3_1); + acc_frag_3_2 = __WMMA_8_16_16_half_half(a_frag_3_k_0, b_frag_2_k_0, acc_frag_3_2); + acc_frag_3_3 = __WMMA_8_16_16_half_half(a_frag_3_k_0, b_frag_3_k_0, acc_frag_3_3); + acc_frag_3_4 = __WMMA_8_16_16_half_half(a_frag_3_k_0, b_frag_4_k_0, acc_frag_3_4); + acc_frag_3_5 = __WMMA_8_16_16_half_half(a_frag_3_k_0, b_frag_5_k_0, acc_frag_3_5); + acc_frag_3_6 = __WMMA_8_16_16_half_half(a_frag_3_k_0, b_frag_6_k_0, acc_frag_3_6); + acc_frag_3_7 = __WMMA_8_16_16_half_half(a_frag_3_k_0, b_frag_7_k_0, acc_frag_3_7); + + // load next tile if needed + if (block_k < (num_k_blocks-2)) { + __pipeline_memcpy_async(&smem_a_store[store_smem_a_off + ( 0)], &data1[global_a_off + ( 0)], 16); + __pipeline_memcpy_async(&smem_a_store[store_smem_a_off + ( 32*64)], &data1[global_a_off + ( 32*K)], 16); + __pipeline_memcpy_async(&smem_a_store[store_smem_a_off + ( 64*64)], &data1[global_a_off + ( 64*K)], 16); + __pipeline_memcpy_async(&smem_a_store[store_smem_a_off + ( 96*64)], &data1[global_a_off + ( 96*K)], 16); + __pipeline_memcpy_async(&smem_b_store[store_smem_b_off + ( 0)], &data2[global_b_off + ( 0)], 16); + __pipeline_memcpy_async(&smem_b_store[store_smem_b_off + (16*128)], &data2[global_b_off + ( 16*N)], 16); + global_a_off += 32; + global_b_off += 32 * N; + } + __pipeline_commit(); + + // wait next tile + __pipeline_wait_prior(1); + __syncthreads(); + + // load K=0 elements for the next tile + __ldmatrix_a_elems(&a_frag_0_k_0, &smem_a_next[load_smem_a_0_k_0]); + __ldmatrix_a_elems(&a_frag_1_k_0, &smem_a_next[load_smem_a_1_k_0]); + __ldmatrix_a_elems(&a_frag_2_k_0, &smem_a_next[load_smem_a_2_k_0]); + __ldmatrix_a_elems(&a_frag_3_k_0, &smem_a_next[load_smem_a_3_k_0]); + __ldmatrix_b_elems(&b_frag_0_k_0, &b_frag_1_k_0, &smem_b_next[load_smem_b_0_k_0]); + __ldmatrix_b_elems(&b_frag_2_k_0, &b_frag_3_k_0, &smem_b_next[load_smem_b_1_k_0]); + __ldmatrix_b_elems(&b_frag_4_k_0, &b_frag_5_k_0, &smem_b_next[load_smem_b_2_k_0]); + __ldmatrix_b_elems(&b_frag_6_k_0, &b_frag_7_k_0, &smem_b_next[load_smem_b_3_k_0]); + + // MMA K=1, (M=4 x N=8) + acc_frag_0_0 = __WMMA_8_16_16_half_half(a_frag_0_k_1, b_frag_0_k_1, acc_frag_0_0); + acc_frag_0_1 = __WMMA_8_16_16_half_half(a_frag_0_k_1, b_frag_1_k_1, acc_frag_0_1); + acc_frag_0_2 = __WMMA_8_16_16_half_half(a_frag_0_k_1, b_frag_2_k_1, acc_frag_0_2); + acc_frag_0_3 = __WMMA_8_16_16_half_half(a_frag_0_k_1, b_frag_3_k_1, acc_frag_0_3); + acc_frag_0_4 = __WMMA_8_16_16_half_half(a_frag_0_k_1, b_frag_4_k_1, acc_frag_0_4); + acc_frag_0_5 = __WMMA_8_16_16_half_half(a_frag_0_k_1, b_frag_5_k_1, acc_frag_0_5); + acc_frag_0_6 = __WMMA_8_16_16_half_half(a_frag_0_k_1, b_frag_6_k_1, acc_frag_0_6); + acc_frag_0_7 = __WMMA_8_16_16_half_half(a_frag_0_k_1, b_frag_7_k_1, acc_frag_0_7); + acc_frag_1_0 = __WMMA_8_16_16_half_half(a_frag_1_k_1, b_frag_0_k_1, acc_frag_1_0); + acc_frag_1_1 = __WMMA_8_16_16_half_half(a_frag_1_k_1, b_frag_1_k_1, acc_frag_1_1); + acc_frag_1_2 = __WMMA_8_16_16_half_half(a_frag_1_k_1, b_frag_2_k_1, acc_frag_1_2); + acc_frag_1_3 = __WMMA_8_16_16_half_half(a_frag_1_k_1, b_frag_3_k_1, acc_frag_1_3); + acc_frag_1_4 = __WMMA_8_16_16_half_half(a_frag_1_k_1, b_frag_4_k_1, acc_frag_1_4); + acc_frag_1_5 = __WMMA_8_16_16_half_half(a_frag_1_k_1, b_frag_5_k_1, acc_frag_1_5); + acc_frag_1_6 = __WMMA_8_16_16_half_half(a_frag_1_k_1, b_frag_6_k_1, acc_frag_1_6); + acc_frag_1_7 = __WMMA_8_16_16_half_half(a_frag_1_k_1, b_frag_7_k_1, acc_frag_1_7); + acc_frag_2_0 = __WMMA_8_16_16_half_half(a_frag_2_k_1, b_frag_0_k_1, acc_frag_2_0); + acc_frag_2_1 = __WMMA_8_16_16_half_half(a_frag_2_k_1, b_frag_1_k_1, acc_frag_2_1); + acc_frag_2_2 = __WMMA_8_16_16_half_half(a_frag_2_k_1, b_frag_2_k_1, acc_frag_2_2); + acc_frag_2_3 = __WMMA_8_16_16_half_half(a_frag_2_k_1, b_frag_3_k_1, acc_frag_2_3); + acc_frag_2_4 = __WMMA_8_16_16_half_half(a_frag_2_k_1, b_frag_4_k_1, acc_frag_2_4); + acc_frag_2_5 = __WMMA_8_16_16_half_half(a_frag_2_k_1, b_frag_5_k_1, acc_frag_2_5); + acc_frag_2_6 = __WMMA_8_16_16_half_half(a_frag_2_k_1, b_frag_6_k_1, acc_frag_2_6); + acc_frag_2_7 = __WMMA_8_16_16_half_half(a_frag_2_k_1, b_frag_7_k_1, acc_frag_2_7); + acc_frag_3_0 = __WMMA_8_16_16_half_half(a_frag_3_k_1, b_frag_0_k_1, acc_frag_3_0); + acc_frag_3_1 = __WMMA_8_16_16_half_half(a_frag_3_k_1, b_frag_1_k_1, acc_frag_3_1); + acc_frag_3_2 = __WMMA_8_16_16_half_half(a_frag_3_k_1, b_frag_2_k_1, acc_frag_3_2); + acc_frag_3_3 = __WMMA_8_16_16_half_half(a_frag_3_k_1, b_frag_3_k_1, acc_frag_3_3); + acc_frag_3_4 = __WMMA_8_16_16_half_half(a_frag_3_k_1, b_frag_4_k_1, acc_frag_3_4); + acc_frag_3_5 = __WMMA_8_16_16_half_half(a_frag_3_k_1, b_frag_5_k_1, acc_frag_3_5); + acc_frag_3_6 = __WMMA_8_16_16_half_half(a_frag_3_k_1, b_frag_6_k_1, acc_frag_3_6); + acc_frag_3_7 = __WMMA_8_16_16_half_half(a_frag_3_k_1, b_frag_7_k_1, acc_frag_3_7); + } + + // write accumulators to output + __pipeline_wait_prior(0); + __syncthreads(); + + // faster epilogue: write each 8x8 TC accs to SMEM first + // - SMEM_N_WIDTH 8 larger than 128 required to deconflict bank access + // - around 14 micros + // - check bank conflict with in sudo with: "PYTHONPATH=. CUDA=1 GEMM_VARIATION="max" DTYPE_IN=half DTYPE_OUT=half DTYPE_ACC=half CNT=8 INPUT=ONES /usr/local/cuda/bin/ncu --section MemoryWorkloadAnalysis --metrics l1tex__data_bank_conflicts_pipe_lsu_mem_shared_op_ld.sum,l1tex__data_bank_conflicts_pipe_lsu_mem_shared_op_st.sum python3 ./extra/gemm/max_matmul.py" + + // epilogue chunk with 256 threads / WG_M=4 / WG_N=2: split into 8 chunks (hi/lo for each in TC M) + // 1) write 32 rows of 128 cols (rows 0-7, 16-23, 32-39, 48-53 in acc_frag_0.lo, then acc_frag_0.hi, etc.) + // 2) read/write 16 rows of 128 elements in 8 elem (16B) chunks + + half2 *smem32_d = (half2 *)(smem); + half8 *smem128_d = (half8 *)(smem); + half8 *out128_d = (half8 *)(data0); + size_t smem32_d_write_off = (wg_m * 8 * (SMEM_N_WIDTH / 2)) + (wg_n * (16 / 2)); + size_t smem32_d_thread_off = ((wg_threads / 4) * (SMEM_N_WIDTH / 2)) + (wg_threads % 4); + size_t smem128_d_read_off = ((threads / 16) * (SMEM_N_WIDTH / 8)) + (threads % 16); + size_t out128_d_off = ((grid_m * 256) * (N / 8)) + (grid_n * (128 / 8)) + + ((threads / 128) * 16 * (N / 8)) + (((threads / 16) % 8) * (N / 8)) + (threads % 16); + + // write acc_frag_0_* + + // write 32 rows of 128 N elements to SMEM + __syncthreads(); + smem32_d[smem32_d_write_off + smem32_d_thread_off + ( 0*4)] = half2(acc_frag_0_0.x, acc_frag_0_0.y); + smem32_d[smem32_d_write_off + smem32_d_thread_off + ( 1*4)] = half2(acc_frag_0_1.x, acc_frag_0_1.y); + smem32_d[smem32_d_write_off + smem32_d_thread_off + ( 4*4)] = half2(acc_frag_0_2.x, acc_frag_0_2.y); + smem32_d[smem32_d_write_off + smem32_d_thread_off + ( 5*4)] = half2(acc_frag_0_3.x, acc_frag_0_3.y); + smem32_d[smem32_d_write_off + smem32_d_thread_off + ( 8*4)] = half2(acc_frag_0_4.x, acc_frag_0_4.y); + smem32_d[smem32_d_write_off + smem32_d_thread_off + ( 9*4)] = half2(acc_frag_0_5.x, acc_frag_0_5.y); + smem32_d[smem32_d_write_off + smem32_d_thread_off + (12*4)] = half2(acc_frag_0_6.x, acc_frag_0_6.y); + smem32_d[smem32_d_write_off + smem32_d_thread_off + (13*4)] = half2(acc_frag_0_7.x, acc_frag_0_7.y); + + // each thread reads and writes two 8 element chunks + __syncthreads(); + out128_d[out128_d_off + ( 0 * (N / 8))] = smem128_d[smem128_d_read_off]; + out128_d[out128_d_off + (32 * (N / 8))] = smem128_d[smem128_d_read_off + (16 * (SMEM_N_WIDTH / 8))]; + + // write 32 rows of 128 N elements to SMEM + __syncthreads(); + smem32_d[smem32_d_write_off + smem32_d_thread_off + ( 0*4)] = half2(acc_frag_0_0.z, acc_frag_0_0.w); + smem32_d[smem32_d_write_off + smem32_d_thread_off + ( 1*4)] = half2(acc_frag_0_1.z, acc_frag_0_1.w); + smem32_d[smem32_d_write_off + smem32_d_thread_off + ( 4*4)] = half2(acc_frag_0_2.z, acc_frag_0_2.w); + smem32_d[smem32_d_write_off + smem32_d_thread_off + ( 5*4)] = half2(acc_frag_0_3.z, acc_frag_0_3.w); + smem32_d[smem32_d_write_off + smem32_d_thread_off + ( 8*4)] = half2(acc_frag_0_4.z, acc_frag_0_4.w); + smem32_d[smem32_d_write_off + smem32_d_thread_off + ( 9*4)] = half2(acc_frag_0_5.z, acc_frag_0_5.w); + smem32_d[smem32_d_write_off + smem32_d_thread_off + (12*4)] = half2(acc_frag_0_6.z, acc_frag_0_6.w); + smem32_d[smem32_d_write_off + smem32_d_thread_off + (13*4)] = half2(acc_frag_0_7.z, acc_frag_0_7.w); + + // each thread reads and writes two 8 element chunks + __syncthreads(); + out128_d[out128_d_off + ( 8 * (N / 8))] = smem128_d[smem128_d_read_off]; + out128_d[out128_d_off + (40 * (N / 8))] = smem128_d[smem128_d_read_off + (16 * (SMEM_N_WIDTH / 8))]; + + // write acc_frag_1_* + out128_d_off += (64 * (N / 8)); + + // write 32 rows of 128 N elements to SMEM + __syncthreads(); + smem32_d[smem32_d_write_off + smem32_d_thread_off + ( 0*4)] = half2(acc_frag_1_0.x, acc_frag_1_0.y); + smem32_d[smem32_d_write_off + smem32_d_thread_off + ( 1*4)] = half2(acc_frag_1_1.x, acc_frag_1_1.y); + smem32_d[smem32_d_write_off + smem32_d_thread_off + ( 4*4)] = half2(acc_frag_1_2.x, acc_frag_1_2.y); + smem32_d[smem32_d_write_off + smem32_d_thread_off + ( 5*4)] = half2(acc_frag_1_3.x, acc_frag_1_3.y); + smem32_d[smem32_d_write_off + smem32_d_thread_off + ( 8*4)] = half2(acc_frag_1_4.x, acc_frag_1_4.y); + smem32_d[smem32_d_write_off + smem32_d_thread_off + ( 9*4)] = half2(acc_frag_1_5.x, acc_frag_1_5.y); + smem32_d[smem32_d_write_off + smem32_d_thread_off + (12*4)] = half2(acc_frag_1_6.x, acc_frag_1_6.y); + smem32_d[smem32_d_write_off + smem32_d_thread_off + (13*4)] = half2(acc_frag_1_7.x, acc_frag_1_7.y); + + // each thread reads and writes two 8 element chunks + __syncthreads(); + out128_d[out128_d_off + ( 0 * (N / 8))] = smem128_d[smem128_d_read_off]; + out128_d[out128_d_off + (32 * (N / 8))] = smem128_d[smem128_d_read_off + (16 * (SMEM_N_WIDTH / 8))]; + + // write 32 rows of 128 N elements to SMEM + __syncthreads(); + smem32_d[smem32_d_write_off + smem32_d_thread_off + ( 0*4)] = half2(acc_frag_1_0.z, acc_frag_1_0.w); + smem32_d[smem32_d_write_off + smem32_d_thread_off + ( 1*4)] = half2(acc_frag_1_1.z, acc_frag_1_1.w); + smem32_d[smem32_d_write_off + smem32_d_thread_off + ( 4*4)] = half2(acc_frag_1_2.z, acc_frag_1_2.w); + smem32_d[smem32_d_write_off + smem32_d_thread_off + ( 5*4)] = half2(acc_frag_1_3.z, acc_frag_1_3.w); + smem32_d[smem32_d_write_off + smem32_d_thread_off + ( 8*4)] = half2(acc_frag_1_4.z, acc_frag_1_4.w); + smem32_d[smem32_d_write_off + smem32_d_thread_off + ( 9*4)] = half2(acc_frag_1_5.z, acc_frag_1_5.w); + smem32_d[smem32_d_write_off + smem32_d_thread_off + (12*4)] = half2(acc_frag_1_6.z, acc_frag_1_6.w); + smem32_d[smem32_d_write_off + smem32_d_thread_off + (13*4)] = half2(acc_frag_1_7.z, acc_frag_1_7.w); + + // each thread reads and writes two 8 element chunks + __syncthreads(); + out128_d[out128_d_off + ( 8 * (N / 8))] = smem128_d[smem128_d_read_off]; + out128_d[out128_d_off + (40 * (N / 8))] = smem128_d[smem128_d_read_off + (16 * (SMEM_N_WIDTH / 8))]; + + // write acc_frag_2_* + out128_d_off += (64 * (N / 8)); + + // write 32 rows of 128 N elements to SMEM + __syncthreads(); + smem32_d[smem32_d_write_off + smem32_d_thread_off + ( 0*4)] = half2(acc_frag_2_0.x, acc_frag_2_0.y); + smem32_d[smem32_d_write_off + smem32_d_thread_off + ( 1*4)] = half2(acc_frag_2_1.x, acc_frag_2_1.y); + smem32_d[smem32_d_write_off + smem32_d_thread_off + ( 4*4)] = half2(acc_frag_2_2.x, acc_frag_2_2.y); + smem32_d[smem32_d_write_off + smem32_d_thread_off + ( 5*4)] = half2(acc_frag_2_3.x, acc_frag_2_3.y); + smem32_d[smem32_d_write_off + smem32_d_thread_off + ( 8*4)] = half2(acc_frag_2_4.x, acc_frag_2_4.y); + smem32_d[smem32_d_write_off + smem32_d_thread_off + ( 9*4)] = half2(acc_frag_2_5.x, acc_frag_2_5.y); + smem32_d[smem32_d_write_off + smem32_d_thread_off + (12*4)] = half2(acc_frag_2_6.x, acc_frag_2_6.y); + smem32_d[smem32_d_write_off + smem32_d_thread_off + (13*4)] = half2(acc_frag_2_7.x, acc_frag_2_7.y); + + // each thread reads and writes two 8 element chunks + __syncthreads(); + out128_d[out128_d_off + ( 0 * (N / 8))] = smem128_d[smem128_d_read_off]; + out128_d[out128_d_off + (32 * (N / 8))] = smem128_d[smem128_d_read_off + (16 * (SMEM_N_WIDTH / 8))]; + + // write 32 rows of 128 N elements to SMEM + __syncthreads(); + smem32_d[smem32_d_write_off + smem32_d_thread_off + ( 0*4)] = half2(acc_frag_2_0.z, acc_frag_2_0.w); + smem32_d[smem32_d_write_off + smem32_d_thread_off + ( 1*4)] = half2(acc_frag_2_1.z, acc_frag_2_1.w); + smem32_d[smem32_d_write_off + smem32_d_thread_off + ( 4*4)] = half2(acc_frag_2_2.z, acc_frag_2_2.w); + smem32_d[smem32_d_write_off + smem32_d_thread_off + ( 5*4)] = half2(acc_frag_2_3.z, acc_frag_2_3.w); + smem32_d[smem32_d_write_off + smem32_d_thread_off + ( 8*4)] = half2(acc_frag_2_4.z, acc_frag_2_4.w); + smem32_d[smem32_d_write_off + smem32_d_thread_off + ( 9*4)] = half2(acc_frag_2_5.z, acc_frag_2_5.w); + smem32_d[smem32_d_write_off + smem32_d_thread_off + (12*4)] = half2(acc_frag_2_6.z, acc_frag_2_6.w); + smem32_d[smem32_d_write_off + smem32_d_thread_off + (13*4)] = half2(acc_frag_2_7.z, acc_frag_2_7.w); + + // each thread reads and writes two 8 element chunks + __syncthreads(); + out128_d[out128_d_off + ( 8 * (N / 8))] = smem128_d[smem128_d_read_off]; + out128_d[out128_d_off + (40 * (N / 8))] = smem128_d[smem128_d_read_off + (16 * (SMEM_N_WIDTH / 8))]; + + // write acc_frag_3_* + out128_d_off += (64 * (N / 8)); + + // write 32 rows of 128 N elements to SMEM + __syncthreads(); + smem32_d[smem32_d_write_off + smem32_d_thread_off + ( 0*4)] = half2(acc_frag_3_0.x, acc_frag_3_0.y); + smem32_d[smem32_d_write_off + smem32_d_thread_off + ( 1*4)] = half2(acc_frag_3_1.x, acc_frag_3_1.y); + smem32_d[smem32_d_write_off + smem32_d_thread_off + ( 4*4)] = half2(acc_frag_3_2.x, acc_frag_3_2.y); + smem32_d[smem32_d_write_off + smem32_d_thread_off + ( 5*4)] = half2(acc_frag_3_3.x, acc_frag_3_3.y); + smem32_d[smem32_d_write_off + smem32_d_thread_off + ( 8*4)] = half2(acc_frag_3_4.x, acc_frag_3_4.y); + smem32_d[smem32_d_write_off + smem32_d_thread_off + ( 9*4)] = half2(acc_frag_3_5.x, acc_frag_3_5.y); + smem32_d[smem32_d_write_off + smem32_d_thread_off + (12*4)] = half2(acc_frag_3_6.x, acc_frag_3_6.y); + smem32_d[smem32_d_write_off + smem32_d_thread_off + (13*4)] = half2(acc_frag_3_7.x, acc_frag_3_7.y); + + // each thread reads and writes two 8 element chunks + __syncthreads(); + out128_d[out128_d_off + ( 0 * (N / 8))] = smem128_d[smem128_d_read_off]; + out128_d[out128_d_off + (32 * (N / 8))] = smem128_d[smem128_d_read_off + (16 * (SMEM_N_WIDTH / 8))]; + + // write 32 rows of 128 N elements to SMEM + __syncthreads(); + smem32_d[smem32_d_write_off + smem32_d_thread_off + ( 0*4)] = half2(acc_frag_3_0.z, acc_frag_3_0.w); + smem32_d[smem32_d_write_off + smem32_d_thread_off + ( 1*4)] = half2(acc_frag_3_1.z, acc_frag_3_1.w); + smem32_d[smem32_d_write_off + smem32_d_thread_off + ( 4*4)] = half2(acc_frag_3_2.z, acc_frag_3_2.w); + smem32_d[smem32_d_write_off + smem32_d_thread_off + ( 5*4)] = half2(acc_frag_3_3.z, acc_frag_3_3.w); + smem32_d[smem32_d_write_off + smem32_d_thread_off + ( 8*4)] = half2(acc_frag_3_4.z, acc_frag_3_4.w); + smem32_d[smem32_d_write_off + smem32_d_thread_off + ( 9*4)] = half2(acc_frag_3_5.z, acc_frag_3_5.w); + smem32_d[smem32_d_write_off + smem32_d_thread_off + (12*4)] = half2(acc_frag_3_6.z, acc_frag_3_6.w); + smem32_d[smem32_d_write_off + smem32_d_thread_off + (13*4)] = half2(acc_frag_3_7.z, acc_frag_3_7.w); + + // each thread reads and writes two 8 element chunks + __syncthreads(); + out128_d[out128_d_off + ( 8 * (N / 8))] = smem128_d[smem128_d_read_off]; + out128_d[out128_d_off + (40 * (N / 8))] = smem128_d[smem128_d_read_off + (16 * (SMEM_N_WIDTH / 8))]; + + __syncthreads(); +} diff --git a/extra/gemm/max_kernels/nv.fp16_fp16_fp16.no_xor.cu b/extra/gemm/max_kernels/nv.fp16_fp16_fp16.no_xor.cu new file mode 100644 index 0000000000..962a219457 --- /dev/null +++ b/extra/gemm/max_kernels/nv.fp16_fp16_fp16.no_xor.cu @@ -0,0 +1,486 @@ +#define INFINITY (__int_as_float(0x7f800000)) +#define NAN (__int_as_float(0x7fffffff)) +#include +#include +#define SMEM_N_WIDTH 136 + +struct __align__(8) half4 { half x, y, z, w; }; +__device__ half4 make_half4(half x, half y, half z, half w) { half4 r={x, y, z, w}; return r; } + +struct __align__(16) half8 { half x, y, z, w, a, b, c, d; }; +__device__ half8 make_half8(half x, half y, half z, half w, half a, half b, half c, half d) { half8 r={x, y, z, w, a, b, c, d}; return r; } + +__device__ void __ldmatrix_a_elems(half8 *regs, half *smem) { + uint32_t reg0, reg1, reg2, reg3; + asm volatile( + "ldmatrix.sync.aligned.m8n8.x4.shared.b16 {%0, %1, %2, %3}, [%4];" + : "=r"(reg0), "=r"(reg1), "=r"(reg2), "=r"(reg3) + : "l"(__cvta_generic_to_shared(smem)) + ); + uint32_t *addr = reinterpret_cast(regs); + addr[0] = reg0; + addr[1] = reg1; + addr[2] = reg2; + addr[3] = reg3; +} + +__device__ void __ldmatrix_b_elems(half4 *regs_lo, half4 *regs_hi, half *smem) { + uint32_t reg0, reg1, reg2, reg3; + asm volatile( + "ldmatrix.sync.aligned.m8n8.x4.trans.shared.b16 {%0, %1, %2, %3}, [%4];" + : "=r"(reg0), "=r"(reg1), "=r"(reg2), "=r"(reg3) + : "l"(__cvta_generic_to_shared(smem)) + ); + uint32_t *addr_lo = reinterpret_cast(regs_lo); + uint32_t *addr_hi = reinterpret_cast(regs_hi); + addr_lo[0] = reg0; + addr_lo[1] = reg1; + addr_hi[0] = reg2; + addr_hi[1] = reg3; +} + +__device__ half4 __WMMA_8_16_16_half_half(half8 a, half4 b, half4 c) { + int *a_pk = (int *) (&a), *b_pk = (int *) (&b), *c_pk = (int *) (&c); + asm( "mma.sync.aligned.m16n8k16.row.col.f16.f16.f16.f16 { %0, %1 }, { %2, %3, %4, %5 }, { %6, %7 }, { %0, %1 };" + : "+r"(c_pk[0]), "+r"(c_pk[1]): "r"(a_pk[0]), "r"(a_pk[1]), "r"(a_pk[2]), "r"(a_pk[3]), "r"(b_pk[0]), "r"(b_pk[1]) ); + return c; +} + +extern "C" __global__ void __launch_bounds__(256) wmma_example(half* data0, const half* data1, const half* data2, int N, int K) { + extern __shared__ char smem[]; + half *smem_a_0 = (half *)(smem); + half *smem_a_1 = (half *)(smem + 16384); + half *smem_a_2 = (half *)(smem + 32768); + half *smem_b_0 = (half *)(smem + 49152); + half *smem_b_1 = (half *)(smem + 57344); + half *smem_b_2 = (half *)(smem + 65536); + + int grid_m = blockIdx.x; /* M//256 */ + int grid_n = blockIdx.y; /* N//128 */ + int wg_threads = threadIdx.x; // 32 + int wg_m = threadIdx.y; // 4 + int wg_n = threadIdx.z; // 2 + int threads = threadIdx.x + (threadIdx.y * 32) + (threadIdx.z * 128); /* 256 */ + int num_k_blocks = K / 32; + + // ldmatrix indices - 4x loads of 8x8 matrices by 32 threads + // threads 0-7 are row starts for A, 8-15 for B, 16-23 for C, 24-31 for D + // [ A | C ] + // [ - + - ] + // [ B | D ] + + // swizzled A - SMEM_A is 128 rows x 64 cols + size_t global_a_off = ((grid_m * 256) * K) + ((threads % 4) * 8) + (((threads / 4) % 2) * 8 * 16 * K) + ((threads / 8) * K); + size_t store_smem_a_off = ((threads / 8) * 64) + (((threads * 8) ^ threads) & 56); // 32 rows / 64 cols per copy + size_t load_smem_a_row = ((wg_m * 16) + (threads % 16)) * 64; + size_t load_smem_a_phase = (threads / 16) % 2; + size_t load_smem_a_0_k_0 = load_smem_a_row + ( 0 * 64) + (((load_smem_a_phase + 0) ^ (threads % 8)) * 8); + size_t load_smem_a_1_k_0 = load_smem_a_row + (64 * 64) + (((load_smem_a_phase + 0) ^ (threads % 8)) * 8); + size_t load_smem_a_2_k_0 = load_smem_a_row + ( 0 * 64) + (((load_smem_a_phase + 4) ^ (threads % 8)) * 8); + size_t load_smem_a_3_k_0 = load_smem_a_row + (64 * 64) + (((load_smem_a_phase + 4) ^ (threads % 8)) * 8); + size_t load_smem_a_0_k_1 = load_smem_a_row + ( 0 * 64) + (((load_smem_a_phase + 2) ^ (threads % 8)) * 8); + size_t load_smem_a_1_k_1 = load_smem_a_row + (64 * 64) + (((load_smem_a_phase + 2) ^ (threads % 8)) * 8); + size_t load_smem_a_2_k_1 = load_smem_a_row + ( 0 * 64) + (((load_smem_a_phase + 6) ^ (threads % 8)) * 8); + size_t load_smem_a_3_k_1 = load_smem_a_row + (64 * 64) + (((load_smem_a_phase + 6) ^ (threads % 8)) * 8); + + // swizzled B - SMEM_B is 64 rows x 64 cols + size_t global_b_off = (grid_n * 128) + ((threads % 16) * 8) + ((threads / 16) * N); + size_t store_smem_b_off = // 32 rows of 64 cols per copy + ((threads / 128) * (64)) + // [A,C] vs [B,D] in ldmatrix + ((threads % 2) * (2 * 64)) + // [A vs C] or [B vs. D] + (((threads / 2) % 2) * (4 * 64)) + // WG_N in [0, 1] + (((threads / 4) % 4) * (8 * 64)) + // B in [0, 1, 2, 3] + (((threads / 16) % 8) * (8)); // cols in SMEM_B i.e. rows of 8x8 + + size_t load_smem_b_0_k_0 = (wg_n * 4 * 64) + ((wg_threads / 8) * 64) + ((wg_threads % 8) * 8); + size_t load_smem_b_1_k_0 = load_smem_b_0_k_0 + ( 8 * 64); + size_t load_smem_b_2_k_0 = load_smem_b_0_k_0 + (16 * 64); + size_t load_smem_b_3_k_0 = load_smem_b_0_k_0 + (24 * 64); + size_t load_smem_b_0_k_1 = load_smem_b_0_k_0 + (32 * 64); + size_t load_smem_b_1_k_1 = load_smem_b_1_k_0 + (32 * 64); + size_t load_smem_b_2_k_1 = load_smem_b_2_k_0 + (32 * 64); + size_t load_smem_b_3_k_1 = load_smem_b_3_k_0 + (32 * 64); + + // create accs (M=4, N=8) + half4 acc_frag_0_0 = make_half4(0.0f,0.0f,0.0f,0.0f); + half4 acc_frag_0_1 = make_half4(0.0f,0.0f,0.0f,0.0f); + half4 acc_frag_0_2 = make_half4(0.0f,0.0f,0.0f,0.0f); + half4 acc_frag_0_3 = make_half4(0.0f,0.0f,0.0f,0.0f); + half4 acc_frag_0_4 = make_half4(0.0f,0.0f,0.0f,0.0f); + half4 acc_frag_0_5 = make_half4(0.0f,0.0f,0.0f,0.0f); + half4 acc_frag_0_6 = make_half4(0.0f,0.0f,0.0f,0.0f); + half4 acc_frag_0_7 = make_half4(0.0f,0.0f,0.0f,0.0f); + half4 acc_frag_1_0 = make_half4(0.0f,0.0f,0.0f,0.0f); + half4 acc_frag_1_1 = make_half4(0.0f,0.0f,0.0f,0.0f); + half4 acc_frag_1_2 = make_half4(0.0f,0.0f,0.0f,0.0f); + half4 acc_frag_1_3 = make_half4(0.0f,0.0f,0.0f,0.0f); + half4 acc_frag_1_4 = make_half4(0.0f,0.0f,0.0f,0.0f); + half4 acc_frag_1_5 = make_half4(0.0f,0.0f,0.0f,0.0f); + half4 acc_frag_1_6 = make_half4(0.0f,0.0f,0.0f,0.0f); + half4 acc_frag_1_7 = make_half4(0.0f,0.0f,0.0f,0.0f); + half4 acc_frag_2_0 = make_half4(0.0f,0.0f,0.0f,0.0f); + half4 acc_frag_2_1 = make_half4(0.0f,0.0f,0.0f,0.0f); + half4 acc_frag_2_2 = make_half4(0.0f,0.0f,0.0f,0.0f); + half4 acc_frag_2_3 = make_half4(0.0f,0.0f,0.0f,0.0f); + half4 acc_frag_2_4 = make_half4(0.0f,0.0f,0.0f,0.0f); + half4 acc_frag_2_5 = make_half4(0.0f,0.0f,0.0f,0.0f); + half4 acc_frag_2_6 = make_half4(0.0f,0.0f,0.0f,0.0f); + half4 acc_frag_2_7 = make_half4(0.0f,0.0f,0.0f,0.0f); + half4 acc_frag_3_0 = make_half4(0.0f,0.0f,0.0f,0.0f); + half4 acc_frag_3_1 = make_half4(0.0f,0.0f,0.0f,0.0f); + half4 acc_frag_3_2 = make_half4(0.0f,0.0f,0.0f,0.0f); + half4 acc_frag_3_3 = make_half4(0.0f,0.0f,0.0f,0.0f); + half4 acc_frag_3_4 = make_half4(0.0f,0.0f,0.0f,0.0f); + half4 acc_frag_3_5 = make_half4(0.0f,0.0f,0.0f,0.0f); + half4 acc_frag_3_6 = make_half4(0.0f,0.0f,0.0f,0.0f); + half4 acc_frag_3_7 = make_half4(0.0f,0.0f,0.0f,0.0f); + + // create registers for block A elements + half8 a_frag_0_k_0; + half8 a_frag_1_k_0; + half8 a_frag_2_k_0; + half8 a_frag_3_k_0; + half8 a_frag_0_k_1; + half8 a_frag_1_k_1; + half8 a_frag_2_k_1; + half8 a_frag_3_k_1; + + // create register for block B elements + half4 b_frag_0_k_0; + half4 b_frag_1_k_0; + half4 b_frag_2_k_0; + half4 b_frag_3_k_0; + half4 b_frag_4_k_0; + half4 b_frag_5_k_0; + half4 b_frag_6_k_0; + half4 b_frag_7_k_0; + half4 b_frag_0_k_1; + half4 b_frag_1_k_1; + half4 b_frag_2_k_1; + half4 b_frag_3_k_1; + half4 b_frag_4_k_1; + half4 b_frag_5_k_1; + half4 b_frag_6_k_1; + half4 b_frag_7_k_1; + + __syncthreads(); + + // load first tile + __pipeline_memcpy_async(&smem_a_0[store_smem_a_off + ( 0)], &data1[global_a_off + ( 0)], 16); + __pipeline_memcpy_async(&smem_a_0[store_smem_a_off + ( 32*64)], &data1[global_a_off + ( 32*K)], 16); + __pipeline_memcpy_async(&smem_a_0[store_smem_a_off + ( 64*64)], &data1[global_a_off + ( 64*K)], 16); + __pipeline_memcpy_async(&smem_a_0[store_smem_a_off + ( 96*64)], &data1[global_a_off + ( 96*K)], 16); + __pipeline_memcpy_async(&smem_b_0[store_smem_b_off + ( 0)], &data2[global_b_off + ( 0)], 16); + __pipeline_memcpy_async(&smem_b_0[store_smem_b_off + ( 32*64)], &data2[global_b_off + ( 16*N)], 16); + __pipeline_commit(); + global_a_off += 32; + global_b_off += 32 * N; + + // load second tile + __pipeline_memcpy_async(&smem_a_1[store_smem_a_off + ( 0)], &data1[global_a_off + ( 0)], 16); + __pipeline_memcpy_async(&smem_a_1[store_smem_a_off + ( 32*64)], &data1[global_a_off + ( 32*K)], 16); + __pipeline_memcpy_async(&smem_a_1[store_smem_a_off + ( 64*64)], &data1[global_a_off + ( 64*K)], 16); + __pipeline_memcpy_async(&smem_a_1[store_smem_a_off + ( 96*64)], &data1[global_a_off + ( 96*K)], 16); + __pipeline_memcpy_async(&smem_b_1[store_smem_b_off + ( 0)], &data2[global_b_off + ( 0)], 16); + __pipeline_memcpy_async(&smem_b_1[store_smem_b_off + ( 32*64)], &data2[global_b_off + ( 16*N)], 16); + __pipeline_commit(); + + global_a_off += 32; + global_b_off += 32 * N; + + // wait on first pre-fetch load + __pipeline_wait_prior(1); + __syncthreads(); + + // load K=0 elements for the first tile + __ldmatrix_a_elems(&a_frag_0_k_0, &smem_a_0[load_smem_a_0_k_0]); + __ldmatrix_a_elems(&a_frag_1_k_0, &smem_a_0[load_smem_a_1_k_0]); + __ldmatrix_a_elems(&a_frag_2_k_0, &smem_a_0[load_smem_a_2_k_0]); + __ldmatrix_a_elems(&a_frag_3_k_0, &smem_a_0[load_smem_a_3_k_0]); + __ldmatrix_b_elems(&b_frag_0_k_0, &b_frag_1_k_0, &smem_b_0[load_smem_b_0_k_0]); + __ldmatrix_b_elems(&b_frag_2_k_0, &b_frag_3_k_0, &smem_b_0[load_smem_b_1_k_0]); + __ldmatrix_b_elems(&b_frag_4_k_0, &b_frag_5_k_0, &smem_b_0[load_smem_b_2_k_0]); + __ldmatrix_b_elems(&b_frag_6_k_0, &b_frag_7_k_0, &smem_b_0[load_smem_b_3_k_0]); + + for (int block_k = 0; block_k < num_k_blocks; block_k++) { + int phase_k = block_k % 3; + half *smem_a_curr = (phase_k == 0) ? smem_a_0 : ((phase_k == 1) ? smem_a_1 : smem_a_2); + half *smem_b_curr = (phase_k == 0) ? smem_b_0 : ((phase_k == 1) ? smem_b_1 : smem_b_2); + + int next_phase_k = (block_k+1) % 3; + half *smem_a_next = (next_phase_k == 0) ? smem_a_0 : ((next_phase_k == 1) ? smem_a_1 : smem_a_2); + half *smem_b_next = (next_phase_k == 0) ? smem_b_0 : ((next_phase_k == 1) ? smem_b_1 : smem_b_2); + + int store_phase_k = (block_k+2) % 3; + half *smem_a_store = (store_phase_k == 0) ? smem_a_0 : ((store_phase_k == 1) ? smem_a_1 : smem_a_2); + half *smem_b_store = (store_phase_k == 0) ? smem_b_0 : ((store_phase_k == 1) ? smem_b_1 : smem_b_2); + + // load K=1 elements for the current tile + __ldmatrix_a_elems(&a_frag_0_k_1, &smem_a_curr[load_smem_a_0_k_1]); + __ldmatrix_a_elems(&a_frag_1_k_1, &smem_a_curr[load_smem_a_1_k_1]); + __ldmatrix_a_elems(&a_frag_2_k_1, &smem_a_curr[load_smem_a_2_k_1]); + __ldmatrix_a_elems(&a_frag_3_k_1, &smem_a_curr[load_smem_a_3_k_1]); + __ldmatrix_b_elems(&b_frag_0_k_1, &b_frag_1_k_1, &smem_b_curr[load_smem_b_0_k_1]); + __ldmatrix_b_elems(&b_frag_2_k_1, &b_frag_3_k_1, &smem_b_curr[load_smem_b_1_k_1]); + __ldmatrix_b_elems(&b_frag_4_k_1, &b_frag_5_k_1, &smem_b_curr[load_smem_b_2_k_1]); + __ldmatrix_b_elems(&b_frag_6_k_1, &b_frag_7_k_1, &smem_b_curr[load_smem_b_3_k_1]); + + // MMA K=0, (M=4 x N=8) + acc_frag_0_0 = __WMMA_8_16_16_half_half(a_frag_0_k_0, b_frag_0_k_0, acc_frag_0_0); + acc_frag_0_1 = __WMMA_8_16_16_half_half(a_frag_0_k_0, b_frag_1_k_0, acc_frag_0_1); + acc_frag_0_2 = __WMMA_8_16_16_half_half(a_frag_0_k_0, b_frag_2_k_0, acc_frag_0_2); + acc_frag_0_3 = __WMMA_8_16_16_half_half(a_frag_0_k_0, b_frag_3_k_0, acc_frag_0_3); + acc_frag_0_4 = __WMMA_8_16_16_half_half(a_frag_0_k_0, b_frag_4_k_0, acc_frag_0_4); + acc_frag_0_5 = __WMMA_8_16_16_half_half(a_frag_0_k_0, b_frag_5_k_0, acc_frag_0_5); + acc_frag_0_6 = __WMMA_8_16_16_half_half(a_frag_0_k_0, b_frag_6_k_0, acc_frag_0_6); + acc_frag_0_7 = __WMMA_8_16_16_half_half(a_frag_0_k_0, b_frag_7_k_0, acc_frag_0_7); + acc_frag_1_0 = __WMMA_8_16_16_half_half(a_frag_1_k_0, b_frag_0_k_0, acc_frag_1_0); + acc_frag_1_1 = __WMMA_8_16_16_half_half(a_frag_1_k_0, b_frag_1_k_0, acc_frag_1_1); + acc_frag_1_2 = __WMMA_8_16_16_half_half(a_frag_1_k_0, b_frag_2_k_0, acc_frag_1_2); + acc_frag_1_3 = __WMMA_8_16_16_half_half(a_frag_1_k_0, b_frag_3_k_0, acc_frag_1_3); + acc_frag_1_4 = __WMMA_8_16_16_half_half(a_frag_1_k_0, b_frag_4_k_0, acc_frag_1_4); + acc_frag_1_5 = __WMMA_8_16_16_half_half(a_frag_1_k_0, b_frag_5_k_0, acc_frag_1_5); + acc_frag_1_6 = __WMMA_8_16_16_half_half(a_frag_1_k_0, b_frag_6_k_0, acc_frag_1_6); + acc_frag_1_7 = __WMMA_8_16_16_half_half(a_frag_1_k_0, b_frag_7_k_0, acc_frag_1_7); + acc_frag_2_0 = __WMMA_8_16_16_half_half(a_frag_2_k_0, b_frag_0_k_0, acc_frag_2_0); + acc_frag_2_1 = __WMMA_8_16_16_half_half(a_frag_2_k_0, b_frag_1_k_0, acc_frag_2_1); + acc_frag_2_2 = __WMMA_8_16_16_half_half(a_frag_2_k_0, b_frag_2_k_0, acc_frag_2_2); + acc_frag_2_3 = __WMMA_8_16_16_half_half(a_frag_2_k_0, b_frag_3_k_0, acc_frag_2_3); + acc_frag_2_4 = __WMMA_8_16_16_half_half(a_frag_2_k_0, b_frag_4_k_0, acc_frag_2_4); + acc_frag_2_5 = __WMMA_8_16_16_half_half(a_frag_2_k_0, b_frag_5_k_0, acc_frag_2_5); + acc_frag_2_6 = __WMMA_8_16_16_half_half(a_frag_2_k_0, b_frag_6_k_0, acc_frag_2_6); + acc_frag_2_7 = __WMMA_8_16_16_half_half(a_frag_2_k_0, b_frag_7_k_0, acc_frag_2_7); + acc_frag_3_0 = __WMMA_8_16_16_half_half(a_frag_3_k_0, b_frag_0_k_0, acc_frag_3_0); + acc_frag_3_1 = __WMMA_8_16_16_half_half(a_frag_3_k_0, b_frag_1_k_0, acc_frag_3_1); + acc_frag_3_2 = __WMMA_8_16_16_half_half(a_frag_3_k_0, b_frag_2_k_0, acc_frag_3_2); + acc_frag_3_3 = __WMMA_8_16_16_half_half(a_frag_3_k_0, b_frag_3_k_0, acc_frag_3_3); + acc_frag_3_4 = __WMMA_8_16_16_half_half(a_frag_3_k_0, b_frag_4_k_0, acc_frag_3_4); + acc_frag_3_5 = __WMMA_8_16_16_half_half(a_frag_3_k_0, b_frag_5_k_0, acc_frag_3_5); + acc_frag_3_6 = __WMMA_8_16_16_half_half(a_frag_3_k_0, b_frag_6_k_0, acc_frag_3_6); + acc_frag_3_7 = __WMMA_8_16_16_half_half(a_frag_3_k_0, b_frag_7_k_0, acc_frag_3_7); + + // load next tile if needed + if (block_k < (num_k_blocks-2)) { + __pipeline_memcpy_async(&smem_a_store[store_smem_a_off + ( 0)], &data1[global_a_off + ( 0)], 16); + __pipeline_memcpy_async(&smem_a_store[store_smem_a_off + ( 32*64)], &data1[global_a_off + ( 32*K)], 16); + __pipeline_memcpy_async(&smem_a_store[store_smem_a_off + ( 64*64)], &data1[global_a_off + ( 64*K)], 16); + __pipeline_memcpy_async(&smem_a_store[store_smem_a_off + ( 96*64)], &data1[global_a_off + ( 96*K)], 16); + __pipeline_memcpy_async(&smem_b_store[store_smem_b_off + ( 0)], &data2[global_b_off + ( 0)], 16); + __pipeline_memcpy_async(&smem_b_store[store_smem_b_off + ( 32*64)], &data2[global_b_off + ( 16*N)], 16); + global_a_off += 32; + global_b_off += 32 * N; + } + __pipeline_commit(); + + // wait next tile + __pipeline_wait_prior(1); + __syncthreads(); + + // load K=0 elements for the next tile + __ldmatrix_a_elems(&a_frag_0_k_0, &smem_a_next[load_smem_a_0_k_0]); + __ldmatrix_a_elems(&a_frag_1_k_0, &smem_a_next[load_smem_a_1_k_0]); + __ldmatrix_a_elems(&a_frag_2_k_0, &smem_a_next[load_smem_a_2_k_0]); + __ldmatrix_a_elems(&a_frag_3_k_0, &smem_a_next[load_smem_a_3_k_0]); + __ldmatrix_b_elems(&b_frag_0_k_0, &b_frag_1_k_0, &smem_b_next[load_smem_b_0_k_0]); + __ldmatrix_b_elems(&b_frag_2_k_0, &b_frag_3_k_0, &smem_b_next[load_smem_b_1_k_0]); + __ldmatrix_b_elems(&b_frag_4_k_0, &b_frag_5_k_0, &smem_b_next[load_smem_b_2_k_0]); + __ldmatrix_b_elems(&b_frag_6_k_0, &b_frag_7_k_0, &smem_b_next[load_smem_b_3_k_0]); + + // MMA K=1, (M=4 x N=8) + acc_frag_0_0 = __WMMA_8_16_16_half_half(a_frag_0_k_1, b_frag_0_k_1, acc_frag_0_0); + acc_frag_0_1 = __WMMA_8_16_16_half_half(a_frag_0_k_1, b_frag_1_k_1, acc_frag_0_1); + acc_frag_0_2 = __WMMA_8_16_16_half_half(a_frag_0_k_1, b_frag_2_k_1, acc_frag_0_2); + acc_frag_0_3 = __WMMA_8_16_16_half_half(a_frag_0_k_1, b_frag_3_k_1, acc_frag_0_3); + acc_frag_0_4 = __WMMA_8_16_16_half_half(a_frag_0_k_1, b_frag_4_k_1, acc_frag_0_4); + acc_frag_0_5 = __WMMA_8_16_16_half_half(a_frag_0_k_1, b_frag_5_k_1, acc_frag_0_5); + acc_frag_0_6 = __WMMA_8_16_16_half_half(a_frag_0_k_1, b_frag_6_k_1, acc_frag_0_6); + acc_frag_0_7 = __WMMA_8_16_16_half_half(a_frag_0_k_1, b_frag_7_k_1, acc_frag_0_7); + acc_frag_1_0 = __WMMA_8_16_16_half_half(a_frag_1_k_1, b_frag_0_k_1, acc_frag_1_0); + acc_frag_1_1 = __WMMA_8_16_16_half_half(a_frag_1_k_1, b_frag_1_k_1, acc_frag_1_1); + acc_frag_1_2 = __WMMA_8_16_16_half_half(a_frag_1_k_1, b_frag_2_k_1, acc_frag_1_2); + acc_frag_1_3 = __WMMA_8_16_16_half_half(a_frag_1_k_1, b_frag_3_k_1, acc_frag_1_3); + acc_frag_1_4 = __WMMA_8_16_16_half_half(a_frag_1_k_1, b_frag_4_k_1, acc_frag_1_4); + acc_frag_1_5 = __WMMA_8_16_16_half_half(a_frag_1_k_1, b_frag_5_k_1, acc_frag_1_5); + acc_frag_1_6 = __WMMA_8_16_16_half_half(a_frag_1_k_1, b_frag_6_k_1, acc_frag_1_6); + acc_frag_1_7 = __WMMA_8_16_16_half_half(a_frag_1_k_1, b_frag_7_k_1, acc_frag_1_7); + acc_frag_2_0 = __WMMA_8_16_16_half_half(a_frag_2_k_1, b_frag_0_k_1, acc_frag_2_0); + acc_frag_2_1 = __WMMA_8_16_16_half_half(a_frag_2_k_1, b_frag_1_k_1, acc_frag_2_1); + acc_frag_2_2 = __WMMA_8_16_16_half_half(a_frag_2_k_1, b_frag_2_k_1, acc_frag_2_2); + acc_frag_2_3 = __WMMA_8_16_16_half_half(a_frag_2_k_1, b_frag_3_k_1, acc_frag_2_3); + acc_frag_2_4 = __WMMA_8_16_16_half_half(a_frag_2_k_1, b_frag_4_k_1, acc_frag_2_4); + acc_frag_2_5 = __WMMA_8_16_16_half_half(a_frag_2_k_1, b_frag_5_k_1, acc_frag_2_5); + acc_frag_2_6 = __WMMA_8_16_16_half_half(a_frag_2_k_1, b_frag_6_k_1, acc_frag_2_6); + acc_frag_2_7 = __WMMA_8_16_16_half_half(a_frag_2_k_1, b_frag_7_k_1, acc_frag_2_7); + acc_frag_3_0 = __WMMA_8_16_16_half_half(a_frag_3_k_1, b_frag_0_k_1, acc_frag_3_0); + acc_frag_3_1 = __WMMA_8_16_16_half_half(a_frag_3_k_1, b_frag_1_k_1, acc_frag_3_1); + acc_frag_3_2 = __WMMA_8_16_16_half_half(a_frag_3_k_1, b_frag_2_k_1, acc_frag_3_2); + acc_frag_3_3 = __WMMA_8_16_16_half_half(a_frag_3_k_1, b_frag_3_k_1, acc_frag_3_3); + acc_frag_3_4 = __WMMA_8_16_16_half_half(a_frag_3_k_1, b_frag_4_k_1, acc_frag_3_4); + acc_frag_3_5 = __WMMA_8_16_16_half_half(a_frag_3_k_1, b_frag_5_k_1, acc_frag_3_5); + acc_frag_3_6 = __WMMA_8_16_16_half_half(a_frag_3_k_1, b_frag_6_k_1, acc_frag_3_6); + acc_frag_3_7 = __WMMA_8_16_16_half_half(a_frag_3_k_1, b_frag_7_k_1, acc_frag_3_7); + } + + // write accumulators to output + __pipeline_wait_prior(0); + __syncthreads(); + + // faster epilogue: write each 8x8 TC accs to SMEM first + // - SMEM_N_WIDTH 8 larger than 128 required to deconflict bank access + // - around 14 micros + // - check bank conflict with in sudo with: "PYTHONPATH=. CUDA=1 GEMM_VARIATION="max" DTYPE_IN=half DTYPE_OUT=half DTYPE_ACC=half CNT=8 INPUT=ONES /usr/local/cuda/bin/ncu --section MemoryWorkloadAnalysis --metrics l1tex__data_bank_conflicts_pipe_lsu_mem_shared_op_ld.sum,l1tex__data_bank_conflicts_pipe_lsu_mem_shared_op_st.sum python3 ./extra/gemm/max_matmul.py" + + // epilogue chunk with 256 threads / WG_M=4 / WG_N=2: split into 8 chunks (hi/lo for each in TC M) + // 1) write 32 rows of 128 cols (rows 0-7, 16-23, 32-39, 48-53 in acc_frag_0.lo, then acc_frag_0.hi, etc.) + // 2) read/write 16 rows of 128 elements in 8 elem (16B) chunks + + half2 *smem32_d = (half2 *)(smem); + half8 *smem128_d = (half8 *)(smem); + half8 *out128_d = (half8 *)(data0); + size_t smem32_d_write_off = (wg_m * 8 * (SMEM_N_WIDTH / 2)) + (wg_n * (16 / 2)); + size_t smem32_d_thread_off = ((wg_threads / 4) * (SMEM_N_WIDTH / 2)) + (wg_threads % 4); + size_t smem128_d_read_off = ((threads / 16) * (SMEM_N_WIDTH / 8)) + (threads % 16); + size_t out128_d_off = ((grid_m * 256) * (N / 8)) + (grid_n * (128 / 8)) + + ((threads / 128) * 16 * (N / 8)) + (((threads / 16) % 8) * (N / 8)) + (threads % 16); + + // write acc_frag_0_* + + // write 32 rows of 128 N elements to SMEM + __syncthreads(); + smem32_d[smem32_d_write_off + smem32_d_thread_off + ( 0*4)] = half2(acc_frag_0_0.x, acc_frag_0_0.y); + smem32_d[smem32_d_write_off + smem32_d_thread_off + ( 1*4)] = half2(acc_frag_0_1.x, acc_frag_0_1.y); + smem32_d[smem32_d_write_off + smem32_d_thread_off + ( 4*4)] = half2(acc_frag_0_2.x, acc_frag_0_2.y); + smem32_d[smem32_d_write_off + smem32_d_thread_off + ( 5*4)] = half2(acc_frag_0_3.x, acc_frag_0_3.y); + smem32_d[smem32_d_write_off + smem32_d_thread_off + ( 8*4)] = half2(acc_frag_0_4.x, acc_frag_0_4.y); + smem32_d[smem32_d_write_off + smem32_d_thread_off + ( 9*4)] = half2(acc_frag_0_5.x, acc_frag_0_5.y); + smem32_d[smem32_d_write_off + smem32_d_thread_off + (12*4)] = half2(acc_frag_0_6.x, acc_frag_0_6.y); + smem32_d[smem32_d_write_off + smem32_d_thread_off + (13*4)] = half2(acc_frag_0_7.x, acc_frag_0_7.y); + + // each thread reads and writes two 8 element chunks + __syncthreads(); + out128_d[out128_d_off + ( 0 * (N / 8))] = smem128_d[smem128_d_read_off]; + out128_d[out128_d_off + (32 * (N / 8))] = smem128_d[smem128_d_read_off + (16 * (SMEM_N_WIDTH / 8))]; + + // write 32 rows of 128 N elements to SMEM + __syncthreads(); + smem32_d[smem32_d_write_off + smem32_d_thread_off + ( 0*4)] = half2(acc_frag_0_0.z, acc_frag_0_0.w); + smem32_d[smem32_d_write_off + smem32_d_thread_off + ( 1*4)] = half2(acc_frag_0_1.z, acc_frag_0_1.w); + smem32_d[smem32_d_write_off + smem32_d_thread_off + ( 4*4)] = half2(acc_frag_0_2.z, acc_frag_0_2.w); + smem32_d[smem32_d_write_off + smem32_d_thread_off + ( 5*4)] = half2(acc_frag_0_3.z, acc_frag_0_3.w); + smem32_d[smem32_d_write_off + smem32_d_thread_off + ( 8*4)] = half2(acc_frag_0_4.z, acc_frag_0_4.w); + smem32_d[smem32_d_write_off + smem32_d_thread_off + ( 9*4)] = half2(acc_frag_0_5.z, acc_frag_0_5.w); + smem32_d[smem32_d_write_off + smem32_d_thread_off + (12*4)] = half2(acc_frag_0_6.z, acc_frag_0_6.w); + smem32_d[smem32_d_write_off + smem32_d_thread_off + (13*4)] = half2(acc_frag_0_7.z, acc_frag_0_7.w); + + // each thread reads and writes two 8 element chunks + __syncthreads(); + out128_d[out128_d_off + ( 8 * (N / 8))] = smem128_d[smem128_d_read_off]; + out128_d[out128_d_off + (40 * (N / 8))] = smem128_d[smem128_d_read_off + (16 * (SMEM_N_WIDTH / 8))]; + + // write acc_frag_1_* + out128_d_off += (64 * (N / 8)); + + // write 32 rows of 128 N elements to SMEM + __syncthreads(); + smem32_d[smem32_d_write_off + smem32_d_thread_off + ( 0*4)] = half2(acc_frag_1_0.x, acc_frag_1_0.y); + smem32_d[smem32_d_write_off + smem32_d_thread_off + ( 1*4)] = half2(acc_frag_1_1.x, acc_frag_1_1.y); + smem32_d[smem32_d_write_off + smem32_d_thread_off + ( 4*4)] = half2(acc_frag_1_2.x, acc_frag_1_2.y); + smem32_d[smem32_d_write_off + smem32_d_thread_off + ( 5*4)] = half2(acc_frag_1_3.x, acc_frag_1_3.y); + smem32_d[smem32_d_write_off + smem32_d_thread_off + ( 8*4)] = half2(acc_frag_1_4.x, acc_frag_1_4.y); + smem32_d[smem32_d_write_off + smem32_d_thread_off + ( 9*4)] = half2(acc_frag_1_5.x, acc_frag_1_5.y); + smem32_d[smem32_d_write_off + smem32_d_thread_off + (12*4)] = half2(acc_frag_1_6.x, acc_frag_1_6.y); + smem32_d[smem32_d_write_off + smem32_d_thread_off + (13*4)] = half2(acc_frag_1_7.x, acc_frag_1_7.y); + + // each thread reads and writes two 8 element chunks + __syncthreads(); + out128_d[out128_d_off + ( 0 * (N / 8))] = smem128_d[smem128_d_read_off]; + out128_d[out128_d_off + (32 * (N / 8))] = smem128_d[smem128_d_read_off + (16 * (SMEM_N_WIDTH / 8))]; + + // write 32 rows of 128 N elements to SMEM + __syncthreads(); + smem32_d[smem32_d_write_off + smem32_d_thread_off + ( 0*4)] = half2(acc_frag_1_0.z, acc_frag_1_0.w); + smem32_d[smem32_d_write_off + smem32_d_thread_off + ( 1*4)] = half2(acc_frag_1_1.z, acc_frag_1_1.w); + smem32_d[smem32_d_write_off + smem32_d_thread_off + ( 4*4)] = half2(acc_frag_1_2.z, acc_frag_1_2.w); + smem32_d[smem32_d_write_off + smem32_d_thread_off + ( 5*4)] = half2(acc_frag_1_3.z, acc_frag_1_3.w); + smem32_d[smem32_d_write_off + smem32_d_thread_off + ( 8*4)] = half2(acc_frag_1_4.z, acc_frag_1_4.w); + smem32_d[smem32_d_write_off + smem32_d_thread_off + ( 9*4)] = half2(acc_frag_1_5.z, acc_frag_1_5.w); + smem32_d[smem32_d_write_off + smem32_d_thread_off + (12*4)] = half2(acc_frag_1_6.z, acc_frag_1_6.w); + smem32_d[smem32_d_write_off + smem32_d_thread_off + (13*4)] = half2(acc_frag_1_7.z, acc_frag_1_7.w); + + // each thread reads and writes two 8 element chunks + __syncthreads(); + out128_d[out128_d_off + ( 8 * (N / 8))] = smem128_d[smem128_d_read_off]; + out128_d[out128_d_off + (40 * (N / 8))] = smem128_d[smem128_d_read_off + (16 * (SMEM_N_WIDTH / 8))]; + + // write acc_frag_2_* + out128_d_off += (64 * (N / 8)); + + // write 32 rows of 128 N elements to SMEM + __syncthreads(); + smem32_d[smem32_d_write_off + smem32_d_thread_off + ( 0*4)] = half2(acc_frag_2_0.x, acc_frag_2_0.y); + smem32_d[smem32_d_write_off + smem32_d_thread_off + ( 1*4)] = half2(acc_frag_2_1.x, acc_frag_2_1.y); + smem32_d[smem32_d_write_off + smem32_d_thread_off + ( 4*4)] = half2(acc_frag_2_2.x, acc_frag_2_2.y); + smem32_d[smem32_d_write_off + smem32_d_thread_off + ( 5*4)] = half2(acc_frag_2_3.x, acc_frag_2_3.y); + smem32_d[smem32_d_write_off + smem32_d_thread_off + ( 8*4)] = half2(acc_frag_2_4.x, acc_frag_2_4.y); + smem32_d[smem32_d_write_off + smem32_d_thread_off + ( 9*4)] = half2(acc_frag_2_5.x, acc_frag_2_5.y); + smem32_d[smem32_d_write_off + smem32_d_thread_off + (12*4)] = half2(acc_frag_2_6.x, acc_frag_2_6.y); + smem32_d[smem32_d_write_off + smem32_d_thread_off + (13*4)] = half2(acc_frag_2_7.x, acc_frag_2_7.y); + + // each thread reads and writes two 8 element chunks + __syncthreads(); + out128_d[out128_d_off + ( 0 * (N / 8))] = smem128_d[smem128_d_read_off]; + out128_d[out128_d_off + (32 * (N / 8))] = smem128_d[smem128_d_read_off + (16 * (SMEM_N_WIDTH / 8))]; + + // write 32 rows of 128 N elements to SMEM + __syncthreads(); + smem32_d[smem32_d_write_off + smem32_d_thread_off + ( 0*4)] = half2(acc_frag_2_0.z, acc_frag_2_0.w); + smem32_d[smem32_d_write_off + smem32_d_thread_off + ( 1*4)] = half2(acc_frag_2_1.z, acc_frag_2_1.w); + smem32_d[smem32_d_write_off + smem32_d_thread_off + ( 4*4)] = half2(acc_frag_2_2.z, acc_frag_2_2.w); + smem32_d[smem32_d_write_off + smem32_d_thread_off + ( 5*4)] = half2(acc_frag_2_3.z, acc_frag_2_3.w); + smem32_d[smem32_d_write_off + smem32_d_thread_off + ( 8*4)] = half2(acc_frag_2_4.z, acc_frag_2_4.w); + smem32_d[smem32_d_write_off + smem32_d_thread_off + ( 9*4)] = half2(acc_frag_2_5.z, acc_frag_2_5.w); + smem32_d[smem32_d_write_off + smem32_d_thread_off + (12*4)] = half2(acc_frag_2_6.z, acc_frag_2_6.w); + smem32_d[smem32_d_write_off + smem32_d_thread_off + (13*4)] = half2(acc_frag_2_7.z, acc_frag_2_7.w); + + // each thread reads and writes two 8 element chunks + __syncthreads(); + out128_d[out128_d_off + ( 8 * (N / 8))] = smem128_d[smem128_d_read_off]; + out128_d[out128_d_off + (40 * (N / 8))] = smem128_d[smem128_d_read_off + (16 * (SMEM_N_WIDTH / 8))]; + + // write acc_frag_3_* + out128_d_off += (64 * (N / 8)); + + // write 32 rows of 128 N elements to SMEM + __syncthreads(); + smem32_d[smem32_d_write_off + smem32_d_thread_off + ( 0*4)] = half2(acc_frag_3_0.x, acc_frag_3_0.y); + smem32_d[smem32_d_write_off + smem32_d_thread_off + ( 1*4)] = half2(acc_frag_3_1.x, acc_frag_3_1.y); + smem32_d[smem32_d_write_off + smem32_d_thread_off + ( 4*4)] = half2(acc_frag_3_2.x, acc_frag_3_2.y); + smem32_d[smem32_d_write_off + smem32_d_thread_off + ( 5*4)] = half2(acc_frag_3_3.x, acc_frag_3_3.y); + smem32_d[smem32_d_write_off + smem32_d_thread_off + ( 8*4)] = half2(acc_frag_3_4.x, acc_frag_3_4.y); + smem32_d[smem32_d_write_off + smem32_d_thread_off + ( 9*4)] = half2(acc_frag_3_5.x, acc_frag_3_5.y); + smem32_d[smem32_d_write_off + smem32_d_thread_off + (12*4)] = half2(acc_frag_3_6.x, acc_frag_3_6.y); + smem32_d[smem32_d_write_off + smem32_d_thread_off + (13*4)] = half2(acc_frag_3_7.x, acc_frag_3_7.y); + + // each thread reads and writes two 8 element chunks + __syncthreads(); + out128_d[out128_d_off + ( 0 * (N / 8))] = smem128_d[smem128_d_read_off]; + out128_d[out128_d_off + (32 * (N / 8))] = smem128_d[smem128_d_read_off + (16 * (SMEM_N_WIDTH / 8))]; + + // write 32 rows of 128 N elements to SMEM + __syncthreads(); + smem32_d[smem32_d_write_off + smem32_d_thread_off + ( 0*4)] = half2(acc_frag_3_0.z, acc_frag_3_0.w); + smem32_d[smem32_d_write_off + smem32_d_thread_off + ( 1*4)] = half2(acc_frag_3_1.z, acc_frag_3_1.w); + smem32_d[smem32_d_write_off + smem32_d_thread_off + ( 4*4)] = half2(acc_frag_3_2.z, acc_frag_3_2.w); + smem32_d[smem32_d_write_off + smem32_d_thread_off + ( 5*4)] = half2(acc_frag_3_3.z, acc_frag_3_3.w); + smem32_d[smem32_d_write_off + smem32_d_thread_off + ( 8*4)] = half2(acc_frag_3_4.z, acc_frag_3_4.w); + smem32_d[smem32_d_write_off + smem32_d_thread_off + ( 9*4)] = half2(acc_frag_3_5.z, acc_frag_3_5.w); + smem32_d[smem32_d_write_off + smem32_d_thread_off + (12*4)] = half2(acc_frag_3_6.z, acc_frag_3_6.w); + smem32_d[smem32_d_write_off + smem32_d_thread_off + (13*4)] = half2(acc_frag_3_7.z, acc_frag_3_7.w); + + // each thread reads and writes two 8 element chunks + __syncthreads(); + out128_d[out128_d_off + ( 8 * (N / 8))] = smem128_d[smem128_d_read_off]; + out128_d[out128_d_off + (40 * (N / 8))] = smem128_d[smem128_d_read_off + (16 * (SMEM_N_WIDTH / 8))]; + + __syncthreads(); +} diff --git a/extra/gemm/max_kernels/nv.fp16_fp32_fp16.hcopt.cu b/extra/gemm/max_kernels/nv.fp16_fp32_fp16.hcopt.cu new file mode 100644 index 0000000000..f17f8693a1 --- /dev/null +++ b/extra/gemm/max_kernels/nv.fp16_fp32_fp16.hcopt.cu @@ -0,0 +1,157 @@ +#define INFINITY (__int_as_float(0x7f800000)) +#define NAN (__int_as_float(0x7fffffff)) +#include +struct __align__(8) half4 { half x, y, z, w; }; __device__ half4 make_half4(half x, half y, half z, half w) { half4 r={x, y, z, w}; return r; } +struct __align__(16) half8 { half x, y, z, w, a, b, c, d; }; __device__ half8 make_half8(half x, half y, half z, half w, half a, half b, half c, half d) { half8 r={x, y, z, w, a, b, c, d}; return r; } +__device__ float4 __WMMA_8_16_16_half_float(half8 a, half4 b, float4 c) { int *a_pk = (int *) (&a), *b_pk = (int *) (&b); +asm( "mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 { %0, %1, %2, %3 }, { %4, %5, %6, %7 }, { %8, %9 }, { %0, %1, %2, %3 };" + : "+f"(c.x), "+f"(c.y), "+f"(c.z), "+f"(c.w) : "r"(a_pk[0]), "r"(a_pk[1]), "r"(a_pk[2]), "r"(a_pk[3]), "r"(b_pk[0]), "r"(b_pk[1]) ); +return c;} +extern "C" __global__ void __launch_bounds__(128) wmma_example(half* data0, const half* data1, const half* data2) { + int gidx0 = blockIdx.x; /* 32 */ + int gidx1 = blockIdx.y; /* 64 */ + int lidx0 = threadIdx.x; /* 16 */ + int lidx1 = threadIdx.y; /* 2 */ + int lidx2 = threadIdx.z; /* 4 */ + float4 cast0 = make_float4(0.0f,0.0f,0.0f,0.0f); + int alu0 = (gidx0*128); + int alu1 = (gidx1*262144); + int alu2 = (lidx1*32768); + int alu3 = (lidx2*32); + int alu4 = (lidx0/8); + int alu5 = (alu4*16384); + int alu6 = (lidx0%2); + int alu7 = (alu6*2); + int alu8 = ((lidx0/2)%2); + int alu9 = (alu8*4); + int alu10 = ((lidx0/4)%2); + int alu11 = (alu10*8192); + int alu12 = (alu1+alu0+alu7+alu9+alu11+alu5+alu2+alu3); + int alu13 = (alu1+alu7+alu9+alu11+alu5+alu2); + float4 acc0 = cast0; + float4 acc1 = cast0; + float4 acc2 = cast0; + float4 acc3 = cast0; + float4 acc4 = cast0; + float4 acc5 = cast0; + float4 acc6 = cast0; + float4 acc7 = cast0; + float4 acc8 = cast0; + float4 acc9 = cast0; + float4 acc10 = cast0; + float4 acc11 = cast0; + float4 acc12 = cast0; + float4 acc13 = cast0; + float4 acc14 = cast0; + float4 acc15 = cast0; + for (int ridx0 = 0; ridx0 < 256; ridx0++) { + int alu14 = (ridx0*16); + int alu15 = (alu13+alu14); + int alu16 = (alu14+alu13); + int alu17 = (alu0+(alu6*8192)+(alu8*16384)+alu10+(alu4*2)+(lidx1*4)+alu3+(ridx0*65536)); + half val0 = data2[alu17+8]; + half val1 = data2[alu17+16]; + half val2 = data2[alu17+24]; + half val3 = data2[alu17+4096]; + half val4 = data2[alu17+4104]; + half val5 = data2[alu17+4112]; + half val6 = data2[alu17+4120]; + half val7 = data2[alu17+32768]; + half val8 = data2[alu17+32776]; + half val9 = data2[alu17+32784]; + half val10 = data2[alu17+32792]; + half val11 = data2[alu17+36864]; + half val12 = data2[alu17+36872]; + half4 cast1 = make_half4(val0,val4,val8,val12); + half val13 = data2[alu17+36880]; + half4 cast2 = make_half4(val1,val5,val9,val13); + half val14 = data2[alu17+36888]; + half4 cast3 = make_half4(val2,val6,val10,val14); + half val15 = data2[alu17]; + half4 cast4 = make_half4(val15,val3,val7,val11); + half2 val16 = *((half2*)(data1+alu15+4096)); + half2 val17 = *((half2*)(data1+alu15+65536)); + half2 val18 = *((half2*)(data1+alu15+69632)); + half2 val19 = *((half2*)(data1+alu15+131072)); + half2 val20 = *((half2*)(data1+alu15+135168)); + half2 val21 = *((half2*)(data1+alu15+196608)); + half2 val22 = *((half2*)(data1+alu15+200704)); + half2 val23 = *((half2*)(data1+alu15)); + half2 val24 = *((half2*)(data1+alu16+8)); + half2 val25 = *((half2*)(data1+alu16+4104)); + half8 cast5 = make_half8(val23.x,val23.y,val16.x,val16.y,val24.x,val24.y,val25.x,val25.y); + float4 wmma0 = __WMMA_8_16_16_half_float(cast5, cast1, acc1); + float4 wmma1 = __WMMA_8_16_16_half_float(cast5, cast2, acc2); + float4 wmma2 = __WMMA_8_16_16_half_float(cast5, cast3, acc3); + float4 wmma3 = __WMMA_8_16_16_half_float(cast5, cast4, acc0); + half2 val26 = *((half2*)(data1+alu16+65544)); + half2 val27 = *((half2*)(data1+alu16+69640)); + half8 cast6 = make_half8(val17.x,val17.y,val18.x,val18.y,val26.x,val26.y,val27.x,val27.y); + float4 wmma4 = __WMMA_8_16_16_half_float(cast6, cast1, acc5); + float4 wmma5 = __WMMA_8_16_16_half_float(cast6, cast2, acc6); + float4 wmma6 = __WMMA_8_16_16_half_float(cast6, cast3, acc7); + float4 wmma7 = __WMMA_8_16_16_half_float(cast6, cast4, acc4); + half2 val28 = *((half2*)(data1+alu16+131080)); + half2 val29 = *((half2*)(data1+alu16+135176)); + half8 cast7 = make_half8(val19.x,val19.y,val20.x,val20.y,val28.x,val28.y,val29.x,val29.y); + float4 wmma8 = __WMMA_8_16_16_half_float(cast7, cast1, acc9); + float4 wmma9 = __WMMA_8_16_16_half_float(cast7, cast2, acc10); + float4 wmma10 = __WMMA_8_16_16_half_float(cast7, cast3, acc11); + float4 wmma11 = __WMMA_8_16_16_half_float(cast7, cast4, acc8); + half2 val30 = *((half2*)(data1+alu16+196616)); + half2 val31 = *((half2*)(data1+alu16+200712)); + half8 cast8 = make_half8(val21.x,val21.y,val22.x,val22.y,val30.x,val30.y,val31.x,val31.y); + float4 wmma12 = __WMMA_8_16_16_half_float(cast8, cast1, acc13); + float4 wmma13 = __WMMA_8_16_16_half_float(cast8, cast2, acc14); + float4 wmma14 = __WMMA_8_16_16_half_float(cast8, cast3, acc15); + float4 wmma15 = __WMMA_8_16_16_half_float(cast8, cast4, acc12); + acc0 = wmma3; + acc1 = wmma0; + acc2 = wmma1; + acc3 = wmma2; + acc4 = wmma7; + acc5 = wmma4; + acc6 = wmma5; + acc7 = wmma6; + acc8 = wmma11; + acc9 = wmma8; + acc10 = wmma9; + acc11 = wmma10; + acc12 = wmma15; + acc13 = wmma12; + acc14 = wmma13; + acc15 = wmma14; + } + *((half2*)(data0+alu12+8)) = make_half2((half)(acc1.x),(half)(acc1.y)); + *((half2*)(data0+alu12+16)) = make_half2((half)(acc2.x),(half)(acc2.y)); + *((half2*)(data0+alu12+24)) = make_half2((half)(acc3.x),(half)(acc3.y)); + *((half2*)(data0+alu12+4096)) = make_half2((half)(acc0.z),(half)(acc0.w)); + *((half2*)(data0+alu12+4104)) = make_half2((half)(acc1.z),(half)(acc1.w)); + *((half2*)(data0+alu12+4112)) = make_half2((half)(acc2.z),(half)(acc2.w)); + *((half2*)(data0+alu12+4120)) = make_half2((half)(acc3.z),(half)(acc3.w)); + *((half2*)(data0+alu12+65536)) = make_half2((half)(acc4.x),(half)(acc4.y)); + *((half2*)(data0+alu12+65544)) = make_half2((half)(acc5.x),(half)(acc5.y)); + *((half2*)(data0+alu12+65552)) = make_half2((half)(acc6.x),(half)(acc6.y)); + *((half2*)(data0+alu12+65560)) = make_half2((half)(acc7.x),(half)(acc7.y)); + *((half2*)(data0+alu12+69632)) = make_half2((half)(acc4.z),(half)(acc4.w)); + *((half2*)(data0+alu12+69640)) = make_half2((half)(acc5.z),(half)(acc5.w)); + *((half2*)(data0+alu12+69648)) = make_half2((half)(acc6.z),(half)(acc6.w)); + *((half2*)(data0+alu12+69656)) = make_half2((half)(acc7.z),(half)(acc7.w)); + *((half2*)(data0+alu12+131072)) = make_half2((half)(acc8.x),(half)(acc8.y)); + *((half2*)(data0+alu12+131080)) = make_half2((half)(acc9.x),(half)(acc9.y)); + *((half2*)(data0+alu12+131088)) = make_half2((half)(acc10.x),(half)(acc10.y)); + *((half2*)(data0+alu12+131096)) = make_half2((half)(acc11.x),(half)(acc11.y)); + *((half2*)(data0+alu12+135168)) = make_half2((half)(acc8.z),(half)(acc8.w)); + *((half2*)(data0+alu12+135176)) = make_half2((half)(acc9.z),(half)(acc9.w)); + *((half2*)(data0+alu12+135184)) = make_half2((half)(acc10.z),(half)(acc10.w)); + *((half2*)(data0+alu12+135192)) = make_half2((half)(acc11.z),(half)(acc11.w)); + *((half2*)(data0+alu12+196608)) = make_half2((half)(acc12.x),(half)(acc12.y)); + *((half2*)(data0+alu12+196616)) = make_half2((half)(acc13.x),(half)(acc13.y)); + *((half2*)(data0+alu12+196624)) = make_half2((half)(acc14.x),(half)(acc14.y)); + *((half2*)(data0+alu12+196632)) = make_half2((half)(acc15.x),(half)(acc15.y)); + *((half2*)(data0+alu12+200704)) = make_half2((half)(acc12.z),(half)(acc12.w)); + *((half2*)(data0+alu12+200712)) = make_half2((half)(acc13.z),(half)(acc13.w)); + *((half2*)(data0+alu12+200720)) = make_half2((half)(acc14.z),(half)(acc14.w)); + *((half2*)(data0+alu12+200728)) = make_half2((half)(acc15.z),(half)(acc15.w)); + *((half2*)(data0+alu12)) = make_half2((half)(acc0.x),(half)(acc0.y)); +} \ No newline at end of file diff --git a/extra/gemm/max_kernels/nv.fp16_fp32_fp32.2_stage_swizzled_smem_input.cu b/extra/gemm/max_kernels/nv.fp16_fp32_fp32.2_stage_swizzled_smem_input.cu new file mode 100644 index 0000000000..689da87d35 --- /dev/null +++ b/extra/gemm/max_kernels/nv.fp16_fp32_fp32.2_stage_swizzled_smem_input.cu @@ -0,0 +1,398 @@ +#define INFINITY (__int_as_float(0x7f800000)) +#define NAN (__int_as_float(0x7fffffff)) +#include +#include +#define N_PAD 132 + +struct __align__(8) half4 { half x, y, z, w; }; +__device__ half4 make_half4(half x, half y, half z, half w) { half4 r={x, y, z, w}; return r; } + +struct __align__(16) half8 { half x, y, z, w, a, b, c, d; }; +__device__ half8 make_half8(half x, half y, half z, half w, half a, half b, half c, half d) { half8 r={x, y, z, w, a, b, c, d}; return r; } + +__device__ void __ldmatrix_a_elems(half8 *regs, half *smem) { + uint32_t reg0, reg1, reg2, reg3; + asm volatile( + "ldmatrix.sync.aligned.m8n8.x4.shared.b16 {%0, %1, %2, %3}, [%4];" + : "=r"(reg0), "=r"(reg1), "=r"(reg2), "=r"(reg3) + : "l"(__cvta_generic_to_shared(smem)) + ); + uint32_t *addr = reinterpret_cast(regs); + addr[0] = reg0; + addr[1] = reg1; + addr[2] = reg2; + addr[3] = reg3; +} + +__device__ void __ldmatrix_b_elems(half4 *regs_lo, half4 *regs_hi, half *smem) { + uint32_t reg0, reg1, reg2, reg3; + asm volatile( + "ldmatrix.sync.aligned.m8n8.x4.trans.shared.b16 {%0, %1, %2, %3}, [%4];" + : "=r"(reg0), "=r"(reg1), "=r"(reg2), "=r"(reg3) + : "l"(__cvta_generic_to_shared(smem)) + ); + uint32_t *addr_lo = reinterpret_cast(regs_lo); + uint32_t *addr_hi = reinterpret_cast(regs_hi); + addr_lo[0] = reg0; + addr_lo[1] = reg1; + addr_hi[0] = reg2; + addr_hi[1] = reg3; +} + +__device__ float4 __WMMA_8_16_16_half_float(half8 a, half4 b, float4 c) { + int *a_pk = (int *) (&a), *b_pk = (int *) (&b); + asm( "mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 { %0, %1, %2, %3 }, { %4, %5, %6, %7 }, { %8, %9 }, { %0, %1, %2, %3 };" + : "+f"(c.x), "+f"(c.y), "+f"(c.z), "+f"(c.w) : "r"(a_pk[0]), "r"(a_pk[1]), "r"(a_pk[2]), "r"(a_pk[3]), "r"(b_pk[0]), "r"(b_pk[1]) ); + return c; +} + +extern "C" __global__ void __launch_bounds__(128) wmma_example(float* data0, const half* data1, const half* data2, int N, int K) { + int grid_m = blockIdx.x; /* M//64 */ + int grid_n = blockIdx.y; /* N//128 */ + int threads = threadIdx.x; /* 128 */ + int wg_m = (threads/64); // 0 or 1 for 1st and 3rd blocks of b_m=16xb_k=16 vs 2nd and 4th blocks + int wg_n = (threads/32)%2; // 0 or 1 for 1st, 3rd, 5th, 7th blocks of b_n=16xb_k=16 vs 2nd, 4th, 6th, 8th blocks - differs from triton + int wg_threads = threads%32; + int num_k_blocks = K / 64; + + // load indexes + size_t global_a_off = ((grid_m * 64) * K) + ((threads % 8) * 8) + ((threads / 8) * K); + size_t global_b_off = (grid_n * 128) + ((threads % 16) * 8) + ((threads / 16) * N); + + // swizzled smem store offsets - columns of smem are swizzled + // here's a link to a description of the triton: https://github.com/triton-lang/triton/discussions/2026#discussioncomment-6746579 + // see also the thunderkittens impl: https://github.com/HazyResearch/ThunderKittens/blob/main/include/types/shared/st.cuh + size_t store_smem_a_off = ((threads / 8) * 64) + (((threads * 8) ^ threads) & 56); // r15 + size_t store_smem_b_off = ((threads / 16) * 128) + (((threads / 16) * 8) ^ ((threads % 16) * 8)); // r19 + + // ldmatrix indices + // threads 0-7 are row starts for A, 8-15 for B, 16-23 for C, 24-31 for D + // [ A | C ] + // [ - + - ] + // [ B | D ] + + // swizzled ldmatrix + size_t load_smem_a_row = ((wg_m * 16) + (threads % 16)) * 64; // r293 + size_t load_smem_a_phase = (threads / 16) % 2; // r4 + + size_t load_smem_b_row = (threads % 16) * 128; // r299 + size_t load_smem_b_phase = (wg_n * 2) + (((threads / 16) % 2)); // r297 -- this differs from the generated triton kernel (swapped order) + + size_t load_smem_a_0_k_0 = load_smem_a_row + (((load_smem_a_phase + 0) ^ (threads % 8)) * 8); // r38 + size_t load_smem_a_1_k_0 = load_smem_a_0_k_0 + (32 * 64); + size_t load_smem_b_0_k_0 = load_smem_b_row + (((load_smem_b_phase + 0) ^ (threads % 8)) * 8); + size_t load_smem_b_1_k_0 = load_smem_b_row + (((load_smem_b_phase + 4) ^ (threads % 8)) * 8); + size_t load_smem_b_2_k_0 = load_smem_b_row + (((load_smem_b_phase + 8) ^ (threads % 8)) * 8); + size_t load_smem_b_3_k_0 = load_smem_b_row + (((load_smem_b_phase + 12) ^ (threads % 8)) * 8); + + size_t load_smem_a_0_k_1 = load_smem_a_row + (((load_smem_a_phase + 2) ^ (threads % 8)) * 8); // r58 = r293 + r316; + size_t load_smem_a_1_k_1 = load_smem_a_0_k_1 + (32 * 64); + size_t load_smem_b_0_k_1 = load_smem_b_0_k_0 + (16 * 128); + size_t load_smem_b_1_k_1 = load_smem_b_1_k_0 + (16 * 128); + size_t load_smem_b_2_k_1 = load_smem_b_2_k_0 + (16 * 128); + size_t load_smem_b_3_k_1 = load_smem_b_3_k_0 + (16 * 128); + + size_t load_smem_a_0_k_2 = load_smem_a_row + (((load_smem_a_phase + 4) ^ (threads % 8)) * 8); // r59 = r293 + r319; + size_t load_smem_a_1_k_2 = load_smem_a_0_k_2 + (32 * 64); + size_t load_smem_b_0_k_2 = load_smem_b_0_k_0 + (32 * 128); + size_t load_smem_b_1_k_2 = load_smem_b_1_k_0 + (32 * 128); + size_t load_smem_b_2_k_2 = load_smem_b_2_k_0 + (32 * 128); + size_t load_smem_b_3_k_2 = load_smem_b_3_k_0 + (32 * 128); + + size_t load_smem_a_0_k_3 = load_smem_a_row + (((load_smem_a_phase + 6) ^ (threads % 8)) * 8); // r60 = r293 + r322; + size_t load_smem_a_1_k_3 = load_smem_a_0_k_3 + (32 * 64); + size_t load_smem_b_0_k_3 = load_smem_b_0_k_0 + (48 * 128); + size_t load_smem_b_1_k_3 = load_smem_b_1_k_0 + (48 * 128); + size_t load_smem_b_2_k_3 = load_smem_b_2_k_0 + (48 * 128); + size_t load_smem_b_3_k_3 = load_smem_b_3_k_0 + (48 * 128); + + // create shared mem (A_1 8192 bytes, A_2 8192 bytes, B_1 16384 bytes, B2_16384 bytes) + __shared__ alignas(16) char smem[49152]; + + // create accs (16 WMMAs and 4 output elements each) and zero + float4 acc_frag_0_0 = make_float4(0.0f,0.0f,0.0f,0.0f); + float4 acc_frag_0_1 = make_float4(0.0f,0.0f,0.0f,0.0f); + float4 acc_frag_0_2 = make_float4(0.0f,0.0f,0.0f,0.0f); + float4 acc_frag_0_3 = make_float4(0.0f,0.0f,0.0f,0.0f); + float4 acc_frag_0_4 = make_float4(0.0f,0.0f,0.0f,0.0f); + float4 acc_frag_0_5 = make_float4(0.0f,0.0f,0.0f,0.0f); + float4 acc_frag_0_6 = make_float4(0.0f,0.0f,0.0f,0.0f); + float4 acc_frag_0_7 = make_float4(0.0f,0.0f,0.0f,0.0f); + float4 acc_frag_1_0 = make_float4(0.0f,0.0f,0.0f,0.0f); + float4 acc_frag_1_1 = make_float4(0.0f,0.0f,0.0f,0.0f); + float4 acc_frag_1_2 = make_float4(0.0f,0.0f,0.0f,0.0f); + float4 acc_frag_1_3 = make_float4(0.0f,0.0f,0.0f,0.0f); + float4 acc_frag_1_4 = make_float4(0.0f,0.0f,0.0f,0.0f); + float4 acc_frag_1_5 = make_float4(0.0f,0.0f,0.0f,0.0f); + float4 acc_frag_1_6 = make_float4(0.0f,0.0f,0.0f,0.0f); + float4 acc_frag_1_7 = make_float4(0.0f,0.0f,0.0f,0.0f); + + // create registers for block A elements (2) + half8 a_frag_0; + half8 a_frag_1; + + // create register for block B elements (8) + half4 b_frag_0; + half4 b_frag_1; + half4 b_frag_2; + half4 b_frag_3; + half4 b_frag_4; + half4 b_frag_5; + half4 b_frag_6; + half4 b_frag_7; + + half *smem_a_even = (half *)(smem); + half *smem_a_odd = (half *)(smem + 8192); + half *smem_b_even = (half *)(smem + 16384); + half *smem_b_odd = (half *)(smem + 32768); + + // https://developer.nvidia.com/blog/controlling-data-movement-to-boost-performance-on-ampere-architecture/ + // https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#asynchronous-data-copies + + // start first pre-fetch load A + __pipeline_memcpy_async(&smem_a_even[store_smem_a_off + ( 0)], &data1[global_a_off + ( 0)], 16); + __pipeline_memcpy_async(&smem_a_even[store_smem_a_off + (16*64)], &data1[global_a_off + (16*K)], 16); + __pipeline_memcpy_async(&smem_a_even[store_smem_a_off + (32*64)], &data1[global_a_off + (32*K)], 16); + __pipeline_memcpy_async(&smem_a_even[store_smem_a_off + (48*64)], &data1[global_a_off + (48*K)], 16); + + // start first pre-fetch load B + __pipeline_memcpy_async(&smem_b_even[store_smem_b_off + ( 0)], &data2[global_b_off + ( 0)], 16); + __pipeline_memcpy_async(&smem_b_even[store_smem_b_off + ( 8*128)], &data2[global_b_off + ( 8*N)], 16); + __pipeline_memcpy_async(&smem_b_even[store_smem_b_off + (16*128)], &data2[global_b_off + (16*N)], 16); + __pipeline_memcpy_async(&smem_b_even[store_smem_b_off + (24*128)], &data2[global_b_off + (24*N)], 16); + __pipeline_memcpy_async(&smem_b_even[store_smem_b_off + (32*128)], &data2[global_b_off + (32*N)], 16); + __pipeline_memcpy_async(&smem_b_even[store_smem_b_off + (40*128)], &data2[global_b_off + (40*N)], 16); + __pipeline_memcpy_async(&smem_b_even[store_smem_b_off + (48*128)], &data2[global_b_off + (48*N)], 16); + __pipeline_memcpy_async(&smem_b_even[store_smem_b_off + (56*128)], &data2[global_b_off + (56*N)], 16); + __pipeline_commit(); + + global_a_off += 64; + global_b_off += 64 * N; + __syncthreads(); + + // start second pre-fetch load A + __pipeline_memcpy_async(&smem_a_odd[store_smem_a_off + ( 0)], &data1[global_a_off + ( 0)], 16); + __pipeline_memcpy_async(&smem_a_odd[store_smem_a_off + (16*64)], &data1[global_a_off + (16*K)], 16); + __pipeline_memcpy_async(&smem_a_odd[store_smem_a_off + (32*64)], &data1[global_a_off + (32*K)], 16); + __pipeline_memcpy_async(&smem_a_odd[store_smem_a_off + (48*64)], &data1[global_a_off + (48*K)], 16); + + // start second pre-fetch load B + __pipeline_memcpy_async(&smem_b_odd[store_smem_b_off + ( 0)], &data2[global_b_off + ( 0)], 16); + __pipeline_memcpy_async(&smem_b_odd[store_smem_b_off + ( 8*128)], &data2[global_b_off + ( 8*N)], 16); + __pipeline_memcpy_async(&smem_b_odd[store_smem_b_off + (16*128)], &data2[global_b_off + (16*N)], 16); + __pipeline_memcpy_async(&smem_b_odd[store_smem_b_off + (24*128)], &data2[global_b_off + (24*N)], 16); + __pipeline_memcpy_async(&smem_b_odd[store_smem_b_off + (32*128)], &data2[global_b_off + (32*N)], 16); + __pipeline_memcpy_async(&smem_b_odd[store_smem_b_off + (40*128)], &data2[global_b_off + (40*N)], 16); + __pipeline_memcpy_async(&smem_b_odd[store_smem_b_off + (48*128)], &data2[global_b_off + (48*N)], 16); + __pipeline_memcpy_async(&smem_b_odd[store_smem_b_off + (56*128)], &data2[global_b_off + (56*N)], 16); + __pipeline_commit(); + + global_a_off += 64; + global_b_off += 64 * N; + + // wait on needed prefetch value + __pipeline_wait_prior(0); // TODO: this enables fast iterations, but incorrect results with 1 (it shouldn't) + __syncthreads(); + + for (int block_k = 0; block_k < num_k_blocks; block_k++) { + // BLOCK_K==4: unroll 4 iterations of ldmatrix/wmma + half *smem_a_curr = (block_k % 2) ? smem_a_even : smem_a_odd; + half *smem_b_curr = (block_k % 2) ? smem_b_even : smem_b_odd; + + // first load 16 K elements and 16 WMMAs: BLOCK_M==2 * BLOCK_N==8 + __ldmatrix_a_elems(&a_frag_0, &smem_a_curr[load_smem_a_0_k_0]); + __ldmatrix_a_elems(&a_frag_1, &smem_a_curr[load_smem_a_1_k_0]); + __ldmatrix_b_elems(&b_frag_0, &b_frag_1, &smem_b_curr[load_smem_b_0_k_0]); + __ldmatrix_b_elems(&b_frag_2, &b_frag_3, &smem_b_curr[load_smem_b_1_k_0]); + __ldmatrix_b_elems(&b_frag_4, &b_frag_5, &smem_b_curr[load_smem_b_2_k_0]); + __ldmatrix_b_elems(&b_frag_6, &b_frag_7, &smem_b_curr[load_smem_b_3_k_0]); + acc_frag_0_0 = __WMMA_8_16_16_half_float(a_frag_0, b_frag_0, acc_frag_0_0); + acc_frag_0_1 = __WMMA_8_16_16_half_float(a_frag_0, b_frag_1, acc_frag_0_1); + acc_frag_0_2 = __WMMA_8_16_16_half_float(a_frag_0, b_frag_2, acc_frag_0_2); + acc_frag_0_3 = __WMMA_8_16_16_half_float(a_frag_0, b_frag_3, acc_frag_0_3); + acc_frag_0_4 = __WMMA_8_16_16_half_float(a_frag_0, b_frag_4, acc_frag_0_4); + acc_frag_0_5 = __WMMA_8_16_16_half_float(a_frag_0, b_frag_5, acc_frag_0_5); + acc_frag_0_6 = __WMMA_8_16_16_half_float(a_frag_0, b_frag_6, acc_frag_0_6); + acc_frag_0_7 = __WMMA_8_16_16_half_float(a_frag_0, b_frag_7, acc_frag_0_7); + acc_frag_1_0 = __WMMA_8_16_16_half_float(a_frag_1, b_frag_0, acc_frag_1_0); + acc_frag_1_1 = __WMMA_8_16_16_half_float(a_frag_1, b_frag_1, acc_frag_1_1); + acc_frag_1_2 = __WMMA_8_16_16_half_float(a_frag_1, b_frag_2, acc_frag_1_2); + acc_frag_1_3 = __WMMA_8_16_16_half_float(a_frag_1, b_frag_3, acc_frag_1_3); + acc_frag_1_4 = __WMMA_8_16_16_half_float(a_frag_1, b_frag_4, acc_frag_1_4); + acc_frag_1_5 = __WMMA_8_16_16_half_float(a_frag_1, b_frag_5, acc_frag_1_5); + acc_frag_1_6 = __WMMA_8_16_16_half_float(a_frag_1, b_frag_6, acc_frag_1_6); + acc_frag_1_7 = __WMMA_8_16_16_half_float(a_frag_1, b_frag_7, acc_frag_1_7); + + // next 16 K elements + __ldmatrix_a_elems(&a_frag_0, &smem_a_curr[load_smem_a_0_k_1]); + __ldmatrix_a_elems(&a_frag_1, &smem_a_curr[load_smem_a_1_k_1]); + __ldmatrix_b_elems(&b_frag_0, &b_frag_1, &smem_b_curr[load_smem_b_0_k_1]); + __ldmatrix_b_elems(&b_frag_2, &b_frag_3, &smem_b_curr[load_smem_b_1_k_1]); + __ldmatrix_b_elems(&b_frag_4, &b_frag_5, &smem_b_curr[load_smem_b_2_k_1]); + __ldmatrix_b_elems(&b_frag_6, &b_frag_7, &smem_b_curr[load_smem_b_3_k_1]); + acc_frag_0_0 = __WMMA_8_16_16_half_float(a_frag_0, b_frag_0, acc_frag_0_0); + acc_frag_0_1 = __WMMA_8_16_16_half_float(a_frag_0, b_frag_1, acc_frag_0_1); + acc_frag_0_2 = __WMMA_8_16_16_half_float(a_frag_0, b_frag_2, acc_frag_0_2); + acc_frag_0_3 = __WMMA_8_16_16_half_float(a_frag_0, b_frag_3, acc_frag_0_3); + acc_frag_0_4 = __WMMA_8_16_16_half_float(a_frag_0, b_frag_4, acc_frag_0_4); + acc_frag_0_5 = __WMMA_8_16_16_half_float(a_frag_0, b_frag_5, acc_frag_0_5); + acc_frag_0_6 = __WMMA_8_16_16_half_float(a_frag_0, b_frag_6, acc_frag_0_6); + acc_frag_0_7 = __WMMA_8_16_16_half_float(a_frag_0, b_frag_7, acc_frag_0_7); + acc_frag_1_0 = __WMMA_8_16_16_half_float(a_frag_1, b_frag_0, acc_frag_1_0); + acc_frag_1_1 = __WMMA_8_16_16_half_float(a_frag_1, b_frag_1, acc_frag_1_1); + acc_frag_1_2 = __WMMA_8_16_16_half_float(a_frag_1, b_frag_2, acc_frag_1_2); + acc_frag_1_3 = __WMMA_8_16_16_half_float(a_frag_1, b_frag_3, acc_frag_1_3); + acc_frag_1_4 = __WMMA_8_16_16_half_float(a_frag_1, b_frag_4, acc_frag_1_4); + acc_frag_1_5 = __WMMA_8_16_16_half_float(a_frag_1, b_frag_5, acc_frag_1_5); + acc_frag_1_6 = __WMMA_8_16_16_half_float(a_frag_1, b_frag_6, acc_frag_1_6); + acc_frag_1_7 = __WMMA_8_16_16_half_float(a_frag_1, b_frag_7, acc_frag_1_7); + + // next 16 K elements + __ldmatrix_a_elems(&a_frag_0, &smem_a_curr[load_smem_a_0_k_2]); + __ldmatrix_a_elems(&a_frag_1, &smem_a_curr[load_smem_a_1_k_2]); + __ldmatrix_b_elems(&b_frag_0, &b_frag_1, &smem_b_curr[load_smem_b_0_k_2]); + __ldmatrix_b_elems(&b_frag_2, &b_frag_3, &smem_b_curr[load_smem_b_1_k_2]); + __ldmatrix_b_elems(&b_frag_4, &b_frag_5, &smem_b_curr[load_smem_b_2_k_2]); + __ldmatrix_b_elems(&b_frag_6, &b_frag_7, &smem_b_curr[load_smem_b_3_k_2]); + acc_frag_0_0 = __WMMA_8_16_16_half_float(a_frag_0, b_frag_0, acc_frag_0_0); + acc_frag_0_1 = __WMMA_8_16_16_half_float(a_frag_0, b_frag_1, acc_frag_0_1); + acc_frag_0_2 = __WMMA_8_16_16_half_float(a_frag_0, b_frag_2, acc_frag_0_2); + acc_frag_0_3 = __WMMA_8_16_16_half_float(a_frag_0, b_frag_3, acc_frag_0_3); + acc_frag_0_4 = __WMMA_8_16_16_half_float(a_frag_0, b_frag_4, acc_frag_0_4); + acc_frag_0_5 = __WMMA_8_16_16_half_float(a_frag_0, b_frag_5, acc_frag_0_5); + acc_frag_0_6 = __WMMA_8_16_16_half_float(a_frag_0, b_frag_6, acc_frag_0_6); + acc_frag_0_7 = __WMMA_8_16_16_half_float(a_frag_0, b_frag_7, acc_frag_0_7); + acc_frag_1_0 = __WMMA_8_16_16_half_float(a_frag_1, b_frag_0, acc_frag_1_0); + acc_frag_1_1 = __WMMA_8_16_16_half_float(a_frag_1, b_frag_1, acc_frag_1_1); + acc_frag_1_2 = __WMMA_8_16_16_half_float(a_frag_1, b_frag_2, acc_frag_1_2); + acc_frag_1_3 = __WMMA_8_16_16_half_float(a_frag_1, b_frag_3, acc_frag_1_3); + acc_frag_1_4 = __WMMA_8_16_16_half_float(a_frag_1, b_frag_4, acc_frag_1_4); + acc_frag_1_5 = __WMMA_8_16_16_half_float(a_frag_1, b_frag_5, acc_frag_1_5); + acc_frag_1_6 = __WMMA_8_16_16_half_float(a_frag_1, b_frag_6, acc_frag_1_6); + acc_frag_1_7 = __WMMA_8_16_16_half_float(a_frag_1, b_frag_7, acc_frag_1_7); + + // last 16 K elements + __ldmatrix_a_elems(&a_frag_0, &smem_a_curr[load_smem_a_0_k_3]); + __ldmatrix_a_elems(&a_frag_1, &smem_a_curr[load_smem_a_1_k_3]); + __ldmatrix_b_elems(&b_frag_0, &b_frag_1, &smem_b_curr[load_smem_b_0_k_3]); + __ldmatrix_b_elems(&b_frag_2, &b_frag_3, &smem_b_curr[load_smem_b_1_k_3]); + __ldmatrix_b_elems(&b_frag_4, &b_frag_5, &smem_b_curr[load_smem_b_2_k_3]); + __ldmatrix_b_elems(&b_frag_6, &b_frag_7, &smem_b_curr[load_smem_b_3_k_3]); + acc_frag_0_0 = __WMMA_8_16_16_half_float(a_frag_0, b_frag_0, acc_frag_0_0); + acc_frag_0_1 = __WMMA_8_16_16_half_float(a_frag_0, b_frag_1, acc_frag_0_1); + acc_frag_0_2 = __WMMA_8_16_16_half_float(a_frag_0, b_frag_2, acc_frag_0_2); + acc_frag_0_3 = __WMMA_8_16_16_half_float(a_frag_0, b_frag_3, acc_frag_0_3); + acc_frag_0_4 = __WMMA_8_16_16_half_float(a_frag_0, b_frag_4, acc_frag_0_4); + acc_frag_0_5 = __WMMA_8_16_16_half_float(a_frag_0, b_frag_5, acc_frag_0_5); + acc_frag_0_6 = __WMMA_8_16_16_half_float(a_frag_0, b_frag_6, acc_frag_0_6); + acc_frag_0_7 = __WMMA_8_16_16_half_float(a_frag_0, b_frag_7, acc_frag_0_7); + acc_frag_1_0 = __WMMA_8_16_16_half_float(a_frag_1, b_frag_0, acc_frag_1_0); + acc_frag_1_1 = __WMMA_8_16_16_half_float(a_frag_1, b_frag_1, acc_frag_1_1); + acc_frag_1_2 = __WMMA_8_16_16_half_float(a_frag_1, b_frag_2, acc_frag_1_2); + acc_frag_1_3 = __WMMA_8_16_16_half_float(a_frag_1, b_frag_3, acc_frag_1_3); + acc_frag_1_4 = __WMMA_8_16_16_half_float(a_frag_1, b_frag_4, acc_frag_1_4); + acc_frag_1_5 = __WMMA_8_16_16_half_float(a_frag_1, b_frag_5, acc_frag_1_5); + acc_frag_1_6 = __WMMA_8_16_16_half_float(a_frag_1, b_frag_6, acc_frag_1_6); + acc_frag_1_7 = __WMMA_8_16_16_half_float(a_frag_1, b_frag_7, acc_frag_1_7); + + // prefetch next iteration if needed + __syncthreads(); + if (block_k < (num_k_blocks-2)) { + __pipeline_memcpy_async(&smem_a_curr[store_smem_a_off + ( 0)], &data1[global_a_off + ( 0)], 16); + __pipeline_memcpy_async(&smem_a_curr[store_smem_a_off + (16*64)], &data1[global_a_off + (16*K)], 16); + __pipeline_memcpy_async(&smem_a_curr[store_smem_a_off + (32*64)], &data1[global_a_off + (32*K)], 16); + __pipeline_memcpy_async(&smem_a_curr[store_smem_a_off + (48*64)], &data1[global_a_off + (48*K)], 16); + + __pipeline_memcpy_async(&smem_b_curr[store_smem_b_off + ( 0)], &data2[global_b_off + ( 0)], 16); + __pipeline_memcpy_async(&smem_b_curr[store_smem_b_off + ( 8*128)], &data2[global_b_off + ( 8*N)], 16); + __pipeline_memcpy_async(&smem_b_curr[store_smem_b_off + (16*128)], &data2[global_b_off + (16*N)], 16); + __pipeline_memcpy_async(&smem_b_curr[store_smem_b_off + (24*128)], &data2[global_b_off + (24*N)], 16); + __pipeline_memcpy_async(&smem_b_curr[store_smem_b_off + (32*128)], &data2[global_b_off + (32*N)], 16); + __pipeline_memcpy_async(&smem_b_curr[store_smem_b_off + (40*128)], &data2[global_b_off + (40*N)], 16); + __pipeline_memcpy_async(&smem_b_curr[store_smem_b_off + (48*128)], &data2[global_b_off + (48*N)], 16); + __pipeline_memcpy_async(&smem_b_curr[store_smem_b_off + (56*128)], &data2[global_b_off + (56*N)], 16); + + global_a_off += 64; + global_b_off += 64 * N; + } + __pipeline_commit(); + + if (block_k < num_k_blocks-1) { + __pipeline_wait_prior(1); + __syncthreads(); + } + } + + // write accumulators to output + __pipeline_wait_prior(0); + __syncthreads(); + + // slower way: write floats one by one to data0 + size_t wg_c_off = ((grid_m * 64) * N) + (grid_n * 128) + (wg_m * 16 * N) + (wg_n * 16); + size_t thread_c_off = ((wg_threads % 4) * 2) + (((wg_threads / 4) % 8) * N); + data0[wg_c_off + thread_c_off + 0 + ( 0*8)] = acc_frag_0_0.x; + data0[wg_c_off + thread_c_off + 1 + ( 0*8)] = acc_frag_0_0.y; + data0[wg_c_off + thread_c_off + (8 * N) + 0 + ( 0*8)] = acc_frag_0_0.z; + data0[wg_c_off + thread_c_off + (8 * N) + 1 + ( 0*8)] = acc_frag_0_0.w; + data0[wg_c_off + thread_c_off + 0 + ( 1*8)] = acc_frag_0_1.x; + data0[wg_c_off + thread_c_off + 1 + ( 1*8)] = acc_frag_0_1.y; + data0[wg_c_off + thread_c_off + (8 * N) + 0 + ( 1*8)] = acc_frag_0_1.z; + data0[wg_c_off + thread_c_off + (8 * N) + 1 + ( 1*8)] = acc_frag_0_1.w; + data0[wg_c_off + thread_c_off + 0 + ( 4*8)] = acc_frag_0_2.x; + data0[wg_c_off + thread_c_off + 1 + ( 4*8)] = acc_frag_0_2.y; + data0[wg_c_off + thread_c_off + (8 * N) + 0 + ( 4*8)] = acc_frag_0_2.z; + data0[wg_c_off + thread_c_off + (8 * N) + 1 + ( 4*8)] = acc_frag_0_2.w; + data0[wg_c_off + thread_c_off + 0 + ( 5*8)] = acc_frag_0_3.x; + data0[wg_c_off + thread_c_off + 1 + ( 5*8)] = acc_frag_0_3.y; + data0[wg_c_off + thread_c_off + (8 * N) + 0 + ( 5*8)] = acc_frag_0_3.z; + data0[wg_c_off + thread_c_off + (8 * N) + 1 + ( 5*8)] = acc_frag_0_3.w; + data0[wg_c_off + thread_c_off + 0 + ( 8*8)] = acc_frag_0_4.x; + data0[wg_c_off + thread_c_off + 1 + ( 8*8)] = acc_frag_0_4.y; + data0[wg_c_off + thread_c_off + (8 * N) + 0 + ( 8*8)] = acc_frag_0_4.z; + data0[wg_c_off + thread_c_off + (8 * N) + 1 + ( 8*8)] = acc_frag_0_4.w; + data0[wg_c_off + thread_c_off + 0 + ( 9*8)] = acc_frag_0_5.x; + data0[wg_c_off + thread_c_off + 1 + ( 9*8)] = acc_frag_0_5.y; + data0[wg_c_off + thread_c_off + (8 * N) + 0 + ( 9*8)] = acc_frag_0_5.z; + data0[wg_c_off + thread_c_off + (8 * N) + 1 + ( 9*8)] = acc_frag_0_5.w; + data0[wg_c_off + thread_c_off + 0 + (12*8)] = acc_frag_0_6.x; + data0[wg_c_off + thread_c_off + 1 + (12*8)] = acc_frag_0_6.y; + data0[wg_c_off + thread_c_off + (8 * N) + 0 + (12*8)] = acc_frag_0_6.z; + data0[wg_c_off + thread_c_off + (8 * N) + 1 + (12*8)] = acc_frag_0_6.w; + data0[wg_c_off + thread_c_off + 0 + (13*8)] = acc_frag_0_7.x; + data0[wg_c_off + thread_c_off + 1 + (13*8)] = acc_frag_0_7.y; + data0[wg_c_off + thread_c_off + (8 * N) + 0 + (13*8)] = acc_frag_0_7.z; + data0[wg_c_off + thread_c_off + (8 * N) + 1 + (13*8)] = acc_frag_0_7.w; + wg_c_off += 32*N; + data0[wg_c_off + thread_c_off + 0 + ( 0*8)] = acc_frag_1_0.x; + data0[wg_c_off + thread_c_off + 1 + ( 0*8)] = acc_frag_1_0.y; + data0[wg_c_off + thread_c_off + (8 * N) + 0 + ( 0*8)] = acc_frag_1_0.z; + data0[wg_c_off + thread_c_off + (8 * N) + 1 + ( 0*8)] = acc_frag_1_0.w; + data0[wg_c_off + thread_c_off + 0 + ( 1*8)] = acc_frag_1_1.x; + data0[wg_c_off + thread_c_off + 1 + ( 1*8)] = acc_frag_1_1.y; + data0[wg_c_off + thread_c_off + (8 * N) + 0 + ( 1*8)] = acc_frag_1_1.z; + data0[wg_c_off + thread_c_off + (8 * N) + 1 + ( 1*8)] = acc_frag_1_1.w; + data0[wg_c_off + thread_c_off + 0 + ( 4*8)] = acc_frag_1_2.x; + data0[wg_c_off + thread_c_off + 1 + ( 4*8)] = acc_frag_1_2.y; + data0[wg_c_off + thread_c_off + (8 * N) + 0 + ( 4*8)] = acc_frag_1_2.z; + data0[wg_c_off + thread_c_off + (8 * N) + 1 + ( 4*8)] = acc_frag_1_2.w; + data0[wg_c_off + thread_c_off + 0 + ( 5*8)] = acc_frag_1_3.x; + data0[wg_c_off + thread_c_off + 1 + ( 5*8)] = acc_frag_1_3.y; + data0[wg_c_off + thread_c_off + (8 * N) + 0 + ( 5*8)] = acc_frag_1_3.z; + data0[wg_c_off + thread_c_off + (8 * N) + 1 + ( 5*8)] = acc_frag_1_3.w; + data0[wg_c_off + thread_c_off + 0 + ( 8*8)] = acc_frag_1_4.x; + data0[wg_c_off + thread_c_off + 1 + ( 8*8)] = acc_frag_1_4.y; + data0[wg_c_off + thread_c_off + (8 * N) + 0 + ( 8*8)] = acc_frag_1_4.z; + data0[wg_c_off + thread_c_off + (8 * N) + 1 + ( 8*8)] = acc_frag_1_4.w; + data0[wg_c_off + thread_c_off + 0 + ( 9*8)] = acc_frag_1_5.x; + data0[wg_c_off + thread_c_off + 1 + ( 9*8)] = acc_frag_1_5.y; + data0[wg_c_off + thread_c_off + (8 * N) + 0 + ( 9*8)] = acc_frag_1_5.z; + data0[wg_c_off + thread_c_off + (8 * N) + 1 + ( 9*8)] = acc_frag_1_5.w; + data0[wg_c_off + thread_c_off + 0 + (12*8)] = acc_frag_1_6.x; + data0[wg_c_off + thread_c_off + 1 + (12*8)] = acc_frag_1_6.y; + data0[wg_c_off + thread_c_off + (8 * N) + 0 + (12*8)] = acc_frag_1_6.z; + data0[wg_c_off + thread_c_off + (8 * N) + 1 + (12*8)] = acc_frag_1_6.w; + data0[wg_c_off + thread_c_off + 0 + (13*8)] = acc_frag_1_7.x; + data0[wg_c_off + thread_c_off + 1 + (13*8)] = acc_frag_1_7.y; + data0[wg_c_off + thread_c_off + (8 * N) + 0 + (13*8)] = acc_frag_1_7.z; + data0[wg_c_off + thread_c_off + (8 * N) + 1 + (13*8)] = acc_frag_1_7.w; +} diff --git a/extra/gemm/max_kernels/nv.fp16_fp32_fp32.flat_smem_input.cu b/extra/gemm/max_kernels/nv.fp16_fp32_fp32.flat_smem_input.cu new file mode 100644 index 0000000000..cf9c2ca394 --- /dev/null +++ b/extra/gemm/max_kernels/nv.fp16_fp32_fp32.flat_smem_input.cu @@ -0,0 +1,363 @@ +#define INFINITY (__int_as_float(0x7f800000)) +#define NAN (__int_as_float(0x7fffffff)) +#include +#include +#define N_PAD 132 + +struct __align__(8) half4 { half x, y, z, w; }; +__device__ half4 make_half4(half x, half y, half z, half w) { half4 r={x, y, z, w}; return r; } + +struct __align__(16) half8 { half x, y, z, w, a, b, c, d; }; +__device__ half8 make_half8(half x, half y, half z, half w, half a, half b, half c, half d) { half8 r={x, y, z, w, a, b, c, d}; return r; } + +__device__ void __ldmatrix_a_elems(half8 *regs, half *smem) { + uint32_t reg0, reg1, reg2, reg3; + asm volatile( + "ldmatrix.sync.aligned.m8n8.x4.shared.b16 {%0, %1, %2, %3}, [%4];" + : "=r"(reg0), "=r"(reg1), "=r"(reg2), "=r"(reg3) + : "l"(__cvta_generic_to_shared(smem)) + ); + uint32_t *addr = reinterpret_cast(regs); + addr[0] = reg0; + addr[1] = reg1; + addr[2] = reg2; + addr[3] = reg3; +} + +__device__ void __ldmatrix_b_elems(half4 *regs_lo, half4 *regs_hi, half *smem) { + uint32_t reg0, reg1, reg2, reg3; + asm volatile( + "ldmatrix.sync.aligned.m8n8.x4.trans.shared.b16 {%0, %1, %2, %3}, [%4];" + : "=r"(reg0), "=r"(reg1), "=r"(reg2), "=r"(reg3) + : "l"(__cvta_generic_to_shared(smem)) + ); + uint32_t *addr_lo = reinterpret_cast(regs_lo); + uint32_t *addr_hi = reinterpret_cast(regs_hi); + addr_lo[0] = reg0; + addr_lo[1] = reg1; + addr_hi[0] = reg2; + addr_hi[1] = reg3; +} + +__device__ float4 __WMMA_8_16_16_half_float(half8 a, half4 b, float4 c) { + int *a_pk = (int *) (&a), *b_pk = (int *) (&b); + asm( "mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 { %0, %1, %2, %3 }, { %4, %5, %6, %7 }, { %8, %9 }, { %0, %1, %2, %3 };" + : "+f"(c.x), "+f"(c.y), "+f"(c.z), "+f"(c.w) : "r"(a_pk[0]), "r"(a_pk[1]), "r"(a_pk[2]), "r"(a_pk[3]), "r"(b_pk[0]), "r"(b_pk[1]) ); + return c; +} + +extern "C" __global__ void __launch_bounds__(128) wmma_example(float* data0, const half* data1, const half* data2, int N, int K) { + int grid_m = blockIdx.x; /* M//64 */ + int grid_n = blockIdx.y; /* N//128 */ + int threads = threadIdx.x; /* 128 */ + int wg_m = (threads/64); // 0 or 1 for 1st and 3rd blocks of b_m=16xb_k=16 vs 2nd and 4th blocks + int wg_n = (threads/32)%2; // 0 or 1 for 1st, 3rd, 5th, 7th blocks of b_n=16xb_k=16 vs 2nd, 4th, 6th, 8th blocks - differs from triton + int wg_threads = threads%32; + int num_k_blocks = K / 64; + + // load indexes + size_t global_a_off = ((grid_m * 64) * K) + ((threads % 8) * 8) + ((threads / 8) * K); + size_t global_b_off = (grid_n * 128) + ((threads % 16) * 8) + ((threads / 16) * N); + + // non-swizzled - should work slowly with bank conflicts + size_t store_smem_a_off = ((threads % 8) * 8) + ((threads / 8) * 64); + size_t store_smem_b_off = ((threads % 16) * 8) + ((threads / 16) * 128); + + // ldmatrix indices + // threads 0-7 are row starts for A, 8-15 for B, 16-23 for C, 24-31 for D + // [ A | C ] + // [ - + - ] + // [ B | D ] + + // unswizzled ldmatrix + size_t load_smem_a_0_k_0 = (wg_m * 16 * 64) + ((wg_threads % 8) * 64) + (((wg_threads / 8) % 2) * 64 * 8) + ((wg_threads / 16) * 8); + size_t load_smem_a_1_k_0 = load_smem_a_0_k_0 + (32*64); + size_t load_smem_b_0_k_0 = (wg_n * 16) + ((wg_threads % 8) * 128) + (((wg_threads / 8) % 2) * 128 * 8) + ((wg_threads / 16) * 8); + size_t load_smem_b_1_k_0 = load_smem_b_0_k_0 + 32; + size_t load_smem_b_2_k_0 = load_smem_b_0_k_0 + 64; + size_t load_smem_b_3_k_0 = load_smem_b_0_k_0 + 96; + + size_t load_smem_a_0_k_1 = load_smem_a_0_k_0 + 16; + size_t load_smem_a_1_k_1 = load_smem_a_1_k_0 + 16; + size_t load_smem_b_0_k_1 = load_smem_b_0_k_0 + (16 * 128); + size_t load_smem_b_1_k_1 = load_smem_b_1_k_0 + (16 * 128); + size_t load_smem_b_2_k_1 = load_smem_b_2_k_0 + (16 * 128); + size_t load_smem_b_3_k_1 = load_smem_b_3_k_0 + (16 * 128); + + size_t load_smem_a_0_k_2 = load_smem_a_0_k_0 + 32; + size_t load_smem_a_1_k_2 = load_smem_a_1_k_0 + 32; + size_t load_smem_b_0_k_2 = load_smem_b_0_k_0 + (32 * 128); + size_t load_smem_b_1_k_2 = load_smem_b_1_k_0 + (32 * 128); + size_t load_smem_b_2_k_2 = load_smem_b_2_k_0 + (32 * 128); + size_t load_smem_b_3_k_2 = load_smem_b_3_k_0 + (32 * 128); + + size_t load_smem_a_0_k_3 = load_smem_a_0_k_0 + 48; + size_t load_smem_a_1_k_3 = load_smem_a_1_k_0 + 48; + size_t load_smem_b_0_k_3 = load_smem_b_0_k_0 + (48 * 128); + size_t load_smem_b_1_k_3 = load_smem_b_1_k_0 + (48 * 128); + size_t load_smem_b_2_k_3 = load_smem_b_2_k_0 + (48 * 128); + size_t load_smem_b_3_k_3 = load_smem_b_3_k_0 + (48 * 128); + + // create shared mem (A 8192 bytes, B 16384 bytes) + __shared__ alignas(16) char smem[24576]; + + // create accs (16 WMMAs and 4 output elements each) and zero + float4 acc_frag_0_0 = make_float4(0.0f,0.0f,0.0f,0.0f); + float4 acc_frag_0_1 = make_float4(0.0f,0.0f,0.0f,0.0f); + float4 acc_frag_0_2 = make_float4(0.0f,0.0f,0.0f,0.0f); + float4 acc_frag_0_3 = make_float4(0.0f,0.0f,0.0f,0.0f); + float4 acc_frag_0_4 = make_float4(0.0f,0.0f,0.0f,0.0f); + float4 acc_frag_0_5 = make_float4(0.0f,0.0f,0.0f,0.0f); + float4 acc_frag_0_6 = make_float4(0.0f,0.0f,0.0f,0.0f); + float4 acc_frag_0_7 = make_float4(0.0f,0.0f,0.0f,0.0f); + float4 acc_frag_1_0 = make_float4(0.0f,0.0f,0.0f,0.0f); + float4 acc_frag_1_1 = make_float4(0.0f,0.0f,0.0f,0.0f); + float4 acc_frag_1_2 = make_float4(0.0f,0.0f,0.0f,0.0f); + float4 acc_frag_1_3 = make_float4(0.0f,0.0f,0.0f,0.0f); + float4 acc_frag_1_4 = make_float4(0.0f,0.0f,0.0f,0.0f); + float4 acc_frag_1_5 = make_float4(0.0f,0.0f,0.0f,0.0f); + float4 acc_frag_1_6 = make_float4(0.0f,0.0f,0.0f,0.0f); + float4 acc_frag_1_7 = make_float4(0.0f,0.0f,0.0f,0.0f); + + // create registers for block A elements (2) + half8 a_frag_0; + half8 a_frag_1; + + // create register for block B elements (8) + half4 b_frag_0; + half4 b_frag_1; + half4 b_frag_2; + half4 b_frag_3; + half4 b_frag_4; + half4 b_frag_5; + half4 b_frag_6; + half4 b_frag_7; + + half *smem_a = (half *)(smem); + half *smem_b = (half *)(smem + 8192); + + // https://developer.nvidia.com/blog/controlling-data-movement-to-boost-performance-on-ampere-architecture/ + // https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#asynchronous-data-copies + + // start first pre-fetch load A + __pipeline_memcpy_async(&smem_a[store_smem_a_off + ( 0)], &data1[global_a_off + ( 0)], 16); + __pipeline_memcpy_async(&smem_a[store_smem_a_off + (16*64)], &data1[global_a_off + (16*K)], 16); + __pipeline_memcpy_async(&smem_a[store_smem_a_off + (32*64)], &data1[global_a_off + (32*K)], 16); + __pipeline_memcpy_async(&smem_a[store_smem_a_off + (48*64)], &data1[global_a_off + (48*K)], 16); + + // start first pre-fetch load B + __pipeline_memcpy_async(&smem_b[store_smem_b_off + ( 0)], &data2[global_b_off + ( 0)], 16); + __pipeline_memcpy_async(&smem_b[store_smem_b_off + ( 8*128)], &data2[global_b_off + ( 8*N)], 16); + __pipeline_memcpy_async(&smem_b[store_smem_b_off + (16*128)], &data2[global_b_off + (16*N)], 16); + __pipeline_memcpy_async(&smem_b[store_smem_b_off + (24*128)], &data2[global_b_off + (24*N)], 16); + __pipeline_memcpy_async(&smem_b[store_smem_b_off + (32*128)], &data2[global_b_off + (32*N)], 16); + __pipeline_memcpy_async(&smem_b[store_smem_b_off + (40*128)], &data2[global_b_off + (40*N)], 16); + __pipeline_memcpy_async(&smem_b[store_smem_b_off + (48*128)], &data2[global_b_off + (48*N)], 16); + __pipeline_memcpy_async(&smem_b[store_smem_b_off + (56*128)], &data2[global_b_off + (56*N)], 16); + __pipeline_commit(); + + global_a_off += 64; + global_b_off += 64 * N; + __syncthreads(); + + for (int block_k = 0; block_k < num_k_blocks; block_k++) { + // wait on needed prefetch value + __pipeline_wait_prior(0); + __syncthreads(); + + // BLOCK_K==4: unroll 4 iterations of ldmatrix/wmma + half *smem_a_curr = smem_a; + half *smem_b_curr = smem_b; + + // first load 16 K elements and 16 WMMAs: BLOCK_M==2 * BLOCK_N==8 + __ldmatrix_a_elems(&a_frag_0, &smem_a_curr[load_smem_a_0_k_0]); + __ldmatrix_a_elems(&a_frag_1, &smem_a_curr[load_smem_a_1_k_0]); + __ldmatrix_b_elems(&b_frag_0, &b_frag_1, &smem_b_curr[load_smem_b_0_k_0]); + __ldmatrix_b_elems(&b_frag_2, &b_frag_3, &smem_b_curr[load_smem_b_1_k_0]); + __ldmatrix_b_elems(&b_frag_4, &b_frag_5, &smem_b_curr[load_smem_b_2_k_0]); + __ldmatrix_b_elems(&b_frag_6, &b_frag_7, &smem_b_curr[load_smem_b_3_k_0]); + acc_frag_0_0 = __WMMA_8_16_16_half_float(a_frag_0, b_frag_0, acc_frag_0_0); + acc_frag_0_1 = __WMMA_8_16_16_half_float(a_frag_0, b_frag_1, acc_frag_0_1); + acc_frag_0_2 = __WMMA_8_16_16_half_float(a_frag_0, b_frag_2, acc_frag_0_2); + acc_frag_0_3 = __WMMA_8_16_16_half_float(a_frag_0, b_frag_3, acc_frag_0_3); + acc_frag_0_4 = __WMMA_8_16_16_half_float(a_frag_0, b_frag_4, acc_frag_0_4); + acc_frag_0_5 = __WMMA_8_16_16_half_float(a_frag_0, b_frag_5, acc_frag_0_5); + acc_frag_0_6 = __WMMA_8_16_16_half_float(a_frag_0, b_frag_6, acc_frag_0_6); + acc_frag_0_7 = __WMMA_8_16_16_half_float(a_frag_0, b_frag_7, acc_frag_0_7); + acc_frag_1_0 = __WMMA_8_16_16_half_float(a_frag_1, b_frag_0, acc_frag_1_0); + acc_frag_1_1 = __WMMA_8_16_16_half_float(a_frag_1, b_frag_1, acc_frag_1_1); + acc_frag_1_2 = __WMMA_8_16_16_half_float(a_frag_1, b_frag_2, acc_frag_1_2); + acc_frag_1_3 = __WMMA_8_16_16_half_float(a_frag_1, b_frag_3, acc_frag_1_3); + acc_frag_1_4 = __WMMA_8_16_16_half_float(a_frag_1, b_frag_4, acc_frag_1_4); + acc_frag_1_5 = __WMMA_8_16_16_half_float(a_frag_1, b_frag_5, acc_frag_1_5); + acc_frag_1_6 = __WMMA_8_16_16_half_float(a_frag_1, b_frag_6, acc_frag_1_6); + acc_frag_1_7 = __WMMA_8_16_16_half_float(a_frag_1, b_frag_7, acc_frag_1_7); + + // next 16 K elements + __ldmatrix_a_elems(&a_frag_0, &smem_a_curr[load_smem_a_0_k_1]); + __ldmatrix_a_elems(&a_frag_1, &smem_a_curr[load_smem_a_1_k_1]); + __ldmatrix_b_elems(&b_frag_0, &b_frag_1, &smem_b_curr[load_smem_b_0_k_1]); + __ldmatrix_b_elems(&b_frag_2, &b_frag_3, &smem_b_curr[load_smem_b_1_k_1]); + __ldmatrix_b_elems(&b_frag_4, &b_frag_5, &smem_b_curr[load_smem_b_2_k_1]); + __ldmatrix_b_elems(&b_frag_6, &b_frag_7, &smem_b_curr[load_smem_b_3_k_1]); + acc_frag_0_0 = __WMMA_8_16_16_half_float(a_frag_0, b_frag_0, acc_frag_0_0); + acc_frag_0_1 = __WMMA_8_16_16_half_float(a_frag_0, b_frag_1, acc_frag_0_1); + acc_frag_0_2 = __WMMA_8_16_16_half_float(a_frag_0, b_frag_2, acc_frag_0_2); + acc_frag_0_3 = __WMMA_8_16_16_half_float(a_frag_0, b_frag_3, acc_frag_0_3); + acc_frag_0_4 = __WMMA_8_16_16_half_float(a_frag_0, b_frag_4, acc_frag_0_4); + acc_frag_0_5 = __WMMA_8_16_16_half_float(a_frag_0, b_frag_5, acc_frag_0_5); + acc_frag_0_6 = __WMMA_8_16_16_half_float(a_frag_0, b_frag_6, acc_frag_0_6); + acc_frag_0_7 = __WMMA_8_16_16_half_float(a_frag_0, b_frag_7, acc_frag_0_7); + acc_frag_1_0 = __WMMA_8_16_16_half_float(a_frag_1, b_frag_0, acc_frag_1_0); + acc_frag_1_1 = __WMMA_8_16_16_half_float(a_frag_1, b_frag_1, acc_frag_1_1); + acc_frag_1_2 = __WMMA_8_16_16_half_float(a_frag_1, b_frag_2, acc_frag_1_2); + acc_frag_1_3 = __WMMA_8_16_16_half_float(a_frag_1, b_frag_3, acc_frag_1_3); + acc_frag_1_4 = __WMMA_8_16_16_half_float(a_frag_1, b_frag_4, acc_frag_1_4); + acc_frag_1_5 = __WMMA_8_16_16_half_float(a_frag_1, b_frag_5, acc_frag_1_5); + acc_frag_1_6 = __WMMA_8_16_16_half_float(a_frag_1, b_frag_6, acc_frag_1_6); + acc_frag_1_7 = __WMMA_8_16_16_half_float(a_frag_1, b_frag_7, acc_frag_1_7); + + // next 16 K elements + __ldmatrix_a_elems(&a_frag_0, &smem_a_curr[load_smem_a_0_k_2]); + __ldmatrix_a_elems(&a_frag_1, &smem_a_curr[load_smem_a_1_k_2]); + __ldmatrix_b_elems(&b_frag_0, &b_frag_1, &smem_b_curr[load_smem_b_0_k_2]); + __ldmatrix_b_elems(&b_frag_2, &b_frag_3, &smem_b_curr[load_smem_b_1_k_2]); + __ldmatrix_b_elems(&b_frag_4, &b_frag_5, &smem_b_curr[load_smem_b_2_k_2]); + __ldmatrix_b_elems(&b_frag_6, &b_frag_7, &smem_b_curr[load_smem_b_3_k_2]); + acc_frag_0_0 = __WMMA_8_16_16_half_float(a_frag_0, b_frag_0, acc_frag_0_0); + acc_frag_0_1 = __WMMA_8_16_16_half_float(a_frag_0, b_frag_1, acc_frag_0_1); + acc_frag_0_2 = __WMMA_8_16_16_half_float(a_frag_0, b_frag_2, acc_frag_0_2); + acc_frag_0_3 = __WMMA_8_16_16_half_float(a_frag_0, b_frag_3, acc_frag_0_3); + acc_frag_0_4 = __WMMA_8_16_16_half_float(a_frag_0, b_frag_4, acc_frag_0_4); + acc_frag_0_5 = __WMMA_8_16_16_half_float(a_frag_0, b_frag_5, acc_frag_0_5); + acc_frag_0_6 = __WMMA_8_16_16_half_float(a_frag_0, b_frag_6, acc_frag_0_6); + acc_frag_0_7 = __WMMA_8_16_16_half_float(a_frag_0, b_frag_7, acc_frag_0_7); + acc_frag_1_0 = __WMMA_8_16_16_half_float(a_frag_1, b_frag_0, acc_frag_1_0); + acc_frag_1_1 = __WMMA_8_16_16_half_float(a_frag_1, b_frag_1, acc_frag_1_1); + acc_frag_1_2 = __WMMA_8_16_16_half_float(a_frag_1, b_frag_2, acc_frag_1_2); + acc_frag_1_3 = __WMMA_8_16_16_half_float(a_frag_1, b_frag_3, acc_frag_1_3); + acc_frag_1_4 = __WMMA_8_16_16_half_float(a_frag_1, b_frag_4, acc_frag_1_4); + acc_frag_1_5 = __WMMA_8_16_16_half_float(a_frag_1, b_frag_5, acc_frag_1_5); + acc_frag_1_6 = __WMMA_8_16_16_half_float(a_frag_1, b_frag_6, acc_frag_1_6); + acc_frag_1_7 = __WMMA_8_16_16_half_float(a_frag_1, b_frag_7, acc_frag_1_7); + + // last 16 K elements + __ldmatrix_a_elems(&a_frag_0, &smem_a_curr[load_smem_a_0_k_3]); + __ldmatrix_a_elems(&a_frag_1, &smem_a_curr[load_smem_a_1_k_3]); + __ldmatrix_b_elems(&b_frag_0, &b_frag_1, &smem_b_curr[load_smem_b_0_k_3]); + __ldmatrix_b_elems(&b_frag_2, &b_frag_3, &smem_b_curr[load_smem_b_1_k_3]); + __ldmatrix_b_elems(&b_frag_4, &b_frag_5, &smem_b_curr[load_smem_b_2_k_3]); + __ldmatrix_b_elems(&b_frag_6, &b_frag_7, &smem_b_curr[load_smem_b_3_k_3]); + acc_frag_0_0 = __WMMA_8_16_16_half_float(a_frag_0, b_frag_0, acc_frag_0_0); + acc_frag_0_1 = __WMMA_8_16_16_half_float(a_frag_0, b_frag_1, acc_frag_0_1); + acc_frag_0_2 = __WMMA_8_16_16_half_float(a_frag_0, b_frag_2, acc_frag_0_2); + acc_frag_0_3 = __WMMA_8_16_16_half_float(a_frag_0, b_frag_3, acc_frag_0_3); + acc_frag_0_4 = __WMMA_8_16_16_half_float(a_frag_0, b_frag_4, acc_frag_0_4); + acc_frag_0_5 = __WMMA_8_16_16_half_float(a_frag_0, b_frag_5, acc_frag_0_5); + acc_frag_0_6 = __WMMA_8_16_16_half_float(a_frag_0, b_frag_6, acc_frag_0_6); + acc_frag_0_7 = __WMMA_8_16_16_half_float(a_frag_0, b_frag_7, acc_frag_0_7); + acc_frag_1_0 = __WMMA_8_16_16_half_float(a_frag_1, b_frag_0, acc_frag_1_0); + acc_frag_1_1 = __WMMA_8_16_16_half_float(a_frag_1, b_frag_1, acc_frag_1_1); + acc_frag_1_2 = __WMMA_8_16_16_half_float(a_frag_1, b_frag_2, acc_frag_1_2); + acc_frag_1_3 = __WMMA_8_16_16_half_float(a_frag_1, b_frag_3, acc_frag_1_3); + acc_frag_1_4 = __WMMA_8_16_16_half_float(a_frag_1, b_frag_4, acc_frag_1_4); + acc_frag_1_5 = __WMMA_8_16_16_half_float(a_frag_1, b_frag_5, acc_frag_1_5); + acc_frag_1_6 = __WMMA_8_16_16_half_float(a_frag_1, b_frag_6, acc_frag_1_6); + acc_frag_1_7 = __WMMA_8_16_16_half_float(a_frag_1, b_frag_7, acc_frag_1_7); + + // prefetch next iteration if needed + __syncthreads(); + if (block_k < (num_k_blocks-1)) { + __pipeline_memcpy_async(&smem_a_curr[store_smem_a_off + ( 0)], &data1[global_a_off + ( 0)], 16); + __pipeline_memcpy_async(&smem_a_curr[store_smem_a_off + (16*64)], &data1[global_a_off + (16*K)], 16); + __pipeline_memcpy_async(&smem_a_curr[store_smem_a_off + (32*64)], &data1[global_a_off + (32*K)], 16); + __pipeline_memcpy_async(&smem_a_curr[store_smem_a_off + (48*64)], &data1[global_a_off + (48*K)], 16); + + __pipeline_memcpy_async(&smem_b_curr[store_smem_b_off + ( 0)], &data2[global_b_off + ( 0)], 16); + __pipeline_memcpy_async(&smem_b_curr[store_smem_b_off + ( 8*128)], &data2[global_b_off + ( 8*N)], 16); + __pipeline_memcpy_async(&smem_b_curr[store_smem_b_off + (16*128)], &data2[global_b_off + (16*N)], 16); + __pipeline_memcpy_async(&smem_b_curr[store_smem_b_off + (24*128)], &data2[global_b_off + (24*N)], 16); + __pipeline_memcpy_async(&smem_b_curr[store_smem_b_off + (32*128)], &data2[global_b_off + (32*N)], 16); + __pipeline_memcpy_async(&smem_b_curr[store_smem_b_off + (40*128)], &data2[global_b_off + (40*N)], 16); + __pipeline_memcpy_async(&smem_b_curr[store_smem_b_off + (48*128)], &data2[global_b_off + (48*N)], 16); + __pipeline_memcpy_async(&smem_b_curr[store_smem_b_off + (56*128)], &data2[global_b_off + (56*N)], 16); + + global_a_off += 64; + global_b_off += 64 * N; + } + __pipeline_commit(); + } + + // write accumulators to output + __pipeline_wait_prior(0); + __syncthreads(); + + // slower way: write floats one by one to data0 + size_t wg_c_off = ((grid_m * 64) * N) + (grid_n * 128) + (wg_m * 16 * N) + (wg_n * 16); + size_t thread_c_off = ((wg_threads % 4) * 2) + (((wg_threads / 4) % 8) * N); + data0[wg_c_off + thread_c_off + 0 + ( 0*8)] = acc_frag_0_0.x; + data0[wg_c_off + thread_c_off + 1 + ( 0*8)] = acc_frag_0_0.y; + data0[wg_c_off + thread_c_off + (8 * N) + 0 + ( 0*8)] = acc_frag_0_0.z; + data0[wg_c_off + thread_c_off + (8 * N) + 1 + ( 0*8)] = acc_frag_0_0.w; + data0[wg_c_off + thread_c_off + 0 + ( 1*8)] = acc_frag_0_1.x; + data0[wg_c_off + thread_c_off + 1 + ( 1*8)] = acc_frag_0_1.y; + data0[wg_c_off + thread_c_off + (8 * N) + 0 + ( 1*8)] = acc_frag_0_1.z; + data0[wg_c_off + thread_c_off + (8 * N) + 1 + ( 1*8)] = acc_frag_0_1.w; + data0[wg_c_off + thread_c_off + 0 + ( 4*8)] = acc_frag_0_2.x; + data0[wg_c_off + thread_c_off + 1 + ( 4*8)] = acc_frag_0_2.y; + data0[wg_c_off + thread_c_off + (8 * N) + 0 + ( 4*8)] = acc_frag_0_2.z; + data0[wg_c_off + thread_c_off + (8 * N) + 1 + ( 4*8)] = acc_frag_0_2.w; + data0[wg_c_off + thread_c_off + 0 + ( 5*8)] = acc_frag_0_3.x; + data0[wg_c_off + thread_c_off + 1 + ( 5*8)] = acc_frag_0_3.y; + data0[wg_c_off + thread_c_off + (8 * N) + 0 + ( 5*8)] = acc_frag_0_3.z; + data0[wg_c_off + thread_c_off + (8 * N) + 1 + ( 5*8)] = acc_frag_0_3.w; + data0[wg_c_off + thread_c_off + 0 + ( 8*8)] = acc_frag_0_4.x; + data0[wg_c_off + thread_c_off + 1 + ( 8*8)] = acc_frag_0_4.y; + data0[wg_c_off + thread_c_off + (8 * N) + 0 + ( 8*8)] = acc_frag_0_4.z; + data0[wg_c_off + thread_c_off + (8 * N) + 1 + ( 8*8)] = acc_frag_0_4.w; + data0[wg_c_off + thread_c_off + 0 + ( 9*8)] = acc_frag_0_5.x; + data0[wg_c_off + thread_c_off + 1 + ( 9*8)] = acc_frag_0_5.y; + data0[wg_c_off + thread_c_off + (8 * N) + 0 + ( 9*8)] = acc_frag_0_5.z; + data0[wg_c_off + thread_c_off + (8 * N) + 1 + ( 9*8)] = acc_frag_0_5.w; + data0[wg_c_off + thread_c_off + 0 + (12*8)] = acc_frag_0_6.x; + data0[wg_c_off + thread_c_off + 1 + (12*8)] = acc_frag_0_6.y; + data0[wg_c_off + thread_c_off + (8 * N) + 0 + (12*8)] = acc_frag_0_6.z; + data0[wg_c_off + thread_c_off + (8 * N) + 1 + (12*8)] = acc_frag_0_6.w; + data0[wg_c_off + thread_c_off + 0 + (13*8)] = acc_frag_0_7.x; + data0[wg_c_off + thread_c_off + 1 + (13*8)] = acc_frag_0_7.y; + data0[wg_c_off + thread_c_off + (8 * N) + 0 + (13*8)] = acc_frag_0_7.z; + data0[wg_c_off + thread_c_off + (8 * N) + 1 + (13*8)] = acc_frag_0_7.w; + wg_c_off += 32*N; + data0[wg_c_off + thread_c_off + 0 + ( 0*8)] = acc_frag_1_0.x; + data0[wg_c_off + thread_c_off + 1 + ( 0*8)] = acc_frag_1_0.y; + data0[wg_c_off + thread_c_off + (8 * N) + 0 + ( 0*8)] = acc_frag_1_0.z; + data0[wg_c_off + thread_c_off + (8 * N) + 1 + ( 0*8)] = acc_frag_1_0.w; + data0[wg_c_off + thread_c_off + 0 + ( 1*8)] = acc_frag_1_1.x; + data0[wg_c_off + thread_c_off + 1 + ( 1*8)] = acc_frag_1_1.y; + data0[wg_c_off + thread_c_off + (8 * N) + 0 + ( 1*8)] = acc_frag_1_1.z; + data0[wg_c_off + thread_c_off + (8 * N) + 1 + ( 1*8)] = acc_frag_1_1.w; + data0[wg_c_off + thread_c_off + 0 + ( 4*8)] = acc_frag_1_2.x; + data0[wg_c_off + thread_c_off + 1 + ( 4*8)] = acc_frag_1_2.y; + data0[wg_c_off + thread_c_off + (8 * N) + 0 + ( 4*8)] = acc_frag_1_2.z; + data0[wg_c_off + thread_c_off + (8 * N) + 1 + ( 4*8)] = acc_frag_1_2.w; + data0[wg_c_off + thread_c_off + 0 + ( 5*8)] = acc_frag_1_3.x; + data0[wg_c_off + thread_c_off + 1 + ( 5*8)] = acc_frag_1_3.y; + data0[wg_c_off + thread_c_off + (8 * N) + 0 + ( 5*8)] = acc_frag_1_3.z; + data0[wg_c_off + thread_c_off + (8 * N) + 1 + ( 5*8)] = acc_frag_1_3.w; + data0[wg_c_off + thread_c_off + 0 + ( 8*8)] = acc_frag_1_4.x; + data0[wg_c_off + thread_c_off + 1 + ( 8*8)] = acc_frag_1_4.y; + data0[wg_c_off + thread_c_off + (8 * N) + 0 + ( 8*8)] = acc_frag_1_4.z; + data0[wg_c_off + thread_c_off + (8 * N) + 1 + ( 8*8)] = acc_frag_1_4.w; + data0[wg_c_off + thread_c_off + 0 + ( 9*8)] = acc_frag_1_5.x; + data0[wg_c_off + thread_c_off + 1 + ( 9*8)] = acc_frag_1_5.y; + data0[wg_c_off + thread_c_off + (8 * N) + 0 + ( 9*8)] = acc_frag_1_5.z; + data0[wg_c_off + thread_c_off + (8 * N) + 1 + ( 9*8)] = acc_frag_1_5.w; + data0[wg_c_off + thread_c_off + 0 + (12*8)] = acc_frag_1_6.x; + data0[wg_c_off + thread_c_off + 1 + (12*8)] = acc_frag_1_6.y; + data0[wg_c_off + thread_c_off + (8 * N) + 0 + (12*8)] = acc_frag_1_6.z; + data0[wg_c_off + thread_c_off + (8 * N) + 1 + (12*8)] = acc_frag_1_6.w; + data0[wg_c_off + thread_c_off + 0 + (13*8)] = acc_frag_1_7.x; + data0[wg_c_off + thread_c_off + 1 + (13*8)] = acc_frag_1_7.y; + data0[wg_c_off + thread_c_off + (8 * N) + 0 + (13*8)] = acc_frag_1_7.z; + data0[wg_c_off + thread_c_off + (8 * N) + 1 + (13*8)] = acc_frag_1_7.w; +} diff --git a/extra/gemm/max_kernels/nv.fp16_fp32_fp32.max.cu b/extra/gemm/max_kernels/nv.fp16_fp32_fp32.max.cu new file mode 100644 index 0000000000..b21e867738 --- /dev/null +++ b/extra/gemm/max_kernels/nv.fp16_fp32_fp32.max.cu @@ -0,0 +1,439 @@ +#define INFINITY (__int_as_float(0x7f800000)) +#define NAN (__int_as_float(0x7fffffff)) +#include +#include +#define N_PAD 132 + +struct __align__(8) half4 { half x, y, z, w; }; +__device__ half4 make_half4(half x, half y, half z, half w) { half4 r={x, y, z, w}; return r; } + +struct __align__(16) half8 { half x, y, z, w, a, b, c, d; }; +__device__ half8 make_half8(half x, half y, half z, half w, half a, half b, half c, half d) { half8 r={x, y, z, w, a, b, c, d}; return r; } + +__device__ void __ldmatrix_a_elems(half8 *regs, half *smem) { + uint32_t reg0, reg1, reg2, reg3; + asm volatile( + "ldmatrix.sync.aligned.m8n8.x4.shared.b16 {%0, %1, %2, %3}, [%4];" + : "=r"(reg0), "=r"(reg1), "=r"(reg2), "=r"(reg3) + : "l"(__cvta_generic_to_shared(smem)) + ); + uint32_t *addr = reinterpret_cast(regs); + addr[0] = reg0; + addr[1] = reg1; + addr[2] = reg2; + addr[3] = reg3; +} + +__device__ void __ldmatrix_b_elems(half4 *regs_lo, half4 *regs_hi, half *smem) { + uint32_t reg0, reg1, reg2, reg3; + asm volatile( + "ldmatrix.sync.aligned.m8n8.x4.trans.shared.b16 {%0, %1, %2, %3}, [%4];" + : "=r"(reg0), "=r"(reg1), "=r"(reg2), "=r"(reg3) + : "l"(__cvta_generic_to_shared(smem)) + ); + uint32_t *addr_lo = reinterpret_cast(regs_lo); + uint32_t *addr_hi = reinterpret_cast(regs_hi); + addr_lo[0] = reg0; + addr_lo[1] = reg1; + addr_hi[0] = reg2; + addr_hi[1] = reg3; +} + +__device__ float4 __WMMA_8_16_16_half_float(half8 a, half4 b, float4 c) { + int *a_pk = (int *) (&a), *b_pk = (int *) (&b); + asm( "mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 { %0, %1, %2, %3 }, { %4, %5, %6, %7 }, { %8, %9 }, { %0, %1, %2, %3 };" + : "+f"(c.x), "+f"(c.y), "+f"(c.z), "+f"(c.w) : "r"(a_pk[0]), "r"(a_pk[1]), "r"(a_pk[2]), "r"(a_pk[3]), "r"(b_pk[0]), "r"(b_pk[1]) ); + return c; +} + +extern "C" __global__ void __launch_bounds__(128) wmma_example(float* data0, const half* data1, const half* data2, int N, int K) { + int grid_m = blockIdx.x; /* M//64 */ + int grid_n = blockIdx.y; /* N//128 */ + int threads = threadIdx.x; /* 128 */ + int wg_m = (threads/64); // 0 or 1 for 1st and 3rd blocks of b_m=16xb_k=16 vs 2nd and 4th blocks + int wg_n = (threads/32)%2; // 0 or 1 for 1st, 3rd, 5th, 7th blocks of b_n=16xb_k=16 vs 2nd, 4th, 6th, 8th blocks - differs from triton + int wg_threads = threads%32; + int num_k_blocks = K / 64; + + // load indexes + size_t global_a_off = ((grid_m * 64) * K) + ((threads % 8) * 8) + ((threads / 8) * K); + size_t global_b_off = (grid_n * 128) + ((threads % 16) * 8) + ((threads / 16) * N); + + // swizzled smem store offsets - columns of smem are swizzled + // here's a link to a description of the triton: https://github.com/triton-lang/triton/discussions/2026#discussioncomment-6746579 + // see also the thunderkittens impl: https://github.com/HazyResearch/ThunderKittens/blob/main/include/types/shared/st.cuh + size_t store_smem_a_off = ((threads / 8) * 64) + (((threads * 8) ^ threads) & 56); // r15 + size_t store_smem_b_off = ((threads / 16) * 128) + (((threads / 16) * 8) ^ ((threads % 16) * 8)); // r19 + + // ldmatrix indices - 4x loads of 8x8 matrices by 32 threads + // threads 0-7 are row starts for A, 8-15 for B, 16-23 for C, 24-31 for D + // [ A | C ] + // [ - + - ] + // [ B | D ] + + // swizzled ldmatrix + size_t load_smem_a_row = ((wg_m * 16) + (threads % 16)) * 64; // r293 + size_t load_smem_a_phase = (threads / 16) % 2; // r4 + + size_t load_smem_b_row = (threads % 16) * 128; // r299 + size_t load_smem_b_phase = (wg_n * 2) + (((threads / 16) % 2)); // r297 -- this differs from the generated triton kernel (swapped order) + + size_t load_smem_a_0_k_0 = load_smem_a_row + (((load_smem_a_phase + 0) ^ (threads % 8)) * 8); // r38 + size_t load_smem_a_1_k_0 = load_smem_a_0_k_0 + (32 * 64); + size_t load_smem_b_0_k_0 = load_smem_b_row + (((load_smem_b_phase + 0) ^ (threads % 8)) * 8); + size_t load_smem_b_1_k_0 = load_smem_b_row + (((load_smem_b_phase + 4) ^ (threads % 8)) * 8); + size_t load_smem_b_2_k_0 = load_smem_b_row + (((load_smem_b_phase + 8) ^ (threads % 8)) * 8); + size_t load_smem_b_3_k_0 = load_smem_b_row + (((load_smem_b_phase + 12) ^ (threads % 8)) * 8); + + size_t load_smem_a_0_k_1 = load_smem_a_row + (((load_smem_a_phase + 2) ^ (threads % 8)) * 8); // r58 = r293 + r316; + size_t load_smem_a_1_k_1 = load_smem_a_0_k_1 + (32 * 64); + size_t load_smem_b_0_k_1 = load_smem_b_0_k_0 + (16 * 128); + size_t load_smem_b_1_k_1 = load_smem_b_1_k_0 + (16 * 128); + size_t load_smem_b_2_k_1 = load_smem_b_2_k_0 + (16 * 128); + size_t load_smem_b_3_k_1 = load_smem_b_3_k_0 + (16 * 128); + + size_t load_smem_a_0_k_2 = load_smem_a_row + (((load_smem_a_phase + 4) ^ (threads % 8)) * 8); // r59 = r293 + r319; + size_t load_smem_a_1_k_2 = load_smem_a_0_k_2 + (32 * 64); + size_t load_smem_b_0_k_2 = load_smem_b_0_k_0 + (32 * 128); + size_t load_smem_b_1_k_2 = load_smem_b_1_k_0 + (32 * 128); + size_t load_smem_b_2_k_2 = load_smem_b_2_k_0 + (32 * 128); + size_t load_smem_b_3_k_2 = load_smem_b_3_k_0 + (32 * 128); + + size_t load_smem_a_0_k_3 = load_smem_a_row + (((load_smem_a_phase + 6) ^ (threads % 8)) * 8); // r60 = r293 + r322; + size_t load_smem_a_1_k_3 = load_smem_a_0_k_3 + (32 * 64); + size_t load_smem_b_0_k_3 = load_smem_b_0_k_0 + (48 * 128); + size_t load_smem_b_1_k_3 = load_smem_b_1_k_0 + (48 * 128); + size_t load_smem_b_2_k_3 = load_smem_b_2_k_0 + (48 * 128); + size_t load_smem_b_3_k_3 = load_smem_b_3_k_0 + (48 * 128); + + // create shared mem (A_1 8192 bytes, A_2 8192 bytes, B_1 16384 bytes, B2_16384 bytes) + __shared__ alignas(16) char smem[49152]; + + // create accs (16 WMMAs and 4 output elements each) and zero + float4 acc_frag_0_0 = make_float4(0.0f,0.0f,0.0f,0.0f); + float4 acc_frag_0_1 = make_float4(0.0f,0.0f,0.0f,0.0f); + float4 acc_frag_0_2 = make_float4(0.0f,0.0f,0.0f,0.0f); + float4 acc_frag_0_3 = make_float4(0.0f,0.0f,0.0f,0.0f); + float4 acc_frag_0_4 = make_float4(0.0f,0.0f,0.0f,0.0f); + float4 acc_frag_0_5 = make_float4(0.0f,0.0f,0.0f,0.0f); + float4 acc_frag_0_6 = make_float4(0.0f,0.0f,0.0f,0.0f); + float4 acc_frag_0_7 = make_float4(0.0f,0.0f,0.0f,0.0f); + float4 acc_frag_1_0 = make_float4(0.0f,0.0f,0.0f,0.0f); + float4 acc_frag_1_1 = make_float4(0.0f,0.0f,0.0f,0.0f); + float4 acc_frag_1_2 = make_float4(0.0f,0.0f,0.0f,0.0f); + float4 acc_frag_1_3 = make_float4(0.0f,0.0f,0.0f,0.0f); + float4 acc_frag_1_4 = make_float4(0.0f,0.0f,0.0f,0.0f); + float4 acc_frag_1_5 = make_float4(0.0f,0.0f,0.0f,0.0f); + float4 acc_frag_1_6 = make_float4(0.0f,0.0f,0.0f,0.0f); + float4 acc_frag_1_7 = make_float4(0.0f,0.0f,0.0f,0.0f); + + // create registers for block A elements (2) + half8 a_frag_0; + half8 a_frag_1; + + // create register for block B elements (8) + half4 b_frag_0; + half4 b_frag_1; + half4 b_frag_2; + half4 b_frag_3; + half4 b_frag_4; + half4 b_frag_5; + half4 b_frag_6; + half4 b_frag_7; + + half *smem_a_even = (half *)(smem); + half *smem_a_odd = (half *)(smem + 8192); + half *smem_b_even = (half *)(smem + 16384); + half *smem_b_odd = (half *)(smem + 32768); + + // https://developer.nvidia.com/blog/controlling-data-movement-to-boost-performance-on-ampere-architecture/ + // https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#asynchronous-data-copies + + // start first pre-fetch load A + __pipeline_memcpy_async(&smem_a_even[store_smem_a_off + ( 0)], &data1[global_a_off + ( 0)], 16); + __pipeline_memcpy_async(&smem_a_even[store_smem_a_off + (16*64)], &data1[global_a_off + (16*K)], 16); + __pipeline_memcpy_async(&smem_a_even[store_smem_a_off + (32*64)], &data1[global_a_off + (32*K)], 16); + __pipeline_memcpy_async(&smem_a_even[store_smem_a_off + (48*64)], &data1[global_a_off + (48*K)], 16); + + // start first pre-fetch load B + __pipeline_memcpy_async(&smem_b_even[store_smem_b_off + ( 0)], &data2[global_b_off + ( 0)], 16); + __pipeline_memcpy_async(&smem_b_even[store_smem_b_off + ( 8*128)], &data2[global_b_off + ( 8*N)], 16); + __pipeline_memcpy_async(&smem_b_even[store_smem_b_off + (16*128)], &data2[global_b_off + (16*N)], 16); + __pipeline_memcpy_async(&smem_b_even[store_smem_b_off + (24*128)], &data2[global_b_off + (24*N)], 16); + __pipeline_memcpy_async(&smem_b_even[store_smem_b_off + (32*128)], &data2[global_b_off + (32*N)], 16); + __pipeline_memcpy_async(&smem_b_even[store_smem_b_off + (40*128)], &data2[global_b_off + (40*N)], 16); + __pipeline_memcpy_async(&smem_b_even[store_smem_b_off + (48*128)], &data2[global_b_off + (48*N)], 16); + __pipeline_memcpy_async(&smem_b_even[store_smem_b_off + (56*128)], &data2[global_b_off + (56*N)], 16); + __pipeline_commit(); + + global_a_off += 64; + global_b_off += 64 * N; + __syncthreads(); + + // start second pre-fetch load A + __pipeline_memcpy_async(&smem_a_odd[store_smem_a_off + ( 0)], &data1[global_a_off + ( 0)], 16); + __pipeline_memcpy_async(&smem_a_odd[store_smem_a_off + (16*64)], &data1[global_a_off + (16*K)], 16); + __pipeline_memcpy_async(&smem_a_odd[store_smem_a_off + (32*64)], &data1[global_a_off + (32*K)], 16); + __pipeline_memcpy_async(&smem_a_odd[store_smem_a_off + (48*64)], &data1[global_a_off + (48*K)], 16); + + // start second pre-fetch load B + __pipeline_memcpy_async(&smem_b_odd[store_smem_b_off + ( 0)], &data2[global_b_off + ( 0)], 16); + __pipeline_memcpy_async(&smem_b_odd[store_smem_b_off + ( 8*128)], &data2[global_b_off + ( 8*N)], 16); + __pipeline_memcpy_async(&smem_b_odd[store_smem_b_off + (16*128)], &data2[global_b_off + (16*N)], 16); + __pipeline_memcpy_async(&smem_b_odd[store_smem_b_off + (24*128)], &data2[global_b_off + (24*N)], 16); + __pipeline_memcpy_async(&smem_b_odd[store_smem_b_off + (32*128)], &data2[global_b_off + (32*N)], 16); + __pipeline_memcpy_async(&smem_b_odd[store_smem_b_off + (40*128)], &data2[global_b_off + (40*N)], 16); + __pipeline_memcpy_async(&smem_b_odd[store_smem_b_off + (48*128)], &data2[global_b_off + (48*N)], 16); + __pipeline_memcpy_async(&smem_b_odd[store_smem_b_off + (56*128)], &data2[global_b_off + (56*N)], 16); + __pipeline_commit(); + + global_a_off += 64; + global_b_off += 64 * N; + + // wait on needed prefetch value + __pipeline_wait_prior(0); // TODO: this enables fast iterations, but incorrect results with 1 (it shouldn't) + __syncthreads(); + + for (int block_k = 0; block_k < num_k_blocks; block_k++) { + // BLOCK_K==4: unroll 4 iterations of ldmatrix/wmma + half *smem_a_curr = (block_k % 2) ? smem_a_even : smem_a_odd; + half *smem_b_curr = (block_k % 2) ? smem_b_even : smem_b_odd; + + // first load 16 K elements and 16 WMMAs: BLOCK_M==2 * BLOCK_N==8 + __ldmatrix_a_elems(&a_frag_0, &smem_a_curr[load_smem_a_0_k_0]); + __ldmatrix_a_elems(&a_frag_1, &smem_a_curr[load_smem_a_1_k_0]); + __ldmatrix_b_elems(&b_frag_0, &b_frag_1, &smem_b_curr[load_smem_b_0_k_0]); + __ldmatrix_b_elems(&b_frag_2, &b_frag_3, &smem_b_curr[load_smem_b_1_k_0]); + __ldmatrix_b_elems(&b_frag_4, &b_frag_5, &smem_b_curr[load_smem_b_2_k_0]); + __ldmatrix_b_elems(&b_frag_6, &b_frag_7, &smem_b_curr[load_smem_b_3_k_0]); + acc_frag_0_0 = __WMMA_8_16_16_half_float(a_frag_0, b_frag_0, acc_frag_0_0); + acc_frag_0_1 = __WMMA_8_16_16_half_float(a_frag_0, b_frag_1, acc_frag_0_1); + acc_frag_0_2 = __WMMA_8_16_16_half_float(a_frag_0, b_frag_2, acc_frag_0_2); + acc_frag_0_3 = __WMMA_8_16_16_half_float(a_frag_0, b_frag_3, acc_frag_0_3); + acc_frag_0_4 = __WMMA_8_16_16_half_float(a_frag_0, b_frag_4, acc_frag_0_4); + acc_frag_0_5 = __WMMA_8_16_16_half_float(a_frag_0, b_frag_5, acc_frag_0_5); + acc_frag_0_6 = __WMMA_8_16_16_half_float(a_frag_0, b_frag_6, acc_frag_0_6); + acc_frag_0_7 = __WMMA_8_16_16_half_float(a_frag_0, b_frag_7, acc_frag_0_7); + acc_frag_1_0 = __WMMA_8_16_16_half_float(a_frag_1, b_frag_0, acc_frag_1_0); + acc_frag_1_1 = __WMMA_8_16_16_half_float(a_frag_1, b_frag_1, acc_frag_1_1); + acc_frag_1_2 = __WMMA_8_16_16_half_float(a_frag_1, b_frag_2, acc_frag_1_2); + acc_frag_1_3 = __WMMA_8_16_16_half_float(a_frag_1, b_frag_3, acc_frag_1_3); + acc_frag_1_4 = __WMMA_8_16_16_half_float(a_frag_1, b_frag_4, acc_frag_1_4); + acc_frag_1_5 = __WMMA_8_16_16_half_float(a_frag_1, b_frag_5, acc_frag_1_5); + acc_frag_1_6 = __WMMA_8_16_16_half_float(a_frag_1, b_frag_6, acc_frag_1_6); + acc_frag_1_7 = __WMMA_8_16_16_half_float(a_frag_1, b_frag_7, acc_frag_1_7); + + // next 16 K elements + __ldmatrix_a_elems(&a_frag_0, &smem_a_curr[load_smem_a_0_k_1]); + __ldmatrix_a_elems(&a_frag_1, &smem_a_curr[load_smem_a_1_k_1]); + __ldmatrix_b_elems(&b_frag_0, &b_frag_1, &smem_b_curr[load_smem_b_0_k_1]); + __ldmatrix_b_elems(&b_frag_2, &b_frag_3, &smem_b_curr[load_smem_b_1_k_1]); + __ldmatrix_b_elems(&b_frag_4, &b_frag_5, &smem_b_curr[load_smem_b_2_k_1]); + __ldmatrix_b_elems(&b_frag_6, &b_frag_7, &smem_b_curr[load_smem_b_3_k_1]); + acc_frag_0_0 = __WMMA_8_16_16_half_float(a_frag_0, b_frag_0, acc_frag_0_0); + acc_frag_0_1 = __WMMA_8_16_16_half_float(a_frag_0, b_frag_1, acc_frag_0_1); + acc_frag_0_2 = __WMMA_8_16_16_half_float(a_frag_0, b_frag_2, acc_frag_0_2); + acc_frag_0_3 = __WMMA_8_16_16_half_float(a_frag_0, b_frag_3, acc_frag_0_3); + acc_frag_0_4 = __WMMA_8_16_16_half_float(a_frag_0, b_frag_4, acc_frag_0_4); + acc_frag_0_5 = __WMMA_8_16_16_half_float(a_frag_0, b_frag_5, acc_frag_0_5); + acc_frag_0_6 = __WMMA_8_16_16_half_float(a_frag_0, b_frag_6, acc_frag_0_6); + acc_frag_0_7 = __WMMA_8_16_16_half_float(a_frag_0, b_frag_7, acc_frag_0_7); + acc_frag_1_0 = __WMMA_8_16_16_half_float(a_frag_1, b_frag_0, acc_frag_1_0); + acc_frag_1_1 = __WMMA_8_16_16_half_float(a_frag_1, b_frag_1, acc_frag_1_1); + acc_frag_1_2 = __WMMA_8_16_16_half_float(a_frag_1, b_frag_2, acc_frag_1_2); + acc_frag_1_3 = __WMMA_8_16_16_half_float(a_frag_1, b_frag_3, acc_frag_1_3); + acc_frag_1_4 = __WMMA_8_16_16_half_float(a_frag_1, b_frag_4, acc_frag_1_4); + acc_frag_1_5 = __WMMA_8_16_16_half_float(a_frag_1, b_frag_5, acc_frag_1_5); + acc_frag_1_6 = __WMMA_8_16_16_half_float(a_frag_1, b_frag_6, acc_frag_1_6); + acc_frag_1_7 = __WMMA_8_16_16_half_float(a_frag_1, b_frag_7, acc_frag_1_7); + + // next 16 K elements + __ldmatrix_a_elems(&a_frag_0, &smem_a_curr[load_smem_a_0_k_2]); + __ldmatrix_a_elems(&a_frag_1, &smem_a_curr[load_smem_a_1_k_2]); + __ldmatrix_b_elems(&b_frag_0, &b_frag_1, &smem_b_curr[load_smem_b_0_k_2]); + __ldmatrix_b_elems(&b_frag_2, &b_frag_3, &smem_b_curr[load_smem_b_1_k_2]); + __ldmatrix_b_elems(&b_frag_4, &b_frag_5, &smem_b_curr[load_smem_b_2_k_2]); + __ldmatrix_b_elems(&b_frag_6, &b_frag_7, &smem_b_curr[load_smem_b_3_k_2]); + acc_frag_0_0 = __WMMA_8_16_16_half_float(a_frag_0, b_frag_0, acc_frag_0_0); + acc_frag_0_1 = __WMMA_8_16_16_half_float(a_frag_0, b_frag_1, acc_frag_0_1); + acc_frag_0_2 = __WMMA_8_16_16_half_float(a_frag_0, b_frag_2, acc_frag_0_2); + acc_frag_0_3 = __WMMA_8_16_16_half_float(a_frag_0, b_frag_3, acc_frag_0_3); + acc_frag_0_4 = __WMMA_8_16_16_half_float(a_frag_0, b_frag_4, acc_frag_0_4); + acc_frag_0_5 = __WMMA_8_16_16_half_float(a_frag_0, b_frag_5, acc_frag_0_5); + acc_frag_0_6 = __WMMA_8_16_16_half_float(a_frag_0, b_frag_6, acc_frag_0_6); + acc_frag_0_7 = __WMMA_8_16_16_half_float(a_frag_0, b_frag_7, acc_frag_0_7); + acc_frag_1_0 = __WMMA_8_16_16_half_float(a_frag_1, b_frag_0, acc_frag_1_0); + acc_frag_1_1 = __WMMA_8_16_16_half_float(a_frag_1, b_frag_1, acc_frag_1_1); + acc_frag_1_2 = __WMMA_8_16_16_half_float(a_frag_1, b_frag_2, acc_frag_1_2); + acc_frag_1_3 = __WMMA_8_16_16_half_float(a_frag_1, b_frag_3, acc_frag_1_3); + acc_frag_1_4 = __WMMA_8_16_16_half_float(a_frag_1, b_frag_4, acc_frag_1_4); + acc_frag_1_5 = __WMMA_8_16_16_half_float(a_frag_1, b_frag_5, acc_frag_1_5); + acc_frag_1_6 = __WMMA_8_16_16_half_float(a_frag_1, b_frag_6, acc_frag_1_6); + acc_frag_1_7 = __WMMA_8_16_16_half_float(a_frag_1, b_frag_7, acc_frag_1_7); + + // last 16 K elements + __ldmatrix_a_elems(&a_frag_0, &smem_a_curr[load_smem_a_0_k_3]); + __ldmatrix_a_elems(&a_frag_1, &smem_a_curr[load_smem_a_1_k_3]); + __ldmatrix_b_elems(&b_frag_0, &b_frag_1, &smem_b_curr[load_smem_b_0_k_3]); + __ldmatrix_b_elems(&b_frag_2, &b_frag_3, &smem_b_curr[load_smem_b_1_k_3]); + __ldmatrix_b_elems(&b_frag_4, &b_frag_5, &smem_b_curr[load_smem_b_2_k_3]); + __ldmatrix_b_elems(&b_frag_6, &b_frag_7, &smem_b_curr[load_smem_b_3_k_3]); + acc_frag_0_0 = __WMMA_8_16_16_half_float(a_frag_0, b_frag_0, acc_frag_0_0); + acc_frag_0_1 = __WMMA_8_16_16_half_float(a_frag_0, b_frag_1, acc_frag_0_1); + acc_frag_0_2 = __WMMA_8_16_16_half_float(a_frag_0, b_frag_2, acc_frag_0_2); + acc_frag_0_3 = __WMMA_8_16_16_half_float(a_frag_0, b_frag_3, acc_frag_0_3); + acc_frag_0_4 = __WMMA_8_16_16_half_float(a_frag_0, b_frag_4, acc_frag_0_4); + acc_frag_0_5 = __WMMA_8_16_16_half_float(a_frag_0, b_frag_5, acc_frag_0_5); + acc_frag_0_6 = __WMMA_8_16_16_half_float(a_frag_0, b_frag_6, acc_frag_0_6); + acc_frag_0_7 = __WMMA_8_16_16_half_float(a_frag_0, b_frag_7, acc_frag_0_7); + acc_frag_1_0 = __WMMA_8_16_16_half_float(a_frag_1, b_frag_0, acc_frag_1_0); + acc_frag_1_1 = __WMMA_8_16_16_half_float(a_frag_1, b_frag_1, acc_frag_1_1); + acc_frag_1_2 = __WMMA_8_16_16_half_float(a_frag_1, b_frag_2, acc_frag_1_2); + acc_frag_1_3 = __WMMA_8_16_16_half_float(a_frag_1, b_frag_3, acc_frag_1_3); + acc_frag_1_4 = __WMMA_8_16_16_half_float(a_frag_1, b_frag_4, acc_frag_1_4); + acc_frag_1_5 = __WMMA_8_16_16_half_float(a_frag_1, b_frag_5, acc_frag_1_5); + acc_frag_1_6 = __WMMA_8_16_16_half_float(a_frag_1, b_frag_6, acc_frag_1_6); + acc_frag_1_7 = __WMMA_8_16_16_half_float(a_frag_1, b_frag_7, acc_frag_1_7); + + // prefetch next iteration if needed + __syncthreads(); + if (block_k < (num_k_blocks-2)) { + __pipeline_memcpy_async(&smem_a_curr[store_smem_a_off + ( 0)], &data1[global_a_off + ( 0)], 16); + __pipeline_memcpy_async(&smem_a_curr[store_smem_a_off + (16*64)], &data1[global_a_off + (16*K)], 16); + __pipeline_memcpy_async(&smem_a_curr[store_smem_a_off + (32*64)], &data1[global_a_off + (32*K)], 16); + __pipeline_memcpy_async(&smem_a_curr[store_smem_a_off + (48*64)], &data1[global_a_off + (48*K)], 16); + + __pipeline_memcpy_async(&smem_b_curr[store_smem_b_off + ( 0)], &data2[global_b_off + ( 0)], 16); + __pipeline_memcpy_async(&smem_b_curr[store_smem_b_off + ( 8*128)], &data2[global_b_off + ( 8*N)], 16); + __pipeline_memcpy_async(&smem_b_curr[store_smem_b_off + (16*128)], &data2[global_b_off + (16*N)], 16); + __pipeline_memcpy_async(&smem_b_curr[store_smem_b_off + (24*128)], &data2[global_b_off + (24*N)], 16); + __pipeline_memcpy_async(&smem_b_curr[store_smem_b_off + (32*128)], &data2[global_b_off + (32*N)], 16); + __pipeline_memcpy_async(&smem_b_curr[store_smem_b_off + (40*128)], &data2[global_b_off + (40*N)], 16); + __pipeline_memcpy_async(&smem_b_curr[store_smem_b_off + (48*128)], &data2[global_b_off + (48*N)], 16); + __pipeline_memcpy_async(&smem_b_curr[store_smem_b_off + (56*128)], &data2[global_b_off + (56*N)], 16); + + global_a_off += 64; + global_b_off += 64 * N; + } + __pipeline_commit(); + + if (block_k < num_k_blocks-1) { + __pipeline_wait_prior(1); + __syncthreads(); + } + } + + // write accumulators to output + __pipeline_wait_prior(0); + __syncthreads(); + + // store registers to smem first, then read back to do float4 writes to global + float *smem_d = (float *)(smem); + size_t smem_d_off = (wg_m * 16 * N_PAD) + (wg_n * 16) + ((wg_threads % 4) * 2) + (((wg_threads / 4) % 8) * N_PAD); + smem_d[smem_d_off + 0 + ( 0*8) ] = acc_frag_0_0.x; + smem_d[smem_d_off + 1 + ( 0*8) ] = acc_frag_0_0.y; + smem_d[smem_d_off + 0 + ( 0*8) + (8*N_PAD)] = acc_frag_0_0.z; + smem_d[smem_d_off + 1 + ( 0*8) + (8*N_PAD)] = acc_frag_0_0.w; + smem_d[smem_d_off + 0 + ( 1*8) ] = acc_frag_0_1.x; + smem_d[smem_d_off + 1 + ( 1*8) ] = acc_frag_0_1.y; + smem_d[smem_d_off + 0 + ( 1*8) + (8*N_PAD)] = acc_frag_0_1.z; + smem_d[smem_d_off + 1 + ( 1*8) + (8*N_PAD)] = acc_frag_0_1.w; + smem_d[smem_d_off + 0 + ( 4*8) ] = acc_frag_0_2.x; + smem_d[smem_d_off + 1 + ( 4*8) ] = acc_frag_0_2.y; + smem_d[smem_d_off + 0 + ( 4*8) + (8*N_PAD)] = acc_frag_0_2.z; + smem_d[smem_d_off + 1 + ( 4*8) + (8*N_PAD)] = acc_frag_0_2.w; + smem_d[smem_d_off + 0 + ( 5*8) ] = acc_frag_0_3.x; + smem_d[smem_d_off + 1 + ( 5*8) ] = acc_frag_0_3.y; + smem_d[smem_d_off + 0 + ( 5*8) + (8*N_PAD)] = acc_frag_0_3.z; + smem_d[smem_d_off + 1 + ( 5*8) + (8*N_PAD)] = acc_frag_0_3.w; + smem_d[smem_d_off + 0 + ( 8*8) ] = acc_frag_0_4.x; + smem_d[smem_d_off + 1 + ( 8*8) ] = acc_frag_0_4.y; + smem_d[smem_d_off + 0 + ( 8*8) + (8*N_PAD)] = acc_frag_0_4.z; + smem_d[smem_d_off + 1 + ( 8*8) + (8*N_PAD)] = acc_frag_0_4.w; + smem_d[smem_d_off + 0 + ( 9*8) ] = acc_frag_0_5.x; + smem_d[smem_d_off + 1 + ( 9*8) ] = acc_frag_0_5.y; + smem_d[smem_d_off + 0 + ( 9*8) + (8*N_PAD)] = acc_frag_0_5.z; + smem_d[smem_d_off + 1 + ( 9*8) + (8*N_PAD)] = acc_frag_0_5.w; + smem_d[smem_d_off + 0 + (12*8) ] = acc_frag_0_6.x; + smem_d[smem_d_off + 1 + (12*8) ] = acc_frag_0_6.y; + smem_d[smem_d_off + 0 + (12*8) + (8*N_PAD)] = acc_frag_0_6.z; + smem_d[smem_d_off + 1 + (12*8) + (8*N_PAD)] = acc_frag_0_6.w; + smem_d[smem_d_off + 0 + (13*8) ] = acc_frag_0_7.x; + smem_d[smem_d_off + 1 + (13*8) ] = acc_frag_0_7.y; + smem_d[smem_d_off + 0 + (13*8) + (8*N_PAD)] = acc_frag_0_7.z; + smem_d[smem_d_off + 1 + (13*8) + (8*N_PAD)] = acc_frag_0_7.w; + + __syncthreads(); + size_t load_smem_d_off = ((threads % 32) * 4) + ((threads / 32) * N_PAD); + float4 d_0_0 = *((float4 *)(smem_d + load_smem_d_off + ( 0 * N_PAD))); + float4 d_0_1 = *((float4 *)(smem_d + load_smem_d_off + ( 4 * N_PAD))); + float4 d_0_2 = *((float4 *)(smem_d + load_smem_d_off + ( 8 * N_PAD))); + float4 d_0_3 = *((float4 *)(smem_d + load_smem_d_off + (12 * N_PAD))); + float4 d_0_4 = *((float4 *)(smem_d + load_smem_d_off + (16 * N_PAD))); + float4 d_0_5 = *((float4 *)(smem_d + load_smem_d_off + (20 * N_PAD))); + float4 d_0_6 = *((float4 *)(smem_d + load_smem_d_off + (24 * N_PAD))); + float4 d_0_7 = *((float4 *)(smem_d + load_smem_d_off + (28 * N_PAD))); + + __syncthreads(); + smem_d[smem_d_off + 0 + ( 0*8) ] = acc_frag_1_0.x; + smem_d[smem_d_off + 1 + ( 0*8) ] = acc_frag_1_0.y; + smem_d[smem_d_off + 0 + ( 0*8) + (8*N_PAD)] = acc_frag_1_0.z; + smem_d[smem_d_off + 1 + ( 0*8) + (8*N_PAD)] = acc_frag_1_0.w; + smem_d[smem_d_off + 0 + ( 1*8) ] = acc_frag_1_1.x; + smem_d[smem_d_off + 1 + ( 1*8) ] = acc_frag_1_1.y; + smem_d[smem_d_off + 0 + ( 1*8) + (8*N_PAD)] = acc_frag_1_1.z; + smem_d[smem_d_off + 1 + ( 1*8) + (8*N_PAD)] = acc_frag_1_1.w; + smem_d[smem_d_off + 0 + ( 4*8) ] = acc_frag_1_2.x; + smem_d[smem_d_off + 1 + ( 4*8) ] = acc_frag_1_2.y; + smem_d[smem_d_off + 0 + ( 4*8) + (8*N_PAD)] = acc_frag_1_2.z; + smem_d[smem_d_off + 1 + ( 4*8) + (8*N_PAD)] = acc_frag_1_2.w; + smem_d[smem_d_off + 0 + ( 5*8) ] = acc_frag_1_3.x; + smem_d[smem_d_off + 1 + ( 5*8) ] = acc_frag_1_3.y; + smem_d[smem_d_off + 0 + ( 5*8) + (8*N_PAD)] = acc_frag_1_3.z; + smem_d[smem_d_off + 1 + ( 5*8) + (8*N_PAD)] = acc_frag_1_3.w; + smem_d[smem_d_off + 0 + ( 8*8) ] = acc_frag_1_4.x; + smem_d[smem_d_off + 1 + ( 8*8) ] = acc_frag_1_4.y; + smem_d[smem_d_off + 0 + ( 8*8) + (8*N_PAD)] = acc_frag_1_4.z; + smem_d[smem_d_off + 1 + ( 8*8) + (8*N_PAD)] = acc_frag_1_4.w; + smem_d[smem_d_off + 0 + ( 9*8) ] = acc_frag_1_5.x; + smem_d[smem_d_off + 1 + ( 9*8) ] = acc_frag_1_5.y; + smem_d[smem_d_off + 0 + ( 9*8) + (8*N_PAD)] = acc_frag_1_5.z; + smem_d[smem_d_off + 1 + ( 9*8) + (8*N_PAD)] = acc_frag_1_5.w; + smem_d[smem_d_off + 0 + (12*8) ] = acc_frag_1_6.x; + smem_d[smem_d_off + 1 + (12*8) ] = acc_frag_1_6.y; + smem_d[smem_d_off + 0 + (12*8) + (8*N_PAD)] = acc_frag_1_6.z; + smem_d[smem_d_off + 1 + (12*8) + (8*N_PAD)] = acc_frag_1_6.w; + smem_d[smem_d_off + 0 + (13*8) ] = acc_frag_1_7.x; + smem_d[smem_d_off + 1 + (13*8) ] = acc_frag_1_7.y; + smem_d[smem_d_off + 0 + (13*8) + (8*N_PAD)] = acc_frag_1_7.z; + smem_d[smem_d_off + 1 + (13*8) + (8*N_PAD)] = acc_frag_1_7.w; + + __syncthreads(); + float4 d_1_0 = *((float4 *)(smem_d + load_smem_d_off + ( 0 * N_PAD))); + float4 d_1_1 = *((float4 *)(smem_d + load_smem_d_off + ( 4 * N_PAD))); + float4 d_1_2 = *((float4 *)(smem_d + load_smem_d_off + ( 8 * N_PAD))); + float4 d_1_3 = *((float4 *)(smem_d + load_smem_d_off + (12 * N_PAD))); + float4 d_1_4 = *((float4 *)(smem_d + load_smem_d_off + (16 * N_PAD))); + float4 d_1_5 = *((float4 *)(smem_d + load_smem_d_off + (20 * N_PAD))); + float4 d_1_6 = *((float4 *)(smem_d + load_smem_d_off + (24 * N_PAD))); + float4 d_1_7 = *((float4 *)(smem_d + load_smem_d_off + (28 * N_PAD))); + + __syncthreads(); + float *global_d = &data0[((grid_m * 64) * N) + (grid_n * 128) + ((threads % 32) * 4) + ((threads / 32) * N)]; + *((float4 *)(global_d + 0*N)) = d_0_0; + *((float4 *)(global_d + 4*N)) = d_0_1; + *((float4 *)(global_d + 8*N)) = d_0_2; + *((float4 *)(global_d + 12*N)) = d_0_3; + *((float4 *)(global_d + 16*N)) = d_0_4; + *((float4 *)(global_d + 20*N)) = d_0_5; + *((float4 *)(global_d + 24*N)) = d_0_6; + *((float4 *)(global_d + 28*N)) = d_0_7; + *((float4 *)(global_d + 32*N)) = d_1_0; + *((float4 *)(global_d + 36*N)) = d_1_1; + *((float4 *)(global_d + 40*N)) = d_1_2; + *((float4 *)(global_d + 44*N)) = d_1_3; + *((float4 *)(global_d + 48*N)) = d_1_4; + *((float4 *)(global_d + 52*N)) = d_1_5; + *((float4 *)(global_d + 56*N)) = d_1_6; + *((float4 *)(global_d + 60*N)) = d_1_7; +} diff --git a/extra/gemm/max_kernels/nv.fp16_fp32_fp32.swizzled_smem_input.cu b/extra/gemm/max_kernels/nv.fp16_fp32_fp32.swizzled_smem_input.cu new file mode 100644 index 0000000000..8e168b8930 --- /dev/null +++ b/extra/gemm/max_kernels/nv.fp16_fp32_fp32.swizzled_smem_input.cu @@ -0,0 +1,371 @@ +#define INFINITY (__int_as_float(0x7f800000)) +#define NAN (__int_as_float(0x7fffffff)) +#include +#include +#define N_PAD 132 + +struct __align__(8) half4 { half x, y, z, w; }; +__device__ half4 make_half4(half x, half y, half z, half w) { half4 r={x, y, z, w}; return r; } + +struct __align__(16) half8 { half x, y, z, w, a, b, c, d; }; +__device__ half8 make_half8(half x, half y, half z, half w, half a, half b, half c, half d) { half8 r={x, y, z, w, a, b, c, d}; return r; } + +__device__ void __ldmatrix_a_elems(half8 *regs, half *smem) { + uint32_t reg0, reg1, reg2, reg3; + asm volatile( + "ldmatrix.sync.aligned.m8n8.x4.shared.b16 {%0, %1, %2, %3}, [%4];" + : "=r"(reg0), "=r"(reg1), "=r"(reg2), "=r"(reg3) + : "l"(__cvta_generic_to_shared(smem)) + ); + uint32_t *addr = reinterpret_cast(regs); + addr[0] = reg0; + addr[1] = reg1; + addr[2] = reg2; + addr[3] = reg3; +} + +__device__ void __ldmatrix_b_elems(half4 *regs_lo, half4 *regs_hi, half *smem) { + uint32_t reg0, reg1, reg2, reg3; + asm volatile( + "ldmatrix.sync.aligned.m8n8.x4.trans.shared.b16 {%0, %1, %2, %3}, [%4];" + : "=r"(reg0), "=r"(reg1), "=r"(reg2), "=r"(reg3) + : "l"(__cvta_generic_to_shared(smem)) + ); + uint32_t *addr_lo = reinterpret_cast(regs_lo); + uint32_t *addr_hi = reinterpret_cast(regs_hi); + addr_lo[0] = reg0; + addr_lo[1] = reg1; + addr_hi[0] = reg2; + addr_hi[1] = reg3; +} + +__device__ float4 __WMMA_8_16_16_half_float(half8 a, half4 b, float4 c) { + int *a_pk = (int *) (&a), *b_pk = (int *) (&b); + asm( "mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 { %0, %1, %2, %3 }, { %4, %5, %6, %7 }, { %8, %9 }, { %0, %1, %2, %3 };" + : "+f"(c.x), "+f"(c.y), "+f"(c.z), "+f"(c.w) : "r"(a_pk[0]), "r"(a_pk[1]), "r"(a_pk[2]), "r"(a_pk[3]), "r"(b_pk[0]), "r"(b_pk[1]) ); + return c; +} + +extern "C" __global__ void __launch_bounds__(128) wmma_example(float* data0, const half* data1, const half* data2, int N, int K) { + int grid_m = blockIdx.x; /* M//64 */ + int grid_n = blockIdx.y; /* N//128 */ + int threads = threadIdx.x; /* 128 */ + int wg_m = (threads/64); // 0 or 1 for 1st and 3rd blocks of b_m=16xb_k=16 vs 2nd and 4th blocks + int wg_n = (threads/32)%2; // 0 or 1 for 1st, 3rd, 5th, 7th blocks of b_n=16xb_k=16 vs 2nd, 4th, 6th, 8th blocks - differs from triton + int wg_threads = threads%32; + int num_k_blocks = K / 64; + + // load indexes + size_t global_a_off = ((grid_m * 64) * K) + ((threads % 8) * 8) + ((threads / 8) * K); + size_t global_b_off = (grid_n * 128) + ((threads % 16) * 8) + ((threads / 16) * N); + + // swizzled smem store offsets - columns of smem are swizzled + // here's a link to a description of the triton: https://github.com/triton-lang/triton/discussions/2026#discussioncomment-6746579 + // see also the thunderkittens impl: https://github.com/HazyResearch/ThunderKittens/blob/main/include/types/shared/st.cuh + size_t store_smem_a_off = ((threads / 8) * 64) + (((threads * 8) ^ threads) & 56); // r15 + size_t store_smem_b_off = ((threads / 16) * 128) + (((threads / 16) * 8) ^ ((threads % 16) * 8)); // r19 + + // ldmatrix indices + // threads 0-7 are row starts for A, 8-15 for B, 16-23 for C, 24-31 for D + // [ A | C ] + // [ - + - ] + // [ B | D ] + + // swizzled ldmatrix + size_t load_smem_a_row = ((wg_m * 16) + (threads % 16)) * 64; // r293 + size_t load_smem_a_phase = (threads / 16) % 2; // r4 + + size_t load_smem_b_row = (threads % 16) * 128; // r299 + size_t load_smem_b_phase = (wg_n * 2) + (((threads / 16) % 2)); // r297 -- this differs from the generated triton kernel (swapped order) + + size_t load_smem_a_0_k_0 = load_smem_a_row + (((load_smem_a_phase + 0) ^ (threads % 8)) * 8); // r38 + size_t load_smem_a_1_k_0 = load_smem_a_0_k_0 + (32 * 64); + size_t load_smem_b_0_k_0 = load_smem_b_row + (((load_smem_b_phase + 0) ^ (threads % 8)) * 8); + size_t load_smem_b_1_k_0 = load_smem_b_row + (((load_smem_b_phase + 4) ^ (threads % 8)) * 8); + size_t load_smem_b_2_k_0 = load_smem_b_row + (((load_smem_b_phase + 8) ^ (threads % 8)) * 8); + size_t load_smem_b_3_k_0 = load_smem_b_row + (((load_smem_b_phase + 12) ^ (threads % 8)) * 8); + + size_t load_smem_a_0_k_1 = load_smem_a_row + (((load_smem_a_phase + 2) ^ (threads % 8)) * 8); // r58 = r293 + r316; + size_t load_smem_a_1_k_1 = load_smem_a_0_k_1 + (32 * 64); + size_t load_smem_b_0_k_1 = load_smem_b_0_k_0 + (16 * 128); + size_t load_smem_b_1_k_1 = load_smem_b_1_k_0 + (16 * 128); + size_t load_smem_b_2_k_1 = load_smem_b_2_k_0 + (16 * 128); + size_t load_smem_b_3_k_1 = load_smem_b_3_k_0 + (16 * 128); + + size_t load_smem_a_0_k_2 = load_smem_a_row + (((load_smem_a_phase + 4) ^ (threads % 8)) * 8); // r59 = r293 + r319; + size_t load_smem_a_1_k_2 = load_smem_a_0_k_2 + (32 * 64); + size_t load_smem_b_0_k_2 = load_smem_b_0_k_0 + (32 * 128); + size_t load_smem_b_1_k_2 = load_smem_b_1_k_0 + (32 * 128); + size_t load_smem_b_2_k_2 = load_smem_b_2_k_0 + (32 * 128); + size_t load_smem_b_3_k_2 = load_smem_b_3_k_0 + (32 * 128); + + size_t load_smem_a_0_k_3 = load_smem_a_row + (((load_smem_a_phase + 6) ^ (threads % 8)) * 8); // r60 = r293 + r322; + size_t load_smem_a_1_k_3 = load_smem_a_0_k_3 + (32 * 64); + size_t load_smem_b_0_k_3 = load_smem_b_0_k_0 + (48 * 128); + size_t load_smem_b_1_k_3 = load_smem_b_1_k_0 + (48 * 128); + size_t load_smem_b_2_k_3 = load_smem_b_2_k_0 + (48 * 128); + size_t load_smem_b_3_k_3 = load_smem_b_3_k_0 + (48 * 128); + + // create shared mem (A 8192 bytes, B 16384 bytes) + __shared__ alignas(16) char smem[24576]; + + // create accs (16 WMMAs and 4 output elements each) and zero + float4 acc_frag_0_0 = make_float4(0.0f,0.0f,0.0f,0.0f); + float4 acc_frag_0_1 = make_float4(0.0f,0.0f,0.0f,0.0f); + float4 acc_frag_0_2 = make_float4(0.0f,0.0f,0.0f,0.0f); + float4 acc_frag_0_3 = make_float4(0.0f,0.0f,0.0f,0.0f); + float4 acc_frag_0_4 = make_float4(0.0f,0.0f,0.0f,0.0f); + float4 acc_frag_0_5 = make_float4(0.0f,0.0f,0.0f,0.0f); + float4 acc_frag_0_6 = make_float4(0.0f,0.0f,0.0f,0.0f); + float4 acc_frag_0_7 = make_float4(0.0f,0.0f,0.0f,0.0f); + float4 acc_frag_1_0 = make_float4(0.0f,0.0f,0.0f,0.0f); + float4 acc_frag_1_1 = make_float4(0.0f,0.0f,0.0f,0.0f); + float4 acc_frag_1_2 = make_float4(0.0f,0.0f,0.0f,0.0f); + float4 acc_frag_1_3 = make_float4(0.0f,0.0f,0.0f,0.0f); + float4 acc_frag_1_4 = make_float4(0.0f,0.0f,0.0f,0.0f); + float4 acc_frag_1_5 = make_float4(0.0f,0.0f,0.0f,0.0f); + float4 acc_frag_1_6 = make_float4(0.0f,0.0f,0.0f,0.0f); + float4 acc_frag_1_7 = make_float4(0.0f,0.0f,0.0f,0.0f); + + // create registers for block A elements (2) + half8 a_frag_0; + half8 a_frag_1; + + // create register for block B elements (8) + half4 b_frag_0; + half4 b_frag_1; + half4 b_frag_2; + half4 b_frag_3; + half4 b_frag_4; + half4 b_frag_5; + half4 b_frag_6; + half4 b_frag_7; + + half *smem_a = (half *)(smem); + half *smem_b = (half *)(smem + 8192); + + // https://developer.nvidia.com/blog/controlling-data-movement-to-boost-performance-on-ampere-architecture/ + // https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#asynchronous-data-copies + + // start first pre-fetch load A + __pipeline_memcpy_async(&smem_a[store_smem_a_off + ( 0)], &data1[global_a_off + ( 0)], 16); + __pipeline_memcpy_async(&smem_a[store_smem_a_off + (16*64)], &data1[global_a_off + (16*K)], 16); + __pipeline_memcpy_async(&smem_a[store_smem_a_off + (32*64)], &data1[global_a_off + (32*K)], 16); + __pipeline_memcpy_async(&smem_a[store_smem_a_off + (48*64)], &data1[global_a_off + (48*K)], 16); + + // start first pre-fetch load B + __pipeline_memcpy_async(&smem_b[store_smem_b_off + ( 0)], &data2[global_b_off + ( 0)], 16); + __pipeline_memcpy_async(&smem_b[store_smem_b_off + ( 8*128)], &data2[global_b_off + ( 8*N)], 16); + __pipeline_memcpy_async(&smem_b[store_smem_b_off + (16*128)], &data2[global_b_off + (16*N)], 16); + __pipeline_memcpy_async(&smem_b[store_smem_b_off + (24*128)], &data2[global_b_off + (24*N)], 16); + __pipeline_memcpy_async(&smem_b[store_smem_b_off + (32*128)], &data2[global_b_off + (32*N)], 16); + __pipeline_memcpy_async(&smem_b[store_smem_b_off + (40*128)], &data2[global_b_off + (40*N)], 16); + __pipeline_memcpy_async(&smem_b[store_smem_b_off + (48*128)], &data2[global_b_off + (48*N)], 16); + __pipeline_memcpy_async(&smem_b[store_smem_b_off + (56*128)], &data2[global_b_off + (56*N)], 16); + __pipeline_commit(); + + global_a_off += 64; + global_b_off += 64 * N; + __syncthreads(); + + for (int block_k = 0; block_k < num_k_blocks; block_k++) { + // wait on needed prefetch value + __pipeline_wait_prior(0); + __syncthreads(); + + // BLOCK_K==4: unroll 4 iterations of ldmatrix/wmma + half *smem_a_curr = smem_a; + half *smem_b_curr = smem_b; + + // first load 16 K elements and 16 WMMAs: BLOCK_M==2 * BLOCK_N==8 + __ldmatrix_a_elems(&a_frag_0, &smem_a_curr[load_smem_a_0_k_0]); + __ldmatrix_a_elems(&a_frag_1, &smem_a_curr[load_smem_a_1_k_0]); + __ldmatrix_b_elems(&b_frag_0, &b_frag_1, &smem_b_curr[load_smem_b_0_k_0]); + __ldmatrix_b_elems(&b_frag_2, &b_frag_3, &smem_b_curr[load_smem_b_1_k_0]); + __ldmatrix_b_elems(&b_frag_4, &b_frag_5, &smem_b_curr[load_smem_b_2_k_0]); + __ldmatrix_b_elems(&b_frag_6, &b_frag_7, &smem_b_curr[load_smem_b_3_k_0]); + acc_frag_0_0 = __WMMA_8_16_16_half_float(a_frag_0, b_frag_0, acc_frag_0_0); + acc_frag_0_1 = __WMMA_8_16_16_half_float(a_frag_0, b_frag_1, acc_frag_0_1); + acc_frag_0_2 = __WMMA_8_16_16_half_float(a_frag_0, b_frag_2, acc_frag_0_2); + acc_frag_0_3 = __WMMA_8_16_16_half_float(a_frag_0, b_frag_3, acc_frag_0_3); + acc_frag_0_4 = __WMMA_8_16_16_half_float(a_frag_0, b_frag_4, acc_frag_0_4); + acc_frag_0_5 = __WMMA_8_16_16_half_float(a_frag_0, b_frag_5, acc_frag_0_5); + acc_frag_0_6 = __WMMA_8_16_16_half_float(a_frag_0, b_frag_6, acc_frag_0_6); + acc_frag_0_7 = __WMMA_8_16_16_half_float(a_frag_0, b_frag_7, acc_frag_0_7); + acc_frag_1_0 = __WMMA_8_16_16_half_float(a_frag_1, b_frag_0, acc_frag_1_0); + acc_frag_1_1 = __WMMA_8_16_16_half_float(a_frag_1, b_frag_1, acc_frag_1_1); + acc_frag_1_2 = __WMMA_8_16_16_half_float(a_frag_1, b_frag_2, acc_frag_1_2); + acc_frag_1_3 = __WMMA_8_16_16_half_float(a_frag_1, b_frag_3, acc_frag_1_3); + acc_frag_1_4 = __WMMA_8_16_16_half_float(a_frag_1, b_frag_4, acc_frag_1_4); + acc_frag_1_5 = __WMMA_8_16_16_half_float(a_frag_1, b_frag_5, acc_frag_1_5); + acc_frag_1_6 = __WMMA_8_16_16_half_float(a_frag_1, b_frag_6, acc_frag_1_6); + acc_frag_1_7 = __WMMA_8_16_16_half_float(a_frag_1, b_frag_7, acc_frag_1_7); + + // next 16 K elements + __ldmatrix_a_elems(&a_frag_0, &smem_a_curr[load_smem_a_0_k_1]); + __ldmatrix_a_elems(&a_frag_1, &smem_a_curr[load_smem_a_1_k_1]); + __ldmatrix_b_elems(&b_frag_0, &b_frag_1, &smem_b_curr[load_smem_b_0_k_1]); + __ldmatrix_b_elems(&b_frag_2, &b_frag_3, &smem_b_curr[load_smem_b_1_k_1]); + __ldmatrix_b_elems(&b_frag_4, &b_frag_5, &smem_b_curr[load_smem_b_2_k_1]); + __ldmatrix_b_elems(&b_frag_6, &b_frag_7, &smem_b_curr[load_smem_b_3_k_1]); + acc_frag_0_0 = __WMMA_8_16_16_half_float(a_frag_0, b_frag_0, acc_frag_0_0); + acc_frag_0_1 = __WMMA_8_16_16_half_float(a_frag_0, b_frag_1, acc_frag_0_1); + acc_frag_0_2 = __WMMA_8_16_16_half_float(a_frag_0, b_frag_2, acc_frag_0_2); + acc_frag_0_3 = __WMMA_8_16_16_half_float(a_frag_0, b_frag_3, acc_frag_0_3); + acc_frag_0_4 = __WMMA_8_16_16_half_float(a_frag_0, b_frag_4, acc_frag_0_4); + acc_frag_0_5 = __WMMA_8_16_16_half_float(a_frag_0, b_frag_5, acc_frag_0_5); + acc_frag_0_6 = __WMMA_8_16_16_half_float(a_frag_0, b_frag_6, acc_frag_0_6); + acc_frag_0_7 = __WMMA_8_16_16_half_float(a_frag_0, b_frag_7, acc_frag_0_7); + acc_frag_1_0 = __WMMA_8_16_16_half_float(a_frag_1, b_frag_0, acc_frag_1_0); + acc_frag_1_1 = __WMMA_8_16_16_half_float(a_frag_1, b_frag_1, acc_frag_1_1); + acc_frag_1_2 = __WMMA_8_16_16_half_float(a_frag_1, b_frag_2, acc_frag_1_2); + acc_frag_1_3 = __WMMA_8_16_16_half_float(a_frag_1, b_frag_3, acc_frag_1_3); + acc_frag_1_4 = __WMMA_8_16_16_half_float(a_frag_1, b_frag_4, acc_frag_1_4); + acc_frag_1_5 = __WMMA_8_16_16_half_float(a_frag_1, b_frag_5, acc_frag_1_5); + acc_frag_1_6 = __WMMA_8_16_16_half_float(a_frag_1, b_frag_6, acc_frag_1_6); + acc_frag_1_7 = __WMMA_8_16_16_half_float(a_frag_1, b_frag_7, acc_frag_1_7); + + // next 16 K elements + __ldmatrix_a_elems(&a_frag_0, &smem_a_curr[load_smem_a_0_k_2]); + __ldmatrix_a_elems(&a_frag_1, &smem_a_curr[load_smem_a_1_k_2]); + __ldmatrix_b_elems(&b_frag_0, &b_frag_1, &smem_b_curr[load_smem_b_0_k_2]); + __ldmatrix_b_elems(&b_frag_2, &b_frag_3, &smem_b_curr[load_smem_b_1_k_2]); + __ldmatrix_b_elems(&b_frag_4, &b_frag_5, &smem_b_curr[load_smem_b_2_k_2]); + __ldmatrix_b_elems(&b_frag_6, &b_frag_7, &smem_b_curr[load_smem_b_3_k_2]); + acc_frag_0_0 = __WMMA_8_16_16_half_float(a_frag_0, b_frag_0, acc_frag_0_0); + acc_frag_0_1 = __WMMA_8_16_16_half_float(a_frag_0, b_frag_1, acc_frag_0_1); + acc_frag_0_2 = __WMMA_8_16_16_half_float(a_frag_0, b_frag_2, acc_frag_0_2); + acc_frag_0_3 = __WMMA_8_16_16_half_float(a_frag_0, b_frag_3, acc_frag_0_3); + acc_frag_0_4 = __WMMA_8_16_16_half_float(a_frag_0, b_frag_4, acc_frag_0_4); + acc_frag_0_5 = __WMMA_8_16_16_half_float(a_frag_0, b_frag_5, acc_frag_0_5); + acc_frag_0_6 = __WMMA_8_16_16_half_float(a_frag_0, b_frag_6, acc_frag_0_6); + acc_frag_0_7 = __WMMA_8_16_16_half_float(a_frag_0, b_frag_7, acc_frag_0_7); + acc_frag_1_0 = __WMMA_8_16_16_half_float(a_frag_1, b_frag_0, acc_frag_1_0); + acc_frag_1_1 = __WMMA_8_16_16_half_float(a_frag_1, b_frag_1, acc_frag_1_1); + acc_frag_1_2 = __WMMA_8_16_16_half_float(a_frag_1, b_frag_2, acc_frag_1_2); + acc_frag_1_3 = __WMMA_8_16_16_half_float(a_frag_1, b_frag_3, acc_frag_1_3); + acc_frag_1_4 = __WMMA_8_16_16_half_float(a_frag_1, b_frag_4, acc_frag_1_4); + acc_frag_1_5 = __WMMA_8_16_16_half_float(a_frag_1, b_frag_5, acc_frag_1_5); + acc_frag_1_6 = __WMMA_8_16_16_half_float(a_frag_1, b_frag_6, acc_frag_1_6); + acc_frag_1_7 = __WMMA_8_16_16_half_float(a_frag_1, b_frag_7, acc_frag_1_7); + + // last 16 K elements + __ldmatrix_a_elems(&a_frag_0, &smem_a_curr[load_smem_a_0_k_3]); + __ldmatrix_a_elems(&a_frag_1, &smem_a_curr[load_smem_a_1_k_3]); + __ldmatrix_b_elems(&b_frag_0, &b_frag_1, &smem_b_curr[load_smem_b_0_k_3]); + __ldmatrix_b_elems(&b_frag_2, &b_frag_3, &smem_b_curr[load_smem_b_1_k_3]); + __ldmatrix_b_elems(&b_frag_4, &b_frag_5, &smem_b_curr[load_smem_b_2_k_3]); + __ldmatrix_b_elems(&b_frag_6, &b_frag_7, &smem_b_curr[load_smem_b_3_k_3]); + acc_frag_0_0 = __WMMA_8_16_16_half_float(a_frag_0, b_frag_0, acc_frag_0_0); + acc_frag_0_1 = __WMMA_8_16_16_half_float(a_frag_0, b_frag_1, acc_frag_0_1); + acc_frag_0_2 = __WMMA_8_16_16_half_float(a_frag_0, b_frag_2, acc_frag_0_2); + acc_frag_0_3 = __WMMA_8_16_16_half_float(a_frag_0, b_frag_3, acc_frag_0_3); + acc_frag_0_4 = __WMMA_8_16_16_half_float(a_frag_0, b_frag_4, acc_frag_0_4); + acc_frag_0_5 = __WMMA_8_16_16_half_float(a_frag_0, b_frag_5, acc_frag_0_5); + acc_frag_0_6 = __WMMA_8_16_16_half_float(a_frag_0, b_frag_6, acc_frag_0_6); + acc_frag_0_7 = __WMMA_8_16_16_half_float(a_frag_0, b_frag_7, acc_frag_0_7); + acc_frag_1_0 = __WMMA_8_16_16_half_float(a_frag_1, b_frag_0, acc_frag_1_0); + acc_frag_1_1 = __WMMA_8_16_16_half_float(a_frag_1, b_frag_1, acc_frag_1_1); + acc_frag_1_2 = __WMMA_8_16_16_half_float(a_frag_1, b_frag_2, acc_frag_1_2); + acc_frag_1_3 = __WMMA_8_16_16_half_float(a_frag_1, b_frag_3, acc_frag_1_3); + acc_frag_1_4 = __WMMA_8_16_16_half_float(a_frag_1, b_frag_4, acc_frag_1_4); + acc_frag_1_5 = __WMMA_8_16_16_half_float(a_frag_1, b_frag_5, acc_frag_1_5); + acc_frag_1_6 = __WMMA_8_16_16_half_float(a_frag_1, b_frag_6, acc_frag_1_6); + acc_frag_1_7 = __WMMA_8_16_16_half_float(a_frag_1, b_frag_7, acc_frag_1_7); + + // prefetch next iteration if needed + __syncthreads(); + if (block_k < (num_k_blocks-1)) { + __pipeline_memcpy_async(&smem_a_curr[store_smem_a_off + ( 0)], &data1[global_a_off + ( 0)], 16); + __pipeline_memcpy_async(&smem_a_curr[store_smem_a_off + (16*64)], &data1[global_a_off + (16*K)], 16); + __pipeline_memcpy_async(&smem_a_curr[store_smem_a_off + (32*64)], &data1[global_a_off + (32*K)], 16); + __pipeline_memcpy_async(&smem_a_curr[store_smem_a_off + (48*64)], &data1[global_a_off + (48*K)], 16); + + __pipeline_memcpy_async(&smem_b_curr[store_smem_b_off + ( 0)], &data2[global_b_off + ( 0)], 16); + __pipeline_memcpy_async(&smem_b_curr[store_smem_b_off + ( 8*128)], &data2[global_b_off + ( 8*N)], 16); + __pipeline_memcpy_async(&smem_b_curr[store_smem_b_off + (16*128)], &data2[global_b_off + (16*N)], 16); + __pipeline_memcpy_async(&smem_b_curr[store_smem_b_off + (24*128)], &data2[global_b_off + (24*N)], 16); + __pipeline_memcpy_async(&smem_b_curr[store_smem_b_off + (32*128)], &data2[global_b_off + (32*N)], 16); + __pipeline_memcpy_async(&smem_b_curr[store_smem_b_off + (40*128)], &data2[global_b_off + (40*N)], 16); + __pipeline_memcpy_async(&smem_b_curr[store_smem_b_off + (48*128)], &data2[global_b_off + (48*N)], 16); + __pipeline_memcpy_async(&smem_b_curr[store_smem_b_off + (56*128)], &data2[global_b_off + (56*N)], 16); + + global_a_off += 64; + global_b_off += 64 * N; + } + __pipeline_commit(); + } + + // write accumulators to output + __pipeline_wait_prior(0); + __syncthreads(); + + // slower way: write floats one by one to data0 + size_t wg_c_off = ((grid_m * 64) * N) + (grid_n * 128) + (wg_m * 16 * N) + (wg_n * 16); + size_t thread_c_off = ((wg_threads % 4) * 2) + (((wg_threads / 4) % 8) * N); + data0[wg_c_off + thread_c_off + 0 + ( 0*8)] = acc_frag_0_0.x; + data0[wg_c_off + thread_c_off + 1 + ( 0*8)] = acc_frag_0_0.y; + data0[wg_c_off + thread_c_off + (8 * N) + 0 + ( 0*8)] = acc_frag_0_0.z; + data0[wg_c_off + thread_c_off + (8 * N) + 1 + ( 0*8)] = acc_frag_0_0.w; + data0[wg_c_off + thread_c_off + 0 + ( 1*8)] = acc_frag_0_1.x; + data0[wg_c_off + thread_c_off + 1 + ( 1*8)] = acc_frag_0_1.y; + data0[wg_c_off + thread_c_off + (8 * N) + 0 + ( 1*8)] = acc_frag_0_1.z; + data0[wg_c_off + thread_c_off + (8 * N) + 1 + ( 1*8)] = acc_frag_0_1.w; + data0[wg_c_off + thread_c_off + 0 + ( 4*8)] = acc_frag_0_2.x; + data0[wg_c_off + thread_c_off + 1 + ( 4*8)] = acc_frag_0_2.y; + data0[wg_c_off + thread_c_off + (8 * N) + 0 + ( 4*8)] = acc_frag_0_2.z; + data0[wg_c_off + thread_c_off + (8 * N) + 1 + ( 4*8)] = acc_frag_0_2.w; + data0[wg_c_off + thread_c_off + 0 + ( 5*8)] = acc_frag_0_3.x; + data0[wg_c_off + thread_c_off + 1 + ( 5*8)] = acc_frag_0_3.y; + data0[wg_c_off + thread_c_off + (8 * N) + 0 + ( 5*8)] = acc_frag_0_3.z; + data0[wg_c_off + thread_c_off + (8 * N) + 1 + ( 5*8)] = acc_frag_0_3.w; + data0[wg_c_off + thread_c_off + 0 + ( 8*8)] = acc_frag_0_4.x; + data0[wg_c_off + thread_c_off + 1 + ( 8*8)] = acc_frag_0_4.y; + data0[wg_c_off + thread_c_off + (8 * N) + 0 + ( 8*8)] = acc_frag_0_4.z; + data0[wg_c_off + thread_c_off + (8 * N) + 1 + ( 8*8)] = acc_frag_0_4.w; + data0[wg_c_off + thread_c_off + 0 + ( 9*8)] = acc_frag_0_5.x; + data0[wg_c_off + thread_c_off + 1 + ( 9*8)] = acc_frag_0_5.y; + data0[wg_c_off + thread_c_off + (8 * N) + 0 + ( 9*8)] = acc_frag_0_5.z; + data0[wg_c_off + thread_c_off + (8 * N) + 1 + ( 9*8)] = acc_frag_0_5.w; + data0[wg_c_off + thread_c_off + 0 + (12*8)] = acc_frag_0_6.x; + data0[wg_c_off + thread_c_off + 1 + (12*8)] = acc_frag_0_6.y; + data0[wg_c_off + thread_c_off + (8 * N) + 0 + (12*8)] = acc_frag_0_6.z; + data0[wg_c_off + thread_c_off + (8 * N) + 1 + (12*8)] = acc_frag_0_6.w; + data0[wg_c_off + thread_c_off + 0 + (13*8)] = acc_frag_0_7.x; + data0[wg_c_off + thread_c_off + 1 + (13*8)] = acc_frag_0_7.y; + data0[wg_c_off + thread_c_off + (8 * N) + 0 + (13*8)] = acc_frag_0_7.z; + data0[wg_c_off + thread_c_off + (8 * N) + 1 + (13*8)] = acc_frag_0_7.w; + wg_c_off += 32*N; + data0[wg_c_off + thread_c_off + 0 + ( 0*8)] = acc_frag_1_0.x; + data0[wg_c_off + thread_c_off + 1 + ( 0*8)] = acc_frag_1_0.y; + data0[wg_c_off + thread_c_off + (8 * N) + 0 + ( 0*8)] = acc_frag_1_0.z; + data0[wg_c_off + thread_c_off + (8 * N) + 1 + ( 0*8)] = acc_frag_1_0.w; + data0[wg_c_off + thread_c_off + 0 + ( 1*8)] = acc_frag_1_1.x; + data0[wg_c_off + thread_c_off + 1 + ( 1*8)] = acc_frag_1_1.y; + data0[wg_c_off + thread_c_off + (8 * N) + 0 + ( 1*8)] = acc_frag_1_1.z; + data0[wg_c_off + thread_c_off + (8 * N) + 1 + ( 1*8)] = acc_frag_1_1.w; + data0[wg_c_off + thread_c_off + 0 + ( 4*8)] = acc_frag_1_2.x; + data0[wg_c_off + thread_c_off + 1 + ( 4*8)] = acc_frag_1_2.y; + data0[wg_c_off + thread_c_off + (8 * N) + 0 + ( 4*8)] = acc_frag_1_2.z; + data0[wg_c_off + thread_c_off + (8 * N) + 1 + ( 4*8)] = acc_frag_1_2.w; + data0[wg_c_off + thread_c_off + 0 + ( 5*8)] = acc_frag_1_3.x; + data0[wg_c_off + thread_c_off + 1 + ( 5*8)] = acc_frag_1_3.y; + data0[wg_c_off + thread_c_off + (8 * N) + 0 + ( 5*8)] = acc_frag_1_3.z; + data0[wg_c_off + thread_c_off + (8 * N) + 1 + ( 5*8)] = acc_frag_1_3.w; + data0[wg_c_off + thread_c_off + 0 + ( 8*8)] = acc_frag_1_4.x; + data0[wg_c_off + thread_c_off + 1 + ( 8*8)] = acc_frag_1_4.y; + data0[wg_c_off + thread_c_off + (8 * N) + 0 + ( 8*8)] = acc_frag_1_4.z; + data0[wg_c_off + thread_c_off + (8 * N) + 1 + ( 8*8)] = acc_frag_1_4.w; + data0[wg_c_off + thread_c_off + 0 + ( 9*8)] = acc_frag_1_5.x; + data0[wg_c_off + thread_c_off + 1 + ( 9*8)] = acc_frag_1_5.y; + data0[wg_c_off + thread_c_off + (8 * N) + 0 + ( 9*8)] = acc_frag_1_5.z; + data0[wg_c_off + thread_c_off + (8 * N) + 1 + ( 9*8)] = acc_frag_1_5.w; + data0[wg_c_off + thread_c_off + 0 + (12*8)] = acc_frag_1_6.x; + data0[wg_c_off + thread_c_off + 1 + (12*8)] = acc_frag_1_6.y; + data0[wg_c_off + thread_c_off + (8 * N) + 0 + (12*8)] = acc_frag_1_6.z; + data0[wg_c_off + thread_c_off + (8 * N) + 1 + (12*8)] = acc_frag_1_6.w; + data0[wg_c_off + thread_c_off + 0 + (13*8)] = acc_frag_1_7.x; + data0[wg_c_off + thread_c_off + 1 + (13*8)] = acc_frag_1_7.y; + data0[wg_c_off + thread_c_off + (8 * N) + 0 + (13*8)] = acc_frag_1_7.z; + data0[wg_c_off + thread_c_off + (8 * N) + 1 + (13*8)] = acc_frag_1_7.w; +} diff --git a/extra/gemm/max_matmul.py b/extra/gemm/max_matmul.py new file mode 100644 index 0000000000..8b6195d7c3 --- /dev/null +++ b/extra/gemm/max_matmul.py @@ -0,0 +1,232 @@ +import numpy as np, os +from tinygrad.helpers import getenv, flat_mv +from tinygrad import dtypes +from typing import Optional, List, Tuple, cast, Dict, Final, DefaultDict, Self + +# for copied uops +from tinygrad.codegen.kernel import Kernel, KernelOptError +from tinygrad.ops import UOp, Ops, BinaryOps, UnaryOps, TernaryOps, KernelInfo +from tinygrad.engine.search import Opt, OptOps +from tinygrad import Device, dtypes, Tensor +from tinygrad.dtype import PtrDType, DType, DTYPES_DICT +from tinygrad.shape.shapetracker import ShapeTracker +from tinygrad.shape.view import View + +script_dir = os.path.dirname(os.path.abspath(__file__)) + +# problem variations +DTYPE_IN = DTYPES_DICT[getenv("DTYPE_IN", "half")] +DTYPE_OUT = DTYPES_DICT[getenv("DTYPE_OUT", "half")] +DTYPE_ACC = DTYPES_DICT[getenv("DTYPE_ACC", "float")] +N = getenv("N", 4096) +M = getenv("M", N) +K = getenv("K", N) +CNT = getenv("CNT", 10) +ATOL = getenv("ATOL", 5e-3 if DTYPE_IN == dtypes.float else 1e-2) +RTOL = getenv("RTOL", 1e-4 if DTYPE_IN == dtypes.float else 1e-3) +FLOPS = M * N * K * 2 +BW = 2 * ((M*K) + (K*N) + (M*N)) + +# algorithm variations +INPUT = getenv("INPUT", "RAND") +GEMM_VARIATION = getenv("GEMM_VARIATION", "nv_hcopt") + +def randoms(): + if INPUT == "RAND": + na = np.random.default_rng().normal(scale=1.0, size=(M,K)).astype(dtype=np.float32) + nb = np.random.default_rng().normal(scale=1.0, size=(K,N)).astype(dtype=np.float32) + elif INPUT == "IDENTITY" and M==N==K: + na = np.identity(K, dtype=np.float32) + nb = np.identity(K, dtype=np.float32) + elif INPUT == "OUTPUTONES" and M==K: + na = np.identity(K, dtype=np.float32) + nb = np.ones((K,N), dtype=np.float32) + else: + na = np.ones((M,K), dtype=np.float32) + nb = np.ones((K,N), dtype=np.float32) + nc = np.zeros(M*N, np.float32) + if DTYPE_IN != dtypes.float: + na = na.astype(np.bfloat16 if DTYPE_IN == dtypes.bfloat16 else np.float16) + nb = nb.astype(np.bfloat16 if DTYPE_IN == dtypes.bfloat16 else np.float16) + if DTYPE_OUT != dtypes.float: + nc = nc.astype(np.bfloat16 if DTYPE_IN == dtypes.bfloat16 else np.float16) + return na, nb, nc + +def ast_to_cuda_prog(compiler, ast, opts): + k = Kernel(ast) + k.required_optimizations() + for opt in opts: + k.apply_opt(opt) + p = k.to_program() + return CUDAProgram(device, p.function_name, compiler.compile(p.src)) + +if __name__ == "__main__": + print(f"gemm variation: {GEMM_VARIATION=} {M=} {N=} {K=} {DTYPE_IN=} {DTYPE_OUT=} {DTYPE_ACC=}") + prog, global_size, local_size = None, None, None + + if getenv("CUDA") == 1: + from tinygrad.runtime.ops_cuda import CUDAAllocator, CUDADevice, CUDAProgram, CUDACompiler + device = CUDADevice("cuda:0") + compiler = CUDACompiler(device.arch) + cudaalloc = CUDAAllocator(device) + + a = cudaalloc.alloc(M*K*DTYPE_IN.itemsize) + b = cudaalloc.alloc(K*N*DTYPE_IN.itemsize) + c = cudaalloc.alloc(M*N*DTYPE_OUT.itemsize) + + if GEMM_VARIATION == "max" and (M%64)==0 and (N%128)==0 and (K%64)==0 and DTYPE_IN == dtypes.half and DTYPE_OUT == dtypes.float and DTYPE_ACC == dtypes.float: + print("Using CUDA and triton-generated kernel") + # See nv_triton_gemm.annotated.ptx for PTX code which was generated from `PYTHONPATH=. DEBUG=6 CUDA=1 PTX=1 python3 extra/gemm/triton_nv_matmul.py` + # this kernel with M=N=K=4096 does 162TFLOPS, vs torch at 144TFLOPS and BEAM=8 tinygrad at 138TFLOPS. theo max is 165TFLOPS. + + # WMMA element size is (M, N, K) = (16, 8, 16) + # warpgroup size in WMMA tiles is (B_M, B_N, B_K) = (2, 8, 4) so 64 HMMA calls per threadgroup reduce iteration + # thread block size is (T_M, T_N, T_K) = (2, 2, 1), i.e. macro blocks in M and N, so 256 HMMA calls per kernel reduce iteration + # kernel reduce iteration size in elements = (64, 128, 64) + # single iteration SMEM_A = (64 * 64) * (2 bytes / half) = 8192 bytes, SMEM_B = (128 * 64) * (2 bytes / half) = 16384 bytes + # double-buffer smem = (8192 + 16384) * 2 = 49152 bytes + # reduce for_loop size = [1, 1, (4096 // 16 // 4)==64] + # NOTE: T_K > 0 would be group_for_reduce + prog = CUDAProgram(device, "wmma_example", compiler.compile(open(os.path.join(script_dir, 'max_kernels/nv.fp16_fp32_fp32.max.cu')).read())) + args = (c, a, b) + kwargs = { + 'global_size': [M//64, N//128, 1], + 'local_size': [128, 1, 1], # 4 warpgroups == (T_M:=2) * (T_N:=2) + 'wait': True, + 'vals': (N, K), + } + elif GEMM_VARIATION == "2_stage_swizzled_smem_input" and (M%64)==0 and (N%128)==0 and (K%64)==0 and DTYPE_IN == dtypes.half and DTYPE_OUT == dtypes.float and DTYPE_ACC == dtypes.float: + print("Using CUDA, 2-stage reduce pipeline, swizzled SMEM inputs") + prog = CUDAProgram(device, "wmma_example", compiler.compile(open(os.path.join(script_dir, 'max_kernels/nv.fp16_fp32_fp32.2_stage_swizzled_smem_input.cu')).read())) + args = (c, a, b) + kwargs = { + 'global_size': [M//64, N//128, 1], + 'local_size': [128, 1, 1], # 4 warpgroups == (T_M:=2) * (T_N:=2) + 'wait': True, + 'vals': (N, K), + } + elif GEMM_VARIATION == "swizzled_smem_input" and (M%64)==0 and (N%128)==0 and (K%64)==0 and DTYPE_IN == dtypes.half and DTYPE_OUT == dtypes.float and DTYPE_ACC == dtypes.float: + print("Using CUDA, swizzled SMEM inputs") + prog = CUDAProgram(device, "wmma_example", compiler.compile(open(os.path.join(script_dir, 'max_kernels/nv.fp16_fp32_fp32.swizzled_smem_input.cu')).read())) + args = (c, a, b) + kwargs = { + 'global_size': [M//64, N//128, 1], + 'local_size': [128, 1, 1], # 4 warpgroups == (T_M:=2) * (T_N:=2) + 'wait': True, + 'vals': (N, K), + } + elif GEMM_VARIATION == "flat_smem_input" and (M%64)==0 and (N%128)==0 and (K%64)==0 and DTYPE_IN == dtypes.half and DTYPE_OUT == dtypes.float and DTYPE_ACC == dtypes.float: + print("Using CUDA, flat SMEM inputs") + prog = CUDAProgram(device, "wmma_example", compiler.compile(open(os.path.join(script_dir, 'max_kernels/nv.fp16_fp32_fp32.flat_smem_input.cu')).read())) + args = (c, a, b) + kwargs = { + 'global_size': [M//64, N//128, 1], + 'local_size': [128, 1, 1], # 4 warpgroups == (T_M:=2) * (T_N:=2) + 'wait': True, + 'vals': (N, K), + } + elif GEMM_VARIATION == "hcopt" and M == N == K == 4096 and DTYPE_IN == dtypes.half and DTYPE_OUT == dtypes.half and DTYPE_ACC == dtypes.float: + print("Using CUDA and generated hcopt") + # [Opt(op=OptOps.TC, axis=0, amt=0), Opt(op=OptOps.UPCAST, axis=0, amt=4), Opt(op=OptOps.UPCAST, axis=1, amt=4), Opt(op=OptOps.LOCAL, axis=1, amt=4)] + prog = CUDAProgram(device, "wmma_example", compiler.compile(open(os.path.join(script_dir, 'max_kernels/nv.fp16_fp32_fp16.hcopt.cu')).read())) + args = (c, a, b) + kwargs = { + 'global_size': [32, 64, 1], + 'local_size': [16, 2, 4], # 16,2 are warp, 4 workgroups upcasted to axis=1 + 'wait': True, + } + elif GEMM_VARIATION == "2_stage" and (M%64)== 0 and (N%128)==0 and (K%64)==0 and DTYPE_IN == dtypes.half and DTYPE_OUT == dtypes.half and DTYPE_ACC == dtypes.half: + print("Using CUDA and un-optimized 2-stage, swizzled SMEM inputs and direct acc to output kernel") + prog = CUDAProgram(device, "wmma_example", compiler.compile(open(os.path.join(script_dir, 'max_kernels/nv.fp16_fp16_fp16.2_stage.cu')).read())) + args = (c, a, b) + kwargs = { + 'global_size': [M//64, N//128, 1], + 'local_size': [128, 1, 1], # 4 warpgroups == (T_M:=2) * (T_N:=2) + 'wait': True, + 'vals': (N, K), + } + elif GEMM_VARIATION == "3_stage" and (M%256)== 0 and (N%128)==0 and (K%32)==0 and DTYPE_IN == dtypes.half and DTYPE_OUT == dtypes.half and DTYPE_ACC == dtypes.half: + print("Using CUDA and 3-stage (interleave global copies and ldmatrix)") + prog = CUDAProgram(device, "wmma_example", compiler.compile(open(os.path.join(script_dir, 'max_kernels/nv.fp16_fp16_fp16.3_stage.cu')).read()), 73728) + args = (c, a, b) + kwargs = { + 'global_size': [M//256, N//128, 1], + 'local_size': [32, 4, 2], # 8 warpgroups, WG_M=4 and WG_N=2 + 'wait': True, + 'vals': (N, K), + } + elif GEMM_VARIATION == "3_stage_swizzled" and (M%256)== 0 and (N%128)==0 and (K%32)==0 and DTYPE_IN == dtypes.half and DTYPE_OUT == dtypes.half and DTYPE_ACC == dtypes.half: + print("Using CUDA and 3-stage (interleave global copies and ldmatrix) and swizzled SMEM inputs") + prog = CUDAProgram(device, "wmma_example", compiler.compile(open(os.path.join(script_dir, 'max_kernels/nv.fp16_fp16_fp16.3_stage_swizzled.cu')).read()), 73728) + args = (c, a, b) + kwargs = { + 'global_size': [M//256, N//128, 1], + 'local_size': [32, 4, 2], # 8 warpgroups, WG_M=4 and WG_N=2 + 'wait': True, + 'vals': (N, K), + } + elif GEMM_VARIATION == "max" and (M%256)== 0 and (N%128)==0 and (K%32)==0 and DTYPE_IN == dtypes.half and DTYPE_OUT == dtypes.half and DTYPE_ACC == dtypes.half: + print("Using CUDA and 3-stage (interleave global copies and ldmatrix), swizzled SMEM inputs and epilogue") + prog = CUDAProgram(device, "wmma_example", compiler.compile(open(os.path.join(script_dir, 'max_kernels/nv.fp16_fp16_fp16.max.cu')).read()), 73728) + args = (c, a, b) + kwargs = { + 'global_size': [M//256, N//128, 1], + 'local_size': [32, 4, 2], # 8 warpgroups, WG_M=4 and WG_N=2 + 'wait': True, + 'vals': (N, K), + } + elif GEMM_VARIATION == "no_xor" and (M%256)== 0 and (N%128)==0 and (K%32)==0 and DTYPE_IN == dtypes.half and DTYPE_OUT == dtypes.half and DTYPE_ACC == dtypes.half: + print("Using CUDA and 3-stage (interleave global copies and ldmatrix), swizzled SMEM inputs and epilogue") + prog = CUDAProgram(device, "wmma_example", compiler.compile(open(os.path.join(script_dir, 'max_kernels/nv.fp16_fp16_fp16.no_xor.cu')).read()), 73728) + args = (c, a, b) + kwargs = { + 'global_size': [M//256, N//128, 1], + 'local_size': [32, 4, 2], # 8 warpgroups, WG_M=4 and WG_N=2 + 'wait': True, + 'vals': (N, K), + } + else: + raise RuntimeError(f"invalid gemm variation: {GEMM_VARIATION=} {M=} {N=} {K=} {DTYPE_IN=} {DTYPE_OUT=} {DTYPE_ACC=}") + + tms = [] + na, nb, nc = randoms() + cudaalloc.copyin(a, bytearray(na)) + cudaalloc.copyin(b, bytearray(nb)) + for i in range(CNT): + tms.append(prog(*args, **kwargs)) + cudaalloc.copyout(flat_mv(nc.data), c) + comp = na.astype(np.float32) @ nb.astype(np.float32) + result = nc.reshape(M, N).astype(np.float32) + + print(f"{N*N:10d} {min(tms)*1e6:9.2f} us, would be {FLOPS*1e-9/min(tms):9.2f} GFLOPS matmul, {BW*1e-9/min(tms):.2f} GB/s") + try: + np.testing.assert_allclose(result, comp, atol=ATOL, rtol=RTOL) + except AssertionError as e: + if getenv("DEBUG_VALUES") > 0: + indices = np.where(~np.isclose(result, comp, rtol=RTOL, atol=ATOL)) + non_matching_elements_result = result[indices] + non_matching_elements_comp = comp[indices] + print("valid :", np.where(np.isclose(result, comp, rtol=RTOL, atol=ATOL))) + print("invalid :", indices) + print("result :", non_matching_elements_result) + print("ground truth:", non_matching_elements_comp) + print("result sum :", np.sum(result)) + print("ground sum :", np.sum(comp)) + raise e + + if getenv("DEBUG_VALUES") > 0: + print(comp) + print("ground sum :", np.sum(comp)) + print(result) + print("result sum :", np.sum(result)) + + elif getenv("AMD") == 1: + # note: https://hipfft.readthedocs.io/en/rocm-6.1.2/how-to/fine-tuning-llms/optimizing-triton-kernel.html + + # also this is different than the rocblas/tensile approach to GEMM + # see: https://github.com/ROCm/Tensile/blob/develop/Tensile/KernelWriterAssembly.py + raise RuntimeError("invalid max_matmul device") + + else: + raise RuntimeError("invalid max_matmul device") +