diff --git a/examples/c++/polynomial_multiplication/example.cu b/examples/c++/polynomial_multiplication/example.cu index 7100fee9..ada72112 100644 --- a/examples/c++/polynomial_multiplication/example.cu +++ b/examples/c++/polynomial_multiplication/example.cu @@ -46,10 +46,10 @@ int main(int argc, char** argv) // init domain auto ntt_config = ntt::DefaultNTTConfig(); - ntt_config.ordering = ntt::Ordering::kNN; // TODO: use NR for forward and RN for backward - ntt_config.is_force_radix2 = (argc > 1) ? atoi(argv[1]) : false; + const bool is_radix2_alg = (argc > 1) ? atoi(argv[1]) : false; + ntt_config.ntt_algorithm = is_radix2_alg ? ntt::NttAlgorithm::Radix2 : ntt::NttAlgorithm::MixedRadix; - const char* ntt_alg_str = ntt_config.is_force_radix2 ? "Radix-2" : "Mixed-Radix"; + const char* ntt_alg_str = is_radix2_alg ? "Radix-2" : "Mixed-Radix"; std::cout << "Polynomial multiplication with " << ntt_alg_str << " NTT: "; CHK_IF_RETURN(cudaEventCreate(&start)); @@ -78,6 +78,7 @@ int main(int argc, char** argv) // (3) NTT for A,B from cpu to gpu ntt_config.are_inputs_on_device = false; ntt_config.are_outputs_on_device = true; + ntt_config.ordering = ntt::Ordering::kNM; CHK_IF_RETURN(ntt::NTT(CpuA.get(), NTT_SIZE, ntt::NTTDir::kForward, ntt_config, GpuA)); CHK_IF_RETURN(ntt::NTT(CpuB.get(), NTT_SIZE, ntt::NTTDir::kForward, ntt_config, GpuB)); @@ -89,6 +90,7 @@ int main(int argc, char** argv) // (5) INTT (in place) ntt_config.are_inputs_on_device = true; ntt_config.are_outputs_on_device = true; + ntt_config.ordering = ntt::Ordering::kMN; CHK_IF_RETURN(ntt::NTT(MulGpu, NTT_SIZE, ntt::NTTDir::kInverse, ntt_config, MulGpu)); CHK_IF_RETURN(cudaFreeAsync(GpuA, ntt_config.ctx.stream)); diff --git a/icicle/appUtils/ntt/kernel_ntt.cu b/icicle/appUtils/ntt/kernel_ntt.cu index 2a4985e4..01c8623d 100644 --- a/icicle/appUtils/ntt/kernel_ntt.cu +++ b/icicle/appUtils/ntt/kernel_ntt.cu @@ -6,7 +6,7 @@ namespace ntt { - static __device__ uint32_t dig_rev(uint32_t num, uint32_t log_size, bool dit) + static inline __device__ uint32_t dig_rev(uint32_t num, uint32_t log_size, bool dit) { uint32_t rev_num = 0, temp, dig_len; if (dit) { @@ -29,10 +29,34 @@ namespace ntt { return rev_num; } + static inline __device__ uint32_t bit_rev(uint32_t num, uint32_t log_size) { return __brev(num) >> (32 - log_size); } + + enum eRevType { None, RevToMixedRev, MixedRevToRev, NaturalToMixedRev, NaturalToRev, MixedRevToNatural }; + + static __device__ uint32_t generalized_rev(uint32_t num, uint32_t log_size, bool dit, eRevType rev_type) + { + switch (rev_type) { + case eRevType::RevToMixedRev: + // R -> N -> MR + return dig_rev(bit_rev(num, log_size), log_size, dit); + case eRevType::MixedRevToRev: + // MR -> N -> R + return bit_rev(dig_rev(num, log_size, dit), log_size); + case eRevType::NaturalToMixedRev: + case eRevType::MixedRevToNatural: + return dig_rev(num, log_size, dit); + case eRevType::NaturalToRev: + return bit_rev(num, log_size); + default: + return num; + } + return num; + } + // Note: the following reorder kernels are fused with normalization for INTT template - static __global__ void - reorder_digits_inplace_and_normalize_kernel(E* arr, uint32_t log_size, bool dit, bool is_normalize, S inverse_N) + static __global__ void reorder_digits_inplace_and_normalize_kernel( + E* arr, uint32_t log_size, bool dit, eRevType rev_type, bool is_normalize, S inverse_N) { // launch N threads (per batch element) // each thread starts from one index and calculates the corresponding group @@ -50,7 +74,7 @@ namespace ntt { uint32_t i = 1; for (; i < MAX_GROUP_SIZE;) { - next_element = dig_rev(next_element, log_size, dit); + next_element = generalized_rev(next_element, log_size, dit, rev_type); if (next_element < idx) return; // not handling this group if (next_element == idx) break; // calculated whole group group[i++] = next_element + size * batch_idx; @@ -67,16 +91,17 @@ namespace ntt { template __launch_bounds__(64) __global__ void reorder_digits_and_normalize_kernel( - E* arr, E* arr_reordered, uint32_t log_size, bool dit, bool is_normalize, S inverse_N) + E* arr, E* arr_reordered, uint32_t log_size, bool dit, eRevType rev_type, bool is_normalize, S inverse_N) { uint32_t tid = blockDim.x * blockIdx.x + threadIdx.x; uint32_t rd = tid; - uint32_t wr = ((tid >> log_size) << log_size) + dig_rev(tid & ((1 << log_size) - 1), log_size, dit); + uint32_t wr = + ((tid >> log_size) << log_size) + generalized_rev(tid & ((1 << log_size) - 1), log_size, dit, rev_type); arr_reordered[wr] = is_normalize ? arr[rd] * inverse_N : arr[rd]; } template - static __global__ void BatchMulKernelDigReverse( + static __global__ void batch_elementwise_mul_with_reorder( E* in_vec, int n_elements, int batch_size, @@ -84,14 +109,14 @@ namespace ntt { int step, int n_scalars, int logn, - bool digit_rev, + eRevType rev_type, bool dit, E* out_vec) { int tid = blockDim.x * blockIdx.x + threadIdx.x; if (tid >= n_elements * batch_size) return; int64_t scalar_id = tid % n_elements; - if (digit_rev) scalar_id = dig_rev(tid, logn, dit); + if (rev_type != eRevType::None) scalar_id = generalized_rev(tid, logn, dit, rev_type); out_vec[tid] = *(scalar_vec + ((scalar_id * step) % n_scalars)) * in_vec[tid]; } @@ -630,56 +655,80 @@ namespace ntt { { CHK_INIT_IF_RETURN(); - // TODO: can we support all orderings? Note that reversal is generally digit reverse (generalization of bit reverse) - if (ordering != Ordering::kNN) { - throw IcicleError(IcicleError_t::InvalidArgument, "Mixed-Radix NTT supports NN ordering only"); - } - const int logn = int(log2(ntt_size)); - const int NOF_BLOCKS = ((1 << logn) * batch_size + 64 - 1) / 64; const int NOF_THREADS = min(64, (1 << logn) * batch_size); - // Note: dif is slightly faster than dit but since reordering is a post-process stage, it must be computed in-place - // which make it slower e2e in most cases. dit reorders as a pre-process stage and therefore supports both in-place - // and out-of-place (when in!=out); - const bool reverse_input = ordering == Ordering::kNN; - const bool is_dit = ordering == Ordering::kNN || ordering == Ordering::kRN; bool is_normalize = is_inverse; const bool is_on_coset = (coset_gen_index != 0) || arbitrary_coset; const int n_twiddles = 1 << max_logn; + // Note: for evaluation on coset, need to reorder the coset too to match the data for element-wise multiplication + eRevType reverse_input = None, reverse_output = None, reverse_coset = None; + bool dit = false; + switch (ordering) { + case Ordering::kNN: + reverse_input = eRevType::NaturalToMixedRev; + dit = true; + break; + case Ordering::kRN: + reverse_input = eRevType::RevToMixedRev; + dit = true; + reverse_coset = is_inverse ? eRevType::None : eRevType::NaturalToRev; + break; + case Ordering::kNR: + reverse_output = eRevType::MixedRevToRev; + reverse_coset = is_inverse ? eRevType::NaturalToRev : eRevType::None; + break; + case Ordering::kRR: + reverse_input = eRevType::RevToMixedRev; + dit = true; + reverse_output = eRevType::NaturalToRev; + reverse_coset = eRevType::NaturalToRev; + break; + case Ordering::kMN: + dit = true; + reverse_coset = is_inverse ? None : eRevType::NaturalToMixedRev; + break; + case Ordering::kNM: + reverse_coset = is_inverse ? eRevType::NaturalToMixedRev : eRevType::None; + break; + } - // TODO: fuse BatchMulKernelDigReverse with input reorder (and normalize)? if (is_on_coset && !is_inverse) { - BatchMulKernelDigReverse<<>>( + batch_elementwise_mul_with_reorder<<>>( d_input, ntt_size, batch_size, arbitrary_coset ? arbitrary_coset : external_twiddles, - arbitrary_coset ? 1 : coset_gen_index, n_twiddles, logn, false /*digit_rev*/, is_dit, d_output); + arbitrary_coset ? 1 : coset_gen_index, n_twiddles, logn, reverse_coset, dit, d_output); d_input = d_output; } - if (reverse_input) { - // Note: fused reorder and normalize (for INTT) + if (reverse_input != eRevType::None) { const bool is_reverse_in_place = (d_input == d_output); if (is_reverse_in_place) { reorder_digits_inplace_and_normalize_kernel<<>>( - d_output, logn, is_dit, is_normalize, S::inv_log_size(logn)); + d_output, logn, dit, reverse_input, is_normalize, S::inv_log_size(logn)); } else { reorder_digits_and_normalize_kernel<<>>( - d_input, d_output, logn, is_dit, is_normalize, S::inv_log_size(logn)); + d_input, d_output, logn, dit, reverse_input, is_normalize, S::inv_log_size(logn)); } is_normalize = false; + d_input = d_output; } // inplace ntt CHK_IF_RETURN(large_ntt( - d_output, d_output, external_twiddles, internal_twiddles, basic_twiddles, logn, max_logn, batch_size, is_inverse, - is_normalize, is_dit, cuda_stream)); + d_input, d_output, external_twiddles, internal_twiddles, basic_twiddles, logn, max_logn, batch_size, is_inverse, + (is_normalize && reverse_output == eRevType::None), dit, cuda_stream)); + + if (reverse_output != eRevType::None) { + reorder_digits_inplace_and_normalize_kernel<<>>( + d_output, logn, dit, reverse_output, is_normalize, S::inv_log_size(logn)); + } if (is_on_coset && is_inverse) { - BatchMulKernelDigReverse<<>>( + batch_elementwise_mul_with_reorder<<>>( d_output, ntt_size, batch_size, arbitrary_coset ? arbitrary_coset : external_twiddles + n_twiddles, - arbitrary_coset ? 1 : -coset_gen_index, n_twiddles, logn, false /*digit_rev*/, is_dit, d_output); + arbitrary_coset ? 1 : -coset_gen_index, n_twiddles, logn, reverse_coset, dit, d_output); } return CHK_LAST(); diff --git a/icicle/appUtils/ntt/ntt.cu b/icicle/appUtils/ntt/ntt.cu index 755165b5..a31ada0f 100644 --- a/icicle/appUtils/ntt/ntt.cu +++ b/icicle/appUtils/ntt/ntt.cu @@ -278,7 +278,7 @@ namespace ntt { int batch_size, int logn, bool inverse, - bool ct_buttterfly, + bool dit, S* arbitrary_coset, int coset_gen_index, cudaStream_t stream, @@ -309,9 +309,9 @@ namespace ntt { if (direct_coset) utils_internal::BatchMulKernel<<>>( d_input, n, batch_size, arbitrary_coset ? arbitrary_coset : d_twiddles, arbitrary_coset ? 1 : coset_gen_index, - n_twiddles, logn, ct_buttterfly, d_output); + n_twiddles, logn, dit, d_output); - if (ct_buttterfly) { + if (dit) { if (is_shared_mem_enabled) ntt_template_kernel_shared<<>>( direct_coset ? d_output : d_input, 1 << logn_shmem, d_twiddles, n_twiddles, total_tasks, 0, logn_shmem, @@ -340,7 +340,7 @@ namespace ntt { if (is_on_coset) utils_internal::BatchMulKernel<<>>( d_output, n, batch_size, arbitrary_coset ? arbitrary_coset : d_twiddles, - arbitrary_coset ? 1 : -coset_gen_index, -n_twiddles, logn, !ct_buttterfly, d_output); + arbitrary_coset ? 1 : -coset_gen_index, -n_twiddles, logn, !dit, d_output); utils_internal::NormalizeKernel <<>>(d_output, S::inv_log_size(logn), n * batch_size); @@ -452,6 +452,76 @@ namespace ntt { return CHK_LAST(); } + template + static bool is_choose_radix2_algorithm(int logn, int batch_size, const NTTConfig& config) + { + const bool is_mixed_radix_alg_supported = (logn > 3 && logn != 7); + const bool is_user_selected_radix2_alg = config.ntt_algorithm == NttAlgorithm::Radix2; + const bool is_force_radix2 = !is_mixed_radix_alg_supported || is_user_selected_radix2_alg; + if (is_force_radix2) return true; + + const bool is_user_selected_mixed_radix_alg = config.ntt_algorithm == NttAlgorithm::MixedRadix; + if (is_user_selected_mixed_radix_alg) return false; + + // Heuristic to automatically select an algorithm + // Note that generally the decision depends on {logn, batch, ordering, inverse, coset, in-place, coeff-field} and + // the specific GPU. + // the following heuristic is a simplification based on measurements. Users can try both and select the algorithm + // based on the specific case via the 'NTTConfig.ntt_algorithm' field + + if (logn >= 16) return false; // mixed-radix is typically faster in those cases + if (logn <= 11) return true; // radix-2 is typically faster for batch<=256 in those cases + const int log_batch = (int)log2(batch_size); + return (logn + log_batch <= 18); // almost the cutoff point where both are equal + } + + template + cudaError_t radix2_ntt( + E* d_input, + E* d_output, + S* twiddles, + int ntt_size, + int max_size, + int batch_size, + bool is_inverse, + Ordering ordering, + S* arbitrary_coset, + int coset_gen_index, + cudaStream_t cuda_stream) + { + CHK_INIT_IF_RETURN(); + + const int logn = int(log2(ntt_size)); + + bool dit = true; + bool reverse_input = false; + switch (ordering) { + case Ordering::kNN: + reverse_input = true; + break; + case Ordering::kNR: + case Ordering::kNM: + dit = false; + break; + case Ordering::kRR: + reverse_input = true; + dit = false; + break; + case Ordering::kRN: + case Ordering::kMN: + dit = true; + reverse_input = false; + } + + if (reverse_input) reverse_order_batch(d_input, ntt_size, logn, batch_size, cuda_stream, d_output); + + CHK_IF_RETURN(ntt_inplace_batch_template( + reverse_input ? d_output : d_input, ntt_size, twiddles, max_size, batch_size, logn, is_inverse, dit, + arbitrary_coset, coset_gen_index, cuda_stream, d_output)); + + return CHK_LAST(); + } + template cudaError_t NTT(E* input, int size, NTTDir dir, NTTConfig& config, E* output) { @@ -501,37 +571,17 @@ namespace ntt { h_coset.clear(); } - // (heuristic) cutoff point where mixed-radix is faster than radix-2 - const bool is_small_ntt = (logn < 16) && ((size_t)size * batch_size < (1 << 20)); - const bool is_NN = config.ordering == Ordering::kNN; // TODO Yuval: relax this limitation - const bool is_radix2_algorithm = config.is_force_radix2 || is_small_ntt || !is_NN; + const bool is_radix2_algorithm = is_choose_radix2_algorithm(logn, batch_size, config); + const bool is_inverse = dir == NTTDir::kInverse; if (is_radix2_algorithm) { - bool ct_butterfly = true; - bool reverse_input = false; - switch (config.ordering) { - case Ordering::kNN: - reverse_input = true; - break; - case Ordering::kNR: - ct_butterfly = false; - break; - case Ordering::kRR: - reverse_input = true; - ct_butterfly = false; - break; - } - - if (reverse_input) reverse_order_batch(d_input, size, logn, batch_size, stream, d_output); - - CHK_IF_RETURN(ntt_inplace_batch_template( - reverse_input ? d_output : d_input, size, Domain::twiddles, Domain::max_size, batch_size, logn, - dir == NTTDir::kInverse, ct_butterfly, coset, coset_index, stream, d_output)); - - } else { // mixed-radix algorithm + CHK_IF_RETURN(ntt::radix2_ntt( + d_input, d_output, Domain::twiddles, size, Domain::max_size, batch_size, is_inverse, config.ordering, + coset, coset_index, stream)); + } else { CHK_IF_RETURN(ntt::mixed_radix_ntt( d_input, d_output, Domain::twiddles, Domain::internal_twiddles, Domain::basic_twiddles, size, - Domain::max_log_size, batch_size, dir == NTTDir::kInverse, config.ordering, coset, coset_index, stream)); + Domain::max_log_size, batch_size, is_inverse, config.ordering, coset, coset_index, stream)); } if (!are_outputs_on_device) @@ -550,14 +600,14 @@ namespace ntt { { device_context::DeviceContext ctx = device_context::get_default_device_context(); NTTConfig config = { - ctx, // ctx - S::one(), // coset_gen - 1, // batch_size - Ordering::kNN, // ordering - false, // are_inputs_on_device - false, // are_outputs_on_device - false, // is_async - false, // is_force_radix2 + ctx, // ctx + S::one(), // coset_gen + 1, // batch_size + Ordering::kNN, // ordering + false, // are_inputs_on_device + false, // are_outputs_on_device + false, // is_async + NttAlgorithm::Auto, // ntt_algorithm }; return config; } diff --git a/icicle/appUtils/ntt/ntt.cuh b/icicle/appUtils/ntt/ntt.cuh index e82de4ef..6285a3c9 100644 --- a/icicle/appUtils/ntt/ntt.cuh +++ b/icicle/appUtils/ntt/ntt.cuh @@ -59,8 +59,28 @@ namespace ntt { * a_4, a_2, a_6, a_1, a_5, a_3, a_7\} \f$). * - kRN: inputs are bit-reversed-order and outputs are natural-order. * - kRR: inputs and outputs are bit-reversed-order. + * + * Mixed-Radix NTT: digit-reversal is a generalization of bit-reversal where the latter is a special case with 1b + * digits. Mixed-radix NTTs of different sizes would generate different reordering of inputs/outputs. Having said + * that, for a given size N it is guaranteed that every two mixed-radix NTTs of size N would have the same + * digit-reversal pattern. The following orderings kNM and kMN are conceptually like kNR and kRN but for + * mixed-digit-reordering. Note that for the cases '(1) NTT, (2) elementwise ops and (3) INTT' kNM and kMN are most + * efficient. + * Note: kNR, kRN, kRR refer to the radix-2 NTT reversal pattern. Those cases are supported by mixed-radix NTT with + * reduced efficiency compared to kNM and kMN. + * - kNM: inputs are natural-order and outputs are digit-reversed-order (=mixed). + * - kMN: inputs are digit-reversed-order (=mixed) and outputs are natural-order. */ - enum class Ordering { kNN, kNR, kRN, kRR }; + enum class Ordering { kNN, kNR, kRN, kRR, kNM, kMN }; + + /** + * @enum NttAlgorithm + * Which NTT algorithm to use. options are: + * - Auto: implementation selects automatically based on heuristic. This value is a good default for most cases. + * - Radix2: explicitly select radix-2 NTT algorithm + * - MixedRadix: explicitly select mixed-radix NTT algorithm + */ + enum class NttAlgorithm { Auto, Radix2, MixedRadix }; /** * @struct NTTConfig @@ -80,8 +100,8 @@ namespace ntt { * non-blocking and you'd need to synchronize it explicitly by running * `cudaStreamSynchronize` or `cudaDeviceSynchronize`. If set to false, the NTT * function will block the current CPU thread. */ - bool is_force_radix2; /**< Explicitly select radix-2 NTT algorithm. Default value: false (the implementation selects - radix-2 or mixed-radix algorithm based on heuristics). */ + NttAlgorithm ntt_algorithm; /**< Explicitly select the NTT algorithm. Default value: Auto (the implementation + selects radix-2 or mixed-radix algorithm based on heuristics). */ }; /** diff --git a/icicle/appUtils/ntt/tests/verification.cu b/icicle/appUtils/ntt/tests/verification.cu index 141ea1ca..36f97fa4 100644 --- a/icicle/appUtils/ntt/tests/verification.cu +++ b/icicle/appUtils/ntt/tests/verification.cu @@ -35,22 +35,25 @@ int main(int argc, char** argv) cudaEvent_t icicle_start, icicle_stop, new_start, new_stop; float icicle_time, new_time; - int NTT_LOG_SIZE = (argc > 1) ? atoi(argv[1]) : 16; // assuming second input is the log-size + int NTT_LOG_SIZE = (argc > 1) ? atoi(argv[1]) : 4; // assuming second input is the log-size int NTT_SIZE = 1 << NTT_LOG_SIZE; bool INPLACE = (argc > 2) ? atoi(argv[2]) : false; int INV = (argc > 3) ? atoi(argv[3]) : false; - int BATCH_SIZE = (argc > 4) ? atoi(argv[4]) : 32; - int COSET_IDX = (argc > 5) ? atoi(argv[5]) : 1; + int BATCH_SIZE = (argc > 4) ? atoi(argv[4]) : 1; + int COSET_IDX = (argc > 5) ? atoi(argv[5]) : 0; + const ntt::Ordering ordering = (argc > 6) ? ntt::Ordering(atoi(argv[6])) : ntt::Ordering::kNN; - const ntt::Ordering ordering = ntt::Ordering::kNN; + // Note: NM, MN are not expected to be equal when comparing mixed-radix and radix-2 NTTs const char* ordering_str = ordering == ntt::Ordering::kNN ? "NN" : ordering == ntt::Ordering::kNR ? "NR" : ordering == ntt::Ordering::kRN ? "RN" - : "RR"; + : ordering == ntt::Ordering::kRR ? "RR" + : ordering == ntt::Ordering::kNM ? "NM" + : "MN"; printf( - "running ntt 2^%d, ordering=%s, inplace=%d, inverse=%d, batch_size=%d, coset-idx=%d\n", NTT_LOG_SIZE, ordering_str, - INPLACE, INV, BATCH_SIZE, COSET_IDX); + "running ntt 2^%d, inplace=%d, inverse=%d, batch_size=%d, coset-idx=%d, ordering=%s\n", NTT_LOG_SIZE, INPLACE, INV, + BATCH_SIZE, COSET_IDX, ordering_str); CHK_IF_RETURN(cudaFree(nullptr)); // init GPU context (warmup) @@ -103,7 +106,7 @@ int main(int argc, char** argv) auto benchmark = [&](bool is_print, int iterations) -> cudaError_t { // NEW CHK_IF_RETURN(cudaEventRecord(new_start, ntt_config.ctx.stream)); - ntt_config.is_force_radix2 = false; // mixed-radix ntt (a.k.a new ntt) + ntt_config.ntt_algorithm = ntt::NttAlgorithm::MixedRadix; for (size_t i = 0; i < iterations; i++) { ntt::NTT( INPLACE ? GpuOutputNew : GpuScalars, NTT_SIZE, INV ? ntt::NTTDir::kInverse : ntt::NTTDir::kForward, ntt_config, @@ -116,7 +119,7 @@ int main(int argc, char** argv) // OLD CHK_IF_RETURN(cudaEventRecord(icicle_start, ntt_config.ctx.stream)); - ntt_config.is_force_radix2 = true; + ntt_config.ntt_algorithm = ntt::NttAlgorithm::Radix2; for (size_t i = 0; i < iterations; i++) { ntt::NTT(GpuScalars, NTT_SIZE, INV ? ntt::NTTDir::kInverse : ntt::NTTDir::kForward, ntt_config, GpuOutputOld); } diff --git a/wrappers/rust/icicle-core/src/ntt/mod.rs b/wrappers/rust/icicle-core/src/ntt/mod.rs index 24591413..6ef6a5b9 100644 --- a/wrappers/rust/icicle-core/src/ntt/mod.rs +++ b/wrappers/rust/icicle-core/src/ntt/mod.rs @@ -30,6 +30,17 @@ pub enum NTTDir { /// a_4, a_2, a_6, a_1, a_5, a_3, a_7`. /// - kRN: inputs are bit-reversed-order and outputs are natural-order. /// - kRR: inputs and outputs are bit-reversed-order. +/// +/// Mixed-Radix NTT: digit-reversal is a generalization of bit-reversal where the latter is a special case with 1b +/// digits. Mixed-radix NTTs of different sizes would generate different reordering of inputs/outputs. Having said +/// that, for a given size N it is guaranteed that every two mixed-radix NTTs of size N would have the same +/// digit-reversal pattern. The following orderings kNM and kMN are conceptually like kNR and kRN but for +/// mixed-digit-reordering. Note that for the cases '(1) NTT, (2) elementwise ops and (3) INTT' kNM and kMN are most +/// efficient. +/// Note: kNR, kRN, kRR refer to the radix-2 NTT reversal pattern. Those cases are supported by mixed-radix NTT with +/// reduced efficiency compared to kNM and kMN. +/// - kNM: inputs are natural-order and outputs are digit-reversed-order (=mixed). +/// - kMN: inputs are digit-reversed-order (=mixed) and outputs are natural-order. #[allow(non_camel_case_types)] #[repr(C)] #[derive(Clone, Copy, Debug, PartialEq, Eq)] @@ -38,6 +49,22 @@ pub enum Ordering { kNR, kRN, kRR, + kNM, + kMN, +} + +///Which NTT algorithm to use. options are: +///- Auto: implementation selects automatically based on heuristic. This value is a good default for most cases. +///- Radix2: explicitly select radix-2 NTT algorithm +///- MixedRadix: explicitly select mixed-radix NTT algorithm +/// +#[allow(non_camel_case_types)] +#[repr(C)] +#[derive(Clone, Copy, Debug, PartialEq, Eq)] +pub enum NttAlgorithm { + Auto, + Radix2, + MixedRadix, } /// Struct that encodes NTT parameters to be passed into the [ntt](ntt) function. @@ -57,8 +84,9 @@ pub struct NTTConfig<'a, S> { /// Whether to run the NTT asynchronously. If set to `true`, the NTT function will be non-blocking and you'd need to synchronize /// it explicitly by running `stream.synchronize()`. If set to false, the NTT function will block the current CPU thread. pub is_async: bool, - /// Explicitly select radix-2 NTT algorithm. Default value: false (the implementation selects radix-2 or mixed-radix algorithm based on heuristics). - pub is_force_radix2: bool, + /// Explicitly select the NTT algorithm. Default value: Auto (the implementation selects radix-2 or mixed-radix algorithm based + /// on heuristics + pub ntt_algorithm: NttAlgorithm, } impl<'a, S: FieldImpl> NTTConfig<'a, S> { @@ -72,7 +100,7 @@ impl<'a, S: FieldImpl> NTTConfig<'a, S> { are_inputs_on_device: false, are_outputs_on_device: false, is_async: false, - is_force_radix2: false, + ntt_algorithm: NttAlgorithm::Auto, } } } diff --git a/wrappers/rust/icicle-core/src/ntt/tests.rs b/wrappers/rust/icicle-core/src/ntt/tests.rs index 3138a230..3819441b 100644 --- a/wrappers/rust/icicle-core/src/ntt/tests.rs +++ b/wrappers/rust/icicle-core/src/ntt/tests.rs @@ -6,7 +6,7 @@ use icicle_cuda_runtime::memory::HostOrDeviceSlice; use icicle_cuda_runtime::stream::CudaStream; use crate::{ - ntt::{get_default_ntt_config, initialize_domain, ntt, NTTDir, Ordering}, + ntt::{get_default_ntt_config, initialize_domain, ntt, NTTDir, NttAlgorithm, Ordering}, traits::{ArkConvertible, FieldImpl, GenerateRandom}, }; @@ -61,23 +61,26 @@ where let scalars_mont = unsafe { &*(&ark_scalars[..] as *const _ as *const [F]) }; let scalars_mont_h = HostOrDeviceSlice::on_host(scalars_mont.to_vec()); - let config = get_default_ntt_config(); - let mut ntt_result = HostOrDeviceSlice::on_host(vec![F::zero(); test_size]); - ntt(&scalars_mont_h, NTTDir::kForward, &config, &mut ntt_result).unwrap(); - assert_ne!(ntt_result.as_slice(), scalars_mont); + let mut config = get_default_ntt_config(); + for alg in [NttAlgorithm::Radix2, NttAlgorithm::MixedRadix] { + config.ntt_algorithm = alg; + let mut ntt_result = HostOrDeviceSlice::on_host(vec![F::zero(); test_size]); + ntt(&scalars_mont_h, NTTDir::kForward, &config, &mut ntt_result).unwrap(); + assert_ne!(ntt_result.as_slice(), scalars_mont); - let mut ark_ntt_result = ark_scalars.clone(); - ark_domain.fft_in_place(&mut ark_ntt_result); - assert_ne!(ark_ntt_result, ark_scalars); + let mut ark_ntt_result = ark_scalars.clone(); + ark_domain.fft_in_place(&mut ark_ntt_result); + assert_ne!(ark_ntt_result, ark_scalars); - let ntt_result_as_ark = - unsafe { &*(ntt_result.as_slice() as *const _ as *const [::ArkEquivalent]) }; - assert_eq!(ark_ntt_result, ntt_result_as_ark); + let ntt_result_as_ark = + unsafe { &*(ntt_result.as_slice() as *const _ as *const [::ArkEquivalent]) }; + assert_eq!(ark_ntt_result, ntt_result_as_ark); - let mut intt_result = HostOrDeviceSlice::on_host(vec![F::zero(); test_size]); - ntt(&ntt_result, NTTDir::kInverse, &config, &mut intt_result).unwrap(); + let mut intt_result = HostOrDeviceSlice::on_host(vec![F::zero(); test_size]); + ntt(&ntt_result, NTTDir::kInverse, &config, &mut intt_result).unwrap(); - assert_eq!(intt_result.as_slice(), scalars_mont); + assert_eq!(intt_result.as_slice(), scalars_mont); + } } } @@ -103,55 +106,58 @@ where .map(|v| v.to_ark()) .collect::>(); - let mut config = get_default_ntt_config(); - config.ordering = Ordering::kNR; - let mut ntt_result_1 = HostOrDeviceSlice::on_host(vec![F::zero(); small_size]); - let mut ntt_result_2 = HostOrDeviceSlice::on_host(vec![F::zero(); small_size]); - ntt(&scalars_h, NTTDir::kForward, &config, &mut ntt_result_1).unwrap(); - assert_ne!(*ntt_result_1.as_slice(), scalars); - config.coset_gen = F::from_ark(test_size_rou); - ntt(&scalars_h, NTTDir::kForward, &config, &mut ntt_result_2).unwrap(); - let mut ntt_large_result = HostOrDeviceSlice::on_host(vec![F::zero(); test_size]); - // back to non-coset NTT - config.coset_gen = F::one(); - scalars.resize(test_size, F::zero()); - ntt( - &HostOrDeviceSlice::on_host(scalars.clone()), - NTTDir::kForward, - &config, - &mut ntt_large_result, - ) - .unwrap(); - assert_eq!(*ntt_result_1.as_slice(), ntt_large_result.as_slice()[..small_size]); - assert_eq!(*ntt_result_2.as_slice(), ntt_large_result.as_slice()[small_size..]); + for alg in [NttAlgorithm::Radix2, NttAlgorithm::MixedRadix] { + let mut config = get_default_ntt_config(); + config.ordering = Ordering::kNR; + config.ntt_algorithm = alg; + let mut ntt_result_1 = HostOrDeviceSlice::on_host(vec![F::zero(); small_size]); + let mut ntt_result_2 = HostOrDeviceSlice::on_host(vec![F::zero(); small_size]); + ntt(&scalars_h, NTTDir::kForward, &config, &mut ntt_result_1).unwrap(); + assert_ne!(*ntt_result_1.as_slice(), scalars); + config.coset_gen = F::from_ark(test_size_rou); + ntt(&scalars_h, NTTDir::kForward, &config, &mut ntt_result_2).unwrap(); + let mut ntt_large_result = HostOrDeviceSlice::on_host(vec![F::zero(); test_size]); + // back to non-coset NTT + config.coset_gen = F::one(); + scalars.resize(test_size, F::zero()); + ntt( + &HostOrDeviceSlice::on_host(scalars.clone()), + NTTDir::kForward, + &config, + &mut ntt_large_result, + ) + .unwrap(); + assert_eq!(*ntt_result_1.as_slice(), ntt_large_result.as_slice()[..small_size]); + assert_eq!(*ntt_result_2.as_slice(), ntt_large_result.as_slice()[small_size..]); - let mut ark_large_scalars = ark_scalars.clone(); - ark_small_domain.fft_in_place(&mut ark_scalars); - let ntt_result_as_ark = ntt_large_result - .as_slice() - .iter() - .map(|p| p.to_ark()) - .collect::>(); - assert_eq!( - ark_scalars[..small_size], - list_to_reverse_bit_order(&ntt_result_as_ark[small_size..]) - ); - ark_large_domain.fft_in_place(&mut ark_large_scalars); - assert_eq!(ark_large_scalars, list_to_reverse_bit_order(&ntt_result_as_ark)); + let mut ark_large_scalars = ark_scalars.clone(); + ark_small_domain.fft_in_place(&mut ark_scalars); + let ntt_result_as_ark = ntt_large_result + .as_slice() + .iter() + .map(|p| p.to_ark()) + .collect::>(); + assert_eq!( + ark_scalars[..small_size], + list_to_reverse_bit_order(&ntt_result_as_ark[small_size..]) + ); + ark_large_domain.fft_in_place(&mut ark_large_scalars); + assert_eq!(ark_large_scalars, list_to_reverse_bit_order(&ntt_result_as_ark)); - config.coset_gen = F::from_ark(test_size_rou); - config.ordering = Ordering::kRN; - let mut intt_result = HostOrDeviceSlice::on_host(vec![F::zero(); small_size]); - ntt(&ntt_result_2, NTTDir::kInverse, &config, &mut intt_result).unwrap(); - assert_eq!(*intt_result.as_slice(), scalars[..small_size]); + config.coset_gen = F::from_ark(test_size_rou); + config.ordering = Ordering::kRN; + let mut intt_result = HostOrDeviceSlice::on_host(vec![F::zero(); small_size]); + ntt(&ntt_result_2, NTTDir::kInverse, &config, &mut intt_result).unwrap(); + assert_eq!(*intt_result.as_slice(), scalars[..small_size]); - ark_small_domain.ifft_in_place(&mut ark_scalars); - let intt_result_as_ark = intt_result - .as_slice() - .iter() - .map(|p| p.to_ark()) - .collect::>(); - assert_eq!(ark_scalars[..small_size], intt_result_as_ark); + ark_small_domain.ifft_in_place(&mut ark_scalars); + let intt_result_as_ark = intt_result + .as_slice() + .iter() + .map(|p| p.to_ark()) + .collect::>(); + assert_eq!(ark_scalars[..small_size], intt_result_as_ark); + } } } @@ -183,31 +189,34 @@ where .collect::>(); let mut config = get_default_ntt_config(); - config.ordering = Ordering::kNR; config.coset_gen = F::from_ark(coset_gen); - let mut ntt_result = HostOrDeviceSlice::on_host(vec![F::zero(); test_size]); - ntt(&scalars, NTTDir::kForward, &config, &mut ntt_result).unwrap(); - assert_ne!(scalars.as_slice(), ntt_result.as_slice()); + for alg in [NttAlgorithm::Radix2, NttAlgorithm::MixedRadix] { + config.ordering = Ordering::kNR; + config.ntt_algorithm = alg; + let mut ntt_result = HostOrDeviceSlice::on_host(vec![F::zero(); test_size]); + ntt(&scalars, NTTDir::kForward, &config, &mut ntt_result).unwrap(); + assert_ne!(scalars.as_slice(), ntt_result.as_slice()); - let ark_scalars_copy = ark_scalars.clone(); - ark_domain.fft_in_place(&mut ark_scalars); - let ntt_result_as_ark = ntt_result - .as_slice() - .iter() - .map(|p| p.to_ark()) - .collect::>(); - assert_eq!(ark_scalars, list_to_reverse_bit_order(&ntt_result_as_ark)); - ark_domain.ifft_in_place(&mut ark_scalars); - assert_eq!(ark_scalars, ark_scalars_copy); + let ark_scalars_copy = ark_scalars.clone(); + ark_domain.fft_in_place(&mut ark_scalars); + let ntt_result_as_ark = ntt_result + .as_slice() + .iter() + .map(|p| p.to_ark()) + .collect::>(); + assert_eq!(ark_scalars, list_to_reverse_bit_order(&ntt_result_as_ark)); + ark_domain.ifft_in_place(&mut ark_scalars); + assert_eq!(ark_scalars, ark_scalars_copy); - config.ordering = Ordering::kRN; - ntt(&ntt_result, NTTDir::kInverse, &config, &mut scalars).unwrap(); - let ntt_result_as_ark = scalars - .as_slice() - .iter() - .map(|p| p.to_ark()) - .collect::>(); - assert_eq!(ark_scalars, ntt_result_as_ark); + config.ordering = Ordering::kRN; + ntt(&ntt_result, NTTDir::kInverse, &config, &mut scalars).unwrap(); + let ntt_result_as_ark = scalars + .as_slice() + .iter() + .map(|p| p.to_ark()) + .collect::>(); + assert_eq!(ark_scalars, ntt_result_as_ark); + } } } } @@ -216,7 +225,7 @@ pub fn check_ntt_batch() where ::Config: NTT + GenerateRandom, { - let test_sizes = [1 << 4, 1 << 14]; + let test_sizes = [1 << 4, 1 << 12]; let batch_sizes = [1, 1 << 4, 100]; for test_size in test_sizes { let coset_generators = [F::one(), F::Config::generate_random(1)[0]]; @@ -226,26 +235,37 @@ where for coset_gen in coset_generators { for is_inverse in [NTTDir::kInverse, NTTDir::kForward] { - for ordering in [Ordering::kNN, Ordering::kNR, Ordering::kRN, Ordering::kRR] { + for ordering in [ + Ordering::kNN, + Ordering::kNR, + Ordering::kRN, + Ordering::kRR, + Ordering::kNM, + Ordering::kMN, + ] { config.coset_gen = coset_gen; config.ordering = ordering; - config.batch_size = batch_size as i32; - let mut batch_ntt_result = HostOrDeviceSlice::on_host(vec![F::zero(); batch_size * test_size]); - ntt(&scalars, is_inverse, &config, &mut batch_ntt_result).unwrap(); - config.batch_size = 1; - let mut one_ntt_result = HostOrDeviceSlice::on_host(vec![F::one(); test_size]); - for i in 0..batch_size { - ntt( - &HostOrDeviceSlice::on_host(scalars[i * test_size..(i + 1) * test_size].to_vec()), - is_inverse, - &config, - &mut one_ntt_result, - ) - .unwrap(); - assert_eq!( - batch_ntt_result[i * test_size..(i + 1) * test_size], - *one_ntt_result.as_slice() - ); + for alg in [NttAlgorithm::Radix2, NttAlgorithm::MixedRadix] { + config.batch_size = batch_size as i32; + config.ntt_algorithm = alg; + let mut batch_ntt_result = + HostOrDeviceSlice::on_host(vec![F::zero(); batch_size * test_size]); + ntt(&scalars, is_inverse, &config, &mut batch_ntt_result).unwrap(); + config.batch_size = 1; + let mut one_ntt_result = HostOrDeviceSlice::on_host(vec![F::one(); test_size]); + for i in 0..batch_size { + ntt( + &HostOrDeviceSlice::on_host(scalars[i * test_size..(i + 1) * test_size].to_vec()), + is_inverse, + &config, + &mut one_ntt_result, + ) + .unwrap(); + assert_eq!( + batch_ntt_result[i * test_size..(i + 1) * test_size], + *one_ntt_result.as_slice() + ); + } } } } @@ -286,22 +306,25 @@ where config .ctx .stream = &stream; - ntt(&scalars_d, NTTDir::kForward, &config, &mut ntt_out_d).unwrap(); - ntt(&ntt_out_d, NTTDir::kInverse, &config, &mut scalars_d).unwrap(); - let mut intt_result_h = vec![F::zero(); test_size * batch_size]; - scalars_d - .copy_to_host_async(&mut intt_result_h, &stream) - .unwrap(); - stream - .synchronize() - .unwrap(); - assert_eq!(scalars_h, intt_result_h); - if coset_gen == F::one() { - let mut ntt_result_h = vec![F::zero(); test_size * batch_size]; - ntt_out_d - .copy_to_host(&mut ntt_result_h) + for alg in [NttAlgorithm::Radix2, NttAlgorithm::MixedRadix] { + config.ntt_algorithm = alg; + ntt(&scalars_d, NTTDir::kForward, &config, &mut ntt_out_d).unwrap(); + ntt(&ntt_out_d, NTTDir::kInverse, &config, &mut scalars_d).unwrap(); + let mut intt_result_h = vec![F::zero(); test_size * batch_size]; + scalars_d + .copy_to_host_async(&mut intt_result_h, &stream) .unwrap(); - assert_eq!(sum_of_coeffs, ntt_result_h[0].to_ark()); + stream + .synchronize() + .unwrap(); + assert_eq!(scalars_h, intt_result_h); + if coset_gen == F::one() { + let mut ntt_result_h = vec![F::zero(); test_size * batch_size]; + ntt_out_d + .copy_to_host(&mut ntt_result_h) + .unwrap(); + assert_eq!(sum_of_coeffs, ntt_result_h[0].to_ark()); + } } } }