Mixed-radix NTT support all orderings (#371)

- Mixed-radix NTT orderings support
- radix-2 small refactor: split core logic to function and renamed ct_butterfly to dit
- testing both radix2 and mixed-radix algs for all ntt tests
This commit is contained in:
yshekel
2024-02-13 15:49:24 +02:00
committed by GitHub
parent ae060313db
commit a02459c64d
7 changed files with 379 additions and 204 deletions

View File

@@ -46,10 +46,10 @@ int main(int argc, char** argv)
// init domain
auto ntt_config = ntt::DefaultNTTConfig<test_scalar>();
ntt_config.ordering = ntt::Ordering::kNN; // TODO: use NR for forward and RN for backward
ntt_config.is_force_radix2 = (argc > 1) ? atoi(argv[1]) : false;
const bool is_radix2_alg = (argc > 1) ? atoi(argv[1]) : false;
ntt_config.ntt_algorithm = is_radix2_alg ? ntt::NttAlgorithm::Radix2 : ntt::NttAlgorithm::MixedRadix;
const char* ntt_alg_str = ntt_config.is_force_radix2 ? "Radix-2" : "Mixed-Radix";
const char* ntt_alg_str = is_radix2_alg ? "Radix-2" : "Mixed-Radix";
std::cout << "Polynomial multiplication with " << ntt_alg_str << " NTT: ";
CHK_IF_RETURN(cudaEventCreate(&start));
@@ -78,6 +78,7 @@ int main(int argc, char** argv)
// (3) NTT for A,B from cpu to gpu
ntt_config.are_inputs_on_device = false;
ntt_config.are_outputs_on_device = true;
ntt_config.ordering = ntt::Ordering::kNM;
CHK_IF_RETURN(ntt::NTT(CpuA.get(), NTT_SIZE, ntt::NTTDir::kForward, ntt_config, GpuA));
CHK_IF_RETURN(ntt::NTT(CpuB.get(), NTT_SIZE, ntt::NTTDir::kForward, ntt_config, GpuB));
@@ -89,6 +90,7 @@ int main(int argc, char** argv)
// (5) INTT (in place)
ntt_config.are_inputs_on_device = true;
ntt_config.are_outputs_on_device = true;
ntt_config.ordering = ntt::Ordering::kMN;
CHK_IF_RETURN(ntt::NTT(MulGpu, NTT_SIZE, ntt::NTTDir::kInverse, ntt_config, MulGpu));
CHK_IF_RETURN(cudaFreeAsync(GpuA, ntt_config.ctx.stream));

View File

@@ -6,7 +6,7 @@
namespace ntt {
static __device__ uint32_t dig_rev(uint32_t num, uint32_t log_size, bool dit)
static inline __device__ uint32_t dig_rev(uint32_t num, uint32_t log_size, bool dit)
{
uint32_t rev_num = 0, temp, dig_len;
if (dit) {
@@ -29,10 +29,34 @@ namespace ntt {
return rev_num;
}
static inline __device__ uint32_t bit_rev(uint32_t num, uint32_t log_size) { return __brev(num) >> (32 - log_size); }
enum eRevType { None, RevToMixedRev, MixedRevToRev, NaturalToMixedRev, NaturalToRev, MixedRevToNatural };
static __device__ uint32_t generalized_rev(uint32_t num, uint32_t log_size, bool dit, eRevType rev_type)
{
switch (rev_type) {
case eRevType::RevToMixedRev:
// R -> N -> MR
return dig_rev(bit_rev(num, log_size), log_size, dit);
case eRevType::MixedRevToRev:
// MR -> N -> R
return bit_rev(dig_rev(num, log_size, dit), log_size);
case eRevType::NaturalToMixedRev:
case eRevType::MixedRevToNatural:
return dig_rev(num, log_size, dit);
case eRevType::NaturalToRev:
return bit_rev(num, log_size);
default:
return num;
}
return num;
}
// Note: the following reorder kernels are fused with normalization for INTT
template <typename E, typename S, uint32_t MAX_GROUP_SIZE = 80>
static __global__ void
reorder_digits_inplace_and_normalize_kernel(E* arr, uint32_t log_size, bool dit, bool is_normalize, S inverse_N)
static __global__ void reorder_digits_inplace_and_normalize_kernel(
E* arr, uint32_t log_size, bool dit, eRevType rev_type, bool is_normalize, S inverse_N)
{
// launch N threads (per batch element)
// each thread starts from one index and calculates the corresponding group
@@ -50,7 +74,7 @@ namespace ntt {
uint32_t i = 1;
for (; i < MAX_GROUP_SIZE;) {
next_element = dig_rev(next_element, log_size, dit);
next_element = generalized_rev(next_element, log_size, dit, rev_type);
if (next_element < idx) return; // not handling this group
if (next_element == idx) break; // calculated whole group
group[i++] = next_element + size * batch_idx;
@@ -67,16 +91,17 @@ namespace ntt {
template <typename E, typename S>
__launch_bounds__(64) __global__ void reorder_digits_and_normalize_kernel(
E* arr, E* arr_reordered, uint32_t log_size, bool dit, bool is_normalize, S inverse_N)
E* arr, E* arr_reordered, uint32_t log_size, bool dit, eRevType rev_type, bool is_normalize, S inverse_N)
{
uint32_t tid = blockDim.x * blockIdx.x + threadIdx.x;
uint32_t rd = tid;
uint32_t wr = ((tid >> log_size) << log_size) + dig_rev(tid & ((1 << log_size) - 1), log_size, dit);
uint32_t wr =
((tid >> log_size) << log_size) + generalized_rev(tid & ((1 << log_size) - 1), log_size, dit, rev_type);
arr_reordered[wr] = is_normalize ? arr[rd] * inverse_N : arr[rd];
}
template <typename E, typename S>
static __global__ void BatchMulKernelDigReverse(
static __global__ void batch_elementwise_mul_with_reorder(
E* in_vec,
int n_elements,
int batch_size,
@@ -84,14 +109,14 @@ namespace ntt {
int step,
int n_scalars,
int logn,
bool digit_rev,
eRevType rev_type,
bool dit,
E* out_vec)
{
int tid = blockDim.x * blockIdx.x + threadIdx.x;
if (tid >= n_elements * batch_size) return;
int64_t scalar_id = tid % n_elements;
if (digit_rev) scalar_id = dig_rev(tid, logn, dit);
if (rev_type != eRevType::None) scalar_id = generalized_rev(tid, logn, dit, rev_type);
out_vec[tid] = *(scalar_vec + ((scalar_id * step) % n_scalars)) * in_vec[tid];
}
@@ -630,56 +655,80 @@ namespace ntt {
{
CHK_INIT_IF_RETURN();
// TODO: can we support all orderings? Note that reversal is generally digit reverse (generalization of bit reverse)
if (ordering != Ordering::kNN) {
throw IcicleError(IcicleError_t::InvalidArgument, "Mixed-Radix NTT supports NN ordering only");
}
const int logn = int(log2(ntt_size));
const int NOF_BLOCKS = ((1 << logn) * batch_size + 64 - 1) / 64;
const int NOF_THREADS = min(64, (1 << logn) * batch_size);
// Note: dif is slightly faster than dit but since reordering is a post-process stage, it must be computed in-place
// which make it slower e2e in most cases. dit reorders as a pre-process stage and therefore supports both in-place
// and out-of-place (when in!=out);
const bool reverse_input = ordering == Ordering::kNN;
const bool is_dit = ordering == Ordering::kNN || ordering == Ordering::kRN;
bool is_normalize = is_inverse;
const bool is_on_coset = (coset_gen_index != 0) || arbitrary_coset;
const int n_twiddles = 1 << max_logn;
// Note: for evaluation on coset, need to reorder the coset too to match the data for element-wise multiplication
eRevType reverse_input = None, reverse_output = None, reverse_coset = None;
bool dit = false;
switch (ordering) {
case Ordering::kNN:
reverse_input = eRevType::NaturalToMixedRev;
dit = true;
break;
case Ordering::kRN:
reverse_input = eRevType::RevToMixedRev;
dit = true;
reverse_coset = is_inverse ? eRevType::None : eRevType::NaturalToRev;
break;
case Ordering::kNR:
reverse_output = eRevType::MixedRevToRev;
reverse_coset = is_inverse ? eRevType::NaturalToRev : eRevType::None;
break;
case Ordering::kRR:
reverse_input = eRevType::RevToMixedRev;
dit = true;
reverse_output = eRevType::NaturalToRev;
reverse_coset = eRevType::NaturalToRev;
break;
case Ordering::kMN:
dit = true;
reverse_coset = is_inverse ? None : eRevType::NaturalToMixedRev;
break;
case Ordering::kNM:
reverse_coset = is_inverse ? eRevType::NaturalToMixedRev : eRevType::None;
break;
}
// TODO: fuse BatchMulKernelDigReverse with input reorder (and normalize)?
if (is_on_coset && !is_inverse) {
BatchMulKernelDigReverse<<<NOF_BLOCKS, NOF_THREADS, 0, cuda_stream>>>(
batch_elementwise_mul_with_reorder<<<NOF_BLOCKS, NOF_THREADS, 0, cuda_stream>>>(
d_input, ntt_size, batch_size, arbitrary_coset ? arbitrary_coset : external_twiddles,
arbitrary_coset ? 1 : coset_gen_index, n_twiddles, logn, false /*digit_rev*/, is_dit, d_output);
arbitrary_coset ? 1 : coset_gen_index, n_twiddles, logn, reverse_coset, dit, d_output);
d_input = d_output;
}
if (reverse_input) {
// Note: fused reorder and normalize (for INTT)
if (reverse_input != eRevType::None) {
const bool is_reverse_in_place = (d_input == d_output);
if (is_reverse_in_place) {
reorder_digits_inplace_and_normalize_kernel<<<NOF_BLOCKS, NOF_THREADS, 0, cuda_stream>>>(
d_output, logn, is_dit, is_normalize, S::inv_log_size(logn));
d_output, logn, dit, reverse_input, is_normalize, S::inv_log_size(logn));
} else {
reorder_digits_and_normalize_kernel<<<NOF_BLOCKS, NOF_THREADS, 0, cuda_stream>>>(
d_input, d_output, logn, is_dit, is_normalize, S::inv_log_size(logn));
d_input, d_output, logn, dit, reverse_input, is_normalize, S::inv_log_size(logn));
}
is_normalize = false;
d_input = d_output;
}
// inplace ntt
CHK_IF_RETURN(large_ntt(
d_output, d_output, external_twiddles, internal_twiddles, basic_twiddles, logn, max_logn, batch_size, is_inverse,
is_normalize, is_dit, cuda_stream));
d_input, d_output, external_twiddles, internal_twiddles, basic_twiddles, logn, max_logn, batch_size, is_inverse,
(is_normalize && reverse_output == eRevType::None), dit, cuda_stream));
if (reverse_output != eRevType::None) {
reorder_digits_inplace_and_normalize_kernel<<<NOF_BLOCKS, NOF_THREADS, 0, cuda_stream>>>(
d_output, logn, dit, reverse_output, is_normalize, S::inv_log_size(logn));
}
if (is_on_coset && is_inverse) {
BatchMulKernelDigReverse<<<NOF_BLOCKS, NOF_THREADS, 0, cuda_stream>>>(
batch_elementwise_mul_with_reorder<<<NOF_BLOCKS, NOF_THREADS, 0, cuda_stream>>>(
d_output, ntt_size, batch_size, arbitrary_coset ? arbitrary_coset : external_twiddles + n_twiddles,
arbitrary_coset ? 1 : -coset_gen_index, n_twiddles, logn, false /*digit_rev*/, is_dit, d_output);
arbitrary_coset ? 1 : -coset_gen_index, n_twiddles, logn, reverse_coset, dit, d_output);
}
return CHK_LAST();

View File

@@ -278,7 +278,7 @@ namespace ntt {
int batch_size,
int logn,
bool inverse,
bool ct_buttterfly,
bool dit,
S* arbitrary_coset,
int coset_gen_index,
cudaStream_t stream,
@@ -309,9 +309,9 @@ namespace ntt {
if (direct_coset)
utils_internal::BatchMulKernel<E, S><<<num_blocks_coset, num_threads_coset, 0, stream>>>(
d_input, n, batch_size, arbitrary_coset ? arbitrary_coset : d_twiddles, arbitrary_coset ? 1 : coset_gen_index,
n_twiddles, logn, ct_buttterfly, d_output);
n_twiddles, logn, dit, d_output);
if (ct_buttterfly) {
if (dit) {
if (is_shared_mem_enabled)
ntt_template_kernel_shared<<<num_blocks, num_threads, shared_mem, stream>>>(
direct_coset ? d_output : d_input, 1 << logn_shmem, d_twiddles, n_twiddles, total_tasks, 0, logn_shmem,
@@ -340,7 +340,7 @@ namespace ntt {
if (is_on_coset)
utils_internal::BatchMulKernel<E, S><<<num_blocks_coset, num_threads_coset, 0, stream>>>(
d_output, n, batch_size, arbitrary_coset ? arbitrary_coset : d_twiddles,
arbitrary_coset ? 1 : -coset_gen_index, -n_twiddles, logn, !ct_buttterfly, d_output);
arbitrary_coset ? 1 : -coset_gen_index, -n_twiddles, logn, !dit, d_output);
utils_internal::NormalizeKernel<E, S>
<<<num_blocks_coset, num_threads_coset, 0, stream>>>(d_output, S::inv_log_size(logn), n * batch_size);
@@ -452,6 +452,76 @@ namespace ntt {
return CHK_LAST();
}
template <typename S>
static bool is_choose_radix2_algorithm(int logn, int batch_size, const NTTConfig<S>& config)
{
const bool is_mixed_radix_alg_supported = (logn > 3 && logn != 7);
const bool is_user_selected_radix2_alg = config.ntt_algorithm == NttAlgorithm::Radix2;
const bool is_force_radix2 = !is_mixed_radix_alg_supported || is_user_selected_radix2_alg;
if (is_force_radix2) return true;
const bool is_user_selected_mixed_radix_alg = config.ntt_algorithm == NttAlgorithm::MixedRadix;
if (is_user_selected_mixed_radix_alg) return false;
// Heuristic to automatically select an algorithm
// Note that generally the decision depends on {logn, batch, ordering, inverse, coset, in-place, coeff-field} and
// the specific GPU.
// the following heuristic is a simplification based on measurements. Users can try both and select the algorithm
// based on the specific case via the 'NTTConfig.ntt_algorithm' field
if (logn >= 16) return false; // mixed-radix is typically faster in those cases
if (logn <= 11) return true; // radix-2 is typically faster for batch<=256 in those cases
const int log_batch = (int)log2(batch_size);
return (logn + log_batch <= 18); // almost the cutoff point where both are equal
}
template <typename S, typename E>
cudaError_t radix2_ntt(
E* d_input,
E* d_output,
S* twiddles,
int ntt_size,
int max_size,
int batch_size,
bool is_inverse,
Ordering ordering,
S* arbitrary_coset,
int coset_gen_index,
cudaStream_t cuda_stream)
{
CHK_INIT_IF_RETURN();
const int logn = int(log2(ntt_size));
bool dit = true;
bool reverse_input = false;
switch (ordering) {
case Ordering::kNN:
reverse_input = true;
break;
case Ordering::kNR:
case Ordering::kNM:
dit = false;
break;
case Ordering::kRR:
reverse_input = true;
dit = false;
break;
case Ordering::kRN:
case Ordering::kMN:
dit = true;
reverse_input = false;
}
if (reverse_input) reverse_order_batch(d_input, ntt_size, logn, batch_size, cuda_stream, d_output);
CHK_IF_RETURN(ntt_inplace_batch_template(
reverse_input ? d_output : d_input, ntt_size, twiddles, max_size, batch_size, logn, is_inverse, dit,
arbitrary_coset, coset_gen_index, cuda_stream, d_output));
return CHK_LAST();
}
template <typename S, typename E>
cudaError_t NTT(E* input, int size, NTTDir dir, NTTConfig<S>& config, E* output)
{
@@ -501,37 +571,17 @@ namespace ntt {
h_coset.clear();
}
// (heuristic) cutoff point where mixed-radix is faster than radix-2
const bool is_small_ntt = (logn < 16) && ((size_t)size * batch_size < (1 << 20));
const bool is_NN = config.ordering == Ordering::kNN; // TODO Yuval: relax this limitation
const bool is_radix2_algorithm = config.is_force_radix2 || is_small_ntt || !is_NN;
const bool is_radix2_algorithm = is_choose_radix2_algorithm(logn, batch_size, config);
const bool is_inverse = dir == NTTDir::kInverse;
if (is_radix2_algorithm) {
bool ct_butterfly = true;
bool reverse_input = false;
switch (config.ordering) {
case Ordering::kNN:
reverse_input = true;
break;
case Ordering::kNR:
ct_butterfly = false;
break;
case Ordering::kRR:
reverse_input = true;
ct_butterfly = false;
break;
}
if (reverse_input) reverse_order_batch(d_input, size, logn, batch_size, stream, d_output);
CHK_IF_RETURN(ntt_inplace_batch_template(
reverse_input ? d_output : d_input, size, Domain<S>::twiddles, Domain<S>::max_size, batch_size, logn,
dir == NTTDir::kInverse, ct_butterfly, coset, coset_index, stream, d_output));
} else { // mixed-radix algorithm
CHK_IF_RETURN(ntt::radix2_ntt(
d_input, d_output, Domain<S>::twiddles, size, Domain<S>::max_size, batch_size, is_inverse, config.ordering,
coset, coset_index, stream));
} else {
CHK_IF_RETURN(ntt::mixed_radix_ntt(
d_input, d_output, Domain<S>::twiddles, Domain<S>::internal_twiddles, Domain<S>::basic_twiddles, size,
Domain<S>::max_log_size, batch_size, dir == NTTDir::kInverse, config.ordering, coset, coset_index, stream));
Domain<S>::max_log_size, batch_size, is_inverse, config.ordering, coset, coset_index, stream));
}
if (!are_outputs_on_device)
@@ -550,14 +600,14 @@ namespace ntt {
{
device_context::DeviceContext ctx = device_context::get_default_device_context();
NTTConfig<S> config = {
ctx, // ctx
S::one(), // coset_gen
1, // batch_size
Ordering::kNN, // ordering
false, // are_inputs_on_device
false, // are_outputs_on_device
false, // is_async
false, // is_force_radix2
ctx, // ctx
S::one(), // coset_gen
1, // batch_size
Ordering::kNN, // ordering
false, // are_inputs_on_device
false, // are_outputs_on_device
false, // is_async
NttAlgorithm::Auto, // ntt_algorithm
};
return config;
}

View File

@@ -59,8 +59,28 @@ namespace ntt {
* a_4, a_2, a_6, a_1, a_5, a_3, a_7\} \f$).
* - kRN: inputs are bit-reversed-order and outputs are natural-order.
* - kRR: inputs and outputs are bit-reversed-order.
*
* Mixed-Radix NTT: digit-reversal is a generalization of bit-reversal where the latter is a special case with 1b
* digits. Mixed-radix NTTs of different sizes would generate different reordering of inputs/outputs. Having said
* that, for a given size N it is guaranteed that every two mixed-radix NTTs of size N would have the same
* digit-reversal pattern. The following orderings kNM and kMN are conceptually like kNR and kRN but for
* mixed-digit-reordering. Note that for the cases '(1) NTT, (2) elementwise ops and (3) INTT' kNM and kMN are most
* efficient.
* Note: kNR, kRN, kRR refer to the radix-2 NTT reversal pattern. Those cases are supported by mixed-radix NTT with
* reduced efficiency compared to kNM and kMN.
* - kNM: inputs are natural-order and outputs are digit-reversed-order (=mixed).
* - kMN: inputs are digit-reversed-order (=mixed) and outputs are natural-order.
*/
enum class Ordering { kNN, kNR, kRN, kRR };
enum class Ordering { kNN, kNR, kRN, kRR, kNM, kMN };
/**
* @enum NttAlgorithm
* Which NTT algorithm to use. options are:
* - Auto: implementation selects automatically based on heuristic. This value is a good default for most cases.
* - Radix2: explicitly select radix-2 NTT algorithm
* - MixedRadix: explicitly select mixed-radix NTT algorithm
*/
enum class NttAlgorithm { Auto, Radix2, MixedRadix };
/**
* @struct NTTConfig
@@ -80,8 +100,8 @@ namespace ntt {
* non-blocking and you'd need to synchronize it explicitly by running
* `cudaStreamSynchronize` or `cudaDeviceSynchronize`. If set to false, the NTT
* function will block the current CPU thread. */
bool is_force_radix2; /**< Explicitly select radix-2 NTT algorithm. Default value: false (the implementation selects
radix-2 or mixed-radix algorithm based on heuristics). */
NttAlgorithm ntt_algorithm; /**< Explicitly select the NTT algorithm. Default value: Auto (the implementation
selects radix-2 or mixed-radix algorithm based on heuristics). */
};
/**

View File

@@ -35,22 +35,25 @@ int main(int argc, char** argv)
cudaEvent_t icicle_start, icicle_stop, new_start, new_stop;
float icicle_time, new_time;
int NTT_LOG_SIZE = (argc > 1) ? atoi(argv[1]) : 16; // assuming second input is the log-size
int NTT_LOG_SIZE = (argc > 1) ? atoi(argv[1]) : 4; // assuming second input is the log-size
int NTT_SIZE = 1 << NTT_LOG_SIZE;
bool INPLACE = (argc > 2) ? atoi(argv[2]) : false;
int INV = (argc > 3) ? atoi(argv[3]) : false;
int BATCH_SIZE = (argc > 4) ? atoi(argv[4]) : 32;
int COSET_IDX = (argc > 5) ? atoi(argv[5]) : 1;
int BATCH_SIZE = (argc > 4) ? atoi(argv[4]) : 1;
int COSET_IDX = (argc > 5) ? atoi(argv[5]) : 0;
const ntt::Ordering ordering = (argc > 6) ? ntt::Ordering(atoi(argv[6])) : ntt::Ordering::kNN;
const ntt::Ordering ordering = ntt::Ordering::kNN;
// Note: NM, MN are not expected to be equal when comparing mixed-radix and radix-2 NTTs
const char* ordering_str = ordering == ntt::Ordering::kNN ? "NN"
: ordering == ntt::Ordering::kNR ? "NR"
: ordering == ntt::Ordering::kRN ? "RN"
: "RR";
: ordering == ntt::Ordering::kRR ? "RR"
: ordering == ntt::Ordering::kNM ? "NM"
: "MN";
printf(
"running ntt 2^%d, ordering=%s, inplace=%d, inverse=%d, batch_size=%d, coset-idx=%d\n", NTT_LOG_SIZE, ordering_str,
INPLACE, INV, BATCH_SIZE, COSET_IDX);
"running ntt 2^%d, inplace=%d, inverse=%d, batch_size=%d, coset-idx=%d, ordering=%s\n", NTT_LOG_SIZE, INPLACE, INV,
BATCH_SIZE, COSET_IDX, ordering_str);
CHK_IF_RETURN(cudaFree(nullptr)); // init GPU context (warmup)
@@ -103,7 +106,7 @@ int main(int argc, char** argv)
auto benchmark = [&](bool is_print, int iterations) -> cudaError_t {
// NEW
CHK_IF_RETURN(cudaEventRecord(new_start, ntt_config.ctx.stream));
ntt_config.is_force_radix2 = false; // mixed-radix ntt (a.k.a new ntt)
ntt_config.ntt_algorithm = ntt::NttAlgorithm::MixedRadix;
for (size_t i = 0; i < iterations; i++) {
ntt::NTT(
INPLACE ? GpuOutputNew : GpuScalars, NTT_SIZE, INV ? ntt::NTTDir::kInverse : ntt::NTTDir::kForward, ntt_config,
@@ -116,7 +119,7 @@ int main(int argc, char** argv)
// OLD
CHK_IF_RETURN(cudaEventRecord(icicle_start, ntt_config.ctx.stream));
ntt_config.is_force_radix2 = true;
ntt_config.ntt_algorithm = ntt::NttAlgorithm::Radix2;
for (size_t i = 0; i < iterations; i++) {
ntt::NTT(GpuScalars, NTT_SIZE, INV ? ntt::NTTDir::kInverse : ntt::NTTDir::kForward, ntt_config, GpuOutputOld);
}

View File

@@ -30,6 +30,17 @@ pub enum NTTDir {
/// a_4, a_2, a_6, a_1, a_5, a_3, a_7`.
/// - kRN: inputs are bit-reversed-order and outputs are natural-order.
/// - kRR: inputs and outputs are bit-reversed-order.
///
/// Mixed-Radix NTT: digit-reversal is a generalization of bit-reversal where the latter is a special case with 1b
/// digits. Mixed-radix NTTs of different sizes would generate different reordering of inputs/outputs. Having said
/// that, for a given size N it is guaranteed that every two mixed-radix NTTs of size N would have the same
/// digit-reversal pattern. The following orderings kNM and kMN are conceptually like kNR and kRN but for
/// mixed-digit-reordering. Note that for the cases '(1) NTT, (2) elementwise ops and (3) INTT' kNM and kMN are most
/// efficient.
/// Note: kNR, kRN, kRR refer to the radix-2 NTT reversal pattern. Those cases are supported by mixed-radix NTT with
/// reduced efficiency compared to kNM and kMN.
/// - kNM: inputs are natural-order and outputs are digit-reversed-order (=mixed).
/// - kMN: inputs are digit-reversed-order (=mixed) and outputs are natural-order.
#[allow(non_camel_case_types)]
#[repr(C)]
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
@@ -38,6 +49,22 @@ pub enum Ordering {
kNR,
kRN,
kRR,
kNM,
kMN,
}
///Which NTT algorithm to use. options are:
///- Auto: implementation selects automatically based on heuristic. This value is a good default for most cases.
///- Radix2: explicitly select radix-2 NTT algorithm
///- MixedRadix: explicitly select mixed-radix NTT algorithm
///
#[allow(non_camel_case_types)]
#[repr(C)]
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
pub enum NttAlgorithm {
Auto,
Radix2,
MixedRadix,
}
/// Struct that encodes NTT parameters to be passed into the [ntt](ntt) function.
@@ -57,8 +84,9 @@ pub struct NTTConfig<'a, S> {
/// Whether to run the NTT asynchronously. If set to `true`, the NTT function will be non-blocking and you'd need to synchronize
/// it explicitly by running `stream.synchronize()`. If set to false, the NTT function will block the current CPU thread.
pub is_async: bool,
/// Explicitly select radix-2 NTT algorithm. Default value: false (the implementation selects radix-2 or mixed-radix algorithm based on heuristics).
pub is_force_radix2: bool,
/// Explicitly select the NTT algorithm. Default value: Auto (the implementation selects radix-2 or mixed-radix algorithm based
/// on heuristics
pub ntt_algorithm: NttAlgorithm,
}
impl<'a, S: FieldImpl> NTTConfig<'a, S> {
@@ -72,7 +100,7 @@ impl<'a, S: FieldImpl> NTTConfig<'a, S> {
are_inputs_on_device: false,
are_outputs_on_device: false,
is_async: false,
is_force_radix2: false,
ntt_algorithm: NttAlgorithm::Auto,
}
}
}

View File

@@ -6,7 +6,7 @@ use icicle_cuda_runtime::memory::HostOrDeviceSlice;
use icicle_cuda_runtime::stream::CudaStream;
use crate::{
ntt::{get_default_ntt_config, initialize_domain, ntt, NTTDir, Ordering},
ntt::{get_default_ntt_config, initialize_domain, ntt, NTTDir, NttAlgorithm, Ordering},
traits::{ArkConvertible, FieldImpl, GenerateRandom},
};
@@ -61,23 +61,26 @@ where
let scalars_mont = unsafe { &*(&ark_scalars[..] as *const _ as *const [F]) };
let scalars_mont_h = HostOrDeviceSlice::on_host(scalars_mont.to_vec());
let config = get_default_ntt_config();
let mut ntt_result = HostOrDeviceSlice::on_host(vec![F::zero(); test_size]);
ntt(&scalars_mont_h, NTTDir::kForward, &config, &mut ntt_result).unwrap();
assert_ne!(ntt_result.as_slice(), scalars_mont);
let mut config = get_default_ntt_config();
for alg in [NttAlgorithm::Radix2, NttAlgorithm::MixedRadix] {
config.ntt_algorithm = alg;
let mut ntt_result = HostOrDeviceSlice::on_host(vec![F::zero(); test_size]);
ntt(&scalars_mont_h, NTTDir::kForward, &config, &mut ntt_result).unwrap();
assert_ne!(ntt_result.as_slice(), scalars_mont);
let mut ark_ntt_result = ark_scalars.clone();
ark_domain.fft_in_place(&mut ark_ntt_result);
assert_ne!(ark_ntt_result, ark_scalars);
let mut ark_ntt_result = ark_scalars.clone();
ark_domain.fft_in_place(&mut ark_ntt_result);
assert_ne!(ark_ntt_result, ark_scalars);
let ntt_result_as_ark =
unsafe { &*(ntt_result.as_slice() as *const _ as *const [<F as ArkConvertible>::ArkEquivalent]) };
assert_eq!(ark_ntt_result, ntt_result_as_ark);
let ntt_result_as_ark =
unsafe { &*(ntt_result.as_slice() as *const _ as *const [<F as ArkConvertible>::ArkEquivalent]) };
assert_eq!(ark_ntt_result, ntt_result_as_ark);
let mut intt_result = HostOrDeviceSlice::on_host(vec![F::zero(); test_size]);
ntt(&ntt_result, NTTDir::kInverse, &config, &mut intt_result).unwrap();
let mut intt_result = HostOrDeviceSlice::on_host(vec![F::zero(); test_size]);
ntt(&ntt_result, NTTDir::kInverse, &config, &mut intt_result).unwrap();
assert_eq!(intt_result.as_slice(), scalars_mont);
assert_eq!(intt_result.as_slice(), scalars_mont);
}
}
}
@@ -103,55 +106,58 @@ where
.map(|v| v.to_ark())
.collect::<Vec<F::ArkEquivalent>>();
let mut config = get_default_ntt_config();
config.ordering = Ordering::kNR;
let mut ntt_result_1 = HostOrDeviceSlice::on_host(vec![F::zero(); small_size]);
let mut ntt_result_2 = HostOrDeviceSlice::on_host(vec![F::zero(); small_size]);
ntt(&scalars_h, NTTDir::kForward, &config, &mut ntt_result_1).unwrap();
assert_ne!(*ntt_result_1.as_slice(), scalars);
config.coset_gen = F::from_ark(test_size_rou);
ntt(&scalars_h, NTTDir::kForward, &config, &mut ntt_result_2).unwrap();
let mut ntt_large_result = HostOrDeviceSlice::on_host(vec![F::zero(); test_size]);
// back to non-coset NTT
config.coset_gen = F::one();
scalars.resize(test_size, F::zero());
ntt(
&HostOrDeviceSlice::on_host(scalars.clone()),
NTTDir::kForward,
&config,
&mut ntt_large_result,
)
.unwrap();
assert_eq!(*ntt_result_1.as_slice(), ntt_large_result.as_slice()[..small_size]);
assert_eq!(*ntt_result_2.as_slice(), ntt_large_result.as_slice()[small_size..]);
for alg in [NttAlgorithm::Radix2, NttAlgorithm::MixedRadix] {
let mut config = get_default_ntt_config();
config.ordering = Ordering::kNR;
config.ntt_algorithm = alg;
let mut ntt_result_1 = HostOrDeviceSlice::on_host(vec![F::zero(); small_size]);
let mut ntt_result_2 = HostOrDeviceSlice::on_host(vec![F::zero(); small_size]);
ntt(&scalars_h, NTTDir::kForward, &config, &mut ntt_result_1).unwrap();
assert_ne!(*ntt_result_1.as_slice(), scalars);
config.coset_gen = F::from_ark(test_size_rou);
ntt(&scalars_h, NTTDir::kForward, &config, &mut ntt_result_2).unwrap();
let mut ntt_large_result = HostOrDeviceSlice::on_host(vec![F::zero(); test_size]);
// back to non-coset NTT
config.coset_gen = F::one();
scalars.resize(test_size, F::zero());
ntt(
&HostOrDeviceSlice::on_host(scalars.clone()),
NTTDir::kForward,
&config,
&mut ntt_large_result,
)
.unwrap();
assert_eq!(*ntt_result_1.as_slice(), ntt_large_result.as_slice()[..small_size]);
assert_eq!(*ntt_result_2.as_slice(), ntt_large_result.as_slice()[small_size..]);
let mut ark_large_scalars = ark_scalars.clone();
ark_small_domain.fft_in_place(&mut ark_scalars);
let ntt_result_as_ark = ntt_large_result
.as_slice()
.iter()
.map(|p| p.to_ark())
.collect::<Vec<F::ArkEquivalent>>();
assert_eq!(
ark_scalars[..small_size],
list_to_reverse_bit_order(&ntt_result_as_ark[small_size..])
);
ark_large_domain.fft_in_place(&mut ark_large_scalars);
assert_eq!(ark_large_scalars, list_to_reverse_bit_order(&ntt_result_as_ark));
let mut ark_large_scalars = ark_scalars.clone();
ark_small_domain.fft_in_place(&mut ark_scalars);
let ntt_result_as_ark = ntt_large_result
.as_slice()
.iter()
.map(|p| p.to_ark())
.collect::<Vec<F::ArkEquivalent>>();
assert_eq!(
ark_scalars[..small_size],
list_to_reverse_bit_order(&ntt_result_as_ark[small_size..])
);
ark_large_domain.fft_in_place(&mut ark_large_scalars);
assert_eq!(ark_large_scalars, list_to_reverse_bit_order(&ntt_result_as_ark));
config.coset_gen = F::from_ark(test_size_rou);
config.ordering = Ordering::kRN;
let mut intt_result = HostOrDeviceSlice::on_host(vec![F::zero(); small_size]);
ntt(&ntt_result_2, NTTDir::kInverse, &config, &mut intt_result).unwrap();
assert_eq!(*intt_result.as_slice(), scalars[..small_size]);
config.coset_gen = F::from_ark(test_size_rou);
config.ordering = Ordering::kRN;
let mut intt_result = HostOrDeviceSlice::on_host(vec![F::zero(); small_size]);
ntt(&ntt_result_2, NTTDir::kInverse, &config, &mut intt_result).unwrap();
assert_eq!(*intt_result.as_slice(), scalars[..small_size]);
ark_small_domain.ifft_in_place(&mut ark_scalars);
let intt_result_as_ark = intt_result
.as_slice()
.iter()
.map(|p| p.to_ark())
.collect::<Vec<F::ArkEquivalent>>();
assert_eq!(ark_scalars[..small_size], intt_result_as_ark);
ark_small_domain.ifft_in_place(&mut ark_scalars);
let intt_result_as_ark = intt_result
.as_slice()
.iter()
.map(|p| p.to_ark())
.collect::<Vec<F::ArkEquivalent>>();
assert_eq!(ark_scalars[..small_size], intt_result_as_ark);
}
}
}
@@ -183,31 +189,34 @@ where
.collect::<Vec<F::ArkEquivalent>>();
let mut config = get_default_ntt_config();
config.ordering = Ordering::kNR;
config.coset_gen = F::from_ark(coset_gen);
let mut ntt_result = HostOrDeviceSlice::on_host(vec![F::zero(); test_size]);
ntt(&scalars, NTTDir::kForward, &config, &mut ntt_result).unwrap();
assert_ne!(scalars.as_slice(), ntt_result.as_slice());
for alg in [NttAlgorithm::Radix2, NttAlgorithm::MixedRadix] {
config.ordering = Ordering::kNR;
config.ntt_algorithm = alg;
let mut ntt_result = HostOrDeviceSlice::on_host(vec![F::zero(); test_size]);
ntt(&scalars, NTTDir::kForward, &config, &mut ntt_result).unwrap();
assert_ne!(scalars.as_slice(), ntt_result.as_slice());
let ark_scalars_copy = ark_scalars.clone();
ark_domain.fft_in_place(&mut ark_scalars);
let ntt_result_as_ark = ntt_result
.as_slice()
.iter()
.map(|p| p.to_ark())
.collect::<Vec<F::ArkEquivalent>>();
assert_eq!(ark_scalars, list_to_reverse_bit_order(&ntt_result_as_ark));
ark_domain.ifft_in_place(&mut ark_scalars);
assert_eq!(ark_scalars, ark_scalars_copy);
let ark_scalars_copy = ark_scalars.clone();
ark_domain.fft_in_place(&mut ark_scalars);
let ntt_result_as_ark = ntt_result
.as_slice()
.iter()
.map(|p| p.to_ark())
.collect::<Vec<F::ArkEquivalent>>();
assert_eq!(ark_scalars, list_to_reverse_bit_order(&ntt_result_as_ark));
ark_domain.ifft_in_place(&mut ark_scalars);
assert_eq!(ark_scalars, ark_scalars_copy);
config.ordering = Ordering::kRN;
ntt(&ntt_result, NTTDir::kInverse, &config, &mut scalars).unwrap();
let ntt_result_as_ark = scalars
.as_slice()
.iter()
.map(|p| p.to_ark())
.collect::<Vec<F::ArkEquivalent>>();
assert_eq!(ark_scalars, ntt_result_as_ark);
config.ordering = Ordering::kRN;
ntt(&ntt_result, NTTDir::kInverse, &config, &mut scalars).unwrap();
let ntt_result_as_ark = scalars
.as_slice()
.iter()
.map(|p| p.to_ark())
.collect::<Vec<F::ArkEquivalent>>();
assert_eq!(ark_scalars, ntt_result_as_ark);
}
}
}
}
@@ -216,7 +225,7 @@ pub fn check_ntt_batch<F: FieldImpl>()
where
<F as FieldImpl>::Config: NTT<F> + GenerateRandom<F>,
{
let test_sizes = [1 << 4, 1 << 14];
let test_sizes = [1 << 4, 1 << 12];
let batch_sizes = [1, 1 << 4, 100];
for test_size in test_sizes {
let coset_generators = [F::one(), F::Config::generate_random(1)[0]];
@@ -226,26 +235,37 @@ where
for coset_gen in coset_generators {
for is_inverse in [NTTDir::kInverse, NTTDir::kForward] {
for ordering in [Ordering::kNN, Ordering::kNR, Ordering::kRN, Ordering::kRR] {
for ordering in [
Ordering::kNN,
Ordering::kNR,
Ordering::kRN,
Ordering::kRR,
Ordering::kNM,
Ordering::kMN,
] {
config.coset_gen = coset_gen;
config.ordering = ordering;
config.batch_size = batch_size as i32;
let mut batch_ntt_result = HostOrDeviceSlice::on_host(vec![F::zero(); batch_size * test_size]);
ntt(&scalars, is_inverse, &config, &mut batch_ntt_result).unwrap();
config.batch_size = 1;
let mut one_ntt_result = HostOrDeviceSlice::on_host(vec![F::one(); test_size]);
for i in 0..batch_size {
ntt(
&HostOrDeviceSlice::on_host(scalars[i * test_size..(i + 1) * test_size].to_vec()),
is_inverse,
&config,
&mut one_ntt_result,
)
.unwrap();
assert_eq!(
batch_ntt_result[i * test_size..(i + 1) * test_size],
*one_ntt_result.as_slice()
);
for alg in [NttAlgorithm::Radix2, NttAlgorithm::MixedRadix] {
config.batch_size = batch_size as i32;
config.ntt_algorithm = alg;
let mut batch_ntt_result =
HostOrDeviceSlice::on_host(vec![F::zero(); batch_size * test_size]);
ntt(&scalars, is_inverse, &config, &mut batch_ntt_result).unwrap();
config.batch_size = 1;
let mut one_ntt_result = HostOrDeviceSlice::on_host(vec![F::one(); test_size]);
for i in 0..batch_size {
ntt(
&HostOrDeviceSlice::on_host(scalars[i * test_size..(i + 1) * test_size].to_vec()),
is_inverse,
&config,
&mut one_ntt_result,
)
.unwrap();
assert_eq!(
batch_ntt_result[i * test_size..(i + 1) * test_size],
*one_ntt_result.as_slice()
);
}
}
}
}
@@ -286,22 +306,25 @@ where
config
.ctx
.stream = &stream;
ntt(&scalars_d, NTTDir::kForward, &config, &mut ntt_out_d).unwrap();
ntt(&ntt_out_d, NTTDir::kInverse, &config, &mut scalars_d).unwrap();
let mut intt_result_h = vec![F::zero(); test_size * batch_size];
scalars_d
.copy_to_host_async(&mut intt_result_h, &stream)
.unwrap();
stream
.synchronize()
.unwrap();
assert_eq!(scalars_h, intt_result_h);
if coset_gen == F::one() {
let mut ntt_result_h = vec![F::zero(); test_size * batch_size];
ntt_out_d
.copy_to_host(&mut ntt_result_h)
for alg in [NttAlgorithm::Radix2, NttAlgorithm::MixedRadix] {
config.ntt_algorithm = alg;
ntt(&scalars_d, NTTDir::kForward, &config, &mut ntt_out_d).unwrap();
ntt(&ntt_out_d, NTTDir::kInverse, &config, &mut scalars_d).unwrap();
let mut intt_result_h = vec![F::zero(); test_size * batch_size];
scalars_d
.copy_to_host_async(&mut intt_result_h, &stream)
.unwrap();
assert_eq!(sum_of_coeffs, ntt_result_h[0].to_ark());
stream
.synchronize()
.unwrap();
assert_eq!(scalars_h, intt_result_h);
if coset_gen == F::one() {
let mut ntt_result_h = vec![F::zero(); test_size * batch_size];
ntt_out_d
.copy_to_host(&mut ntt_result_h)
.unwrap();
assert_eq!(sum_of_coeffs, ntt_result_h[0].to_ark());
}
}
}
}