mirror of
https://github.com/pseXperiments/icicle.git
synced 2026-01-09 15:37:58 -05:00
merge NTT part
This commit is contained in:
@@ -46,17 +46,10 @@ int main(int argc, char** argv)
|
||||
|
||||
// init domain
|
||||
auto ntt_config = ntt::DefaultNTTConfig<test_scalar>();
|
||||
<<<<<<< HEAD
|
||||
const bool is_radix2_alg = (argc > 1) ? atoi(argv[1]) : false;
|
||||
ntt_config.ntt_algorithm = is_radix2_alg ? ntt::NttAlgorithm::Radix2 : ntt::NttAlgorithm::MixedRadix;
|
||||
|
||||
const char* ntt_alg_str = is_radix2_alg ? "Radix-2" : "Mixed-Radix";
|
||||
=======
|
||||
ntt_config.ordering = ntt::Ordering::kNN; // TODO: use NR for forward and RN for backward
|
||||
ntt_config.is_force_radix2 = (argc > 1) ? atoi(argv[1]) : false;
|
||||
|
||||
const char* ntt_alg_str = ntt_config.is_force_radix2 ? "Radix-2" : "Mixed-Radix";
|
||||
>>>>>>> main
|
||||
std::cout << "Polynomial multiplication with " << ntt_alg_str << " NTT: ";
|
||||
|
||||
CHK_IF_RETURN(cudaEventCreate(&start));
|
||||
@@ -85,10 +78,7 @@ int main(int argc, char** argv)
|
||||
// (3) NTT for A,B from cpu to gpu
|
||||
ntt_config.are_inputs_on_device = false;
|
||||
ntt_config.are_outputs_on_device = true;
|
||||
<<<<<<< HEAD
|
||||
ntt_config.ordering = ntt::Ordering::kNM;
|
||||
=======
|
||||
>>>>>>> main
|
||||
CHK_IF_RETURN(ntt::NTT(CpuA.get(), NTT_SIZE, ntt::NTTDir::kForward, ntt_config, GpuA));
|
||||
CHK_IF_RETURN(ntt::NTT(CpuB.get(), NTT_SIZE, ntt::NTTDir::kForward, ntt_config, GpuB));
|
||||
|
||||
@@ -100,10 +90,7 @@ int main(int argc, char** argv)
|
||||
// (5) INTT (in place)
|
||||
ntt_config.are_inputs_on_device = true;
|
||||
ntt_config.are_outputs_on_device = true;
|
||||
<<<<<<< HEAD
|
||||
ntt_config.ordering = ntt::Ordering::kMN;
|
||||
=======
|
||||
>>>>>>> main
|
||||
CHK_IF_RETURN(ntt::NTT(MulGpu, NTT_SIZE, ntt::NTTDir::kInverse, ntt_config, MulGpu));
|
||||
|
||||
CHK_IF_RETURN(cudaFreeAsync(GpuA, ntt_config.ctx.stream));
|
||||
|
||||
@@ -6,11 +6,7 @@
|
||||
|
||||
namespace ntt {
|
||||
|
||||
<<<<<<< HEAD
|
||||
static inline __device__ uint32_t dig_rev(uint32_t num, uint32_t log_size, bool dit)
|
||||
=======
|
||||
static __device__ uint32_t dig_rev(uint32_t num, uint32_t log_size, bool dit)
|
||||
>>>>>>> main
|
||||
{
|
||||
uint32_t rev_num = 0, temp, dig_len;
|
||||
if (dit) {
|
||||
@@ -33,7 +29,6 @@ namespace ntt {
|
||||
return rev_num;
|
||||
}
|
||||
|
||||
<<<<<<< HEAD
|
||||
static inline __device__ uint32_t bit_rev(uint32_t num, uint32_t log_size) { return __brev(num) >> (32 - log_size); }
|
||||
|
||||
enum eRevType { None, RevToMixedRev, MixedRevToRev, NaturalToMixedRev, NaturalToRev, MixedRevToNatural };
|
||||
@@ -64,19 +59,10 @@ namespace ntt {
|
||||
E* arr, uint32_t log_size, bool dit, eRevType rev_type, bool is_normalize, S inverse_N)
|
||||
{
|
||||
// launch N threads (per batch element)
|
||||
=======
|
||||
// Note: the following reorder kernels are fused with normalization for INTT
|
||||
template <typename E, typename S, uint32_t MAX_GROUP_SIZE = 80>
|
||||
static __global__ void
|
||||
reorder_digits_inplace_kernel(E* arr, uint32_t log_size, bool dit, bool is_normalize, S inverse_N)
|
||||
{
|
||||
// launch N threads
|
||||
>>>>>>> main
|
||||
// each thread starts from one index and calculates the corresponding group
|
||||
// if its index is the smallest number in the group -> do the memory transformation
|
||||
// else --> do nothing
|
||||
|
||||
<<<<<<< HEAD
|
||||
const uint32_t size = 1 << log_size;
|
||||
const uint32_t tid = blockDim.x * blockIdx.x + threadIdx.x;
|
||||
const uint32_t idx = tid % size;
|
||||
@@ -94,32 +80,12 @@ namespace ntt {
|
||||
group[i++] = next_element + size * batch_idx;
|
||||
}
|
||||
|
||||
=======
|
||||
const uint32_t idx = blockDim.x * blockIdx.x + threadIdx.x;
|
||||
uint32_t next_element = idx;
|
||||
uint32_t group[MAX_GROUP_SIZE];
|
||||
group[0] = idx;
|
||||
|
||||
uint32_t i = 1;
|
||||
for (; i < MAX_GROUP_SIZE;) {
|
||||
next_element = dig_rev(next_element, log_size, dit);
|
||||
if (next_element < idx) return; // not handling this group
|
||||
if (next_element == idx) break; // calculated whole group
|
||||
group[i++] = next_element;
|
||||
}
|
||||
|
||||
if (i == 1) { // single element in group --> nothing to do (except maybe normalize for INTT)
|
||||
if (is_normalize) { arr[idx] = arr[idx] * inverse_N; }
|
||||
return;
|
||||
}
|
||||
>>>>>>> main
|
||||
--i;
|
||||
// reaching here means I am handling this group
|
||||
const E last_element_in_group = arr[group[i]];
|
||||
for (; i > 0; --i) {
|
||||
arr[group[i]] = is_normalize ? (arr[group[i - 1]] * inverse_N) : arr[group[i - 1]];
|
||||
}
|
||||
<<<<<<< HEAD
|
||||
arr[group[0]] = is_normalize ? (last_element_in_group * inverse_N) : last_element_in_group;
|
||||
}
|
||||
|
||||
@@ -131,23 +97,10 @@ namespace ntt {
|
||||
uint32_t rd = tid;
|
||||
uint32_t wr =
|
||||
((tid >> log_size) << log_size) + generalized_rev(tid & ((1 << log_size) - 1), log_size, dit, rev_type);
|
||||
=======
|
||||
arr[idx] = is_normalize ? (last_element_in_group * inverse_N) : last_element_in_group;
|
||||
}
|
||||
|
||||
template <typename E, typename S>
|
||||
__launch_bounds__(64) __global__
|
||||
void reorder_digits_kernel(E* arr, E* arr_reordered, uint32_t log_size, bool dit, bool is_normalize, S inverse_N)
|
||||
{
|
||||
uint32_t tid = blockDim.x * blockIdx.x + threadIdx.x;
|
||||
uint32_t rd = tid;
|
||||
uint32_t wr = dig_rev(tid, log_size, dit);
|
||||
>>>>>>> main
|
||||
arr_reordered[wr] = is_normalize ? arr[rd] * inverse_N : arr[rd];
|
||||
}
|
||||
|
||||
template <typename E, typename S>
|
||||
<<<<<<< HEAD
|
||||
static __global__ void batch_elementwise_mul_with_reorder(
|
||||
E* in_vec,
|
||||
int n_elements,
|
||||
@@ -168,8 +121,6 @@ namespace ntt {
|
||||
}
|
||||
|
||||
template <typename E, typename S>
|
||||
=======
|
||||
>>>>>>> main
|
||||
__launch_bounds__(64) __global__ void ntt64(
|
||||
E* in,
|
||||
E* out,
|
||||
@@ -178,10 +129,7 @@ namespace ntt {
|
||||
S* basic_twiddles,
|
||||
uint32_t log_size,
|
||||
uint32_t tw_log_size,
|
||||
<<<<<<< HEAD
|
||||
uint32_t nof_ntt_blocks,
|
||||
=======
|
||||
>>>>>>> main
|
||||
uint32_t data_stride,
|
||||
uint32_t log_data_stride,
|
||||
uint32_t twiddle_stride,
|
||||
@@ -200,11 +148,8 @@ namespace ntt {
|
||||
s_meta.ntt_block_id = (blockIdx.x << 3) + (strided ? (threadIdx.x & 0x7) : (threadIdx.x >> 3));
|
||||
s_meta.ntt_inp_id = strided ? (threadIdx.x >> 3) : (threadIdx.x & 0x7);
|
||||
|
||||
<<<<<<< HEAD
|
||||
if (s_meta.ntt_block_id >= nof_ntt_blocks) return;
|
||||
|
||||
=======
|
||||
>>>>>>> main
|
||||
engine.loadBasicTwiddles(basic_twiddles, inv);
|
||||
engine.loadGlobalData(in, data_stride, log_data_stride, log_size, strided, s_meta);
|
||||
if (twiddle_stride && dit) {
|
||||
@@ -242,10 +187,7 @@ namespace ntt {
|
||||
S* basic_twiddles,
|
||||
uint32_t log_size,
|
||||
uint32_t tw_log_size,
|
||||
<<<<<<< HEAD
|
||||
uint32_t nof_ntt_blocks,
|
||||
=======
|
||||
>>>>>>> main
|
||||
uint32_t data_stride,
|
||||
uint32_t log_data_stride,
|
||||
uint32_t twiddle_stride,
|
||||
@@ -265,11 +207,8 @@ namespace ntt {
|
||||
s_meta.ntt_block_id = (blockIdx.x << 4) + (strided ? (threadIdx.x & 0xf) : (threadIdx.x >> 2));
|
||||
s_meta.ntt_inp_id = strided ? (threadIdx.x >> 4) : (threadIdx.x & 0x3);
|
||||
|
||||
<<<<<<< HEAD
|
||||
if (s_meta.ntt_block_id >= nof_ntt_blocks) return;
|
||||
|
||||
=======
|
||||
>>>>>>> main
|
||||
engine.loadBasicTwiddles(basic_twiddles, inv);
|
||||
engine.loadGlobalData(in, data_stride, log_data_stride, log_size, strided, s_meta);
|
||||
engine.loadInternalTwiddles32(internal_twiddles, strided, inv);
|
||||
@@ -295,12 +234,8 @@ namespace ntt {
|
||||
S* internal_twiddles,
|
||||
S* basic_twiddles,
|
||||
uint32_t log_size,
|
||||
<<<<<<< HEAD
|
||||
uint32_t tw_log_size,
|
||||
uint32_t nof_ntt_blocks,
|
||||
=======
|
||||
int32_t tw_log_size,
|
||||
>>>>>>> main
|
||||
uint32_t data_stride,
|
||||
uint32_t log_data_stride,
|
||||
uint32_t twiddle_stride,
|
||||
@@ -320,11 +255,8 @@ namespace ntt {
|
||||
s_meta.ntt_block_id = (blockIdx.x << 4) + (strided ? (threadIdx.x & 0xf) : (threadIdx.x >> 2));
|
||||
s_meta.ntt_inp_id = strided ? (threadIdx.x >> 4) : (threadIdx.x & 0x3);
|
||||
|
||||
<<<<<<< HEAD
|
||||
if (s_meta.ntt_block_id >= nof_ntt_blocks) return;
|
||||
|
||||
=======
|
||||
>>>>>>> main
|
||||
engine.loadBasicTwiddles(basic_twiddles, inv);
|
||||
engine.loadGlobalData32(in, data_stride, log_data_stride, log_size, strided, s_meta);
|
||||
if (twiddle_stride) {
|
||||
@@ -351,10 +283,7 @@ namespace ntt {
|
||||
S* basic_twiddles,
|
||||
uint32_t log_size,
|
||||
uint32_t tw_log_size,
|
||||
<<<<<<< HEAD
|
||||
uint32_t nof_ntt_blocks,
|
||||
=======
|
||||
>>>>>>> main
|
||||
uint32_t data_stride,
|
||||
uint32_t log_data_stride,
|
||||
uint32_t twiddle_stride,
|
||||
@@ -374,11 +303,8 @@ namespace ntt {
|
||||
s_meta.ntt_block_id = (blockIdx.x << 5) + (strided ? (threadIdx.x & 0x1f) : (threadIdx.x >> 1));
|
||||
s_meta.ntt_inp_id = strided ? (threadIdx.x >> 5) : (threadIdx.x & 0x1);
|
||||
|
||||
<<<<<<< HEAD
|
||||
if (s_meta.ntt_block_id >= nof_ntt_blocks) return;
|
||||
|
||||
=======
|
||||
>>>>>>> main
|
||||
engine.loadBasicTwiddles(basic_twiddles, inv);
|
||||
engine.loadGlobalData(in, data_stride, log_data_stride, log_size, strided, s_meta);
|
||||
engine.loadInternalTwiddles16(internal_twiddles, strided, inv);
|
||||
@@ -405,10 +331,7 @@ namespace ntt {
|
||||
S* basic_twiddles,
|
||||
uint32_t log_size,
|
||||
uint32_t tw_log_size,
|
||||
<<<<<<< HEAD
|
||||
uint32_t nof_ntt_blocks,
|
||||
=======
|
||||
>>>>>>> main
|
||||
uint32_t data_stride,
|
||||
uint32_t log_data_stride,
|
||||
uint32_t twiddle_stride,
|
||||
@@ -428,11 +351,8 @@ namespace ntt {
|
||||
s_meta.ntt_block_id = (blockIdx.x << 5) + (strided ? (threadIdx.x & 0x1f) : (threadIdx.x >> 1));
|
||||
s_meta.ntt_inp_id = strided ? (threadIdx.x >> 5) : (threadIdx.x & 0x1);
|
||||
|
||||
<<<<<<< HEAD
|
||||
if (s_meta.ntt_block_id >= nof_ntt_blocks) return;
|
||||
|
||||
=======
|
||||
>>>>>>> main
|
||||
engine.loadBasicTwiddles(basic_twiddles, inv);
|
||||
engine.loadGlobalData16(in, data_stride, log_data_stride, log_size, strided, s_meta);
|
||||
if (twiddle_stride) {
|
||||
@@ -590,10 +510,7 @@ namespace ntt {
|
||||
S* basic_twiddles,
|
||||
uint32_t log_size,
|
||||
uint32_t tw_log_size,
|
||||
<<<<<<< HEAD
|
||||
uint32_t batch_size,
|
||||
=======
|
||||
>>>>>>> main
|
||||
bool inv,
|
||||
bool normalize,
|
||||
bool dit,
|
||||
@@ -606,7 +523,6 @@ namespace ntt {
|
||||
}
|
||||
|
||||
if (log_size == 4) {
|
||||
<<<<<<< HEAD
|
||||
const int NOF_THREADS = min(64, 2 * batch_size);
|
||||
const int NOF_BLOCKS = (2 * batch_size + NOF_THREADS - 1) / NOF_THREADS;
|
||||
|
||||
@@ -620,23 +536,10 @@ namespace ntt {
|
||||
false, 0, inv, dit);
|
||||
}
|
||||
if (normalize) normalize_kernel<<<batch_size, 16, 0, cuda_stream>>>(out, S::inv_log_size(4));
|
||||
=======
|
||||
if (dit) {
|
||||
ntt16dit<<<1, 2, 8 * 64 * sizeof(E), cuda_stream>>>(
|
||||
in, out, external_twiddles, internal_twiddles, basic_twiddles, log_size, tw_log_size, 1, 0, 0, false, 0, inv,
|
||||
dit);
|
||||
} else { // dif
|
||||
ntt16<<<1, 2, 8 * 64 * sizeof(E), cuda_stream>>>(
|
||||
in, out, external_twiddles, internal_twiddles, basic_twiddles, log_size, tw_log_size, 1, 0, 0, false, 0, inv,
|
||||
dit);
|
||||
}
|
||||
if (normalize) normalize_kernel<<<1, 16, 0, cuda_stream>>>(out, S::inv_log_size(4));
|
||||
>>>>>>> main
|
||||
return CHK_LAST();
|
||||
}
|
||||
|
||||
if (log_size == 5) {
|
||||
<<<<<<< HEAD
|
||||
const int NOF_THREADS = min(64, 4 * batch_size);
|
||||
const int NOF_BLOCKS = (4 * batch_size + NOF_THREADS - 1) / NOF_THREADS;
|
||||
if (dit) {
|
||||
@@ -649,40 +552,20 @@ namespace ntt {
|
||||
false, 0, inv, dit);
|
||||
}
|
||||
if (normalize) normalize_kernel<<<batch_size, 32, 0, cuda_stream>>>(out, S::inv_log_size(5));
|
||||
=======
|
||||
if (dit) {
|
||||
ntt32dit<<<1, 4, 8 * 64 * sizeof(E), cuda_stream>>>(
|
||||
in, out, external_twiddles, internal_twiddles, basic_twiddles, log_size, tw_log_size, 1, 0, 0, false, 0, inv,
|
||||
dit);
|
||||
} else { // dif
|
||||
ntt32<<<1, 4, 8 * 64 * sizeof(E), cuda_stream>>>(
|
||||
in, out, external_twiddles, internal_twiddles, basic_twiddles, log_size, tw_log_size, 1, 0, 0, false, 0, inv,
|
||||
dit);
|
||||
}
|
||||
if (normalize) normalize_kernel<<<1, 32, 0, cuda_stream>>>(out, S::inv_log_size(5));
|
||||
>>>>>>> main
|
||||
return CHK_LAST();
|
||||
}
|
||||
|
||||
if (log_size == 6) {
|
||||
<<<<<<< HEAD
|
||||
const int NOF_THREADS = min(64, 8 * batch_size);
|
||||
const int NOF_BLOCKS = (8 * batch_size + NOF_THREADS - 1) / NOF_THREADS;
|
||||
ntt64<<<NOF_BLOCKS, NOF_THREADS, 8 * 64 * sizeof(E), cuda_stream>>>(
|
||||
in, out, external_twiddles, internal_twiddles, basic_twiddles, log_size, tw_log_size, batch_size, 1, 0, 0,
|
||||
false, 0, inv, dit);
|
||||
if (normalize) normalize_kernel<<<batch_size, 64, 0, cuda_stream>>>(out, S::inv_log_size(6));
|
||||
=======
|
||||
ntt64<<<1, 8, 8 * 64 * sizeof(E), cuda_stream>>>(
|
||||
in, out, external_twiddles, internal_twiddles, basic_twiddles, log_size, tw_log_size, 1, 0, 0, false, 0, inv,
|
||||
dit);
|
||||
if (normalize) normalize_kernel<<<1, 64, 0, cuda_stream>>>(out, S::inv_log_size(6));
|
||||
>>>>>>> main
|
||||
return CHK_LAST();
|
||||
}
|
||||
|
||||
if (log_size == 8) {
|
||||
<<<<<<< HEAD
|
||||
const int NOF_THREADS = 64;
|
||||
const int NOF_BLOCKS = (32 * batch_size + NOF_THREADS - 1) / NOF_THREADS;
|
||||
if (dit) {
|
||||
@@ -701,33 +584,11 @@ namespace ntt {
|
||||
(1 << log_size - 4) * batch_size, 1, 0, 0, false, 0, inv, dit);
|
||||
}
|
||||
if (normalize) normalize_kernel<<<batch_size, 256, 0, cuda_stream>>>(out, S::inv_log_size(8));
|
||||
=======
|
||||
if (dit) {
|
||||
ntt16dit<<<1, 32, 8 * 64 * sizeof(E), cuda_stream>>>(
|
||||
in, out, external_twiddles, internal_twiddles, basic_twiddles, log_size, tw_log_size, 1, 0, 0, false, 0, inv,
|
||||
dit);
|
||||
ntt16dit<<<1, 64, 8 * 64 * sizeof(E), cuda_stream>>>(
|
||||
out, out, external_twiddles, internal_twiddles, basic_twiddles, log_size, tw_log_size, 16, 4, 16, true, 1,
|
||||
inv,
|
||||
dit); // we need threads 32+ although 16-31 are idle
|
||||
} else { // dif
|
||||
ntt16<<<1, 64, 8 * 64 * sizeof(E), cuda_stream>>>(
|
||||
in, out, external_twiddles, internal_twiddles, basic_twiddles, log_size, tw_log_size, 16, 4, 16, true, 1, inv,
|
||||
dit); // we need threads 32+ although 16-31 are idle
|
||||
ntt16<<<1, 32, 8 * 64 * sizeof(E), cuda_stream>>>(
|
||||
out, out, external_twiddles, internal_twiddles, basic_twiddles, log_size, tw_log_size, 1, 0, 0, false, 0, inv,
|
||||
dit);
|
||||
}
|
||||
if (normalize) normalize_kernel<<<1, 256, 0, cuda_stream>>>(out, S::inv_log_size(8));
|
||||
>>>>>>> main
|
||||
return CHK_LAST();
|
||||
}
|
||||
|
||||
// general case:
|
||||
<<<<<<< HEAD
|
||||
uint32_t nof_blocks = (1 << (log_size - 9)) * batch_size;
|
||||
=======
|
||||
>>>>>>> main
|
||||
if (dit) {
|
||||
for (int i = 0; i < 5; i++) {
|
||||
uint32_t stage_size = STAGE_SIZES_HOST[log_size][i];
|
||||
@@ -735,7 +596,6 @@ namespace ntt {
|
||||
for (int j = 0; j < i; j++)
|
||||
stride_log += STAGE_SIZES_HOST[log_size][j];
|
||||
if (stage_size == 6)
|
||||
<<<<<<< HEAD
|
||||
ntt64<<<nof_blocks, 64, 8 * 64 * sizeof(E), cuda_stream>>>(
|
||||
i ? out : in, out, external_twiddles, internal_twiddles, basic_twiddles, log_size, tw_log_size,
|
||||
(1 << log_size - 6) * batch_size, 1 << stride_log, stride_log, i ? (1 << stride_log) : 0, i, i, inv, dit);
|
||||
@@ -747,19 +607,6 @@ namespace ntt {
|
||||
ntt16dit<<<nof_blocks, 64, 8 * 64 * sizeof(E), cuda_stream>>>(
|
||||
i ? out : in, out, external_twiddles, internal_twiddles, basic_twiddles, log_size, tw_log_size,
|
||||
(1 << log_size - 4) * batch_size, 1 << stride_log, stride_log, i ? (1 << stride_log) : 0, i, i, inv, dit);
|
||||
=======
|
||||
ntt64<<<1 << (log_size - 9), 64, 8 * 64 * sizeof(E), cuda_stream>>>(
|
||||
i ? out : in, out, external_twiddles, internal_twiddles, basic_twiddles, log_size, tw_log_size,
|
||||
1 << stride_log, stride_log, i ? (1 << stride_log) : 0, i, i, inv, dit);
|
||||
else if (stage_size == 5)
|
||||
ntt32dit<<<1 << (log_size - 9), 64, 8 * 64 * sizeof(E), cuda_stream>>>(
|
||||
i ? out : in, out, external_twiddles, internal_twiddles, basic_twiddles, log_size, tw_log_size,
|
||||
1 << stride_log, stride_log, i ? (1 << stride_log) : 0, i, i, inv, dit);
|
||||
else if (stage_size == 4)
|
||||
ntt16dit<<<1 << (log_size - 9), 64, 8 * 64 * sizeof(E), cuda_stream>>>(
|
||||
i ? out : in, out, external_twiddles, internal_twiddles, basic_twiddles, log_size, tw_log_size,
|
||||
1 << stride_log, stride_log, i ? (1 << stride_log) : 0, i, i, inv, dit);
|
||||
>>>>>>> main
|
||||
}
|
||||
} else { // dif
|
||||
bool first_run = false, prev_stage = false;
|
||||
@@ -770,7 +617,6 @@ namespace ntt {
|
||||
stride_log += STAGE_SIZES_HOST[log_size][j];
|
||||
first_run = stage_size && !prev_stage;
|
||||
if (stage_size == 6)
|
||||
<<<<<<< HEAD
|
||||
ntt64<<<nof_blocks, 64, 8 * 64 * sizeof(E), cuda_stream>>>(
|
||||
first_run ? in : out, out, external_twiddles, internal_twiddles, basic_twiddles, log_size, tw_log_size,
|
||||
(1 << log_size - 6) * batch_size, 1 << stride_log, stride_log, i ? (1 << stride_log) : 0, i, i, inv, dit);
|
||||
@@ -787,23 +633,6 @@ namespace ntt {
|
||||
}
|
||||
if (normalize)
|
||||
normalize_kernel<<<(1 << (log_size - 8)) * batch_size, 256, 0, cuda_stream>>>(out, S::inv_log_size(log_size));
|
||||
=======
|
||||
ntt64<<<1 << (log_size - 9), 64, 8 * 64 * sizeof(E), cuda_stream>>>(
|
||||
first_run ? in : out, out, external_twiddles, internal_twiddles, basic_twiddles, log_size, tw_log_size,
|
||||
1 << stride_log, stride_log, i ? (1 << stride_log) : 0, i, i, inv, dit);
|
||||
else if (stage_size == 5)
|
||||
ntt32<<<1 << (log_size - 9), 64, 8 * 64 * sizeof(E), cuda_stream>>>(
|
||||
first_run ? in : out, out, external_twiddles, internal_twiddles, basic_twiddles, log_size, tw_log_size,
|
||||
1 << stride_log, stride_log, i ? (1 << stride_log) : 0, i, i, inv, dit);
|
||||
else if (stage_size == 4)
|
||||
ntt16<<<1 << (log_size - 9), 64, 8 * 64 * sizeof(E), cuda_stream>>>(
|
||||
first_run ? in : out, out, external_twiddles, internal_twiddles, basic_twiddles, log_size, tw_log_size,
|
||||
1 << stride_log, stride_log, i ? (1 << stride_log) : 0, i, i, inv, dit);
|
||||
prev_stage = stage_size;
|
||||
}
|
||||
}
|
||||
if (normalize) normalize_kernel<<<1 << (log_size - 8), 256, 0, cuda_stream>>>(out, S::inv_log_size(log_size));
|
||||
>>>>>>> main
|
||||
|
||||
return CHK_LAST();
|
||||
}
|
||||
@@ -817,21 +646,15 @@ namespace ntt {
|
||||
S* basic_twiddles,
|
||||
int ntt_size,
|
||||
int max_logn,
|
||||
<<<<<<< HEAD
|
||||
int batch_size,
|
||||
bool is_inverse,
|
||||
Ordering ordering,
|
||||
S* arbitrary_coset,
|
||||
int coset_gen_index,
|
||||
=======
|
||||
bool is_inverse,
|
||||
Ordering ordering,
|
||||
>>>>>>> main
|
||||
cudaStream_t cuda_stream)
|
||||
{
|
||||
CHK_INIT_IF_RETURN();
|
||||
|
||||
<<<<<<< HEAD
|
||||
const int logn = int(log2(ntt_size));
|
||||
const int NOF_BLOCKS = ((1 << logn) * batch_size + 64 - 1) / 64;
|
||||
const int NOF_THREADS = min(64, (1 << logn) * batch_size);
|
||||
@@ -890,38 +713,10 @@ namespace ntt {
|
||||
}
|
||||
is_normalize = false;
|
||||
d_input = d_output;
|
||||
=======
|
||||
// TODO: can we support all orderings? Note that reversal is generally digit reverse (generalization of bit reverse)
|
||||
if (ordering != Ordering::kNN) {
|
||||
throw IcicleError(IcicleError_t::InvalidArgument, "Mixed-Radix NTT supports NN ordering only");
|
||||
}
|
||||
|
||||
const int logn = int(log2(ntt_size));
|
||||
|
||||
const int NOF_BLOCKS = (1 << (max(logn, 6) - 6));
|
||||
const int NOF_THREADS = min(64, 1 << logn);
|
||||
|
||||
const bool reverse_input = ordering == Ordering::kNN;
|
||||
const bool is_dit = ordering == Ordering::kNN || ordering == Ordering::kRN;
|
||||
bool is_normalize = is_inverse;
|
||||
|
||||
if (reverse_input) {
|
||||
// Note: fusing reorder with normalize for INTT
|
||||
const bool is_reverse_in_place = (d_input == d_output);
|
||||
if (is_reverse_in_place) {
|
||||
reorder_digits_inplace_kernel<<<NOF_BLOCKS, NOF_THREADS, 0, cuda_stream>>>(
|
||||
d_output, logn, is_dit, is_normalize, S::inv_log_size(logn));
|
||||
} else {
|
||||
reorder_digits_kernel<<<NOF_BLOCKS, NOF_THREADS, 0, cuda_stream>>>(
|
||||
d_input, d_output, logn, is_dit, is_normalize, S::inv_log_size(logn));
|
||||
}
|
||||
is_normalize = false;
|
||||
>>>>>>> main
|
||||
}
|
||||
|
||||
// inplace ntt
|
||||
CHK_IF_RETURN(large_ntt(
|
||||
<<<<<<< HEAD
|
||||
d_input, d_output, external_twiddles, internal_twiddles, basic_twiddles, logn, max_logn, batch_size, is_inverse,
|
||||
(is_normalize && reverse_output == eRevType::None), dit, cuda_stream));
|
||||
|
||||
@@ -935,10 +730,6 @@ namespace ntt {
|
||||
d_output, ntt_size, batch_size, arbitrary_coset ? arbitrary_coset : external_twiddles + n_twiddles,
|
||||
arbitrary_coset ? 1 : -coset_gen_index, n_twiddles, logn, reverse_coset, dit, d_output);
|
||||
}
|
||||
=======
|
||||
d_output, d_output, external_twiddles, internal_twiddles, basic_twiddles, logn, max_logn, is_inverse,
|
||||
is_normalize, is_dit, cuda_stream));
|
||||
>>>>>>> main
|
||||
|
||||
return CHK_LAST();
|
||||
}
|
||||
@@ -960,16 +751,11 @@ namespace ntt {
|
||||
curve_config::scalar_t* basic_twiddles,
|
||||
int ntt_size,
|
||||
int max_logn,
|
||||
<<<<<<< HEAD
|
||||
int batch_size,
|
||||
bool is_inverse,
|
||||
Ordering ordering,
|
||||
curve_config::scalar_t* arbitrary_coset,
|
||||
int coset_gen_index,
|
||||
=======
|
||||
bool is_inverse,
|
||||
Ordering ordering,
|
||||
>>>>>>> main
|
||||
cudaStream_t cuda_stream);
|
||||
|
||||
} // namespace ntt
|
||||
|
||||
@@ -8,11 +8,8 @@
|
||||
#include "utils/utils_kernels.cuh"
|
||||
#include "utils/utils.h"
|
||||
#include "appUtils/ntt/ntt_impl.cuh"
|
||||
<<<<<<< HEAD
|
||||
|
||||
#include <mutex>
|
||||
=======
|
||||
>>>>>>> main
|
||||
|
||||
namespace ntt {
|
||||
|
||||
@@ -366,7 +363,6 @@ namespace ntt {
|
||||
template <typename S>
|
||||
class Domain
|
||||
{
|
||||
<<<<<<< HEAD
|
||||
// Mutex for protecting access to the domain/device container array
|
||||
static inline std::mutex device_domain_mutex;
|
||||
// The domain-per-device container - assumption is InitDomain is called once per device per program.
|
||||
@@ -378,37 +374,21 @@ namespace ntt {
|
||||
|
||||
S* internal_twiddles = nullptr; // required by mixed-radix NTT
|
||||
S* basic_twiddles = nullptr; // required by mixed-radix NTT
|
||||
=======
|
||||
static inline int max_size = 0;
|
||||
static inline int max_log_size = 0;
|
||||
static inline S* twiddles = nullptr;
|
||||
static inline std::unordered_map<S, int> coset_index = {};
|
||||
|
||||
static inline S* internal_twiddles = nullptr; // required by mixed-radix NTT
|
||||
static inline S* basic_twiddles = nullptr; // required by mixed-radix NTT
|
||||
>>>>>>> main
|
||||
|
||||
public:
|
||||
template <typename U>
|
||||
friend cudaError_t InitDomain<U>(U primitive_root, device_context::DeviceContext& ctx);
|
||||
|
||||
<<<<<<< HEAD
|
||||
cudaError_t ReleaseDomain(device_context::DeviceContext& ctx);
|
||||
=======
|
||||
static cudaError_t ReleaseDomain(device_context::DeviceContext& ctx);
|
||||
>>>>>>> main
|
||||
|
||||
template <typename U, typename E>
|
||||
friend cudaError_t NTT<U, E>(E* input, int size, NTTDir dir, NTTConfig<U>& config, E* output);
|
||||
};
|
||||
|
||||
template <typename S>
|
||||
<<<<<<< HEAD
|
||||
static inline Domain<S> domains_for_devices[device_context::MAX_DEVICES] = {};
|
||||
|
||||
template <typename S>
|
||||
=======
|
||||
>>>>>>> main
|
||||
cudaError_t InitDomain(S primitive_root, device_context::DeviceContext& ctx)
|
||||
{
|
||||
CHK_INIT_IF_RETURN();
|
||||
@@ -416,7 +396,6 @@ namespace ntt {
|
||||
Domain<S>& domain = domains_for_devices<S>[ctx.device_id];
|
||||
|
||||
// only generate twiddles if they haven't been generated yet
|
||||
<<<<<<< HEAD
|
||||
// please note that this offers just basic thread-safety,
|
||||
// it's assumed a singleton (non-enforced) that is supposed
|
||||
// to be initialized once per device per program lifetime
|
||||
@@ -426,18 +405,12 @@ namespace ntt {
|
||||
// double check locking
|
||||
if (domain.twiddles) return CHK_LAST(); // another thread is already initializing the domain
|
||||
|
||||
=======
|
||||
// please note that this is not thread-safe at all,
|
||||
// but it's a singleton that is supposed to be initialized once per program lifetime
|
||||
if (!Domain<S>::twiddles) {
|
||||
>>>>>>> main
|
||||
bool found_logn = false;
|
||||
S omega = primitive_root;
|
||||
unsigned omegas_count = S::get_omegas_count();
|
||||
for (int i = 0; i < omegas_count; i++) {
|
||||
omega = S::sqr(omega);
|
||||
if (!found_logn) {
|
||||
<<<<<<< HEAD
|
||||
++domain.max_log_size;
|
||||
found_logn = omega == S::one();
|
||||
if (found_logn) break;
|
||||
@@ -457,33 +430,12 @@ namespace ntt {
|
||||
CHK_IF_RETURN(generate_external_twiddles_generic(
|
||||
primitive_root, domain.twiddles, domain.internal_twiddles, domain.basic_twiddles, domain.max_log_size,
|
||||
ctx.stream));
|
||||
=======
|
||||
++Domain<S>::max_log_size;
|
||||
found_logn = omega == S::one();
|
||||
if (found_logn) break;
|
||||
}
|
||||
}
|
||||
Domain<S>::max_size = (int)pow(2, Domain<S>::max_log_size);
|
||||
if (omega != S::one()) {
|
||||
throw IcicleError(
|
||||
IcicleError_t::InvalidArgument, "Primitive root provided to the InitDomain function is not in the subgroup");
|
||||
}
|
||||
|
||||
// allocate and calculate twiddles on GPU
|
||||
// Note: radix-2 INTT needs ONE in last element (in addition to first element), therefore have n+1 elements
|
||||
// Managed allocation allows host to read the elements (logn) without copying all (n) TFs back to host
|
||||
CHK_IF_RETURN(cudaMallocManaged(&Domain<S>::twiddles, (Domain<S>::max_size + 1) * sizeof(S)));
|
||||
CHK_IF_RETURN(generate_external_twiddles_generic(
|
||||
primitive_root, Domain<S>::twiddles, Domain<S>::internal_twiddles, Domain<S>::basic_twiddles,
|
||||
Domain<S>::max_log_size, ctx.stream));
|
||||
>>>>>>> main
|
||||
CHK_IF_RETURN(cudaStreamSynchronize(ctx.stream));
|
||||
|
||||
const bool is_map_only_powers_of_primitive_root = true;
|
||||
if (is_map_only_powers_of_primitive_root) {
|
||||
// populate the coset_index map. Note that only powers of the primitive-root are stored (1, PR, PR^2, PR^4, PR^8
|
||||
// etc.)
|
||||
<<<<<<< HEAD
|
||||
domain.coset_index[S::one()] = 0;
|
||||
for (int i = 0; i < domain.max_log_size; ++i) {
|
||||
const int index = (int)pow(2, i);
|
||||
@@ -493,17 +445,6 @@ namespace ntt {
|
||||
// populate all values
|
||||
for (int i = 0; i < domain.max_size; ++i) {
|
||||
domain.coset_index[domain.twiddles[i]] = i;
|
||||
=======
|
||||
Domain<S>::coset_index[S::one()] = 0;
|
||||
for (int i = 0; i < Domain<S>::max_log_size; ++i) {
|
||||
const int index = (int)pow(2, i);
|
||||
Domain<S>::coset_index[Domain<S>::twiddles[index]] = index;
|
||||
}
|
||||
} else {
|
||||
// populate all values
|
||||
for (int i = 0; i < Domain<S>::max_size; ++i) {
|
||||
Domain<S>::coset_index[Domain<S>::twiddles[i]] = i;
|
||||
>>>>>>> main
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -529,7 +470,6 @@ namespace ntt {
|
||||
return CHK_LAST();
|
||||
}
|
||||
|
||||
<<<<<<< HEAD
|
||||
template <typename S>
|
||||
static bool is_choose_radix2_algorithm(int logn, int batch_size, const NTTConfig<S>& config)
|
||||
{
|
||||
@@ -600,8 +540,6 @@ namespace ntt {
|
||||
return CHK_LAST();
|
||||
}
|
||||
|
||||
=======
|
||||
>>>>>>> main
|
||||
template <typename S, typename E>
|
||||
cudaError_t NTT(E* input, int size, NTTDir dir, NTTConfig<S>& config, E* output)
|
||||
{
|
||||
@@ -643,7 +581,6 @@ namespace ntt {
|
||||
} else {
|
||||
CHK_IF_RETURN(cudaMallocAsync(&d_output, input_size_bytes, stream));
|
||||
}
|
||||
<<<<<<< HEAD
|
||||
|
||||
S* coset = nullptr;
|
||||
int coset_index = 0;
|
||||
@@ -675,64 +612,10 @@ namespace ntt {
|
||||
batch_size, is_inverse, config.ordering, coset, coset_index, stream));
|
||||
}
|
||||
|
||||
=======
|
||||
|
||||
S* coset = nullptr;
|
||||
int coset_index = 0;
|
||||
try {
|
||||
coset_index = Domain<S>::coset_index.at(config.coset_gen);
|
||||
} catch (...) {
|
||||
// if coset index is not found in the subgroup, compute coset powers on CPU and move them to device
|
||||
std::vector<S> h_coset;
|
||||
h_coset.push_back(S::one());
|
||||
S coset_gen = (dir == NTTDir::kInverse) ? S::inverse(config.coset_gen) : config.coset_gen;
|
||||
for (int i = 1; i < size; i++) {
|
||||
h_coset.push_back(h_coset.at(i - 1) * coset_gen);
|
||||
}
|
||||
CHK_IF_RETURN(cudaMallocAsync(&coset, size * sizeof(S), stream));
|
||||
CHK_IF_RETURN(cudaMemcpyAsync(coset, &h_coset.front(), size * sizeof(S), cudaMemcpyHostToDevice, stream));
|
||||
h_coset.clear();
|
||||
}
|
||||
|
||||
const bool is_small_ntt = logn < 16; // cutoff point where mixed-radix is faster than radix-2
|
||||
const bool is_on_coset = (coset_index != 0) || coset; // coset not supported by mixed-radix algorithm yet
|
||||
const bool is_batch_ntt = batch_size > 1; // batch not supported by mixed-radidx algorithm yet
|
||||
const bool is_NN = config.ordering == Ordering::kNN; // TODO Yuval: relax this limitation
|
||||
const bool is_radix2_algorithm = config.is_force_radix2 || is_batch_ntt || is_small_ntt || is_on_coset || !is_NN;
|
||||
|
||||
if (is_radix2_algorithm) {
|
||||
bool ct_butterfly = true;
|
||||
bool reverse_input = false;
|
||||
switch (config.ordering) {
|
||||
case Ordering::kNN:
|
||||
reverse_input = true;
|
||||
break;
|
||||
case Ordering::kNR:
|
||||
ct_butterfly = false;
|
||||
break;
|
||||
case Ordering::kRR:
|
||||
reverse_input = true;
|
||||
ct_butterfly = false;
|
||||
break;
|
||||
}
|
||||
|
||||
if (reverse_input) reverse_order_batch(d_input, size, logn, batch_size, stream, d_output);
|
||||
|
||||
CHK_IF_RETURN(ntt_inplace_batch_template(
|
||||
reverse_input ? d_output : d_input, size, Domain<S>::twiddles, Domain<S>::max_size, batch_size, logn,
|
||||
dir == NTTDir::kInverse, ct_butterfly, coset, coset_index, stream, d_output));
|
||||
|
||||
if (coset) CHK_IF_RETURN(cudaFreeAsync(coset, stream));
|
||||
} else { // mixed-radix algorithm
|
||||
CHK_IF_RETURN(ntt::mixed_radix_ntt(
|
||||
d_input, d_output, Domain<S>::twiddles, Domain<S>::internal_twiddles, Domain<S>::basic_twiddles, size,
|
||||
Domain<S>::max_log_size, dir == NTTDir::kInverse, config.ordering, stream));
|
||||
}
|
||||
|
||||
>>>>>>> main
|
||||
if (!are_outputs_on_device)
|
||||
CHK_IF_RETURN(cudaMemcpyAsync(output, d_output, input_size_bytes, cudaMemcpyDeviceToHost, stream));
|
||||
|
||||
if (coset) CHK_IF_RETURN(cudaFreeAsync(coset, stream));
|
||||
if (!are_inputs_on_device) CHK_IF_RETURN(cudaFreeAsync(d_input, stream));
|
||||
if (!are_outputs_on_device) CHK_IF_RETURN(cudaFreeAsync(d_output, stream));
|
||||
if (!config.is_async) return CHK_STICKY(cudaStreamSynchronize(stream));
|
||||
@@ -745,7 +628,6 @@ namespace ntt {
|
||||
{
|
||||
device_context::DeviceContext ctx = device_context::get_default_device_context();
|
||||
NTTConfig<S> config = {
|
||||
<<<<<<< HEAD
|
||||
ctx, // ctx
|
||||
S::one(), // coset_gen
|
||||
1, // batch_size
|
||||
@@ -754,16 +636,6 @@ namespace ntt {
|
||||
false, // are_outputs_on_device
|
||||
false, // is_async
|
||||
NttAlgorithm::Auto, // ntt_algorithm
|
||||
=======
|
||||
ctx, // ctx
|
||||
S::one(), // coset_gen
|
||||
1, // batch_size
|
||||
Ordering::kNN, // ordering
|
||||
false, // are_inputs_on_device
|
||||
false, // are_outputs_on_device
|
||||
false, // is_async
|
||||
false, // is_force_radix2
|
||||
>>>>>>> main
|
||||
};
|
||||
return config;
|
||||
}
|
||||
|
||||
@@ -100,13 +100,8 @@ namespace ntt {
|
||||
* non-blocking and you'd need to synchronize it explicitly by running
|
||||
* `cudaStreamSynchronize` or `cudaDeviceSynchronize`. If set to false, the NTT
|
||||
* function will block the current CPU thread. */
|
||||
<<<<<<< HEAD
|
||||
NttAlgorithm ntt_algorithm; /**< Explicitly select the NTT algorithm. Default value: Auto (the implementation
|
||||
selects radix-2 or mixed-radix algorithm based on heuristics). */
|
||||
=======
|
||||
bool is_force_radix2; /**< Explicitly select radix-2 NTT algorithm. Default value: false (the implementation selects
|
||||
radix-2 or mixed-radix algorithm based on heuristics). */
|
||||
>>>>>>> main
|
||||
};
|
||||
|
||||
/**
|
||||
|
||||
@@ -25,16 +25,11 @@ namespace ntt {
|
||||
S* basic_twiddles,
|
||||
int ntt_size,
|
||||
int max_logn,
|
||||
<<<<<<< HEAD
|
||||
int batch_size,
|
||||
bool is_inverse,
|
||||
Ordering ordering,
|
||||
S* arbitrary_coset,
|
||||
int coset_gen_index,
|
||||
=======
|
||||
bool is_inverse,
|
||||
Ordering ordering,
|
||||
>>>>>>> main
|
||||
cudaStream_t cuda_stream);
|
||||
|
||||
} // namespace ntt
|
||||
|
||||
@@ -3,7 +3,6 @@
|
||||
|
||||
#include "primitives/field.cuh"
|
||||
#include "primitives/projective.cuh"
|
||||
#include "utils/cuda_utils.cuh"
|
||||
#include <chrono>
|
||||
#include <iostream>
|
||||
#include <vector>
|
||||
@@ -26,11 +25,7 @@ void random_samples(test_data* res, uint32_t count)
|
||||
void incremental_values(test_scalar* res, uint32_t count)
|
||||
{
|
||||
for (int i = 0; i < count; i++) {
|
||||
<<<<<<< HEAD
|
||||
res[i] = i ? res[i - 1] + test_scalar::one() : test_scalar::zero();
|
||||
=======
|
||||
res[i] = i ? res[i - 1] + test_scalar::one() * test_scalar::omega(4) : test_scalar::zero();
|
||||
>>>>>>> main
|
||||
}
|
||||
}
|
||||
|
||||
@@ -39,7 +34,6 @@ int main(int argc, char** argv)
|
||||
cudaEvent_t icicle_start, icicle_stop, new_start, new_stop;
|
||||
float icicle_time, new_time;
|
||||
|
||||
<<<<<<< HEAD
|
||||
int NTT_LOG_SIZE = (argc > 1) ? atoi(argv[1]) : 4; // assuming second input is the log-size
|
||||
int NTT_SIZE = 1 << NTT_LOG_SIZE;
|
||||
bool INPLACE = (argc > 2) ? atoi(argv[2]) : false;
|
||||
@@ -61,32 +55,13 @@ int main(int argc, char** argv)
|
||||
BATCH_SIZE, COSET_IDX, ordering_str);
|
||||
|
||||
CHK_IF_RETURN(cudaFree(nullptr)); // init GPU context (warmup)
|
||||
=======
|
||||
int NTT_LOG_SIZE = (argc > 1) ? atoi(argv[1]) : 19; // assuming second input is the log-size
|
||||
int NTT_SIZE = 1 << NTT_LOG_SIZE;
|
||||
bool INPLACE = (argc > 2) ? atoi(argv[2]) : true;
|
||||
int INV = (argc > 3) ? atoi(argv[3]) : true;
|
||||
|
||||
const ntt::Ordering ordering = ntt::Ordering::kNN;
|
||||
const char* ordering_str = ordering == ntt::Ordering::kNN ? "NN"
|
||||
: ordering == ntt::Ordering::kNR ? "NR"
|
||||
: ordering == ntt::Ordering::kRN ? "RN"
|
||||
: "RR";
|
||||
|
||||
printf("running ntt 2^%d, ordering=%s, inplace=%d, inverse=%d\n", NTT_LOG_SIZE, ordering_str, INPLACE, INV);
|
||||
|
||||
cudaFree(nullptr); // init GPU context (warmup)
|
||||
>>>>>>> main
|
||||
|
||||
// init domain
|
||||
auto ntt_config = ntt::DefaultNTTConfig<test_scalar>();
|
||||
ntt_config.ordering = ordering;
|
||||
ntt_config.are_inputs_on_device = true;
|
||||
ntt_config.are_outputs_on_device = true;
|
||||
<<<<<<< HEAD
|
||||
ntt_config.batch_size = BATCH_SIZE;
|
||||
=======
|
||||
>>>>>>> main
|
||||
|
||||
CHK_IF_RETURN(cudaEventCreate(&icicle_start));
|
||||
CHK_IF_RETURN(cudaEventCreate(&icicle_stop));
|
||||
@@ -101,7 +76,6 @@ int main(int argc, char** argv)
|
||||
std::cout << "initDomain took: " << duration / 1000 << " MS" << std::endl;
|
||||
|
||||
// cpu allocation
|
||||
<<<<<<< HEAD
|
||||
auto CpuScalars = std::make_unique<test_data[]>(NTT_SIZE * BATCH_SIZE);
|
||||
auto CpuOutputOld = std::make_unique<test_data[]>(NTT_SIZE * BATCH_SIZE);
|
||||
auto CpuOutputNew = std::make_unique<test_data[]>(NTT_SIZE * BATCH_SIZE);
|
||||
@@ -136,36 +110,6 @@ int main(int argc, char** argv)
|
||||
CHK_IF_RETURN(ntt::NTT(
|
||||
INPLACE ? GpuOutputNew : GpuScalars, NTT_SIZE, INV ? ntt::NTTDir::kInverse : ntt::NTTDir::kForward, ntt_config,
|
||||
GpuOutputNew));
|
||||
=======
|
||||
auto CpuScalars = std::make_unique<test_data[]>(NTT_SIZE);
|
||||
auto CpuOutputOld = std::make_unique<test_data[]>(NTT_SIZE);
|
||||
auto CpuOutputNew = std::make_unique<test_data[]>(NTT_SIZE);
|
||||
|
||||
// gpu allocation
|
||||
test_data *GpuScalars, *GpuOutputOld, *GpuOutputNew;
|
||||
CHK_IF_RETURN(cudaMalloc(&GpuScalars, sizeof(test_data) * NTT_SIZE));
|
||||
CHK_IF_RETURN(cudaMalloc(&GpuOutputOld, sizeof(test_data) * NTT_SIZE));
|
||||
CHK_IF_RETURN(cudaMalloc(&GpuOutputNew, sizeof(test_data) * NTT_SIZE));
|
||||
|
||||
// init inputs
|
||||
incremental_values(CpuScalars.get(), NTT_SIZE);
|
||||
CHK_IF_RETURN(cudaMemcpy(GpuScalars, CpuScalars.get(), NTT_SIZE, cudaMemcpyHostToDevice));
|
||||
|
||||
// inplace
|
||||
if (INPLACE) {
|
||||
CHK_IF_RETURN(cudaMemcpy(GpuOutputNew, GpuScalars, NTT_SIZE * sizeof(test_data), cudaMemcpyDeviceToDevice));
|
||||
}
|
||||
|
||||
// run ntt
|
||||
auto benchmark = [&](bool is_print, int iterations) -> cudaError_t {
|
||||
// NEW
|
||||
CHK_IF_RETURN(cudaEventRecord(new_start, ntt_config.ctx.stream));
|
||||
ntt_config.is_force_radix2 = false; // mixed-radix ntt (a.k.a new ntt)
|
||||
for (size_t i = 0; i < iterations; i++) {
|
||||
ntt::NTT(
|
||||
INPLACE ? GpuOutputNew : GpuScalars, NTT_SIZE, INV ? ntt::NTTDir::kInverse : ntt::NTTDir::kForward, ntt_config,
|
||||
GpuOutputNew);
|
||||
>>>>>>> main
|
||||
}
|
||||
CHK_IF_RETURN(cudaEventRecord(new_stop, ntt_config.ctx.stream));
|
||||
CHK_IF_RETURN(cudaStreamSynchronize(ntt_config.ctx.stream));
|
||||
@@ -174,16 +118,10 @@ int main(int argc, char** argv)
|
||||
|
||||
// OLD
|
||||
CHK_IF_RETURN(cudaEventRecord(icicle_start, ntt_config.ctx.stream));
|
||||
<<<<<<< HEAD
|
||||
ntt_config.ntt_algorithm = ntt::NttAlgorithm::Radix2;
|
||||
for (size_t i = 0; i < iterations; i++) {
|
||||
CHK_IF_RETURN(
|
||||
ntt::NTT(GpuScalars, NTT_SIZE, INV ? ntt::NTTDir::kInverse : ntt::NTTDir::kForward, ntt_config, GpuOutputOld));
|
||||
=======
|
||||
ntt_config.is_force_radix2 = true;
|
||||
for (size_t i = 0; i < iterations; i++) {
|
||||
ntt::NTT(GpuScalars, NTT_SIZE, INV ? ntt::NTTDir::kInverse : ntt::NTTDir::kForward, ntt_config, GpuOutputOld);
|
||||
>>>>>>> main
|
||||
}
|
||||
CHK_IF_RETURN(cudaEventRecord(icicle_stop, ntt_config.ctx.stream));
|
||||
CHK_IF_RETURN(cudaStreamSynchronize(ntt_config.ctx.stream));
|
||||
@@ -201,17 +139,12 @@ int main(int argc, char** argv)
|
||||
CHK_IF_RETURN(benchmark(false /*=print*/, 1)); // warmup
|
||||
int count = INPLACE ? 1 : 10;
|
||||
if (INPLACE) {
|
||||
<<<<<<< HEAD
|
||||
CHK_IF_RETURN(
|
||||
cudaMemcpy(GpuOutputNew, GpuScalars, NTT_SIZE * BATCH_SIZE * sizeof(test_data), cudaMemcpyDeviceToDevice));
|
||||
=======
|
||||
CHK_IF_RETURN(cudaMemcpy(GpuOutputNew, GpuScalars, NTT_SIZE * sizeof(test_data), cudaMemcpyDeviceToDevice));
|
||||
>>>>>>> main
|
||||
}
|
||||
CHK_IF_RETURN(benchmark(true /*=print*/, count));
|
||||
|
||||
// verify
|
||||
<<<<<<< HEAD
|
||||
CHK_IF_RETURN(
|
||||
cudaMemcpy(CpuOutputNew.get(), GpuOutputNew, NTT_SIZE * BATCH_SIZE * sizeof(test_data), cudaMemcpyDeviceToHost));
|
||||
CHK_IF_RETURN(
|
||||
@@ -219,13 +152,6 @@ int main(int argc, char** argv)
|
||||
|
||||
bool success = true;
|
||||
for (int i = 0; i < NTT_SIZE * BATCH_SIZE; i++) {
|
||||
=======
|
||||
CHK_IF_RETURN(cudaMemcpy(CpuOutputNew.get(), GpuOutputNew, NTT_SIZE * sizeof(test_data), cudaMemcpyDeviceToHost));
|
||||
CHK_IF_RETURN(cudaMemcpy(CpuOutputOld.get(), GpuOutputOld, NTT_SIZE * sizeof(test_data), cudaMemcpyDeviceToHost));
|
||||
|
||||
bool success = true;
|
||||
for (int i = 0; i < NTT_SIZE; i++) {
|
||||
>>>>>>> main
|
||||
if (CpuOutputNew[i] != CpuOutputOld[i]) {
|
||||
success = false;
|
||||
// std::cout << i << " ref " << CpuOutputOld[i] << " != " << CpuOutputNew[i] << std::endl;
|
||||
|
||||
Reference in New Issue
Block a user