diff --git a/icicle/appUtils/ntt/kernel_ntt.cu b/icicle/appUtils/ntt/kernel_ntt.cu index a45c90b5..55ea561e 100644 --- a/icicle/appUtils/ntt/kernel_ntt.cu +++ b/icicle/appUtils/ntt/kernel_ntt.cu @@ -56,7 +56,15 @@ namespace ntt { // Note: the following reorder kernels are fused with normalization for INTT template static __global__ void reorder_digits_inplace_and_normalize_kernel( - E* arr, uint32_t log_size, bool dit, bool fast_tw, eRevType rev_type, bool is_normalize, S inverse_N) + E* arr, + uint32_t log_size, + bool columns_batch, + uint32_t batch_size, + bool dit, + bool fast_tw, + eRevType rev_type, + bool is_normalize, + S inverse_N) { // launch N threads (per batch element) // each thread starts from one index and calculates the corresponding group @@ -65,19 +73,20 @@ namespace ntt { const uint32_t size = 1 << log_size; const uint32_t tid = blockDim.x * blockIdx.x + threadIdx.x; - const uint32_t idx = tid % size; - const uint32_t batch_idx = tid / size; + const uint32_t idx = columns_batch ? tid / batch_size : tid % size; + const uint32_t batch_idx = columns_batch ? tid % batch_size : tid / size; + if (tid >= size * batch_size) return; uint32_t next_element = idx; uint32_t group[MAX_GROUP_SIZE]; - group[0] = next_element + size * batch_idx; + group[0] = columns_batch ? next_element * batch_size + batch_idx : next_element + size * batch_idx; uint32_t i = 1; for (; i < MAX_GROUP_SIZE;) { next_element = generalized_rev(next_element, log_size, dit, fast_tw, rev_type); if (next_element < idx) return; // not handling this group if (next_element == idx) break; // calculated whole group - group[i++] = next_element + size * batch_idx; + group[i++] = columns_batch ? next_element * batch_size + batch_idx : next_element + size * batch_idx; } --i; @@ -94,6 +103,9 @@ namespace ntt { E* arr, E* arr_reordered, uint32_t log_size, + bool columns_batch, + uint32_t batch_size, + uint32_t columns_batch_size, bool dit, bool fast_tw, eRevType rev_type, @@ -101,29 +113,33 @@ namespace ntt { S inverse_N) { uint32_t tid = blockDim.x * blockIdx.x + threadIdx.x; + if (tid >= (1 << log_size) * batch_size) return; uint32_t rd = tid; - uint32_t wr = - ((tid >> log_size) << log_size) + generalized_rev(tid & ((1 << log_size) - 1), log_size, dit, fast_tw, rev_type); - arr_reordered[wr] = is_normalize ? arr[rd] * inverse_N : arr[rd]; + uint32_t wr = (columns_batch ? 0 : ((tid >> log_size) << log_size)) + + generalized_rev((tid / columns_batch_size) & ((1 << log_size) - 1), log_size, dit, fast_tw, rev_type); + arr_reordered[wr * columns_batch_size + (tid % columns_batch_size)] = is_normalize ? arr[rd] * inverse_N : arr[rd]; } template - static __global__ void batch_elementwise_mul_with_reorder( + static __global__ void batch_elementwise_mul_with_reorder_kernel( E* in_vec, - int n_elements, - int batch_size, + uint32_t size, + bool columns_batch, + uint32_t batch_size, + uint32_t columns_batch_size, S* scalar_vec, int step, int n_scalars, - int logn, + uint32_t log_size, eRevType rev_type, bool dit, E* out_vec) { int tid = blockDim.x * blockIdx.x + threadIdx.x; - if (tid >= n_elements * batch_size) return; - int64_t scalar_id = tid % n_elements; - if (rev_type != eRevType::None) scalar_id = generalized_rev(tid, logn, dit, false, rev_type); + if (tid >= size * batch_size) return; + int64_t scalar_id = (tid / columns_batch_size) % size; + if (rev_type != eRevType::None) + scalar_id = generalized_rev((tid / columns_batch_size) & ((1 << log_size) - 1), log_size, dit, false, rev_type); out_vec[tid] = *(scalar_vec + ((scalar_id * step) % n_scalars)) * in_vec[tid]; } @@ -136,6 +152,7 @@ namespace ntt { S* basic_twiddles, uint32_t log_size, uint32_t tw_log_size, + uint32_t columns_batch_size, uint32_t nof_ntt_blocks, uint32_t data_stride, uint32_t log_data_stride, @@ -153,19 +170,27 @@ namespace ntt { s_meta.th_stride = 8; s_meta.ntt_block_size = 64; - s_meta.ntt_block_id = (blockIdx.x << 3) + (strided ? (threadIdx.x & 0x7) : (threadIdx.x >> 3)); + s_meta.ntt_block_id = columns_batch_size ? blockIdx.x / ((columns_batch_size + 7) / 8) + : (blockIdx.x << 3) + (strided ? (threadIdx.x & 0x7) : (threadIdx.x >> 3)); s_meta.ntt_inp_id = strided ? (threadIdx.x >> 3) : (threadIdx.x & 0x7); - if (s_meta.ntt_block_id >= nof_ntt_blocks) return; + s_meta.batch_id = + columns_batch_size ? (threadIdx.x & 0x7) + ((blockIdx.x % ((columns_batch_size + 7) / 8)) << 3) : 0; + if (s_meta.ntt_block_id >= nof_ntt_blocks || (columns_batch_size > 0 && s_meta.batch_id >= columns_batch_size)) + return; if (fast_tw) engine.loadBasicTwiddles(basic_twiddles); else engine.loadBasicTwiddlesGeneric(basic_twiddles, inv); - engine.loadGlobalData(in, data_stride, log_data_stride, log_size, strided, s_meta); + if (columns_batch_size) + engine.loadGlobalDataColumnBatch(in, data_stride, log_data_stride, s_meta, columns_batch_size); + else + engine.loadGlobalData(in, data_stride, log_data_stride, strided, s_meta); + if (twiddle_stride && dit) { if (fast_tw) - engine.loadExternalTwiddles64(external_twiddles, twiddle_stride, log_data_stride, strided, s_meta); + engine.loadExternalTwiddles64(external_twiddles, twiddle_stride, log_data_stride, s_meta); else engine.loadExternalTwiddlesGeneric64( external_twiddles, twiddle_stride, log_data_stride, s_meta, tw_log_size, inv); @@ -189,13 +214,16 @@ namespace ntt { if (twiddle_stride && !dit) { if (fast_tw) - engine.loadExternalTwiddles64(external_twiddles, twiddle_stride, log_data_stride, strided, s_meta); + engine.loadExternalTwiddles64(external_twiddles, twiddle_stride, log_data_stride, s_meta); else engine.loadExternalTwiddlesGeneric64( external_twiddles, twiddle_stride, log_data_stride, s_meta, tw_log_size, inv); engine.twiddlesExternal(); } - engine.storeGlobalData(out, data_stride, log_data_stride, log_size, strided, s_meta); + if (columns_batch_size) + engine.storeGlobalDataColumnBatch(out, data_stride, log_data_stride, s_meta, columns_batch_size); + else + engine.storeGlobalData(out, data_stride, log_data_stride, strided, s_meta); } template @@ -207,6 +235,7 @@ namespace ntt { S* basic_twiddles, uint32_t log_size, uint32_t tw_log_size, + uint32_t columns_batch_size, uint32_t nof_ntt_blocks, uint32_t data_stride, uint32_t log_data_stride, @@ -225,16 +254,25 @@ namespace ntt { s_meta.th_stride = 4; s_meta.ntt_block_size = 32; - s_meta.ntt_block_id = (blockIdx.x << 4) + (strided ? (threadIdx.x & 0xf) : (threadIdx.x >> 2)); + s_meta.ntt_block_id = columns_batch_size ? blockIdx.x / ((columns_batch_size + 15) / 16) + : (blockIdx.x << 4) + (strided ? (threadIdx.x & 0xf) : (threadIdx.x >> 2)); s_meta.ntt_inp_id = strided ? (threadIdx.x >> 4) : (threadIdx.x & 0x3); - if (s_meta.ntt_block_id >= nof_ntt_blocks) return; + s_meta.batch_id = + columns_batch_size ? (threadIdx.x & 0xf) + ((blockIdx.x % ((columns_batch_size + 15) / 16)) << 4) : 0; + if (s_meta.ntt_block_id >= nof_ntt_blocks || (columns_batch_size > 0 && s_meta.batch_id >= columns_batch_size)) + return; if (fast_tw) engine.loadBasicTwiddles(basic_twiddles); else engine.loadBasicTwiddlesGeneric(basic_twiddles, inv); - engine.loadGlobalData(in, data_stride, log_data_stride, log_size, strided, s_meta); + + if (columns_batch_size) + engine.loadGlobalDataColumnBatch(in, data_stride, log_data_stride, s_meta, columns_batch_size); + else + engine.loadGlobalData(in, data_stride, log_data_stride, strided, s_meta); + if (fast_tw) engine.loadInternalTwiddles32(internal_twiddles, strided); else @@ -247,13 +285,16 @@ namespace ntt { engine.ntt4_2(); if (twiddle_stride) { if (fast_tw) - engine.loadExternalTwiddles32(external_twiddles, twiddle_stride, log_data_stride, strided, s_meta); + engine.loadExternalTwiddles32(external_twiddles, twiddle_stride, log_data_stride, s_meta); else engine.loadExternalTwiddlesGeneric32( external_twiddles, twiddle_stride, log_data_stride, s_meta, tw_log_size, inv); engine.twiddlesExternal(); } - engine.storeGlobalData32(out, data_stride, log_data_stride, log_size, strided, s_meta); + if (columns_batch_size) + engine.storeGlobalData32ColumnBatch(out, data_stride, log_data_stride, s_meta, columns_batch_size); + else + engine.storeGlobalData32(out, data_stride, log_data_stride, strided, s_meta); } template @@ -265,6 +306,7 @@ namespace ntt { S* basic_twiddles, uint32_t log_size, uint32_t tw_log_size, + uint32_t columns_batch_size, uint32_t nof_ntt_blocks, uint32_t data_stride, uint32_t log_data_stride, @@ -283,19 +325,27 @@ namespace ntt { s_meta.th_stride = 4; s_meta.ntt_block_size = 32; - s_meta.ntt_block_id = (blockIdx.x << 4) + (strided ? (threadIdx.x & 0xf) : (threadIdx.x >> 2)); + s_meta.ntt_block_id = columns_batch_size ? blockIdx.x / ((columns_batch_size + 15) / 16) + : (blockIdx.x << 4) + (strided ? (threadIdx.x & 0xf) : (threadIdx.x >> 2)); s_meta.ntt_inp_id = strided ? (threadIdx.x >> 4) : (threadIdx.x & 0x3); - if (s_meta.ntt_block_id >= nof_ntt_blocks) return; + s_meta.batch_id = + columns_batch_size ? (threadIdx.x & 0xf) + ((blockIdx.x % ((columns_batch_size + 15) / 16)) << 4) : 0; + if (s_meta.ntt_block_id >= nof_ntt_blocks || (columns_batch_size > 0 && s_meta.batch_id >= columns_batch_size)) + return; if (fast_tw) engine.loadBasicTwiddles(basic_twiddles); else engine.loadBasicTwiddlesGeneric(basic_twiddles, inv); - engine.loadGlobalData32(in, data_stride, log_data_stride, log_size, strided, s_meta); + + if (columns_batch_size) + engine.loadGlobalData32ColumnBatch(in, data_stride, log_data_stride, s_meta, columns_batch_size); + else + engine.loadGlobalData32(in, data_stride, log_data_stride, strided, s_meta); if (twiddle_stride) { if (fast_tw) - engine.loadExternalTwiddles32(external_twiddles, twiddle_stride, log_data_stride, strided, s_meta); + engine.loadExternalTwiddles32(external_twiddles, twiddle_stride, log_data_stride, s_meta); else engine.loadExternalTwiddlesGeneric32( external_twiddles, twiddle_stride, log_data_stride, s_meta, tw_log_size, inv); @@ -311,7 +361,10 @@ namespace ntt { engine.SharedData32Rows8(shmem, false, false, strided); // load engine.twiddlesInternal(); engine.ntt8win(); - engine.storeGlobalData(out, data_stride, log_data_stride, log_size, strided, s_meta); + if (columns_batch_size) + engine.storeGlobalDataColumnBatch(out, data_stride, log_data_stride, s_meta, columns_batch_size); + else + engine.storeGlobalData(out, data_stride, log_data_stride, strided, s_meta); } template @@ -323,6 +376,7 @@ namespace ntt { S* basic_twiddles, uint32_t log_size, uint32_t tw_log_size, + uint32_t columns_batch_size, uint32_t nof_ntt_blocks, uint32_t data_stride, uint32_t log_data_stride, @@ -341,16 +395,26 @@ namespace ntt { s_meta.th_stride = 2; s_meta.ntt_block_size = 16; - s_meta.ntt_block_id = (blockIdx.x << 5) + (strided ? (threadIdx.x & 0x1f) : (threadIdx.x >> 1)); + s_meta.ntt_block_id = columns_batch_size + ? blockIdx.x / ((columns_batch_size + 31) / 32) + : (blockIdx.x << 5) + (strided ? (threadIdx.x & 0x1f) : (threadIdx.x >> 1)); s_meta.ntt_inp_id = strided ? (threadIdx.x >> 5) : (threadIdx.x & 0x1); - if (s_meta.ntt_block_id >= nof_ntt_blocks) return; + s_meta.batch_id = + columns_batch_size ? (threadIdx.x & 0x1f) + ((blockIdx.x % ((columns_batch_size + 31) / 32)) << 5) : 0; + if (s_meta.ntt_block_id >= nof_ntt_blocks || (columns_batch_size > 0 && s_meta.batch_id >= columns_batch_size)) + return; if (fast_tw) engine.loadBasicTwiddles(basic_twiddles); else engine.loadBasicTwiddlesGeneric(basic_twiddles, inv); - engine.loadGlobalData(in, data_stride, log_data_stride, log_size, strided, s_meta); + + if (columns_batch_size) + engine.loadGlobalDataColumnBatch(in, data_stride, log_data_stride, s_meta, columns_batch_size); + else + engine.loadGlobalData(in, data_stride, log_data_stride, strided, s_meta); + if (fast_tw) engine.loadInternalTwiddles16(internal_twiddles, strided); else @@ -363,13 +427,16 @@ namespace ntt { engine.ntt2_4(); if (twiddle_stride) { if (fast_tw) - engine.loadExternalTwiddles16(external_twiddles, twiddle_stride, log_data_stride, strided, s_meta); + engine.loadExternalTwiddles16(external_twiddles, twiddle_stride, log_data_stride, s_meta); else engine.loadExternalTwiddlesGeneric16( external_twiddles, twiddle_stride, log_data_stride, s_meta, tw_log_size, inv); engine.twiddlesExternal(); } - engine.storeGlobalData16(out, data_stride, log_data_stride, log_size, strided, s_meta); + if (columns_batch_size) + engine.storeGlobalData16ColumnBatch(out, data_stride, log_data_stride, s_meta, columns_batch_size); + else + engine.storeGlobalData16(out, data_stride, log_data_stride, strided, s_meta); } template @@ -381,6 +448,7 @@ namespace ntt { S* basic_twiddles, uint32_t log_size, uint32_t tw_log_size, + uint32_t columns_batch_size, uint32_t nof_ntt_blocks, uint32_t data_stride, uint32_t log_data_stride, @@ -399,19 +467,29 @@ namespace ntt { s_meta.th_stride = 2; s_meta.ntt_block_size = 16; - s_meta.ntt_block_id = (blockIdx.x << 5) + (strided ? (threadIdx.x & 0x1f) : (threadIdx.x >> 1)); + s_meta.ntt_block_id = columns_batch_size + ? blockIdx.x / ((columns_batch_size + 31) / 32) + : (blockIdx.x << 5) + (strided ? (threadIdx.x & 0x1f) : (threadIdx.x >> 1)); s_meta.ntt_inp_id = strided ? (threadIdx.x >> 5) : (threadIdx.x & 0x1); - if (s_meta.ntt_block_id >= nof_ntt_blocks) return; + s_meta.batch_id = + columns_batch_size ? (threadIdx.x & 0x1f) + ((blockIdx.x % ((columns_batch_size + 31) / 32)) << 5) : 0; + if (s_meta.ntt_block_id >= nof_ntt_blocks || (columns_batch_size > 0 && s_meta.batch_id >= columns_batch_size)) + return; if (fast_tw) engine.loadBasicTwiddles(basic_twiddles); else engine.loadBasicTwiddlesGeneric(basic_twiddles, inv); - engine.loadGlobalData16(in, data_stride, log_data_stride, log_size, strided, s_meta); + + if (columns_batch_size) + engine.loadGlobalData16ColumnBatch(in, data_stride, log_data_stride, s_meta, columns_batch_size); + else + engine.loadGlobalData16(in, data_stride, log_data_stride, strided, s_meta); + if (twiddle_stride) { if (fast_tw) - engine.loadExternalTwiddles16(external_twiddles, twiddle_stride, log_data_stride, strided, s_meta); + engine.loadExternalTwiddles16(external_twiddles, twiddle_stride, log_data_stride, s_meta); else engine.loadExternalTwiddlesGeneric16( external_twiddles, twiddle_stride, log_data_stride, s_meta, tw_log_size, inv); @@ -427,13 +505,17 @@ namespace ntt { engine.SharedData16Rows8(shmem, false, false, strided); // load engine.twiddlesInternal(); engine.ntt8win(); - engine.storeGlobalData(out, data_stride, log_data_stride, log_size, strided, s_meta); + if (columns_batch_size) + engine.storeGlobalDataColumnBatch(out, data_stride, log_data_stride, s_meta, columns_batch_size); + else + engine.storeGlobalData(out, data_stride, log_data_stride, strided, s_meta); } template - __global__ void normalize_kernel(E* data, S norm_factor) + __global__ void normalize_kernel(E* data, S norm_factor, uint32_t size) { uint32_t tid = blockIdx.x * blockDim.x + threadIdx.x; + if (tid >= size) return; data[tid] = data[tid] * norm_factor; } @@ -666,6 +748,7 @@ namespace ntt { uint32_t log_size, uint32_t tw_log_size, uint32_t batch_size, + bool columns_batch, bool inv, bool normalize, bool dit, @@ -679,72 +762,83 @@ namespace ntt { } if (log_size == 4) { - const int NOF_THREADS = min(64, 2 * batch_size); - const int NOF_BLOCKS = (2 * batch_size + NOF_THREADS - 1) / NOF_THREADS; - + const int NOF_THREADS = columns_batch ? 64 : min(64, 2 * batch_size); + const int NOF_BLOCKS = + columns_batch ? ((batch_size + 31) / 32) : (2 * batch_size + NOF_THREADS - 1) / NOF_THREADS; if (dit) { ntt16dit<<>>( - in, out, external_twiddles, internal_twiddles, basic_twiddles, log_size, tw_log_size, batch_size, 1, 0, 0, - false, 0, inv, dit, fast_tw); + in, out, external_twiddles, internal_twiddles, basic_twiddles, log_size, tw_log_size, + columns_batch ? batch_size : 0, columns_batch ? 1 : batch_size, 1, 0, 0, columns_batch, 0, inv, dit, fast_tw); } else { // dif ntt16<<>>( - in, out, external_twiddles, internal_twiddles, basic_twiddles, log_size, tw_log_size, batch_size, 1, 0, 0, - false, 0, inv, dit, fast_tw); + in, out, external_twiddles, internal_twiddles, basic_twiddles, log_size, tw_log_size, + columns_batch ? batch_size : 0, columns_batch ? 1 : batch_size, 1, 0, 0, columns_batch, 0, inv, dit, fast_tw); } - if (normalize) normalize_kernel<<>>(out, S::inv_log_size(4)); + if (normalize) + normalize_kernel<<>>(out, S::inv_log_size(4), (1 << log_size) * batch_size); return CHK_LAST(); } if (log_size == 5) { - const int NOF_THREADS = min(64, 4 * batch_size); - const int NOF_BLOCKS = (4 * batch_size + NOF_THREADS - 1) / NOF_THREADS; + const int NOF_THREADS = columns_batch ? 64 : min(64, 4 * batch_size); + const int NOF_BLOCKS = + columns_batch ? ((batch_size + 15) / 16) : (4 * batch_size + NOF_THREADS - 1) / NOF_THREADS; if (dit) { ntt32dit<<>>( - in, out, external_twiddles, internal_twiddles, basic_twiddles, log_size, tw_log_size, batch_size, 1, 0, 0, - false, 0, inv, dit, fast_tw); + in, out, external_twiddles, internal_twiddles, basic_twiddles, log_size, tw_log_size, + columns_batch ? batch_size : 0, columns_batch ? 1 : batch_size, 1, 0, 0, columns_batch, 0, inv, dit, fast_tw); } else { // dif ntt32<<>>( - in, out, external_twiddles, internal_twiddles, basic_twiddles, log_size, tw_log_size, batch_size, 1, 0, 0, - false, 0, inv, dit, fast_tw); + in, out, external_twiddles, internal_twiddles, basic_twiddles, log_size, tw_log_size, + columns_batch ? batch_size : 0, columns_batch ? 1 : batch_size, 1, 0, 0, columns_batch, 0, inv, dit, fast_tw); } - if (normalize) normalize_kernel<<>>(out, S::inv_log_size(5)); + if (normalize) + normalize_kernel<<>>(out, S::inv_log_size(5), (1 << log_size) * batch_size); return CHK_LAST(); } if (log_size == 6) { - const int NOF_THREADS = min(64, 8 * batch_size); - const int NOF_BLOCKS = (8 * batch_size + NOF_THREADS - 1) / NOF_THREADS; + const int NOF_THREADS = columns_batch ? 64 : min(64, 8 * batch_size); + const int NOF_BLOCKS = + columns_batch ? ((batch_size + 7) / 8) : ((8 * batch_size + NOF_THREADS - 1) / NOF_THREADS); ntt64<<>>( - in, out, external_twiddles, internal_twiddles, basic_twiddles, log_size, tw_log_size, batch_size, 1, 0, 0, - false, 0, inv, dit, fast_tw); - if (normalize) normalize_kernel<<>>(out, S::inv_log_size(6)); + in, out, external_twiddles, internal_twiddles, basic_twiddles, log_size, tw_log_size, + columns_batch ? batch_size : 0, columns_batch ? 1 : batch_size, 1, 0, 0, columns_batch, 0, inv, dit, fast_tw); + if (normalize) + normalize_kernel<<>>(out, S::inv_log_size(6), (1 << log_size) * batch_size); return CHK_LAST(); } if (log_size == 8) { const int NOF_THREADS = 64; - const int NOF_BLOCKS = (32 * batch_size + NOF_THREADS - 1) / NOF_THREADS; + const int NOF_BLOCKS = + columns_batch ? ((batch_size + 31) / 32 * 16) : ((32 * batch_size + NOF_THREADS - 1) / NOF_THREADS); if (dit) { ntt16dit<<>>( in, out, external_twiddles, internal_twiddles, basic_twiddles, log_size, tw_log_size, - (1 << log_size - 4) * batch_size, 1, 0, 0, false, 0, inv, dit, fast_tw); + columns_batch ? batch_size : 0, (1 << log_size - 4) * (columns_batch ? 1 : batch_size), 1, 0, 0, + columns_batch, 0, inv, dit, fast_tw); ntt16dit<<>>( out, out, external_twiddles, internal_twiddles, basic_twiddles, log_size, tw_log_size, - (1 << log_size - 4) * batch_size, 16, 4, 16, true, 1, inv, dit, fast_tw); + columns_batch ? batch_size : 0, (1 << log_size - 4) * (columns_batch ? 1 : batch_size), 16, 4, 16, true, 1, + inv, dit, fast_tw); } else { // dif ntt16<<>>( in, out, external_twiddles, internal_twiddles, basic_twiddles, log_size, tw_log_size, - (1 << log_size - 4) * batch_size, 16, 4, 16, true, 1, inv, dit, fast_tw); + columns_batch ? batch_size : 0, (1 << log_size - 4) * (columns_batch ? 1 : batch_size), 16, 4, 16, true, 1, + inv, dit, fast_tw); ntt16<<>>( out, out, external_twiddles, internal_twiddles, basic_twiddles, log_size, tw_log_size, - (1 << log_size - 4) * batch_size, 1, 0, 0, false, 0, inv, dit, fast_tw); + columns_batch ? batch_size : 0, (1 << log_size - 4) * (columns_batch ? 1 : batch_size), 1, 0, 0, + columns_batch, 0, inv, dit, fast_tw); } - if (normalize) normalize_kernel<<>>(out, S::inv_log_size(8)); + if (normalize) + normalize_kernel<<>>(out, S::inv_log_size(8), (1 << log_size) * batch_size); return CHK_LAST(); } // general case: - uint32_t nof_blocks = (1 << (log_size - 9)) * batch_size; + uint32_t nof_blocks = (1 << (log_size - 9)) * (columns_batch ? ((batch_size + 31) / 32) * 32 : batch_size); if (dit) { for (int i = 0; i < 5; i++) { uint32_t stage_size = fast_tw ? STAGE_SIZES_HOST_FT[log_size][i] : STAGE_SIZES_HOST[log_size][i]; @@ -754,18 +848,18 @@ namespace ntt { if (stage_size == 6) ntt64<<>>( i ? out : in, out, external_twiddles, internal_twiddles, basic_twiddles, log_size, tw_log_size, - (1 << log_size - 6) * batch_size, 1 << stride_log, stride_log, i ? (1 << stride_log) : 0, i, i, inv, dit, - fast_tw); + columns_batch ? batch_size : 0, (1 << log_size - 6) * (columns_batch ? 1 : batch_size), 1 << stride_log, + stride_log, i ? (1 << stride_log) : 0, i || columns_batch, i, inv, dit, fast_tw); else if (stage_size == 5) ntt32dit<<>>( i ? out : in, out, external_twiddles, internal_twiddles, basic_twiddles, log_size, tw_log_size, - (1 << log_size - 5) * batch_size, 1 << stride_log, stride_log, i ? (1 << stride_log) : 0, i, i, inv, dit, - fast_tw); + columns_batch ? batch_size : 0, (1 << log_size - 5) * (columns_batch ? 1 : batch_size), 1 << stride_log, + stride_log, i ? (1 << stride_log) : 0, i || columns_batch, i, inv, dit, fast_tw); else if (stage_size == 4) ntt16dit<<>>( i ? out : in, out, external_twiddles, internal_twiddles, basic_twiddles, log_size, tw_log_size, - (1 << log_size - 4) * batch_size, 1 << stride_log, stride_log, i ? (1 << stride_log) : 0, i, i, inv, dit, - fast_tw); + columns_batch ? batch_size : 0, (1 << log_size - 4) * (columns_batch ? 1 : batch_size), 1 << stride_log, + stride_log, i ? (1 << stride_log) : 0, i || columns_batch, i, inv, dit, fast_tw); } } else { // dif bool first_run = false, prev_stage = false; @@ -778,23 +872,24 @@ namespace ntt { if (stage_size == 6) ntt64<<>>( first_run ? in : out, out, external_twiddles, internal_twiddles, basic_twiddles, log_size, tw_log_size, - (1 << log_size - 6) * batch_size, 1 << stride_log, stride_log, i ? (1 << stride_log) : 0, i, i, inv, dit, - fast_tw); + columns_batch ? batch_size : 0, (1 << log_size - 6) * (columns_batch ? 1 : batch_size), 1 << stride_log, + stride_log, i ? (1 << stride_log) : 0, i || columns_batch, i, inv, dit, fast_tw); else if (stage_size == 5) ntt32<<>>( first_run ? in : out, out, external_twiddles, internal_twiddles, basic_twiddles, log_size, tw_log_size, - (1 << log_size - 5) * batch_size, 1 << stride_log, stride_log, i ? (1 << stride_log) : 0, i, i, inv, dit, - fast_tw); + columns_batch ? batch_size : 0, (1 << log_size - 5) * (columns_batch ? 1 : batch_size), 1 << stride_log, + stride_log, i ? (1 << stride_log) : 0, i || columns_batch, i, inv, dit, fast_tw); else if (stage_size == 4) ntt16<<>>( first_run ? in : out, out, external_twiddles, internal_twiddles, basic_twiddles, log_size, tw_log_size, - (1 << log_size - 4) * batch_size, 1 << stride_log, stride_log, i ? (1 << stride_log) : 0, i, i, inv, dit, - fast_tw); + columns_batch ? batch_size : 0, (1 << log_size - 4) * (columns_batch ? 1 : batch_size), 1 << stride_log, + stride_log, i ? (1 << stride_log) : 0, i || columns_batch, i, inv, dit, fast_tw); prev_stage = stage_size; } } if (normalize) - normalize_kernel<<<(1 << (log_size - 8)) * batch_size, 256, 0, cuda_stream>>>(out, S::inv_log_size(log_size)); + normalize_kernel<<<(1 << (log_size - 8)) * batch_size, 256, 0, cuda_stream>>>( + out, S::inv_log_size(log_size), (1 << log_size) * batch_size); return CHK_LAST(); } @@ -809,6 +904,7 @@ namespace ntt { int ntt_size, int max_logn, int batch_size, + bool columns_batch, bool is_inverse, bool fast_tw, Ordering ordering, @@ -858,9 +954,10 @@ namespace ntt { } if (is_on_coset && !is_inverse) { - batch_elementwise_mul_with_reorder<<>>( - d_input, ntt_size, batch_size, arbitrary_coset ? arbitrary_coset : external_twiddles, - arbitrary_coset ? 1 : coset_gen_index, n_twiddles, logn, reverse_coset, dit, d_output); + batch_elementwise_mul_with_reorder_kernel<<>>( + d_input, ntt_size, columns_batch, batch_size, columns_batch ? batch_size : 1, + arbitrary_coset ? arbitrary_coset : external_twiddles, arbitrary_coset ? 1 : coset_gen_index, n_twiddles, logn, + reverse_coset, dit, d_output); d_input = d_output; } @@ -869,10 +966,11 @@ namespace ntt { const bool is_reverse_in_place = (d_input == d_output); if (is_reverse_in_place) { reorder_digits_inplace_and_normalize_kernel<<>>( - d_output, logn, dit, fast_tw, reverse_input, is_normalize, S::inv_log_size(logn)); + d_output, logn, columns_batch, batch_size, dit, fast_tw, reverse_input, is_normalize, S::inv_log_size(logn)); } else { reorder_digits_and_normalize_kernel<<>>( - d_input, d_output, logn, dit, fast_tw, reverse_input, is_normalize, S::inv_log_size(logn)); + d_input, d_output, logn, columns_batch, batch_size, columns_batch ? batch_size : 1, dit, fast_tw, + reverse_input, is_normalize, S::inv_log_size(logn)); } is_normalize = false; d_input = d_output; @@ -880,18 +978,19 @@ namespace ntt { // inplace ntt CHK_IF_RETURN(large_ntt( - d_input, d_output, external_twiddles, internal_twiddles, basic_twiddles, logn, max_logn, batch_size, is_inverse, - (is_normalize && reverse_output == eRevType::None), dit, fast_tw, cuda_stream)); + d_input, d_output, external_twiddles, internal_twiddles, basic_twiddles, logn, max_logn, batch_size, + columns_batch, is_inverse, (is_normalize && reverse_output == eRevType::None), dit, fast_tw, cuda_stream)); if (reverse_output != eRevType::None) { reorder_digits_inplace_and_normalize_kernel<<>>( - d_output, logn, dit, fast_tw, reverse_output, is_normalize, S::inv_log_size(logn)); + d_output, logn, columns_batch, batch_size, dit, fast_tw, reverse_output, is_normalize, S::inv_log_size(logn)); } if (is_on_coset && is_inverse) { - batch_elementwise_mul_with_reorder<<>>( - d_output, ntt_size, batch_size, arbitrary_coset ? arbitrary_coset : external_twiddles + n_twiddles, - arbitrary_coset ? 1 : -coset_gen_index, n_twiddles, logn, reverse_coset, dit, d_output); + batch_elementwise_mul_with_reorder_kernel<<>>( + d_output, ntt_size, columns_batch, batch_size, columns_batch ? batch_size : 1, + arbitrary_coset ? arbitrary_coset : external_twiddles + n_twiddles, arbitrary_coset ? 1 : -coset_gen_index, + n_twiddles, logn, reverse_coset, dit, d_output); } return CHK_LAST(); @@ -923,6 +1022,7 @@ namespace ntt { int ntt_size, int max_logn, int batch_size, + bool columns_batch, bool is_inverse, bool fast_tw, Ordering ordering, diff --git a/icicle/appUtils/ntt/ntt.cu b/icicle/appUtils/ntt/ntt.cu index 45c35692..871185e8 100644 --- a/icicle/appUtils/ntt/ntt.cu +++ b/icicle/appUtils/ntt/ntt.cu @@ -516,12 +516,15 @@ namespace ntt { static bool is_choose_radix2_algorithm(int logn, int batch_size, const NTTConfig& config) { const bool is_mixed_radix_alg_supported = (logn > 3 && logn != 7); + if (!is_mixed_radix_alg_supported && config.columns_batch) + throw IcicleError(IcicleError_t::InvalidArgument, "columns batch is not supported for given NTT size"); const bool is_user_selected_radix2_alg = config.ntt_algorithm == NttAlgorithm::Radix2; const bool is_force_radix2 = !is_mixed_radix_alg_supported || is_user_selected_radix2_alg; if (is_force_radix2) return true; const bool is_user_selected_mixed_radix_alg = config.ntt_algorithm == NttAlgorithm::MixedRadix; if (is_user_selected_mixed_radix_alg) return false; + if (config.columns_batch) return false; // radix2 does not currently support columns batch mode. // Heuristic to automatically select an algorithm // Note that generally the decision depends on {logn, batch, ordering, inverse, coset, in-place, coeff-field} and @@ -663,7 +666,7 @@ namespace ntt { CHK_IF_RETURN(ntt::mixed_radix_ntt( d_input, d_output, twiddles, internal_twiddles, basic_twiddles, size, domain.max_log_size, batch_size, - is_inverse, is_fast_twiddles_enabled, config.ordering, coset, coset_index, stream)); + config.columns_batch, is_inverse, is_fast_twiddles_enabled, config.ordering, coset, coset_index, stream)); } if (!are_outputs_on_device) @@ -685,6 +688,7 @@ namespace ntt { ctx, // ctx S::one(), // coset_gen 1, // batch_size + false, // columns_batch Ordering::kNN, // ordering false, // are_inputs_on_device false, // are_outputs_on_device diff --git a/icicle/appUtils/ntt/ntt.cuh b/icicle/appUtils/ntt/ntt.cuh index c91f54ff..9faf27bd 100644 --- a/icicle/appUtils/ntt/ntt.cuh +++ b/icicle/appUtils/ntt/ntt.cuh @@ -95,6 +95,8 @@ namespace ntt { S coset_gen; /**< Coset generator. Used to perform coset (i)NTTs. Default value: `S::one()` * (corresponding to no coset being used). */ int batch_size; /**< The number of NTTs to compute. Default value: 1. */ + bool columns_batch; /**< True if the batches are the columns of an input matrix + (they are strided in memory with a stride of ntt size) Default value: false. */ Ordering ordering; /**< Ordering of inputs and outputs. See [Ordering](@ref Ordering). Default value: * `Ordering::kNN`. */ bool are_inputs_on_device; /**< True if inputs are on device and false if they're on host. Default value: false. */ diff --git a/icicle/appUtils/ntt/ntt_impl.cuh b/icicle/appUtils/ntt/ntt_impl.cuh index 15587846..b4cd162d 100644 --- a/icicle/appUtils/ntt/ntt_impl.cuh +++ b/icicle/appUtils/ntt/ntt_impl.cuh @@ -35,6 +35,7 @@ namespace ntt { int ntt_size, int max_logn, int batch_size, + bool columns_batch, bool is_inverse, bool fast_tw, Ordering ordering, diff --git a/icicle/appUtils/ntt/tests/verification.cu b/icicle/appUtils/ntt/tests/verification.cu index 341d951e..751ffe09 100644 --- a/icicle/appUtils/ntt/tests/verification.cu +++ b/icicle/appUtils/ntt/tests/verification.cu @@ -29,6 +29,13 @@ void incremental_values(test_scalar* res, uint32_t count) } } +__global__ void transpose_batch(test_scalar* in, test_scalar* out, int row_size, int column_size) +{ + int tid = blockDim.x * blockIdx.x + threadIdx.x; + if (tid >= row_size * column_size) return; + out[(tid % row_size) * column_size + (tid / row_size)] = in[tid]; +} + int main(int argc, char** argv) { cudaEvent_t icicle_start, icicle_stop, new_start, new_stop; @@ -37,11 +44,12 @@ int main(int argc, char** argv) int NTT_LOG_SIZE = (argc > 1) ? atoi(argv[1]) : 19; int NTT_SIZE = 1 << NTT_LOG_SIZE; bool INPLACE = (argc > 2) ? atoi(argv[2]) : false; - int INV = (argc > 3) ? atoi(argv[3]) : true; - int BATCH_SIZE = (argc > 4) ? atoi(argv[4]) : 1; - int COSET_IDX = (argc > 5) ? atoi(argv[5]) : 0; - const ntt::Ordering ordering = (argc > 6) ? ntt::Ordering(atoi(argv[6])) : ntt::Ordering::kNN; - bool FAST_TW = (argc > 7) ? atoi(argv[7]) : true; + int INV = (argc > 3) ? atoi(argv[3]) : false; + int BATCH_SIZE = (argc > 4) ? atoi(argv[4]) : 150; + bool COLUMNS_BATCH = (argc > 5) ? atoi(argv[5]) : false; + int COSET_IDX = (argc > 6) ? atoi(argv[6]) : 2; + const ntt::Ordering ordering = (argc > 7) ? ntt::Ordering(atoi(argv[7])) : ntt::Ordering::kNN; + bool FAST_TW = (argc > 8) ? atoi(argv[8]) : true; // Note: NM, MN are not expected to be equal when comparing mixed-radix and radix-2 NTTs const char* ordering_str = ordering == ntt::Ordering::kNN ? "NN" @@ -52,8 +60,8 @@ int main(int argc, char** argv) : "MN"; printf( - "running ntt 2^%d, inplace=%d, inverse=%d, batch_size=%d, coset-idx=%d, ordering=%s, fast_tw=%d\n", NTT_LOG_SIZE, - INPLACE, INV, BATCH_SIZE, COSET_IDX, ordering_str, FAST_TW); + "running ntt 2^%d, inplace=%d, inverse=%d, batch_size=%d, columns_batch=%d coset-idx=%d, ordering=%s, fast_tw=%d\n", + NTT_LOG_SIZE, INPLACE, INV, BATCH_SIZE, COLUMNS_BATCH, COSET_IDX, ordering_str, FAST_TW); CHK_IF_RETURN(cudaFree(nullptr)); // init GPU context (warmup) @@ -63,6 +71,7 @@ int main(int argc, char** argv) ntt_config.are_inputs_on_device = true; ntt_config.are_outputs_on_device = true; ntt_config.batch_size = BATCH_SIZE; + ntt_config.columns_batch = COLUMNS_BATCH; CHK_IF_RETURN(cudaEventCreate(&icicle_start)); CHK_IF_RETURN(cudaEventCreate(&icicle_stop)); @@ -83,7 +92,9 @@ int main(int argc, char** argv) // gpu allocation test_data *GpuScalars, *GpuOutputOld, *GpuOutputNew; + test_data* GpuScalarsTransposed; CHK_IF_RETURN(cudaMalloc(&GpuScalars, sizeof(test_data) * NTT_SIZE * BATCH_SIZE)); + CHK_IF_RETURN(cudaMalloc(&GpuScalarsTransposed, sizeof(test_data) * NTT_SIZE * BATCH_SIZE)); CHK_IF_RETURN(cudaMalloc(&GpuOutputOld, sizeof(test_data) * NTT_SIZE * BATCH_SIZE)); CHK_IF_RETURN(cudaMalloc(&GpuOutputNew, sizeof(test_data) * NTT_SIZE * BATCH_SIZE)); @@ -93,10 +104,16 @@ int main(int argc, char** argv) CHK_IF_RETURN( cudaMemcpy(GpuScalars, CpuScalars.get(), NTT_SIZE * BATCH_SIZE * sizeof(test_data), cudaMemcpyHostToDevice)); + if (COLUMNS_BATCH) { + transpose_batch<<<(NTT_SIZE * BATCH_SIZE + 256 - 1) / 256, 256>>>( + GpuScalars, GpuScalarsTransposed, NTT_SIZE, BATCH_SIZE); + } + // inplace if (INPLACE) { - CHK_IF_RETURN( - cudaMemcpy(GpuOutputNew, GpuScalars, NTT_SIZE * BATCH_SIZE * sizeof(test_data), cudaMemcpyDeviceToDevice)); + CHK_IF_RETURN(cudaMemcpy( + GpuOutputNew, COLUMNS_BATCH ? GpuScalarsTransposed : GpuScalars, NTT_SIZE * BATCH_SIZE * sizeof(test_data), + cudaMemcpyDeviceToDevice)); } for (int coset_idx = 0; coset_idx < COSET_IDX; ++coset_idx) { @@ -109,13 +126,14 @@ int main(int argc, char** argv) ntt_config.ntt_algorithm = ntt::NttAlgorithm::MixedRadix; for (size_t i = 0; i < iterations; i++) { CHK_IF_RETURN(ntt::NTT( - INPLACE ? GpuOutputNew : GpuScalars, NTT_SIZE, INV ? ntt::NTTDir::kInverse : ntt::NTTDir::kForward, ntt_config, - GpuOutputNew)); + INPLACE ? GpuOutputNew + : COLUMNS_BATCH ? GpuScalarsTransposed + : GpuScalars, + NTT_SIZE, INV ? ntt::NTTDir::kInverse : ntt::NTTDir::kForward, ntt_config, GpuOutputNew)); } CHK_IF_RETURN(cudaEventRecord(new_stop, ntt_config.ctx.stream)); CHK_IF_RETURN(cudaStreamSynchronize(ntt_config.ctx.stream)); CHK_IF_RETURN(cudaEventElapsedTime(&new_time, new_start, new_stop)); - if (is_print) { fprintf(stderr, "cuda err %d\n", cudaGetLastError()); } // OLD CHK_IF_RETURN(cudaEventRecord(icicle_start, ntt_config.ctx.stream)); @@ -127,7 +145,6 @@ int main(int argc, char** argv) CHK_IF_RETURN(cudaEventRecord(icicle_stop, ntt_config.ctx.stream)); CHK_IF_RETURN(cudaStreamSynchronize(ntt_config.ctx.stream)); CHK_IF_RETURN(cudaEventElapsedTime(&icicle_time, icicle_start, icicle_stop)); - if (is_print) { fprintf(stderr, "cuda err %d\n", cudaGetLastError()); } if (is_print) { printf("Old Runtime=%0.3f MS\n", icicle_time / iterations); @@ -140,11 +157,19 @@ int main(int argc, char** argv) CHK_IF_RETURN(benchmark(false /*=print*/, 1)); // warmup int count = INPLACE ? 1 : 10; if (INPLACE) { - CHK_IF_RETURN( - cudaMemcpy(GpuOutputNew, GpuScalars, NTT_SIZE * BATCH_SIZE * sizeof(test_data), cudaMemcpyDeviceToDevice)); + CHK_IF_RETURN(cudaMemcpy( + GpuOutputNew, COLUMNS_BATCH ? GpuScalarsTransposed : GpuScalars, NTT_SIZE * BATCH_SIZE * sizeof(test_data), + cudaMemcpyDeviceToDevice)); } CHK_IF_RETURN(benchmark(true /*=print*/, count)); + if (COLUMNS_BATCH) { + transpose_batch<<<(NTT_SIZE * BATCH_SIZE + 256 - 1) / 256, 256>>>( + GpuOutputNew, GpuScalarsTransposed, BATCH_SIZE, NTT_SIZE); + CHK_IF_RETURN(cudaMemcpy( + GpuOutputNew, GpuScalarsTransposed, NTT_SIZE * BATCH_SIZE * sizeof(test_data), cudaMemcpyDeviceToDevice)); + } + // verify CHK_IF_RETURN( cudaMemcpy(CpuOutputNew.get(), GpuOutputNew, NTT_SIZE * BATCH_SIZE * sizeof(test_data), cudaMemcpyDeviceToHost)); @@ -153,10 +178,11 @@ int main(int argc, char** argv) bool success = true; for (int i = 0; i < NTT_SIZE * BATCH_SIZE; i++) { + // if (i%64==0) printf("\n"); if (CpuOutputNew[i] != CpuOutputOld[i]) { success = false; // std::cout << i << " ref " << CpuOutputOld[i] << " != " << CpuOutputNew[i] << std::endl; - break; + // break; } else { // std::cout << i << " ref " << CpuOutputOld[i] << " == " << CpuOutputNew[i] << std::endl; // break; diff --git a/icicle/appUtils/ntt/thread_ntt.cu b/icicle/appUtils/ntt/thread_ntt.cu index 9c071bc7..ab61320f 100644 --- a/icicle/appUtils/ntt/thread_ntt.cu +++ b/icicle/appUtils/ntt/thread_ntt.cu @@ -9,6 +9,7 @@ struct stage_metadata { uint32_t th_stride; uint32_t ntt_block_size; + uint32_t batch_id; uint32_t ntt_block_id; uint32_t ntt_inp_id; }; @@ -118,7 +119,7 @@ public: } __device__ __forceinline__ void - loadExternalTwiddles64(S* data, uint32_t tw_order, uint32_t tw_log_order, bool strided, stage_metadata s_meta) + loadExternalTwiddles64(S* data, uint32_t tw_order, uint32_t tw_log_order, stage_metadata s_meta) { data += tw_order * s_meta.ntt_inp_id + (s_meta.ntt_block_id & (tw_order - 1)); @@ -129,7 +130,7 @@ public: } __device__ __forceinline__ void - loadExternalTwiddles32(S* data, uint32_t tw_order, uint32_t tw_log_order, bool strided, stage_metadata s_meta) + loadExternalTwiddles32(S* data, uint32_t tw_order, uint32_t tw_log_order, stage_metadata s_meta) { data += tw_order * s_meta.ntt_inp_id * 2 + (s_meta.ntt_block_id & (tw_order - 1)); @@ -143,7 +144,7 @@ public: } __device__ __forceinline__ void - loadExternalTwiddles16(S* data, uint32_t tw_order, uint32_t tw_log_order, bool strided, stage_metadata s_meta) + loadExternalTwiddles16(S* data, uint32_t tw_order, uint32_t tw_log_order, stage_metadata s_meta) { data += tw_order * s_meta.ntt_inp_id * 4 + (s_meta.ntt_block_id & (tw_order - 1)); @@ -195,8 +196,8 @@ public: } } - __device__ __forceinline__ void loadGlobalData( - E* data, uint32_t data_stride, uint32_t log_data_stride, uint32_t log_size, bool strided, stage_metadata s_meta) + __device__ __forceinline__ void + loadGlobalData(E* data, uint32_t data_stride, uint32_t log_data_stride, bool strided, stage_metadata s_meta) { if (strided) { data += (s_meta.ntt_block_id & (data_stride - 1)) + data_stride * s_meta.ntt_inp_id + @@ -211,8 +212,22 @@ public: } } - __device__ __forceinline__ void storeGlobalData( - E* data, uint32_t data_stride, uint32_t log_data_stride, uint32_t log_size, bool strided, stage_metadata s_meta) + __device__ __forceinline__ void loadGlobalDataColumnBatch( + E* data, uint32_t data_stride, uint32_t log_data_stride, stage_metadata s_meta, uint32_t batch_size) + { + data += ((s_meta.ntt_block_id & (data_stride - 1)) + data_stride * s_meta.ntt_inp_id + + (s_meta.ntt_block_id >> log_data_stride) * data_stride * s_meta.ntt_block_size) * + batch_size + + s_meta.batch_id; + +#pragma unroll + for (uint32_t i = 0; i < 8; i++) { + X[i] = data[s_meta.th_stride * i * data_stride * batch_size]; + } + } + + __device__ __forceinline__ void + storeGlobalData(E* data, uint32_t data_stride, uint32_t log_data_stride, bool strided, stage_metadata s_meta) { if (strided) { data += (s_meta.ntt_block_id & (data_stride - 1)) + data_stride * s_meta.ntt_inp_id + @@ -227,8 +242,22 @@ public: } } - __device__ __forceinline__ void loadGlobalData32( - E* data, uint32_t data_stride, uint32_t log_data_stride, uint32_t log_size, bool strided, stage_metadata s_meta) + __device__ __forceinline__ void storeGlobalDataColumnBatch( + E* data, uint32_t data_stride, uint32_t log_data_stride, stage_metadata s_meta, uint32_t batch_size) + { + data += ((s_meta.ntt_block_id & (data_stride - 1)) + data_stride * s_meta.ntt_inp_id + + (s_meta.ntt_block_id >> log_data_stride) * data_stride * s_meta.ntt_block_size) * + batch_size + + s_meta.batch_id; + +#pragma unroll + for (uint32_t i = 0; i < 8; i++) { + data[s_meta.th_stride * i * data_stride * batch_size] = X[i]; + } + } + + __device__ __forceinline__ void + loadGlobalData32(E* data, uint32_t data_stride, uint32_t log_data_stride, bool strided, stage_metadata s_meta) { if (strided) { data += (s_meta.ntt_block_id & (data_stride - 1)) + data_stride * s_meta.ntt_inp_id * 2 + @@ -246,8 +275,25 @@ public: } } - __device__ __forceinline__ void storeGlobalData32( - E* data, uint32_t data_stride, uint32_t log_data_stride, uint32_t log_size, bool strided, stage_metadata s_meta) + __device__ __forceinline__ void loadGlobalData32ColumnBatch( + E* data, uint32_t data_stride, uint32_t log_data_stride, stage_metadata s_meta, uint32_t batch_size) + { + data += ((s_meta.ntt_block_id & (data_stride - 1)) + data_stride * s_meta.ntt_inp_id * 2 + + (s_meta.ntt_block_id >> log_data_stride) * data_stride * s_meta.ntt_block_size) * + batch_size + + s_meta.batch_id; + +#pragma unroll + for (uint32_t j = 0; j < 2; j++) { +#pragma unroll + for (uint32_t i = 0; i < 4; i++) { + X[4 * j + i] = data[(8 * i + j) * data_stride * batch_size]; + } + } + } + + __device__ __forceinline__ void + storeGlobalData32(E* data, uint32_t data_stride, uint32_t log_data_stride, bool strided, stage_metadata s_meta) { if (strided) { data += (s_meta.ntt_block_id & (data_stride - 1)) + data_stride * s_meta.ntt_inp_id * 2 + @@ -265,8 +311,25 @@ public: } } - __device__ __forceinline__ void loadGlobalData16( - E* data, uint32_t data_stride, uint32_t log_data_stride, uint32_t log_size, bool strided, stage_metadata s_meta) + __device__ __forceinline__ void storeGlobalData32ColumnBatch( + E* data, uint32_t data_stride, uint32_t log_data_stride, stage_metadata s_meta, uint32_t batch_size) + { + data += ((s_meta.ntt_block_id & (data_stride - 1)) + data_stride * s_meta.ntt_inp_id * 2 + + (s_meta.ntt_block_id >> log_data_stride) * data_stride * s_meta.ntt_block_size) * + batch_size + + s_meta.batch_id; + +#pragma unroll + for (uint32_t j = 0; j < 2; j++) { +#pragma unroll + for (uint32_t i = 0; i < 4; i++) { + data[(8 * i + j) * data_stride * batch_size] = X[4 * j + i]; + } + } + } + + __device__ __forceinline__ void + loadGlobalData16(E* data, uint32_t data_stride, uint32_t log_data_stride, bool strided, stage_metadata s_meta) { if (strided) { data += (s_meta.ntt_block_id & (data_stride - 1)) + data_stride * s_meta.ntt_inp_id * 4 + @@ -284,8 +347,25 @@ public: } } - __device__ __forceinline__ void storeGlobalData16( - E* data, uint32_t data_stride, uint32_t log_data_stride, uint32_t log_size, bool strided, stage_metadata s_meta) + __device__ __forceinline__ void loadGlobalData16ColumnBatch( + E* data, uint32_t data_stride, uint32_t log_data_stride, stage_metadata s_meta, uint32_t batch_size) + { + data += ((s_meta.ntt_block_id & (data_stride - 1)) + data_stride * s_meta.ntt_inp_id * 4 + + (s_meta.ntt_block_id >> log_data_stride) * data_stride * s_meta.ntt_block_size) * + batch_size + + s_meta.batch_id; + +#pragma unroll + for (uint32_t j = 0; j < 4; j++) { +#pragma unroll + for (uint32_t i = 0; i < 2; i++) { + X[2 * j + i] = data[(8 * i + j) * data_stride * batch_size]; + } + } + } + + __device__ __forceinline__ void + storeGlobalData16(E* data, uint32_t data_stride, uint32_t log_data_stride, bool strided, stage_metadata s_meta) { if (strided) { data += (s_meta.ntt_block_id & (data_stride - 1)) + data_stride * s_meta.ntt_inp_id * 4 + @@ -303,6 +383,23 @@ public: } } + __device__ __forceinline__ void storeGlobalData16ColumnBatch( + E* data, uint32_t data_stride, uint32_t log_data_stride, stage_metadata s_meta, uint32_t batch_size) + { + data += ((s_meta.ntt_block_id & (data_stride - 1)) + data_stride * s_meta.ntt_inp_id * 4 + + (s_meta.ntt_block_id >> log_data_stride) * data_stride * s_meta.ntt_block_size) * + batch_size + + s_meta.batch_id; + +#pragma unroll + for (uint32_t j = 0; j < 4; j++) { +#pragma unroll + for (uint32_t i = 0; i < 2; i++) { + data[(8 * i + j) * data_stride * batch_size] = X[2 * j + i]; + } + } + } + __device__ __forceinline__ void ntt4_2() { #pragma unroll diff --git a/wrappers/golang/core/ntt.go b/wrappers/golang/core/ntt.go index 54a476e9..1fd0f48f 100644 --- a/wrappers/golang/core/ntt.go +++ b/wrappers/golang/core/ntt.go @@ -31,6 +31,8 @@ type NTTConfig[T any] struct { CosetGen T /// The number of NTTs to compute. Default value: 1. BatchSize int32 + /// If true the function will compute the NTTs over the columns of the input matrix and not over the rows. + ColumnsBatch bool /// Ordering of inputs and outputs. See [Ordering](@ref Ordering). Default value: `Ordering::kNN`. Ordering Ordering areInputsOnDevice bool @@ -46,6 +48,7 @@ func GetDefaultNTTConfig[T any](cosetGen T) NTTConfig[T] { ctx, // Ctx cosetGen, // CosetGen 1, // BatchSize + false, // ColumnsBatch KNN, // Ordering false, // areInputsOnDevice false, // areOutputsOnDevice diff --git a/wrappers/golang/core/ntt_test.go b/wrappers/golang/core/ntt_test.go index ccea6361..93225191 100644 --- a/wrappers/golang/core/ntt_test.go +++ b/wrappers/golang/core/ntt_test.go @@ -19,6 +19,7 @@ func TestNTTDefaultConfig(t *testing.T) { ctx, // Ctx cosetGen, // CosetGen 1, // BatchSize + false, // ColumnsBatch KNN, // Ordering false, // areInputsOnDevice false, // areOutputsOnDevice diff --git a/wrappers/rust/icicle-core/src/ntt/mod.rs b/wrappers/rust/icicle-core/src/ntt/mod.rs index 758b11e5..4c27f435 100644 --- a/wrappers/rust/icicle-core/src/ntt/mod.rs +++ b/wrappers/rust/icicle-core/src/ntt/mod.rs @@ -77,6 +77,8 @@ pub struct NTTConfig<'a, S> { pub coset_gen: S, /// The number of NTTs to compute. Default value: 1. pub batch_size: i32, + /// If true the function will compute the NTTs over the columns of the input matrix and not over the rows. + pub columns_batch: bool, /// Ordering of inputs and outputs. See [Ordering](@ref Ordering). Default value: `Ordering::kNN`. pub ordering: Ordering, are_inputs_on_device: bool, @@ -101,6 +103,7 @@ impl<'a, S: FieldImpl> NTTConfig<'a, S> { ctx: DeviceContext::default_for_device(device_id), coset_gen: S::one(), batch_size: 1, + columns_batch: false, ordering: Ordering::kNN, are_inputs_on_device: false, are_outputs_on_device: false, diff --git a/wrappers/rust/icicle-core/src/ntt/tests.rs b/wrappers/rust/icicle-core/src/ntt/tests.rs index 880c52f5..514efe63 100644 --- a/wrappers/rust/icicle-core/src/ntt/tests.rs +++ b/wrappers/rust/icicle-core/src/ntt/tests.rs @@ -44,6 +44,14 @@ pub fn reverse_bit_order(n: u32, order: u32) -> u32 { u32::from_str_radix(&reversed, 2).unwrap() } +pub fn transpose_flattened_matrix(m: &[T], nrows: usize) -> Vec { + let ncols = m.len() / nrows; + assert!(nrows * ncols == m.len()); + (0..m.len()) + .map(|i| m[(i % nrows) * ncols + i / nrows]) + .collect() +} + pub fn list_to_reverse_bit_order(l: &[T]) -> Vec { l.iter() .enumerate() @@ -253,11 +261,10 @@ where ] { config.coset_gen = coset_gen; config.ordering = ordering; + let mut batch_ntt_result = HostOrDeviceSlice::on_host(vec![F::zero(); batch_size * test_size]); for alg in [NttAlgorithm::Radix2, NttAlgorithm::MixedRadix] { config.batch_size = batch_size as i32; config.ntt_algorithm = alg; - let mut batch_ntt_result = - HostOrDeviceSlice::on_host(vec![F::zero(); batch_size * test_size]); ntt(&scalars, is_inverse, &config, &mut batch_ntt_result).unwrap(); config.batch_size = 1; let mut one_ntt_result = HostOrDeviceSlice::on_host(vec![F::one(); test_size]); @@ -275,6 +282,20 @@ where ); } } + + // for now, columns batching only works with MixedRadix NTT + config.batch_size = batch_size as i32; + config.columns_batch = true; + let transposed_input = + HostOrDeviceSlice::on_host(transpose_flattened_matrix(&scalars[..], batch_size)); + let mut col_batch_ntt_result = + HostOrDeviceSlice::on_host(vec![F::zero(); batch_size * test_size]); + ntt(&transposed_input, is_inverse, &config, &mut col_batch_ntt_result).unwrap(); + assert_eq!( + batch_ntt_result[..], + transpose_flattened_matrix(&col_batch_ntt_result[..], test_size) + ); + config.columns_batch = false; } } }