mirror of
https://github.com/pseXperiments/icicle.git
synced 2026-01-10 05:28:01 -05:00
Mixed-radix NTT support all orderings (#371)
- Mixed-radix NTT orderings support - radix-2 small refactor: split core logic to function and renamed ct_butterfly to dit - testing both radix2 and mixed-radix algs for all ntt tests
This commit is contained in:
@@ -46,10 +46,10 @@ int main(int argc, char** argv)
|
||||
|
||||
// init domain
|
||||
auto ntt_config = ntt::DefaultNTTConfig<test_scalar>();
|
||||
ntt_config.ordering = ntt::Ordering::kNN; // TODO: use NR for forward and RN for backward
|
||||
ntt_config.is_force_radix2 = (argc > 1) ? atoi(argv[1]) : false;
|
||||
const bool is_radix2_alg = (argc > 1) ? atoi(argv[1]) : false;
|
||||
ntt_config.ntt_algorithm = is_radix2_alg ? ntt::NttAlgorithm::Radix2 : ntt::NttAlgorithm::MixedRadix;
|
||||
|
||||
const char* ntt_alg_str = ntt_config.is_force_radix2 ? "Radix-2" : "Mixed-Radix";
|
||||
const char* ntt_alg_str = is_radix2_alg ? "Radix-2" : "Mixed-Radix";
|
||||
std::cout << "Polynomial multiplication with " << ntt_alg_str << " NTT: ";
|
||||
|
||||
CHK_IF_RETURN(cudaEventCreate(&start));
|
||||
@@ -78,6 +78,7 @@ int main(int argc, char** argv)
|
||||
// (3) NTT for A,B from cpu to gpu
|
||||
ntt_config.are_inputs_on_device = false;
|
||||
ntt_config.are_outputs_on_device = true;
|
||||
ntt_config.ordering = ntt::Ordering::kNM;
|
||||
CHK_IF_RETURN(ntt::NTT(CpuA.get(), NTT_SIZE, ntt::NTTDir::kForward, ntt_config, GpuA));
|
||||
CHK_IF_RETURN(ntt::NTT(CpuB.get(), NTT_SIZE, ntt::NTTDir::kForward, ntt_config, GpuB));
|
||||
|
||||
@@ -89,6 +90,7 @@ int main(int argc, char** argv)
|
||||
// (5) INTT (in place)
|
||||
ntt_config.are_inputs_on_device = true;
|
||||
ntt_config.are_outputs_on_device = true;
|
||||
ntt_config.ordering = ntt::Ordering::kMN;
|
||||
CHK_IF_RETURN(ntt::NTT(MulGpu, NTT_SIZE, ntt::NTTDir::kInverse, ntt_config, MulGpu));
|
||||
|
||||
CHK_IF_RETURN(cudaFreeAsync(GpuA, ntt_config.ctx.stream));
|
||||
|
||||
@@ -6,7 +6,7 @@
|
||||
|
||||
namespace ntt {
|
||||
|
||||
static __device__ uint32_t dig_rev(uint32_t num, uint32_t log_size, bool dit)
|
||||
static inline __device__ uint32_t dig_rev(uint32_t num, uint32_t log_size, bool dit)
|
||||
{
|
||||
uint32_t rev_num = 0, temp, dig_len;
|
||||
if (dit) {
|
||||
@@ -29,10 +29,34 @@ namespace ntt {
|
||||
return rev_num;
|
||||
}
|
||||
|
||||
static inline __device__ uint32_t bit_rev(uint32_t num, uint32_t log_size) { return __brev(num) >> (32 - log_size); }
|
||||
|
||||
enum eRevType { None, RevToMixedRev, MixedRevToRev, NaturalToMixedRev, NaturalToRev, MixedRevToNatural };
|
||||
|
||||
static __device__ uint32_t generalized_rev(uint32_t num, uint32_t log_size, bool dit, eRevType rev_type)
|
||||
{
|
||||
switch (rev_type) {
|
||||
case eRevType::RevToMixedRev:
|
||||
// R -> N -> MR
|
||||
return dig_rev(bit_rev(num, log_size), log_size, dit);
|
||||
case eRevType::MixedRevToRev:
|
||||
// MR -> N -> R
|
||||
return bit_rev(dig_rev(num, log_size, dit), log_size);
|
||||
case eRevType::NaturalToMixedRev:
|
||||
case eRevType::MixedRevToNatural:
|
||||
return dig_rev(num, log_size, dit);
|
||||
case eRevType::NaturalToRev:
|
||||
return bit_rev(num, log_size);
|
||||
default:
|
||||
return num;
|
||||
}
|
||||
return num;
|
||||
}
|
||||
|
||||
// Note: the following reorder kernels are fused with normalization for INTT
|
||||
template <typename E, typename S, uint32_t MAX_GROUP_SIZE = 80>
|
||||
static __global__ void
|
||||
reorder_digits_inplace_and_normalize_kernel(E* arr, uint32_t log_size, bool dit, bool is_normalize, S inverse_N)
|
||||
static __global__ void reorder_digits_inplace_and_normalize_kernel(
|
||||
E* arr, uint32_t log_size, bool dit, eRevType rev_type, bool is_normalize, S inverse_N)
|
||||
{
|
||||
// launch N threads (per batch element)
|
||||
// each thread starts from one index and calculates the corresponding group
|
||||
@@ -50,7 +74,7 @@ namespace ntt {
|
||||
|
||||
uint32_t i = 1;
|
||||
for (; i < MAX_GROUP_SIZE;) {
|
||||
next_element = dig_rev(next_element, log_size, dit);
|
||||
next_element = generalized_rev(next_element, log_size, dit, rev_type);
|
||||
if (next_element < idx) return; // not handling this group
|
||||
if (next_element == idx) break; // calculated whole group
|
||||
group[i++] = next_element + size * batch_idx;
|
||||
@@ -67,16 +91,17 @@ namespace ntt {
|
||||
|
||||
template <typename E, typename S>
|
||||
__launch_bounds__(64) __global__ void reorder_digits_and_normalize_kernel(
|
||||
E* arr, E* arr_reordered, uint32_t log_size, bool dit, bool is_normalize, S inverse_N)
|
||||
E* arr, E* arr_reordered, uint32_t log_size, bool dit, eRevType rev_type, bool is_normalize, S inverse_N)
|
||||
{
|
||||
uint32_t tid = blockDim.x * blockIdx.x + threadIdx.x;
|
||||
uint32_t rd = tid;
|
||||
uint32_t wr = ((tid >> log_size) << log_size) + dig_rev(tid & ((1 << log_size) - 1), log_size, dit);
|
||||
uint32_t wr =
|
||||
((tid >> log_size) << log_size) + generalized_rev(tid & ((1 << log_size) - 1), log_size, dit, rev_type);
|
||||
arr_reordered[wr] = is_normalize ? arr[rd] * inverse_N : arr[rd];
|
||||
}
|
||||
|
||||
template <typename E, typename S>
|
||||
static __global__ void BatchMulKernelDigReverse(
|
||||
static __global__ void batch_elementwise_mul_with_reorder(
|
||||
E* in_vec,
|
||||
int n_elements,
|
||||
int batch_size,
|
||||
@@ -84,14 +109,14 @@ namespace ntt {
|
||||
int step,
|
||||
int n_scalars,
|
||||
int logn,
|
||||
bool digit_rev,
|
||||
eRevType rev_type,
|
||||
bool dit,
|
||||
E* out_vec)
|
||||
{
|
||||
int tid = blockDim.x * blockIdx.x + threadIdx.x;
|
||||
if (tid >= n_elements * batch_size) return;
|
||||
int64_t scalar_id = tid % n_elements;
|
||||
if (digit_rev) scalar_id = dig_rev(tid, logn, dit);
|
||||
if (rev_type != eRevType::None) scalar_id = generalized_rev(tid, logn, dit, rev_type);
|
||||
out_vec[tid] = *(scalar_vec + ((scalar_id * step) % n_scalars)) * in_vec[tid];
|
||||
}
|
||||
|
||||
@@ -630,56 +655,80 @@ namespace ntt {
|
||||
{
|
||||
CHK_INIT_IF_RETURN();
|
||||
|
||||
// TODO: can we support all orderings? Note that reversal is generally digit reverse (generalization of bit reverse)
|
||||
if (ordering != Ordering::kNN) {
|
||||
throw IcicleError(IcicleError_t::InvalidArgument, "Mixed-Radix NTT supports NN ordering only");
|
||||
}
|
||||
|
||||
const int logn = int(log2(ntt_size));
|
||||
|
||||
const int NOF_BLOCKS = ((1 << logn) * batch_size + 64 - 1) / 64;
|
||||
const int NOF_THREADS = min(64, (1 << logn) * batch_size);
|
||||
|
||||
// Note: dif is slightly faster than dit but since reordering is a post-process stage, it must be computed in-place
|
||||
// which make it slower e2e in most cases. dit reorders as a pre-process stage and therefore supports both in-place
|
||||
// and out-of-place (when in!=out);
|
||||
const bool reverse_input = ordering == Ordering::kNN;
|
||||
const bool is_dit = ordering == Ordering::kNN || ordering == Ordering::kRN;
|
||||
bool is_normalize = is_inverse;
|
||||
const bool is_on_coset = (coset_gen_index != 0) || arbitrary_coset;
|
||||
const int n_twiddles = 1 << max_logn;
|
||||
// Note: for evaluation on coset, need to reorder the coset too to match the data for element-wise multiplication
|
||||
eRevType reverse_input = None, reverse_output = None, reverse_coset = None;
|
||||
bool dit = false;
|
||||
switch (ordering) {
|
||||
case Ordering::kNN:
|
||||
reverse_input = eRevType::NaturalToMixedRev;
|
||||
dit = true;
|
||||
break;
|
||||
case Ordering::kRN:
|
||||
reverse_input = eRevType::RevToMixedRev;
|
||||
dit = true;
|
||||
reverse_coset = is_inverse ? eRevType::None : eRevType::NaturalToRev;
|
||||
break;
|
||||
case Ordering::kNR:
|
||||
reverse_output = eRevType::MixedRevToRev;
|
||||
reverse_coset = is_inverse ? eRevType::NaturalToRev : eRevType::None;
|
||||
break;
|
||||
case Ordering::kRR:
|
||||
reverse_input = eRevType::RevToMixedRev;
|
||||
dit = true;
|
||||
reverse_output = eRevType::NaturalToRev;
|
||||
reverse_coset = eRevType::NaturalToRev;
|
||||
break;
|
||||
case Ordering::kMN:
|
||||
dit = true;
|
||||
reverse_coset = is_inverse ? None : eRevType::NaturalToMixedRev;
|
||||
break;
|
||||
case Ordering::kNM:
|
||||
reverse_coset = is_inverse ? eRevType::NaturalToMixedRev : eRevType::None;
|
||||
break;
|
||||
}
|
||||
|
||||
// TODO: fuse BatchMulKernelDigReverse with input reorder (and normalize)?
|
||||
if (is_on_coset && !is_inverse) {
|
||||
BatchMulKernelDigReverse<<<NOF_BLOCKS, NOF_THREADS, 0, cuda_stream>>>(
|
||||
batch_elementwise_mul_with_reorder<<<NOF_BLOCKS, NOF_THREADS, 0, cuda_stream>>>(
|
||||
d_input, ntt_size, batch_size, arbitrary_coset ? arbitrary_coset : external_twiddles,
|
||||
arbitrary_coset ? 1 : coset_gen_index, n_twiddles, logn, false /*digit_rev*/, is_dit, d_output);
|
||||
arbitrary_coset ? 1 : coset_gen_index, n_twiddles, logn, reverse_coset, dit, d_output);
|
||||
|
||||
d_input = d_output;
|
||||
}
|
||||
|
||||
if (reverse_input) {
|
||||
// Note: fused reorder and normalize (for INTT)
|
||||
if (reverse_input != eRevType::None) {
|
||||
const bool is_reverse_in_place = (d_input == d_output);
|
||||
if (is_reverse_in_place) {
|
||||
reorder_digits_inplace_and_normalize_kernel<<<NOF_BLOCKS, NOF_THREADS, 0, cuda_stream>>>(
|
||||
d_output, logn, is_dit, is_normalize, S::inv_log_size(logn));
|
||||
d_output, logn, dit, reverse_input, is_normalize, S::inv_log_size(logn));
|
||||
} else {
|
||||
reorder_digits_and_normalize_kernel<<<NOF_BLOCKS, NOF_THREADS, 0, cuda_stream>>>(
|
||||
d_input, d_output, logn, is_dit, is_normalize, S::inv_log_size(logn));
|
||||
d_input, d_output, logn, dit, reverse_input, is_normalize, S::inv_log_size(logn));
|
||||
}
|
||||
is_normalize = false;
|
||||
d_input = d_output;
|
||||
}
|
||||
|
||||
// inplace ntt
|
||||
CHK_IF_RETURN(large_ntt(
|
||||
d_output, d_output, external_twiddles, internal_twiddles, basic_twiddles, logn, max_logn, batch_size, is_inverse,
|
||||
is_normalize, is_dit, cuda_stream));
|
||||
d_input, d_output, external_twiddles, internal_twiddles, basic_twiddles, logn, max_logn, batch_size, is_inverse,
|
||||
(is_normalize && reverse_output == eRevType::None), dit, cuda_stream));
|
||||
|
||||
if (reverse_output != eRevType::None) {
|
||||
reorder_digits_inplace_and_normalize_kernel<<<NOF_BLOCKS, NOF_THREADS, 0, cuda_stream>>>(
|
||||
d_output, logn, dit, reverse_output, is_normalize, S::inv_log_size(logn));
|
||||
}
|
||||
|
||||
if (is_on_coset && is_inverse) {
|
||||
BatchMulKernelDigReverse<<<NOF_BLOCKS, NOF_THREADS, 0, cuda_stream>>>(
|
||||
batch_elementwise_mul_with_reorder<<<NOF_BLOCKS, NOF_THREADS, 0, cuda_stream>>>(
|
||||
d_output, ntt_size, batch_size, arbitrary_coset ? arbitrary_coset : external_twiddles + n_twiddles,
|
||||
arbitrary_coset ? 1 : -coset_gen_index, n_twiddles, logn, false /*digit_rev*/, is_dit, d_output);
|
||||
arbitrary_coset ? 1 : -coset_gen_index, n_twiddles, logn, reverse_coset, dit, d_output);
|
||||
}
|
||||
|
||||
return CHK_LAST();
|
||||
|
||||
@@ -278,7 +278,7 @@ namespace ntt {
|
||||
int batch_size,
|
||||
int logn,
|
||||
bool inverse,
|
||||
bool ct_buttterfly,
|
||||
bool dit,
|
||||
S* arbitrary_coset,
|
||||
int coset_gen_index,
|
||||
cudaStream_t stream,
|
||||
@@ -309,9 +309,9 @@ namespace ntt {
|
||||
if (direct_coset)
|
||||
utils_internal::BatchMulKernel<E, S><<<num_blocks_coset, num_threads_coset, 0, stream>>>(
|
||||
d_input, n, batch_size, arbitrary_coset ? arbitrary_coset : d_twiddles, arbitrary_coset ? 1 : coset_gen_index,
|
||||
n_twiddles, logn, ct_buttterfly, d_output);
|
||||
n_twiddles, logn, dit, d_output);
|
||||
|
||||
if (ct_buttterfly) {
|
||||
if (dit) {
|
||||
if (is_shared_mem_enabled)
|
||||
ntt_template_kernel_shared<<<num_blocks, num_threads, shared_mem, stream>>>(
|
||||
direct_coset ? d_output : d_input, 1 << logn_shmem, d_twiddles, n_twiddles, total_tasks, 0, logn_shmem,
|
||||
@@ -340,7 +340,7 @@ namespace ntt {
|
||||
if (is_on_coset)
|
||||
utils_internal::BatchMulKernel<E, S><<<num_blocks_coset, num_threads_coset, 0, stream>>>(
|
||||
d_output, n, batch_size, arbitrary_coset ? arbitrary_coset : d_twiddles,
|
||||
arbitrary_coset ? 1 : -coset_gen_index, -n_twiddles, logn, !ct_buttterfly, d_output);
|
||||
arbitrary_coset ? 1 : -coset_gen_index, -n_twiddles, logn, !dit, d_output);
|
||||
|
||||
utils_internal::NormalizeKernel<E, S>
|
||||
<<<num_blocks_coset, num_threads_coset, 0, stream>>>(d_output, S::inv_log_size(logn), n * batch_size);
|
||||
@@ -452,6 +452,76 @@ namespace ntt {
|
||||
return CHK_LAST();
|
||||
}
|
||||
|
||||
template <typename S>
|
||||
static bool is_choose_radix2_algorithm(int logn, int batch_size, const NTTConfig<S>& config)
|
||||
{
|
||||
const bool is_mixed_radix_alg_supported = (logn > 3 && logn != 7);
|
||||
const bool is_user_selected_radix2_alg = config.ntt_algorithm == NttAlgorithm::Radix2;
|
||||
const bool is_force_radix2 = !is_mixed_radix_alg_supported || is_user_selected_radix2_alg;
|
||||
if (is_force_radix2) return true;
|
||||
|
||||
const bool is_user_selected_mixed_radix_alg = config.ntt_algorithm == NttAlgorithm::MixedRadix;
|
||||
if (is_user_selected_mixed_radix_alg) return false;
|
||||
|
||||
// Heuristic to automatically select an algorithm
|
||||
// Note that generally the decision depends on {logn, batch, ordering, inverse, coset, in-place, coeff-field} and
|
||||
// the specific GPU.
|
||||
// the following heuristic is a simplification based on measurements. Users can try both and select the algorithm
|
||||
// based on the specific case via the 'NTTConfig.ntt_algorithm' field
|
||||
|
||||
if (logn >= 16) return false; // mixed-radix is typically faster in those cases
|
||||
if (logn <= 11) return true; // radix-2 is typically faster for batch<=256 in those cases
|
||||
const int log_batch = (int)log2(batch_size);
|
||||
return (logn + log_batch <= 18); // almost the cutoff point where both are equal
|
||||
}
|
||||
|
||||
template <typename S, typename E>
|
||||
cudaError_t radix2_ntt(
|
||||
E* d_input,
|
||||
E* d_output,
|
||||
S* twiddles,
|
||||
int ntt_size,
|
||||
int max_size,
|
||||
int batch_size,
|
||||
bool is_inverse,
|
||||
Ordering ordering,
|
||||
S* arbitrary_coset,
|
||||
int coset_gen_index,
|
||||
cudaStream_t cuda_stream)
|
||||
{
|
||||
CHK_INIT_IF_RETURN();
|
||||
|
||||
const int logn = int(log2(ntt_size));
|
||||
|
||||
bool dit = true;
|
||||
bool reverse_input = false;
|
||||
switch (ordering) {
|
||||
case Ordering::kNN:
|
||||
reverse_input = true;
|
||||
break;
|
||||
case Ordering::kNR:
|
||||
case Ordering::kNM:
|
||||
dit = false;
|
||||
break;
|
||||
case Ordering::kRR:
|
||||
reverse_input = true;
|
||||
dit = false;
|
||||
break;
|
||||
case Ordering::kRN:
|
||||
case Ordering::kMN:
|
||||
dit = true;
|
||||
reverse_input = false;
|
||||
}
|
||||
|
||||
if (reverse_input) reverse_order_batch(d_input, ntt_size, logn, batch_size, cuda_stream, d_output);
|
||||
|
||||
CHK_IF_RETURN(ntt_inplace_batch_template(
|
||||
reverse_input ? d_output : d_input, ntt_size, twiddles, max_size, batch_size, logn, is_inverse, dit,
|
||||
arbitrary_coset, coset_gen_index, cuda_stream, d_output));
|
||||
|
||||
return CHK_LAST();
|
||||
}
|
||||
|
||||
template <typename S, typename E>
|
||||
cudaError_t NTT(E* input, int size, NTTDir dir, NTTConfig<S>& config, E* output)
|
||||
{
|
||||
@@ -501,37 +571,17 @@ namespace ntt {
|
||||
h_coset.clear();
|
||||
}
|
||||
|
||||
// (heuristic) cutoff point where mixed-radix is faster than radix-2
|
||||
const bool is_small_ntt = (logn < 16) && ((size_t)size * batch_size < (1 << 20));
|
||||
const bool is_NN = config.ordering == Ordering::kNN; // TODO Yuval: relax this limitation
|
||||
const bool is_radix2_algorithm = config.is_force_radix2 || is_small_ntt || !is_NN;
|
||||
const bool is_radix2_algorithm = is_choose_radix2_algorithm(logn, batch_size, config);
|
||||
const bool is_inverse = dir == NTTDir::kInverse;
|
||||
|
||||
if (is_radix2_algorithm) {
|
||||
bool ct_butterfly = true;
|
||||
bool reverse_input = false;
|
||||
switch (config.ordering) {
|
||||
case Ordering::kNN:
|
||||
reverse_input = true;
|
||||
break;
|
||||
case Ordering::kNR:
|
||||
ct_butterfly = false;
|
||||
break;
|
||||
case Ordering::kRR:
|
||||
reverse_input = true;
|
||||
ct_butterfly = false;
|
||||
break;
|
||||
}
|
||||
|
||||
if (reverse_input) reverse_order_batch(d_input, size, logn, batch_size, stream, d_output);
|
||||
|
||||
CHK_IF_RETURN(ntt_inplace_batch_template(
|
||||
reverse_input ? d_output : d_input, size, Domain<S>::twiddles, Domain<S>::max_size, batch_size, logn,
|
||||
dir == NTTDir::kInverse, ct_butterfly, coset, coset_index, stream, d_output));
|
||||
|
||||
} else { // mixed-radix algorithm
|
||||
CHK_IF_RETURN(ntt::radix2_ntt(
|
||||
d_input, d_output, Domain<S>::twiddles, size, Domain<S>::max_size, batch_size, is_inverse, config.ordering,
|
||||
coset, coset_index, stream));
|
||||
} else {
|
||||
CHK_IF_RETURN(ntt::mixed_radix_ntt(
|
||||
d_input, d_output, Domain<S>::twiddles, Domain<S>::internal_twiddles, Domain<S>::basic_twiddles, size,
|
||||
Domain<S>::max_log_size, batch_size, dir == NTTDir::kInverse, config.ordering, coset, coset_index, stream));
|
||||
Domain<S>::max_log_size, batch_size, is_inverse, config.ordering, coset, coset_index, stream));
|
||||
}
|
||||
|
||||
if (!are_outputs_on_device)
|
||||
@@ -550,14 +600,14 @@ namespace ntt {
|
||||
{
|
||||
device_context::DeviceContext ctx = device_context::get_default_device_context();
|
||||
NTTConfig<S> config = {
|
||||
ctx, // ctx
|
||||
S::one(), // coset_gen
|
||||
1, // batch_size
|
||||
Ordering::kNN, // ordering
|
||||
false, // are_inputs_on_device
|
||||
false, // are_outputs_on_device
|
||||
false, // is_async
|
||||
false, // is_force_radix2
|
||||
ctx, // ctx
|
||||
S::one(), // coset_gen
|
||||
1, // batch_size
|
||||
Ordering::kNN, // ordering
|
||||
false, // are_inputs_on_device
|
||||
false, // are_outputs_on_device
|
||||
false, // is_async
|
||||
NttAlgorithm::Auto, // ntt_algorithm
|
||||
};
|
||||
return config;
|
||||
}
|
||||
|
||||
@@ -59,8 +59,28 @@ namespace ntt {
|
||||
* a_4, a_2, a_6, a_1, a_5, a_3, a_7\} \f$).
|
||||
* - kRN: inputs are bit-reversed-order and outputs are natural-order.
|
||||
* - kRR: inputs and outputs are bit-reversed-order.
|
||||
*
|
||||
* Mixed-Radix NTT: digit-reversal is a generalization of bit-reversal where the latter is a special case with 1b
|
||||
* digits. Mixed-radix NTTs of different sizes would generate different reordering of inputs/outputs. Having said
|
||||
* that, for a given size N it is guaranteed that every two mixed-radix NTTs of size N would have the same
|
||||
* digit-reversal pattern. The following orderings kNM and kMN are conceptually like kNR and kRN but for
|
||||
* mixed-digit-reordering. Note that for the cases '(1) NTT, (2) elementwise ops and (3) INTT' kNM and kMN are most
|
||||
* efficient.
|
||||
* Note: kNR, kRN, kRR refer to the radix-2 NTT reversal pattern. Those cases are supported by mixed-radix NTT with
|
||||
* reduced efficiency compared to kNM and kMN.
|
||||
* - kNM: inputs are natural-order and outputs are digit-reversed-order (=mixed).
|
||||
* - kMN: inputs are digit-reversed-order (=mixed) and outputs are natural-order.
|
||||
*/
|
||||
enum class Ordering { kNN, kNR, kRN, kRR };
|
||||
enum class Ordering { kNN, kNR, kRN, kRR, kNM, kMN };
|
||||
|
||||
/**
|
||||
* @enum NttAlgorithm
|
||||
* Which NTT algorithm to use. options are:
|
||||
* - Auto: implementation selects automatically based on heuristic. This value is a good default for most cases.
|
||||
* - Radix2: explicitly select radix-2 NTT algorithm
|
||||
* - MixedRadix: explicitly select mixed-radix NTT algorithm
|
||||
*/
|
||||
enum class NttAlgorithm { Auto, Radix2, MixedRadix };
|
||||
|
||||
/**
|
||||
* @struct NTTConfig
|
||||
@@ -80,8 +100,8 @@ namespace ntt {
|
||||
* non-blocking and you'd need to synchronize it explicitly by running
|
||||
* `cudaStreamSynchronize` or `cudaDeviceSynchronize`. If set to false, the NTT
|
||||
* function will block the current CPU thread. */
|
||||
bool is_force_radix2; /**< Explicitly select radix-2 NTT algorithm. Default value: false (the implementation selects
|
||||
radix-2 or mixed-radix algorithm based on heuristics). */
|
||||
NttAlgorithm ntt_algorithm; /**< Explicitly select the NTT algorithm. Default value: Auto (the implementation
|
||||
selects radix-2 or mixed-radix algorithm based on heuristics). */
|
||||
};
|
||||
|
||||
/**
|
||||
|
||||
@@ -35,22 +35,25 @@ int main(int argc, char** argv)
|
||||
cudaEvent_t icicle_start, icicle_stop, new_start, new_stop;
|
||||
float icicle_time, new_time;
|
||||
|
||||
int NTT_LOG_SIZE = (argc > 1) ? atoi(argv[1]) : 16; // assuming second input is the log-size
|
||||
int NTT_LOG_SIZE = (argc > 1) ? atoi(argv[1]) : 4; // assuming second input is the log-size
|
||||
int NTT_SIZE = 1 << NTT_LOG_SIZE;
|
||||
bool INPLACE = (argc > 2) ? atoi(argv[2]) : false;
|
||||
int INV = (argc > 3) ? atoi(argv[3]) : false;
|
||||
int BATCH_SIZE = (argc > 4) ? atoi(argv[4]) : 32;
|
||||
int COSET_IDX = (argc > 5) ? atoi(argv[5]) : 1;
|
||||
int BATCH_SIZE = (argc > 4) ? atoi(argv[4]) : 1;
|
||||
int COSET_IDX = (argc > 5) ? atoi(argv[5]) : 0;
|
||||
const ntt::Ordering ordering = (argc > 6) ? ntt::Ordering(atoi(argv[6])) : ntt::Ordering::kNN;
|
||||
|
||||
const ntt::Ordering ordering = ntt::Ordering::kNN;
|
||||
// Note: NM, MN are not expected to be equal when comparing mixed-radix and radix-2 NTTs
|
||||
const char* ordering_str = ordering == ntt::Ordering::kNN ? "NN"
|
||||
: ordering == ntt::Ordering::kNR ? "NR"
|
||||
: ordering == ntt::Ordering::kRN ? "RN"
|
||||
: "RR";
|
||||
: ordering == ntt::Ordering::kRR ? "RR"
|
||||
: ordering == ntt::Ordering::kNM ? "NM"
|
||||
: "MN";
|
||||
|
||||
printf(
|
||||
"running ntt 2^%d, ordering=%s, inplace=%d, inverse=%d, batch_size=%d, coset-idx=%d\n", NTT_LOG_SIZE, ordering_str,
|
||||
INPLACE, INV, BATCH_SIZE, COSET_IDX);
|
||||
"running ntt 2^%d, inplace=%d, inverse=%d, batch_size=%d, coset-idx=%d, ordering=%s\n", NTT_LOG_SIZE, INPLACE, INV,
|
||||
BATCH_SIZE, COSET_IDX, ordering_str);
|
||||
|
||||
CHK_IF_RETURN(cudaFree(nullptr)); // init GPU context (warmup)
|
||||
|
||||
@@ -103,7 +106,7 @@ int main(int argc, char** argv)
|
||||
auto benchmark = [&](bool is_print, int iterations) -> cudaError_t {
|
||||
// NEW
|
||||
CHK_IF_RETURN(cudaEventRecord(new_start, ntt_config.ctx.stream));
|
||||
ntt_config.is_force_radix2 = false; // mixed-radix ntt (a.k.a new ntt)
|
||||
ntt_config.ntt_algorithm = ntt::NttAlgorithm::MixedRadix;
|
||||
for (size_t i = 0; i < iterations; i++) {
|
||||
ntt::NTT(
|
||||
INPLACE ? GpuOutputNew : GpuScalars, NTT_SIZE, INV ? ntt::NTTDir::kInverse : ntt::NTTDir::kForward, ntt_config,
|
||||
@@ -116,7 +119,7 @@ int main(int argc, char** argv)
|
||||
|
||||
// OLD
|
||||
CHK_IF_RETURN(cudaEventRecord(icicle_start, ntt_config.ctx.stream));
|
||||
ntt_config.is_force_radix2 = true;
|
||||
ntt_config.ntt_algorithm = ntt::NttAlgorithm::Radix2;
|
||||
for (size_t i = 0; i < iterations; i++) {
|
||||
ntt::NTT(GpuScalars, NTT_SIZE, INV ? ntt::NTTDir::kInverse : ntt::NTTDir::kForward, ntt_config, GpuOutputOld);
|
||||
}
|
||||
|
||||
@@ -30,6 +30,17 @@ pub enum NTTDir {
|
||||
/// a_4, a_2, a_6, a_1, a_5, a_3, a_7`.
|
||||
/// - kRN: inputs are bit-reversed-order and outputs are natural-order.
|
||||
/// - kRR: inputs and outputs are bit-reversed-order.
|
||||
///
|
||||
/// Mixed-Radix NTT: digit-reversal is a generalization of bit-reversal where the latter is a special case with 1b
|
||||
/// digits. Mixed-radix NTTs of different sizes would generate different reordering of inputs/outputs. Having said
|
||||
/// that, for a given size N it is guaranteed that every two mixed-radix NTTs of size N would have the same
|
||||
/// digit-reversal pattern. The following orderings kNM and kMN are conceptually like kNR and kRN but for
|
||||
/// mixed-digit-reordering. Note that for the cases '(1) NTT, (2) elementwise ops and (3) INTT' kNM and kMN are most
|
||||
/// efficient.
|
||||
/// Note: kNR, kRN, kRR refer to the radix-2 NTT reversal pattern. Those cases are supported by mixed-radix NTT with
|
||||
/// reduced efficiency compared to kNM and kMN.
|
||||
/// - kNM: inputs are natural-order and outputs are digit-reversed-order (=mixed).
|
||||
/// - kMN: inputs are digit-reversed-order (=mixed) and outputs are natural-order.
|
||||
#[allow(non_camel_case_types)]
|
||||
#[repr(C)]
|
||||
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
|
||||
@@ -38,6 +49,22 @@ pub enum Ordering {
|
||||
kNR,
|
||||
kRN,
|
||||
kRR,
|
||||
kNM,
|
||||
kMN,
|
||||
}
|
||||
|
||||
///Which NTT algorithm to use. options are:
|
||||
///- Auto: implementation selects automatically based on heuristic. This value is a good default for most cases.
|
||||
///- Radix2: explicitly select radix-2 NTT algorithm
|
||||
///- MixedRadix: explicitly select mixed-radix NTT algorithm
|
||||
///
|
||||
#[allow(non_camel_case_types)]
|
||||
#[repr(C)]
|
||||
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
|
||||
pub enum NttAlgorithm {
|
||||
Auto,
|
||||
Radix2,
|
||||
MixedRadix,
|
||||
}
|
||||
|
||||
/// Struct that encodes NTT parameters to be passed into the [ntt](ntt) function.
|
||||
@@ -57,8 +84,9 @@ pub struct NTTConfig<'a, S> {
|
||||
/// Whether to run the NTT asynchronously. If set to `true`, the NTT function will be non-blocking and you'd need to synchronize
|
||||
/// it explicitly by running `stream.synchronize()`. If set to false, the NTT function will block the current CPU thread.
|
||||
pub is_async: bool,
|
||||
/// Explicitly select radix-2 NTT algorithm. Default value: false (the implementation selects radix-2 or mixed-radix algorithm based on heuristics).
|
||||
pub is_force_radix2: bool,
|
||||
/// Explicitly select the NTT algorithm. Default value: Auto (the implementation selects radix-2 or mixed-radix algorithm based
|
||||
/// on heuristics
|
||||
pub ntt_algorithm: NttAlgorithm,
|
||||
}
|
||||
|
||||
impl<'a, S: FieldImpl> NTTConfig<'a, S> {
|
||||
@@ -72,7 +100,7 @@ impl<'a, S: FieldImpl> NTTConfig<'a, S> {
|
||||
are_inputs_on_device: false,
|
||||
are_outputs_on_device: false,
|
||||
is_async: false,
|
||||
is_force_radix2: false,
|
||||
ntt_algorithm: NttAlgorithm::Auto,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -6,7 +6,7 @@ use icicle_cuda_runtime::memory::HostOrDeviceSlice;
|
||||
use icicle_cuda_runtime::stream::CudaStream;
|
||||
|
||||
use crate::{
|
||||
ntt::{get_default_ntt_config, initialize_domain, ntt, NTTDir, Ordering},
|
||||
ntt::{get_default_ntt_config, initialize_domain, ntt, NTTDir, NttAlgorithm, Ordering},
|
||||
traits::{ArkConvertible, FieldImpl, GenerateRandom},
|
||||
};
|
||||
|
||||
@@ -61,23 +61,26 @@ where
|
||||
let scalars_mont = unsafe { &*(&ark_scalars[..] as *const _ as *const [F]) };
|
||||
let scalars_mont_h = HostOrDeviceSlice::on_host(scalars_mont.to_vec());
|
||||
|
||||
let config = get_default_ntt_config();
|
||||
let mut ntt_result = HostOrDeviceSlice::on_host(vec![F::zero(); test_size]);
|
||||
ntt(&scalars_mont_h, NTTDir::kForward, &config, &mut ntt_result).unwrap();
|
||||
assert_ne!(ntt_result.as_slice(), scalars_mont);
|
||||
let mut config = get_default_ntt_config();
|
||||
for alg in [NttAlgorithm::Radix2, NttAlgorithm::MixedRadix] {
|
||||
config.ntt_algorithm = alg;
|
||||
let mut ntt_result = HostOrDeviceSlice::on_host(vec![F::zero(); test_size]);
|
||||
ntt(&scalars_mont_h, NTTDir::kForward, &config, &mut ntt_result).unwrap();
|
||||
assert_ne!(ntt_result.as_slice(), scalars_mont);
|
||||
|
||||
let mut ark_ntt_result = ark_scalars.clone();
|
||||
ark_domain.fft_in_place(&mut ark_ntt_result);
|
||||
assert_ne!(ark_ntt_result, ark_scalars);
|
||||
let mut ark_ntt_result = ark_scalars.clone();
|
||||
ark_domain.fft_in_place(&mut ark_ntt_result);
|
||||
assert_ne!(ark_ntt_result, ark_scalars);
|
||||
|
||||
let ntt_result_as_ark =
|
||||
unsafe { &*(ntt_result.as_slice() as *const _ as *const [<F as ArkConvertible>::ArkEquivalent]) };
|
||||
assert_eq!(ark_ntt_result, ntt_result_as_ark);
|
||||
let ntt_result_as_ark =
|
||||
unsafe { &*(ntt_result.as_slice() as *const _ as *const [<F as ArkConvertible>::ArkEquivalent]) };
|
||||
assert_eq!(ark_ntt_result, ntt_result_as_ark);
|
||||
|
||||
let mut intt_result = HostOrDeviceSlice::on_host(vec![F::zero(); test_size]);
|
||||
ntt(&ntt_result, NTTDir::kInverse, &config, &mut intt_result).unwrap();
|
||||
let mut intt_result = HostOrDeviceSlice::on_host(vec![F::zero(); test_size]);
|
||||
ntt(&ntt_result, NTTDir::kInverse, &config, &mut intt_result).unwrap();
|
||||
|
||||
assert_eq!(intt_result.as_slice(), scalars_mont);
|
||||
assert_eq!(intt_result.as_slice(), scalars_mont);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -103,55 +106,58 @@ where
|
||||
.map(|v| v.to_ark())
|
||||
.collect::<Vec<F::ArkEquivalent>>();
|
||||
|
||||
let mut config = get_default_ntt_config();
|
||||
config.ordering = Ordering::kNR;
|
||||
let mut ntt_result_1 = HostOrDeviceSlice::on_host(vec![F::zero(); small_size]);
|
||||
let mut ntt_result_2 = HostOrDeviceSlice::on_host(vec![F::zero(); small_size]);
|
||||
ntt(&scalars_h, NTTDir::kForward, &config, &mut ntt_result_1).unwrap();
|
||||
assert_ne!(*ntt_result_1.as_slice(), scalars);
|
||||
config.coset_gen = F::from_ark(test_size_rou);
|
||||
ntt(&scalars_h, NTTDir::kForward, &config, &mut ntt_result_2).unwrap();
|
||||
let mut ntt_large_result = HostOrDeviceSlice::on_host(vec![F::zero(); test_size]);
|
||||
// back to non-coset NTT
|
||||
config.coset_gen = F::one();
|
||||
scalars.resize(test_size, F::zero());
|
||||
ntt(
|
||||
&HostOrDeviceSlice::on_host(scalars.clone()),
|
||||
NTTDir::kForward,
|
||||
&config,
|
||||
&mut ntt_large_result,
|
||||
)
|
||||
.unwrap();
|
||||
assert_eq!(*ntt_result_1.as_slice(), ntt_large_result.as_slice()[..small_size]);
|
||||
assert_eq!(*ntt_result_2.as_slice(), ntt_large_result.as_slice()[small_size..]);
|
||||
for alg in [NttAlgorithm::Radix2, NttAlgorithm::MixedRadix] {
|
||||
let mut config = get_default_ntt_config();
|
||||
config.ordering = Ordering::kNR;
|
||||
config.ntt_algorithm = alg;
|
||||
let mut ntt_result_1 = HostOrDeviceSlice::on_host(vec![F::zero(); small_size]);
|
||||
let mut ntt_result_2 = HostOrDeviceSlice::on_host(vec![F::zero(); small_size]);
|
||||
ntt(&scalars_h, NTTDir::kForward, &config, &mut ntt_result_1).unwrap();
|
||||
assert_ne!(*ntt_result_1.as_slice(), scalars);
|
||||
config.coset_gen = F::from_ark(test_size_rou);
|
||||
ntt(&scalars_h, NTTDir::kForward, &config, &mut ntt_result_2).unwrap();
|
||||
let mut ntt_large_result = HostOrDeviceSlice::on_host(vec![F::zero(); test_size]);
|
||||
// back to non-coset NTT
|
||||
config.coset_gen = F::one();
|
||||
scalars.resize(test_size, F::zero());
|
||||
ntt(
|
||||
&HostOrDeviceSlice::on_host(scalars.clone()),
|
||||
NTTDir::kForward,
|
||||
&config,
|
||||
&mut ntt_large_result,
|
||||
)
|
||||
.unwrap();
|
||||
assert_eq!(*ntt_result_1.as_slice(), ntt_large_result.as_slice()[..small_size]);
|
||||
assert_eq!(*ntt_result_2.as_slice(), ntt_large_result.as_slice()[small_size..]);
|
||||
|
||||
let mut ark_large_scalars = ark_scalars.clone();
|
||||
ark_small_domain.fft_in_place(&mut ark_scalars);
|
||||
let ntt_result_as_ark = ntt_large_result
|
||||
.as_slice()
|
||||
.iter()
|
||||
.map(|p| p.to_ark())
|
||||
.collect::<Vec<F::ArkEquivalent>>();
|
||||
assert_eq!(
|
||||
ark_scalars[..small_size],
|
||||
list_to_reverse_bit_order(&ntt_result_as_ark[small_size..])
|
||||
);
|
||||
ark_large_domain.fft_in_place(&mut ark_large_scalars);
|
||||
assert_eq!(ark_large_scalars, list_to_reverse_bit_order(&ntt_result_as_ark));
|
||||
let mut ark_large_scalars = ark_scalars.clone();
|
||||
ark_small_domain.fft_in_place(&mut ark_scalars);
|
||||
let ntt_result_as_ark = ntt_large_result
|
||||
.as_slice()
|
||||
.iter()
|
||||
.map(|p| p.to_ark())
|
||||
.collect::<Vec<F::ArkEquivalent>>();
|
||||
assert_eq!(
|
||||
ark_scalars[..small_size],
|
||||
list_to_reverse_bit_order(&ntt_result_as_ark[small_size..])
|
||||
);
|
||||
ark_large_domain.fft_in_place(&mut ark_large_scalars);
|
||||
assert_eq!(ark_large_scalars, list_to_reverse_bit_order(&ntt_result_as_ark));
|
||||
|
||||
config.coset_gen = F::from_ark(test_size_rou);
|
||||
config.ordering = Ordering::kRN;
|
||||
let mut intt_result = HostOrDeviceSlice::on_host(vec![F::zero(); small_size]);
|
||||
ntt(&ntt_result_2, NTTDir::kInverse, &config, &mut intt_result).unwrap();
|
||||
assert_eq!(*intt_result.as_slice(), scalars[..small_size]);
|
||||
config.coset_gen = F::from_ark(test_size_rou);
|
||||
config.ordering = Ordering::kRN;
|
||||
let mut intt_result = HostOrDeviceSlice::on_host(vec![F::zero(); small_size]);
|
||||
ntt(&ntt_result_2, NTTDir::kInverse, &config, &mut intt_result).unwrap();
|
||||
assert_eq!(*intt_result.as_slice(), scalars[..small_size]);
|
||||
|
||||
ark_small_domain.ifft_in_place(&mut ark_scalars);
|
||||
let intt_result_as_ark = intt_result
|
||||
.as_slice()
|
||||
.iter()
|
||||
.map(|p| p.to_ark())
|
||||
.collect::<Vec<F::ArkEquivalent>>();
|
||||
assert_eq!(ark_scalars[..small_size], intt_result_as_ark);
|
||||
ark_small_domain.ifft_in_place(&mut ark_scalars);
|
||||
let intt_result_as_ark = intt_result
|
||||
.as_slice()
|
||||
.iter()
|
||||
.map(|p| p.to_ark())
|
||||
.collect::<Vec<F::ArkEquivalent>>();
|
||||
assert_eq!(ark_scalars[..small_size], intt_result_as_ark);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -183,31 +189,34 @@ where
|
||||
.collect::<Vec<F::ArkEquivalent>>();
|
||||
|
||||
let mut config = get_default_ntt_config();
|
||||
config.ordering = Ordering::kNR;
|
||||
config.coset_gen = F::from_ark(coset_gen);
|
||||
let mut ntt_result = HostOrDeviceSlice::on_host(vec![F::zero(); test_size]);
|
||||
ntt(&scalars, NTTDir::kForward, &config, &mut ntt_result).unwrap();
|
||||
assert_ne!(scalars.as_slice(), ntt_result.as_slice());
|
||||
for alg in [NttAlgorithm::Radix2, NttAlgorithm::MixedRadix] {
|
||||
config.ordering = Ordering::kNR;
|
||||
config.ntt_algorithm = alg;
|
||||
let mut ntt_result = HostOrDeviceSlice::on_host(vec![F::zero(); test_size]);
|
||||
ntt(&scalars, NTTDir::kForward, &config, &mut ntt_result).unwrap();
|
||||
assert_ne!(scalars.as_slice(), ntt_result.as_slice());
|
||||
|
||||
let ark_scalars_copy = ark_scalars.clone();
|
||||
ark_domain.fft_in_place(&mut ark_scalars);
|
||||
let ntt_result_as_ark = ntt_result
|
||||
.as_slice()
|
||||
.iter()
|
||||
.map(|p| p.to_ark())
|
||||
.collect::<Vec<F::ArkEquivalent>>();
|
||||
assert_eq!(ark_scalars, list_to_reverse_bit_order(&ntt_result_as_ark));
|
||||
ark_domain.ifft_in_place(&mut ark_scalars);
|
||||
assert_eq!(ark_scalars, ark_scalars_copy);
|
||||
let ark_scalars_copy = ark_scalars.clone();
|
||||
ark_domain.fft_in_place(&mut ark_scalars);
|
||||
let ntt_result_as_ark = ntt_result
|
||||
.as_slice()
|
||||
.iter()
|
||||
.map(|p| p.to_ark())
|
||||
.collect::<Vec<F::ArkEquivalent>>();
|
||||
assert_eq!(ark_scalars, list_to_reverse_bit_order(&ntt_result_as_ark));
|
||||
ark_domain.ifft_in_place(&mut ark_scalars);
|
||||
assert_eq!(ark_scalars, ark_scalars_copy);
|
||||
|
||||
config.ordering = Ordering::kRN;
|
||||
ntt(&ntt_result, NTTDir::kInverse, &config, &mut scalars).unwrap();
|
||||
let ntt_result_as_ark = scalars
|
||||
.as_slice()
|
||||
.iter()
|
||||
.map(|p| p.to_ark())
|
||||
.collect::<Vec<F::ArkEquivalent>>();
|
||||
assert_eq!(ark_scalars, ntt_result_as_ark);
|
||||
config.ordering = Ordering::kRN;
|
||||
ntt(&ntt_result, NTTDir::kInverse, &config, &mut scalars).unwrap();
|
||||
let ntt_result_as_ark = scalars
|
||||
.as_slice()
|
||||
.iter()
|
||||
.map(|p| p.to_ark())
|
||||
.collect::<Vec<F::ArkEquivalent>>();
|
||||
assert_eq!(ark_scalars, ntt_result_as_ark);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -216,7 +225,7 @@ pub fn check_ntt_batch<F: FieldImpl>()
|
||||
where
|
||||
<F as FieldImpl>::Config: NTT<F> + GenerateRandom<F>,
|
||||
{
|
||||
let test_sizes = [1 << 4, 1 << 14];
|
||||
let test_sizes = [1 << 4, 1 << 12];
|
||||
let batch_sizes = [1, 1 << 4, 100];
|
||||
for test_size in test_sizes {
|
||||
let coset_generators = [F::one(), F::Config::generate_random(1)[0]];
|
||||
@@ -226,26 +235,37 @@ where
|
||||
|
||||
for coset_gen in coset_generators {
|
||||
for is_inverse in [NTTDir::kInverse, NTTDir::kForward] {
|
||||
for ordering in [Ordering::kNN, Ordering::kNR, Ordering::kRN, Ordering::kRR] {
|
||||
for ordering in [
|
||||
Ordering::kNN,
|
||||
Ordering::kNR,
|
||||
Ordering::kRN,
|
||||
Ordering::kRR,
|
||||
Ordering::kNM,
|
||||
Ordering::kMN,
|
||||
] {
|
||||
config.coset_gen = coset_gen;
|
||||
config.ordering = ordering;
|
||||
config.batch_size = batch_size as i32;
|
||||
let mut batch_ntt_result = HostOrDeviceSlice::on_host(vec![F::zero(); batch_size * test_size]);
|
||||
ntt(&scalars, is_inverse, &config, &mut batch_ntt_result).unwrap();
|
||||
config.batch_size = 1;
|
||||
let mut one_ntt_result = HostOrDeviceSlice::on_host(vec![F::one(); test_size]);
|
||||
for i in 0..batch_size {
|
||||
ntt(
|
||||
&HostOrDeviceSlice::on_host(scalars[i * test_size..(i + 1) * test_size].to_vec()),
|
||||
is_inverse,
|
||||
&config,
|
||||
&mut one_ntt_result,
|
||||
)
|
||||
.unwrap();
|
||||
assert_eq!(
|
||||
batch_ntt_result[i * test_size..(i + 1) * test_size],
|
||||
*one_ntt_result.as_slice()
|
||||
);
|
||||
for alg in [NttAlgorithm::Radix2, NttAlgorithm::MixedRadix] {
|
||||
config.batch_size = batch_size as i32;
|
||||
config.ntt_algorithm = alg;
|
||||
let mut batch_ntt_result =
|
||||
HostOrDeviceSlice::on_host(vec![F::zero(); batch_size * test_size]);
|
||||
ntt(&scalars, is_inverse, &config, &mut batch_ntt_result).unwrap();
|
||||
config.batch_size = 1;
|
||||
let mut one_ntt_result = HostOrDeviceSlice::on_host(vec![F::one(); test_size]);
|
||||
for i in 0..batch_size {
|
||||
ntt(
|
||||
&HostOrDeviceSlice::on_host(scalars[i * test_size..(i + 1) * test_size].to_vec()),
|
||||
is_inverse,
|
||||
&config,
|
||||
&mut one_ntt_result,
|
||||
)
|
||||
.unwrap();
|
||||
assert_eq!(
|
||||
batch_ntt_result[i * test_size..(i + 1) * test_size],
|
||||
*one_ntt_result.as_slice()
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -286,22 +306,25 @@ where
|
||||
config
|
||||
.ctx
|
||||
.stream = &stream;
|
||||
ntt(&scalars_d, NTTDir::kForward, &config, &mut ntt_out_d).unwrap();
|
||||
ntt(&ntt_out_d, NTTDir::kInverse, &config, &mut scalars_d).unwrap();
|
||||
let mut intt_result_h = vec![F::zero(); test_size * batch_size];
|
||||
scalars_d
|
||||
.copy_to_host_async(&mut intt_result_h, &stream)
|
||||
.unwrap();
|
||||
stream
|
||||
.synchronize()
|
||||
.unwrap();
|
||||
assert_eq!(scalars_h, intt_result_h);
|
||||
if coset_gen == F::one() {
|
||||
let mut ntt_result_h = vec![F::zero(); test_size * batch_size];
|
||||
ntt_out_d
|
||||
.copy_to_host(&mut ntt_result_h)
|
||||
for alg in [NttAlgorithm::Radix2, NttAlgorithm::MixedRadix] {
|
||||
config.ntt_algorithm = alg;
|
||||
ntt(&scalars_d, NTTDir::kForward, &config, &mut ntt_out_d).unwrap();
|
||||
ntt(&ntt_out_d, NTTDir::kInverse, &config, &mut scalars_d).unwrap();
|
||||
let mut intt_result_h = vec![F::zero(); test_size * batch_size];
|
||||
scalars_d
|
||||
.copy_to_host_async(&mut intt_result_h, &stream)
|
||||
.unwrap();
|
||||
assert_eq!(sum_of_coeffs, ntt_result_h[0].to_ark());
|
||||
stream
|
||||
.synchronize()
|
||||
.unwrap();
|
||||
assert_eq!(scalars_h, intt_result_h);
|
||||
if coset_gen == F::one() {
|
||||
let mut ntt_result_h = vec![F::zero(); test_size * batch_size];
|
||||
ntt_out_d
|
||||
.copy_to_host(&mut ntt_result_h)
|
||||
.unwrap();
|
||||
assert_eq!(sum_of_coeffs, ntt_result_h[0].to_ark());
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user