merge NTT part

This commit is contained in:
Yuval Shekel
2024-02-15 10:20:37 +02:00
parent fd08925ed4
commit ba6c3ae59c
6 changed files with 1 additions and 440 deletions

View File

@@ -46,17 +46,10 @@ int main(int argc, char** argv)
// init domain
auto ntt_config = ntt::DefaultNTTConfig<test_scalar>();
<<<<<<< HEAD
const bool is_radix2_alg = (argc > 1) ? atoi(argv[1]) : false;
ntt_config.ntt_algorithm = is_radix2_alg ? ntt::NttAlgorithm::Radix2 : ntt::NttAlgorithm::MixedRadix;
const char* ntt_alg_str = is_radix2_alg ? "Radix-2" : "Mixed-Radix";
=======
ntt_config.ordering = ntt::Ordering::kNN; // TODO: use NR for forward and RN for backward
ntt_config.is_force_radix2 = (argc > 1) ? atoi(argv[1]) : false;
const char* ntt_alg_str = ntt_config.is_force_radix2 ? "Radix-2" : "Mixed-Radix";
>>>>>>> main
std::cout << "Polynomial multiplication with " << ntt_alg_str << " NTT: ";
CHK_IF_RETURN(cudaEventCreate(&start));
@@ -85,10 +78,7 @@ int main(int argc, char** argv)
// (3) NTT for A,B from cpu to gpu
ntt_config.are_inputs_on_device = false;
ntt_config.are_outputs_on_device = true;
<<<<<<< HEAD
ntt_config.ordering = ntt::Ordering::kNM;
=======
>>>>>>> main
CHK_IF_RETURN(ntt::NTT(CpuA.get(), NTT_SIZE, ntt::NTTDir::kForward, ntt_config, GpuA));
CHK_IF_RETURN(ntt::NTT(CpuB.get(), NTT_SIZE, ntt::NTTDir::kForward, ntt_config, GpuB));
@@ -100,10 +90,7 @@ int main(int argc, char** argv)
// (5) INTT (in place)
ntt_config.are_inputs_on_device = true;
ntt_config.are_outputs_on_device = true;
<<<<<<< HEAD
ntt_config.ordering = ntt::Ordering::kMN;
=======
>>>>>>> main
CHK_IF_RETURN(ntt::NTT(MulGpu, NTT_SIZE, ntt::NTTDir::kInverse, ntt_config, MulGpu));
CHK_IF_RETURN(cudaFreeAsync(GpuA, ntt_config.ctx.stream));

View File

@@ -6,11 +6,7 @@
namespace ntt {
<<<<<<< HEAD
static inline __device__ uint32_t dig_rev(uint32_t num, uint32_t log_size, bool dit)
=======
static __device__ uint32_t dig_rev(uint32_t num, uint32_t log_size, bool dit)
>>>>>>> main
{
uint32_t rev_num = 0, temp, dig_len;
if (dit) {
@@ -33,7 +29,6 @@ namespace ntt {
return rev_num;
}
<<<<<<< HEAD
static inline __device__ uint32_t bit_rev(uint32_t num, uint32_t log_size) { return __brev(num) >> (32 - log_size); }
enum eRevType { None, RevToMixedRev, MixedRevToRev, NaturalToMixedRev, NaturalToRev, MixedRevToNatural };
@@ -64,19 +59,10 @@ namespace ntt {
E* arr, uint32_t log_size, bool dit, eRevType rev_type, bool is_normalize, S inverse_N)
{
// launch N threads (per batch element)
=======
// Note: the following reorder kernels are fused with normalization for INTT
template <typename E, typename S, uint32_t MAX_GROUP_SIZE = 80>
static __global__ void
reorder_digits_inplace_kernel(E* arr, uint32_t log_size, bool dit, bool is_normalize, S inverse_N)
{
// launch N threads
>>>>>>> main
// each thread starts from one index and calculates the corresponding group
// if its index is the smallest number in the group -> do the memory transformation
// else --> do nothing
<<<<<<< HEAD
const uint32_t size = 1 << log_size;
const uint32_t tid = blockDim.x * blockIdx.x + threadIdx.x;
const uint32_t idx = tid % size;
@@ -94,32 +80,12 @@ namespace ntt {
group[i++] = next_element + size * batch_idx;
}
=======
const uint32_t idx = blockDim.x * blockIdx.x + threadIdx.x;
uint32_t next_element = idx;
uint32_t group[MAX_GROUP_SIZE];
group[0] = idx;
uint32_t i = 1;
for (; i < MAX_GROUP_SIZE;) {
next_element = dig_rev(next_element, log_size, dit);
if (next_element < idx) return; // not handling this group
if (next_element == idx) break; // calculated whole group
group[i++] = next_element;
}
if (i == 1) { // single element in group --> nothing to do (except maybe normalize for INTT)
if (is_normalize) { arr[idx] = arr[idx] * inverse_N; }
return;
}
>>>>>>> main
--i;
// reaching here means I am handling this group
const E last_element_in_group = arr[group[i]];
for (; i > 0; --i) {
arr[group[i]] = is_normalize ? (arr[group[i - 1]] * inverse_N) : arr[group[i - 1]];
}
<<<<<<< HEAD
arr[group[0]] = is_normalize ? (last_element_in_group * inverse_N) : last_element_in_group;
}
@@ -131,23 +97,10 @@ namespace ntt {
uint32_t rd = tid;
uint32_t wr =
((tid >> log_size) << log_size) + generalized_rev(tid & ((1 << log_size) - 1), log_size, dit, rev_type);
=======
arr[idx] = is_normalize ? (last_element_in_group * inverse_N) : last_element_in_group;
}
template <typename E, typename S>
__launch_bounds__(64) __global__
void reorder_digits_kernel(E* arr, E* arr_reordered, uint32_t log_size, bool dit, bool is_normalize, S inverse_N)
{
uint32_t tid = blockDim.x * blockIdx.x + threadIdx.x;
uint32_t rd = tid;
uint32_t wr = dig_rev(tid, log_size, dit);
>>>>>>> main
arr_reordered[wr] = is_normalize ? arr[rd] * inverse_N : arr[rd];
}
template <typename E, typename S>
<<<<<<< HEAD
static __global__ void batch_elementwise_mul_with_reorder(
E* in_vec,
int n_elements,
@@ -168,8 +121,6 @@ namespace ntt {
}
template <typename E, typename S>
=======
>>>>>>> main
__launch_bounds__(64) __global__ void ntt64(
E* in,
E* out,
@@ -178,10 +129,7 @@ namespace ntt {
S* basic_twiddles,
uint32_t log_size,
uint32_t tw_log_size,
<<<<<<< HEAD
uint32_t nof_ntt_blocks,
=======
>>>>>>> main
uint32_t data_stride,
uint32_t log_data_stride,
uint32_t twiddle_stride,
@@ -200,11 +148,8 @@ namespace ntt {
s_meta.ntt_block_id = (blockIdx.x << 3) + (strided ? (threadIdx.x & 0x7) : (threadIdx.x >> 3));
s_meta.ntt_inp_id = strided ? (threadIdx.x >> 3) : (threadIdx.x & 0x7);
<<<<<<< HEAD
if (s_meta.ntt_block_id >= nof_ntt_blocks) return;
=======
>>>>>>> main
engine.loadBasicTwiddles(basic_twiddles, inv);
engine.loadGlobalData(in, data_stride, log_data_stride, log_size, strided, s_meta);
if (twiddle_stride && dit) {
@@ -242,10 +187,7 @@ namespace ntt {
S* basic_twiddles,
uint32_t log_size,
uint32_t tw_log_size,
<<<<<<< HEAD
uint32_t nof_ntt_blocks,
=======
>>>>>>> main
uint32_t data_stride,
uint32_t log_data_stride,
uint32_t twiddle_stride,
@@ -265,11 +207,8 @@ namespace ntt {
s_meta.ntt_block_id = (blockIdx.x << 4) + (strided ? (threadIdx.x & 0xf) : (threadIdx.x >> 2));
s_meta.ntt_inp_id = strided ? (threadIdx.x >> 4) : (threadIdx.x & 0x3);
<<<<<<< HEAD
if (s_meta.ntt_block_id >= nof_ntt_blocks) return;
=======
>>>>>>> main
engine.loadBasicTwiddles(basic_twiddles, inv);
engine.loadGlobalData(in, data_stride, log_data_stride, log_size, strided, s_meta);
engine.loadInternalTwiddles32(internal_twiddles, strided, inv);
@@ -295,12 +234,8 @@ namespace ntt {
S* internal_twiddles,
S* basic_twiddles,
uint32_t log_size,
<<<<<<< HEAD
uint32_t tw_log_size,
uint32_t nof_ntt_blocks,
=======
int32_t tw_log_size,
>>>>>>> main
uint32_t data_stride,
uint32_t log_data_stride,
uint32_t twiddle_stride,
@@ -320,11 +255,8 @@ namespace ntt {
s_meta.ntt_block_id = (blockIdx.x << 4) + (strided ? (threadIdx.x & 0xf) : (threadIdx.x >> 2));
s_meta.ntt_inp_id = strided ? (threadIdx.x >> 4) : (threadIdx.x & 0x3);
<<<<<<< HEAD
if (s_meta.ntt_block_id >= nof_ntt_blocks) return;
=======
>>>>>>> main
engine.loadBasicTwiddles(basic_twiddles, inv);
engine.loadGlobalData32(in, data_stride, log_data_stride, log_size, strided, s_meta);
if (twiddle_stride) {
@@ -351,10 +283,7 @@ namespace ntt {
S* basic_twiddles,
uint32_t log_size,
uint32_t tw_log_size,
<<<<<<< HEAD
uint32_t nof_ntt_blocks,
=======
>>>>>>> main
uint32_t data_stride,
uint32_t log_data_stride,
uint32_t twiddle_stride,
@@ -374,11 +303,8 @@ namespace ntt {
s_meta.ntt_block_id = (blockIdx.x << 5) + (strided ? (threadIdx.x & 0x1f) : (threadIdx.x >> 1));
s_meta.ntt_inp_id = strided ? (threadIdx.x >> 5) : (threadIdx.x & 0x1);
<<<<<<< HEAD
if (s_meta.ntt_block_id >= nof_ntt_blocks) return;
=======
>>>>>>> main
engine.loadBasicTwiddles(basic_twiddles, inv);
engine.loadGlobalData(in, data_stride, log_data_stride, log_size, strided, s_meta);
engine.loadInternalTwiddles16(internal_twiddles, strided, inv);
@@ -405,10 +331,7 @@ namespace ntt {
S* basic_twiddles,
uint32_t log_size,
uint32_t tw_log_size,
<<<<<<< HEAD
uint32_t nof_ntt_blocks,
=======
>>>>>>> main
uint32_t data_stride,
uint32_t log_data_stride,
uint32_t twiddle_stride,
@@ -428,11 +351,8 @@ namespace ntt {
s_meta.ntt_block_id = (blockIdx.x << 5) + (strided ? (threadIdx.x & 0x1f) : (threadIdx.x >> 1));
s_meta.ntt_inp_id = strided ? (threadIdx.x >> 5) : (threadIdx.x & 0x1);
<<<<<<< HEAD
if (s_meta.ntt_block_id >= nof_ntt_blocks) return;
=======
>>>>>>> main
engine.loadBasicTwiddles(basic_twiddles, inv);
engine.loadGlobalData16(in, data_stride, log_data_stride, log_size, strided, s_meta);
if (twiddle_stride) {
@@ -590,10 +510,7 @@ namespace ntt {
S* basic_twiddles,
uint32_t log_size,
uint32_t tw_log_size,
<<<<<<< HEAD
uint32_t batch_size,
=======
>>>>>>> main
bool inv,
bool normalize,
bool dit,
@@ -606,7 +523,6 @@ namespace ntt {
}
if (log_size == 4) {
<<<<<<< HEAD
const int NOF_THREADS = min(64, 2 * batch_size);
const int NOF_BLOCKS = (2 * batch_size + NOF_THREADS - 1) / NOF_THREADS;
@@ -620,23 +536,10 @@ namespace ntt {
false, 0, inv, dit);
}
if (normalize) normalize_kernel<<<batch_size, 16, 0, cuda_stream>>>(out, S::inv_log_size(4));
=======
if (dit) {
ntt16dit<<<1, 2, 8 * 64 * sizeof(E), cuda_stream>>>(
in, out, external_twiddles, internal_twiddles, basic_twiddles, log_size, tw_log_size, 1, 0, 0, false, 0, inv,
dit);
} else { // dif
ntt16<<<1, 2, 8 * 64 * sizeof(E), cuda_stream>>>(
in, out, external_twiddles, internal_twiddles, basic_twiddles, log_size, tw_log_size, 1, 0, 0, false, 0, inv,
dit);
}
if (normalize) normalize_kernel<<<1, 16, 0, cuda_stream>>>(out, S::inv_log_size(4));
>>>>>>> main
return CHK_LAST();
}
if (log_size == 5) {
<<<<<<< HEAD
const int NOF_THREADS = min(64, 4 * batch_size);
const int NOF_BLOCKS = (4 * batch_size + NOF_THREADS - 1) / NOF_THREADS;
if (dit) {
@@ -649,40 +552,20 @@ namespace ntt {
false, 0, inv, dit);
}
if (normalize) normalize_kernel<<<batch_size, 32, 0, cuda_stream>>>(out, S::inv_log_size(5));
=======
if (dit) {
ntt32dit<<<1, 4, 8 * 64 * sizeof(E), cuda_stream>>>(
in, out, external_twiddles, internal_twiddles, basic_twiddles, log_size, tw_log_size, 1, 0, 0, false, 0, inv,
dit);
} else { // dif
ntt32<<<1, 4, 8 * 64 * sizeof(E), cuda_stream>>>(
in, out, external_twiddles, internal_twiddles, basic_twiddles, log_size, tw_log_size, 1, 0, 0, false, 0, inv,
dit);
}
if (normalize) normalize_kernel<<<1, 32, 0, cuda_stream>>>(out, S::inv_log_size(5));
>>>>>>> main
return CHK_LAST();
}
if (log_size == 6) {
<<<<<<< HEAD
const int NOF_THREADS = min(64, 8 * batch_size);
const int NOF_BLOCKS = (8 * batch_size + NOF_THREADS - 1) / NOF_THREADS;
ntt64<<<NOF_BLOCKS, NOF_THREADS, 8 * 64 * sizeof(E), cuda_stream>>>(
in, out, external_twiddles, internal_twiddles, basic_twiddles, log_size, tw_log_size, batch_size, 1, 0, 0,
false, 0, inv, dit);
if (normalize) normalize_kernel<<<batch_size, 64, 0, cuda_stream>>>(out, S::inv_log_size(6));
=======
ntt64<<<1, 8, 8 * 64 * sizeof(E), cuda_stream>>>(
in, out, external_twiddles, internal_twiddles, basic_twiddles, log_size, tw_log_size, 1, 0, 0, false, 0, inv,
dit);
if (normalize) normalize_kernel<<<1, 64, 0, cuda_stream>>>(out, S::inv_log_size(6));
>>>>>>> main
return CHK_LAST();
}
if (log_size == 8) {
<<<<<<< HEAD
const int NOF_THREADS = 64;
const int NOF_BLOCKS = (32 * batch_size + NOF_THREADS - 1) / NOF_THREADS;
if (dit) {
@@ -701,33 +584,11 @@ namespace ntt {
(1 << log_size - 4) * batch_size, 1, 0, 0, false, 0, inv, dit);
}
if (normalize) normalize_kernel<<<batch_size, 256, 0, cuda_stream>>>(out, S::inv_log_size(8));
=======
if (dit) {
ntt16dit<<<1, 32, 8 * 64 * sizeof(E), cuda_stream>>>(
in, out, external_twiddles, internal_twiddles, basic_twiddles, log_size, tw_log_size, 1, 0, 0, false, 0, inv,
dit);
ntt16dit<<<1, 64, 8 * 64 * sizeof(E), cuda_stream>>>(
out, out, external_twiddles, internal_twiddles, basic_twiddles, log_size, tw_log_size, 16, 4, 16, true, 1,
inv,
dit); // we need threads 32+ although 16-31 are idle
} else { // dif
ntt16<<<1, 64, 8 * 64 * sizeof(E), cuda_stream>>>(
in, out, external_twiddles, internal_twiddles, basic_twiddles, log_size, tw_log_size, 16, 4, 16, true, 1, inv,
dit); // we need threads 32+ although 16-31 are idle
ntt16<<<1, 32, 8 * 64 * sizeof(E), cuda_stream>>>(
out, out, external_twiddles, internal_twiddles, basic_twiddles, log_size, tw_log_size, 1, 0, 0, false, 0, inv,
dit);
}
if (normalize) normalize_kernel<<<1, 256, 0, cuda_stream>>>(out, S::inv_log_size(8));
>>>>>>> main
return CHK_LAST();
}
// general case:
<<<<<<< HEAD
uint32_t nof_blocks = (1 << (log_size - 9)) * batch_size;
=======
>>>>>>> main
if (dit) {
for (int i = 0; i < 5; i++) {
uint32_t stage_size = STAGE_SIZES_HOST[log_size][i];
@@ -735,7 +596,6 @@ namespace ntt {
for (int j = 0; j < i; j++)
stride_log += STAGE_SIZES_HOST[log_size][j];
if (stage_size == 6)
<<<<<<< HEAD
ntt64<<<nof_blocks, 64, 8 * 64 * sizeof(E), cuda_stream>>>(
i ? out : in, out, external_twiddles, internal_twiddles, basic_twiddles, log_size, tw_log_size,
(1 << log_size - 6) * batch_size, 1 << stride_log, stride_log, i ? (1 << stride_log) : 0, i, i, inv, dit);
@@ -747,19 +607,6 @@ namespace ntt {
ntt16dit<<<nof_blocks, 64, 8 * 64 * sizeof(E), cuda_stream>>>(
i ? out : in, out, external_twiddles, internal_twiddles, basic_twiddles, log_size, tw_log_size,
(1 << log_size - 4) * batch_size, 1 << stride_log, stride_log, i ? (1 << stride_log) : 0, i, i, inv, dit);
=======
ntt64<<<1 << (log_size - 9), 64, 8 * 64 * sizeof(E), cuda_stream>>>(
i ? out : in, out, external_twiddles, internal_twiddles, basic_twiddles, log_size, tw_log_size,
1 << stride_log, stride_log, i ? (1 << stride_log) : 0, i, i, inv, dit);
else if (stage_size == 5)
ntt32dit<<<1 << (log_size - 9), 64, 8 * 64 * sizeof(E), cuda_stream>>>(
i ? out : in, out, external_twiddles, internal_twiddles, basic_twiddles, log_size, tw_log_size,
1 << stride_log, stride_log, i ? (1 << stride_log) : 0, i, i, inv, dit);
else if (stage_size == 4)
ntt16dit<<<1 << (log_size - 9), 64, 8 * 64 * sizeof(E), cuda_stream>>>(
i ? out : in, out, external_twiddles, internal_twiddles, basic_twiddles, log_size, tw_log_size,
1 << stride_log, stride_log, i ? (1 << stride_log) : 0, i, i, inv, dit);
>>>>>>> main
}
} else { // dif
bool first_run = false, prev_stage = false;
@@ -770,7 +617,6 @@ namespace ntt {
stride_log += STAGE_SIZES_HOST[log_size][j];
first_run = stage_size && !prev_stage;
if (stage_size == 6)
<<<<<<< HEAD
ntt64<<<nof_blocks, 64, 8 * 64 * sizeof(E), cuda_stream>>>(
first_run ? in : out, out, external_twiddles, internal_twiddles, basic_twiddles, log_size, tw_log_size,
(1 << log_size - 6) * batch_size, 1 << stride_log, stride_log, i ? (1 << stride_log) : 0, i, i, inv, dit);
@@ -787,23 +633,6 @@ namespace ntt {
}
if (normalize)
normalize_kernel<<<(1 << (log_size - 8)) * batch_size, 256, 0, cuda_stream>>>(out, S::inv_log_size(log_size));
=======
ntt64<<<1 << (log_size - 9), 64, 8 * 64 * sizeof(E), cuda_stream>>>(
first_run ? in : out, out, external_twiddles, internal_twiddles, basic_twiddles, log_size, tw_log_size,
1 << stride_log, stride_log, i ? (1 << stride_log) : 0, i, i, inv, dit);
else if (stage_size == 5)
ntt32<<<1 << (log_size - 9), 64, 8 * 64 * sizeof(E), cuda_stream>>>(
first_run ? in : out, out, external_twiddles, internal_twiddles, basic_twiddles, log_size, tw_log_size,
1 << stride_log, stride_log, i ? (1 << stride_log) : 0, i, i, inv, dit);
else if (stage_size == 4)
ntt16<<<1 << (log_size - 9), 64, 8 * 64 * sizeof(E), cuda_stream>>>(
first_run ? in : out, out, external_twiddles, internal_twiddles, basic_twiddles, log_size, tw_log_size,
1 << stride_log, stride_log, i ? (1 << stride_log) : 0, i, i, inv, dit);
prev_stage = stage_size;
}
}
if (normalize) normalize_kernel<<<1 << (log_size - 8), 256, 0, cuda_stream>>>(out, S::inv_log_size(log_size));
>>>>>>> main
return CHK_LAST();
}
@@ -817,21 +646,15 @@ namespace ntt {
S* basic_twiddles,
int ntt_size,
int max_logn,
<<<<<<< HEAD
int batch_size,
bool is_inverse,
Ordering ordering,
S* arbitrary_coset,
int coset_gen_index,
=======
bool is_inverse,
Ordering ordering,
>>>>>>> main
cudaStream_t cuda_stream)
{
CHK_INIT_IF_RETURN();
<<<<<<< HEAD
const int logn = int(log2(ntt_size));
const int NOF_BLOCKS = ((1 << logn) * batch_size + 64 - 1) / 64;
const int NOF_THREADS = min(64, (1 << logn) * batch_size);
@@ -890,38 +713,10 @@ namespace ntt {
}
is_normalize = false;
d_input = d_output;
=======
// TODO: can we support all orderings? Note that reversal is generally digit reverse (generalization of bit reverse)
if (ordering != Ordering::kNN) {
throw IcicleError(IcicleError_t::InvalidArgument, "Mixed-Radix NTT supports NN ordering only");
}
const int logn = int(log2(ntt_size));
const int NOF_BLOCKS = (1 << (max(logn, 6) - 6));
const int NOF_THREADS = min(64, 1 << logn);
const bool reverse_input = ordering == Ordering::kNN;
const bool is_dit = ordering == Ordering::kNN || ordering == Ordering::kRN;
bool is_normalize = is_inverse;
if (reverse_input) {
// Note: fusing reorder with normalize for INTT
const bool is_reverse_in_place = (d_input == d_output);
if (is_reverse_in_place) {
reorder_digits_inplace_kernel<<<NOF_BLOCKS, NOF_THREADS, 0, cuda_stream>>>(
d_output, logn, is_dit, is_normalize, S::inv_log_size(logn));
} else {
reorder_digits_kernel<<<NOF_BLOCKS, NOF_THREADS, 0, cuda_stream>>>(
d_input, d_output, logn, is_dit, is_normalize, S::inv_log_size(logn));
}
is_normalize = false;
>>>>>>> main
}
// inplace ntt
CHK_IF_RETURN(large_ntt(
<<<<<<< HEAD
d_input, d_output, external_twiddles, internal_twiddles, basic_twiddles, logn, max_logn, batch_size, is_inverse,
(is_normalize && reverse_output == eRevType::None), dit, cuda_stream));
@@ -935,10 +730,6 @@ namespace ntt {
d_output, ntt_size, batch_size, arbitrary_coset ? arbitrary_coset : external_twiddles + n_twiddles,
arbitrary_coset ? 1 : -coset_gen_index, n_twiddles, logn, reverse_coset, dit, d_output);
}
=======
d_output, d_output, external_twiddles, internal_twiddles, basic_twiddles, logn, max_logn, is_inverse,
is_normalize, is_dit, cuda_stream));
>>>>>>> main
return CHK_LAST();
}
@@ -960,16 +751,11 @@ namespace ntt {
curve_config::scalar_t* basic_twiddles,
int ntt_size,
int max_logn,
<<<<<<< HEAD
int batch_size,
bool is_inverse,
Ordering ordering,
curve_config::scalar_t* arbitrary_coset,
int coset_gen_index,
=======
bool is_inverse,
Ordering ordering,
>>>>>>> main
cudaStream_t cuda_stream);
} // namespace ntt

View File

@@ -8,11 +8,8 @@
#include "utils/utils_kernels.cuh"
#include "utils/utils.h"
#include "appUtils/ntt/ntt_impl.cuh"
<<<<<<< HEAD
#include <mutex>
=======
>>>>>>> main
namespace ntt {
@@ -366,7 +363,6 @@ namespace ntt {
template <typename S>
class Domain
{
<<<<<<< HEAD
// Mutex for protecting access to the domain/device container array
static inline std::mutex device_domain_mutex;
// The domain-per-device container - assumption is InitDomain is called once per device per program.
@@ -378,37 +374,21 @@ namespace ntt {
S* internal_twiddles = nullptr; // required by mixed-radix NTT
S* basic_twiddles = nullptr; // required by mixed-radix NTT
=======
static inline int max_size = 0;
static inline int max_log_size = 0;
static inline S* twiddles = nullptr;
static inline std::unordered_map<S, int> coset_index = {};
static inline S* internal_twiddles = nullptr; // required by mixed-radix NTT
static inline S* basic_twiddles = nullptr; // required by mixed-radix NTT
>>>>>>> main
public:
template <typename U>
friend cudaError_t InitDomain<U>(U primitive_root, device_context::DeviceContext& ctx);
<<<<<<< HEAD
cudaError_t ReleaseDomain(device_context::DeviceContext& ctx);
=======
static cudaError_t ReleaseDomain(device_context::DeviceContext& ctx);
>>>>>>> main
template <typename U, typename E>
friend cudaError_t NTT<U, E>(E* input, int size, NTTDir dir, NTTConfig<U>& config, E* output);
};
template <typename S>
<<<<<<< HEAD
static inline Domain<S> domains_for_devices[device_context::MAX_DEVICES] = {};
template <typename S>
=======
>>>>>>> main
cudaError_t InitDomain(S primitive_root, device_context::DeviceContext& ctx)
{
CHK_INIT_IF_RETURN();
@@ -416,7 +396,6 @@ namespace ntt {
Domain<S>& domain = domains_for_devices<S>[ctx.device_id];
// only generate twiddles if they haven't been generated yet
<<<<<<< HEAD
// please note that this offers just basic thread-safety,
// it's assumed a singleton (non-enforced) that is supposed
// to be initialized once per device per program lifetime
@@ -426,18 +405,12 @@ namespace ntt {
// double check locking
if (domain.twiddles) return CHK_LAST(); // another thread is already initializing the domain
=======
// please note that this is not thread-safe at all,
// but it's a singleton that is supposed to be initialized once per program lifetime
if (!Domain<S>::twiddles) {
>>>>>>> main
bool found_logn = false;
S omega = primitive_root;
unsigned omegas_count = S::get_omegas_count();
for (int i = 0; i < omegas_count; i++) {
omega = S::sqr(omega);
if (!found_logn) {
<<<<<<< HEAD
++domain.max_log_size;
found_logn = omega == S::one();
if (found_logn) break;
@@ -457,33 +430,12 @@ namespace ntt {
CHK_IF_RETURN(generate_external_twiddles_generic(
primitive_root, domain.twiddles, domain.internal_twiddles, domain.basic_twiddles, domain.max_log_size,
ctx.stream));
=======
++Domain<S>::max_log_size;
found_logn = omega == S::one();
if (found_logn) break;
}
}
Domain<S>::max_size = (int)pow(2, Domain<S>::max_log_size);
if (omega != S::one()) {
throw IcicleError(
IcicleError_t::InvalidArgument, "Primitive root provided to the InitDomain function is not in the subgroup");
}
// allocate and calculate twiddles on GPU
// Note: radix-2 INTT needs ONE in last element (in addition to first element), therefore have n+1 elements
// Managed allocation allows host to read the elements (logn) without copying all (n) TFs back to host
CHK_IF_RETURN(cudaMallocManaged(&Domain<S>::twiddles, (Domain<S>::max_size + 1) * sizeof(S)));
CHK_IF_RETURN(generate_external_twiddles_generic(
primitive_root, Domain<S>::twiddles, Domain<S>::internal_twiddles, Domain<S>::basic_twiddles,
Domain<S>::max_log_size, ctx.stream));
>>>>>>> main
CHK_IF_RETURN(cudaStreamSynchronize(ctx.stream));
const bool is_map_only_powers_of_primitive_root = true;
if (is_map_only_powers_of_primitive_root) {
// populate the coset_index map. Note that only powers of the primitive-root are stored (1, PR, PR^2, PR^4, PR^8
// etc.)
<<<<<<< HEAD
domain.coset_index[S::one()] = 0;
for (int i = 0; i < domain.max_log_size; ++i) {
const int index = (int)pow(2, i);
@@ -493,17 +445,6 @@ namespace ntt {
// populate all values
for (int i = 0; i < domain.max_size; ++i) {
domain.coset_index[domain.twiddles[i]] = i;
=======
Domain<S>::coset_index[S::one()] = 0;
for (int i = 0; i < Domain<S>::max_log_size; ++i) {
const int index = (int)pow(2, i);
Domain<S>::coset_index[Domain<S>::twiddles[index]] = index;
}
} else {
// populate all values
for (int i = 0; i < Domain<S>::max_size; ++i) {
Domain<S>::coset_index[Domain<S>::twiddles[i]] = i;
>>>>>>> main
}
}
}
@@ -529,7 +470,6 @@ namespace ntt {
return CHK_LAST();
}
<<<<<<< HEAD
template <typename S>
static bool is_choose_radix2_algorithm(int logn, int batch_size, const NTTConfig<S>& config)
{
@@ -600,8 +540,6 @@ namespace ntt {
return CHK_LAST();
}
=======
>>>>>>> main
template <typename S, typename E>
cudaError_t NTT(E* input, int size, NTTDir dir, NTTConfig<S>& config, E* output)
{
@@ -643,7 +581,6 @@ namespace ntt {
} else {
CHK_IF_RETURN(cudaMallocAsync(&d_output, input_size_bytes, stream));
}
<<<<<<< HEAD
S* coset = nullptr;
int coset_index = 0;
@@ -675,64 +612,10 @@ namespace ntt {
batch_size, is_inverse, config.ordering, coset, coset_index, stream));
}
=======
S* coset = nullptr;
int coset_index = 0;
try {
coset_index = Domain<S>::coset_index.at(config.coset_gen);
} catch (...) {
// if coset index is not found in the subgroup, compute coset powers on CPU and move them to device
std::vector<S> h_coset;
h_coset.push_back(S::one());
S coset_gen = (dir == NTTDir::kInverse) ? S::inverse(config.coset_gen) : config.coset_gen;
for (int i = 1; i < size; i++) {
h_coset.push_back(h_coset.at(i - 1) * coset_gen);
}
CHK_IF_RETURN(cudaMallocAsync(&coset, size * sizeof(S), stream));
CHK_IF_RETURN(cudaMemcpyAsync(coset, &h_coset.front(), size * sizeof(S), cudaMemcpyHostToDevice, stream));
h_coset.clear();
}
const bool is_small_ntt = logn < 16; // cutoff point where mixed-radix is faster than radix-2
const bool is_on_coset = (coset_index != 0) || coset; // coset not supported by mixed-radix algorithm yet
const bool is_batch_ntt = batch_size > 1; // batch not supported by mixed-radidx algorithm yet
const bool is_NN = config.ordering == Ordering::kNN; // TODO Yuval: relax this limitation
const bool is_radix2_algorithm = config.is_force_radix2 || is_batch_ntt || is_small_ntt || is_on_coset || !is_NN;
if (is_radix2_algorithm) {
bool ct_butterfly = true;
bool reverse_input = false;
switch (config.ordering) {
case Ordering::kNN:
reverse_input = true;
break;
case Ordering::kNR:
ct_butterfly = false;
break;
case Ordering::kRR:
reverse_input = true;
ct_butterfly = false;
break;
}
if (reverse_input) reverse_order_batch(d_input, size, logn, batch_size, stream, d_output);
CHK_IF_RETURN(ntt_inplace_batch_template(
reverse_input ? d_output : d_input, size, Domain<S>::twiddles, Domain<S>::max_size, batch_size, logn,
dir == NTTDir::kInverse, ct_butterfly, coset, coset_index, stream, d_output));
if (coset) CHK_IF_RETURN(cudaFreeAsync(coset, stream));
} else { // mixed-radix algorithm
CHK_IF_RETURN(ntt::mixed_radix_ntt(
d_input, d_output, Domain<S>::twiddles, Domain<S>::internal_twiddles, Domain<S>::basic_twiddles, size,
Domain<S>::max_log_size, dir == NTTDir::kInverse, config.ordering, stream));
}
>>>>>>> main
if (!are_outputs_on_device)
CHK_IF_RETURN(cudaMemcpyAsync(output, d_output, input_size_bytes, cudaMemcpyDeviceToHost, stream));
if (coset) CHK_IF_RETURN(cudaFreeAsync(coset, stream));
if (!are_inputs_on_device) CHK_IF_RETURN(cudaFreeAsync(d_input, stream));
if (!are_outputs_on_device) CHK_IF_RETURN(cudaFreeAsync(d_output, stream));
if (!config.is_async) return CHK_STICKY(cudaStreamSynchronize(stream));
@@ -745,7 +628,6 @@ namespace ntt {
{
device_context::DeviceContext ctx = device_context::get_default_device_context();
NTTConfig<S> config = {
<<<<<<< HEAD
ctx, // ctx
S::one(), // coset_gen
1, // batch_size
@@ -754,16 +636,6 @@ namespace ntt {
false, // are_outputs_on_device
false, // is_async
NttAlgorithm::Auto, // ntt_algorithm
=======
ctx, // ctx
S::one(), // coset_gen
1, // batch_size
Ordering::kNN, // ordering
false, // are_inputs_on_device
false, // are_outputs_on_device
false, // is_async
false, // is_force_radix2
>>>>>>> main
};
return config;
}

View File

@@ -100,13 +100,8 @@ namespace ntt {
* non-blocking and you'd need to synchronize it explicitly by running
* `cudaStreamSynchronize` or `cudaDeviceSynchronize`. If set to false, the NTT
* function will block the current CPU thread. */
<<<<<<< HEAD
NttAlgorithm ntt_algorithm; /**< Explicitly select the NTT algorithm. Default value: Auto (the implementation
selects radix-2 or mixed-radix algorithm based on heuristics). */
=======
bool is_force_radix2; /**< Explicitly select radix-2 NTT algorithm. Default value: false (the implementation selects
radix-2 or mixed-radix algorithm based on heuristics). */
>>>>>>> main
};
/**

View File

@@ -25,16 +25,11 @@ namespace ntt {
S* basic_twiddles,
int ntt_size,
int max_logn,
<<<<<<< HEAD
int batch_size,
bool is_inverse,
Ordering ordering,
S* arbitrary_coset,
int coset_gen_index,
=======
bool is_inverse,
Ordering ordering,
>>>>>>> main
cudaStream_t cuda_stream);
} // namespace ntt

View File

@@ -3,7 +3,6 @@
#include "primitives/field.cuh"
#include "primitives/projective.cuh"
#include "utils/cuda_utils.cuh"
#include <chrono>
#include <iostream>
#include <vector>
@@ -26,11 +25,7 @@ void random_samples(test_data* res, uint32_t count)
void incremental_values(test_scalar* res, uint32_t count)
{
for (int i = 0; i < count; i++) {
<<<<<<< HEAD
res[i] = i ? res[i - 1] + test_scalar::one() : test_scalar::zero();
=======
res[i] = i ? res[i - 1] + test_scalar::one() * test_scalar::omega(4) : test_scalar::zero();
>>>>>>> main
}
}
@@ -39,7 +34,6 @@ int main(int argc, char** argv)
cudaEvent_t icicle_start, icicle_stop, new_start, new_stop;
float icicle_time, new_time;
<<<<<<< HEAD
int NTT_LOG_SIZE = (argc > 1) ? atoi(argv[1]) : 4; // assuming second input is the log-size
int NTT_SIZE = 1 << NTT_LOG_SIZE;
bool INPLACE = (argc > 2) ? atoi(argv[2]) : false;
@@ -61,32 +55,13 @@ int main(int argc, char** argv)
BATCH_SIZE, COSET_IDX, ordering_str);
CHK_IF_RETURN(cudaFree(nullptr)); // init GPU context (warmup)
=======
int NTT_LOG_SIZE = (argc > 1) ? atoi(argv[1]) : 19; // assuming second input is the log-size
int NTT_SIZE = 1 << NTT_LOG_SIZE;
bool INPLACE = (argc > 2) ? atoi(argv[2]) : true;
int INV = (argc > 3) ? atoi(argv[3]) : true;
const ntt::Ordering ordering = ntt::Ordering::kNN;
const char* ordering_str = ordering == ntt::Ordering::kNN ? "NN"
: ordering == ntt::Ordering::kNR ? "NR"
: ordering == ntt::Ordering::kRN ? "RN"
: "RR";
printf("running ntt 2^%d, ordering=%s, inplace=%d, inverse=%d\n", NTT_LOG_SIZE, ordering_str, INPLACE, INV);
cudaFree(nullptr); // init GPU context (warmup)
>>>>>>> main
// init domain
auto ntt_config = ntt::DefaultNTTConfig<test_scalar>();
ntt_config.ordering = ordering;
ntt_config.are_inputs_on_device = true;
ntt_config.are_outputs_on_device = true;
<<<<<<< HEAD
ntt_config.batch_size = BATCH_SIZE;
=======
>>>>>>> main
CHK_IF_RETURN(cudaEventCreate(&icicle_start));
CHK_IF_RETURN(cudaEventCreate(&icicle_stop));
@@ -101,7 +76,6 @@ int main(int argc, char** argv)
std::cout << "initDomain took: " << duration / 1000 << " MS" << std::endl;
// cpu allocation
<<<<<<< HEAD
auto CpuScalars = std::make_unique<test_data[]>(NTT_SIZE * BATCH_SIZE);
auto CpuOutputOld = std::make_unique<test_data[]>(NTT_SIZE * BATCH_SIZE);
auto CpuOutputNew = std::make_unique<test_data[]>(NTT_SIZE * BATCH_SIZE);
@@ -136,36 +110,6 @@ int main(int argc, char** argv)
CHK_IF_RETURN(ntt::NTT(
INPLACE ? GpuOutputNew : GpuScalars, NTT_SIZE, INV ? ntt::NTTDir::kInverse : ntt::NTTDir::kForward, ntt_config,
GpuOutputNew));
=======
auto CpuScalars = std::make_unique<test_data[]>(NTT_SIZE);
auto CpuOutputOld = std::make_unique<test_data[]>(NTT_SIZE);
auto CpuOutputNew = std::make_unique<test_data[]>(NTT_SIZE);
// gpu allocation
test_data *GpuScalars, *GpuOutputOld, *GpuOutputNew;
CHK_IF_RETURN(cudaMalloc(&GpuScalars, sizeof(test_data) * NTT_SIZE));
CHK_IF_RETURN(cudaMalloc(&GpuOutputOld, sizeof(test_data) * NTT_SIZE));
CHK_IF_RETURN(cudaMalloc(&GpuOutputNew, sizeof(test_data) * NTT_SIZE));
// init inputs
incremental_values(CpuScalars.get(), NTT_SIZE);
CHK_IF_RETURN(cudaMemcpy(GpuScalars, CpuScalars.get(), NTT_SIZE, cudaMemcpyHostToDevice));
// inplace
if (INPLACE) {
CHK_IF_RETURN(cudaMemcpy(GpuOutputNew, GpuScalars, NTT_SIZE * sizeof(test_data), cudaMemcpyDeviceToDevice));
}
// run ntt
auto benchmark = [&](bool is_print, int iterations) -> cudaError_t {
// NEW
CHK_IF_RETURN(cudaEventRecord(new_start, ntt_config.ctx.stream));
ntt_config.is_force_radix2 = false; // mixed-radix ntt (a.k.a new ntt)
for (size_t i = 0; i < iterations; i++) {
ntt::NTT(
INPLACE ? GpuOutputNew : GpuScalars, NTT_SIZE, INV ? ntt::NTTDir::kInverse : ntt::NTTDir::kForward, ntt_config,
GpuOutputNew);
>>>>>>> main
}
CHK_IF_RETURN(cudaEventRecord(new_stop, ntt_config.ctx.stream));
CHK_IF_RETURN(cudaStreamSynchronize(ntt_config.ctx.stream));
@@ -174,16 +118,10 @@ int main(int argc, char** argv)
// OLD
CHK_IF_RETURN(cudaEventRecord(icicle_start, ntt_config.ctx.stream));
<<<<<<< HEAD
ntt_config.ntt_algorithm = ntt::NttAlgorithm::Radix2;
for (size_t i = 0; i < iterations; i++) {
CHK_IF_RETURN(
ntt::NTT(GpuScalars, NTT_SIZE, INV ? ntt::NTTDir::kInverse : ntt::NTTDir::kForward, ntt_config, GpuOutputOld));
=======
ntt_config.is_force_radix2 = true;
for (size_t i = 0; i < iterations; i++) {
ntt::NTT(GpuScalars, NTT_SIZE, INV ? ntt::NTTDir::kInverse : ntt::NTTDir::kForward, ntt_config, GpuOutputOld);
>>>>>>> main
}
CHK_IF_RETURN(cudaEventRecord(icicle_stop, ntt_config.ctx.stream));
CHK_IF_RETURN(cudaStreamSynchronize(ntt_config.ctx.stream));
@@ -201,17 +139,12 @@ int main(int argc, char** argv)
CHK_IF_RETURN(benchmark(false /*=print*/, 1)); // warmup
int count = INPLACE ? 1 : 10;
if (INPLACE) {
<<<<<<< HEAD
CHK_IF_RETURN(
cudaMemcpy(GpuOutputNew, GpuScalars, NTT_SIZE * BATCH_SIZE * sizeof(test_data), cudaMemcpyDeviceToDevice));
=======
CHK_IF_RETURN(cudaMemcpy(GpuOutputNew, GpuScalars, NTT_SIZE * sizeof(test_data), cudaMemcpyDeviceToDevice));
>>>>>>> main
}
CHK_IF_RETURN(benchmark(true /*=print*/, count));
// verify
<<<<<<< HEAD
CHK_IF_RETURN(
cudaMemcpy(CpuOutputNew.get(), GpuOutputNew, NTT_SIZE * BATCH_SIZE * sizeof(test_data), cudaMemcpyDeviceToHost));
CHK_IF_RETURN(
@@ -219,13 +152,6 @@ int main(int argc, char** argv)
bool success = true;
for (int i = 0; i < NTT_SIZE * BATCH_SIZE; i++) {
=======
CHK_IF_RETURN(cudaMemcpy(CpuOutputNew.get(), GpuOutputNew, NTT_SIZE * sizeof(test_data), cudaMemcpyDeviceToHost));
CHK_IF_RETURN(cudaMemcpy(CpuOutputOld.get(), GpuOutputOld, NTT_SIZE * sizeof(test_data), cudaMemcpyDeviceToHost));
bool success = true;
for (int i = 0; i < NTT_SIZE; i++) {
>>>>>>> main
if (CpuOutputNew[i] != CpuOutputOld[i]) {
success = false;
// std::cout << i << " ref " << CpuOutputOld[i] << " != " << CpuOutputNew[i] << std::endl;