feature: mixed-radix NTT fast twiddles mode (#382)

- this mode is allocating additional 4N twiddle-factors to achieve faster computation
- enabled by flag for initDomain(). Defaults to false.

Co-authored-by: hadaringonyama <hadar@ingonyama.com>
This commit is contained in:
yshekel
2024-02-22 00:02:02 +02:00
committed by GitHub
parent 4b221e9665
commit 275b2f4958
9 changed files with 455 additions and 109 deletions

View File

@@ -56,7 +56,7 @@ int main(int argc, char** argv)
CHK_IF_RETURN(cudaEventCreate(&stop));
const test_scalar basic_root = test_scalar::omega(NTT_LOG_SIZE);
ntt::InitDomain(basic_root, ntt_config.ctx);
ntt::InitDomain(basic_root, ntt_config.ctx, true /*=fast_twidddles_mode*/);
// (1) cpu allocation
auto CpuA = std::make_unique<test_data[]>(NTT_SIZE);

View File

@@ -6,12 +6,12 @@
namespace ntt {
static inline __device__ uint32_t dig_rev(uint32_t num, uint32_t log_size, bool dit)
static inline __device__ uint32_t dig_rev(uint32_t num, uint32_t log_size, bool dit, bool fast_tw)
{
uint32_t rev_num = 0, temp, dig_len;
if (dit) {
for (int i = 4; i >= 0; i--) {
dig_len = STAGE_SIZES_DEVICE[log_size][i];
dig_len = fast_tw ? STAGE_SIZES_DEVICE_FT[log_size][i] : STAGE_SIZES_DEVICE[log_size][i];
temp = num & ((1 << dig_len) - 1);
num = num >> dig_len;
rev_num = rev_num << dig_len;
@@ -19,7 +19,7 @@ namespace ntt {
}
} else {
for (int i = 0; i < 5; i++) {
dig_len = STAGE_SIZES_DEVICE[log_size][i];
dig_len = fast_tw ? STAGE_SIZES_DEVICE_FT[log_size][i] : STAGE_SIZES_DEVICE[log_size][i];
temp = num & ((1 << dig_len) - 1);
num = num >> dig_len;
rev_num = rev_num << dig_len;
@@ -33,18 +33,18 @@ namespace ntt {
enum eRevType { None, RevToMixedRev, MixedRevToRev, NaturalToMixedRev, NaturalToRev, MixedRevToNatural };
static __device__ uint32_t generalized_rev(uint32_t num, uint32_t log_size, bool dit, eRevType rev_type)
static __device__ uint32_t generalized_rev(uint32_t num, uint32_t log_size, bool dit, bool fast_tw, eRevType rev_type)
{
switch (rev_type) {
case eRevType::RevToMixedRev:
// R -> N -> MR
return dig_rev(bit_rev(num, log_size), log_size, dit);
return dig_rev(bit_rev(num, log_size), log_size, dit, fast_tw);
case eRevType::MixedRevToRev:
// MR -> N -> R
return bit_rev(dig_rev(num, log_size, dit), log_size);
return bit_rev(dig_rev(num, log_size, dit, fast_tw), log_size);
case eRevType::NaturalToMixedRev:
case eRevType::MixedRevToNatural:
return dig_rev(num, log_size, dit);
return dig_rev(num, log_size, dit, fast_tw);
case eRevType::NaturalToRev:
return bit_rev(num, log_size);
default:
@@ -56,7 +56,7 @@ namespace ntt {
// Note: the following reorder kernels are fused with normalization for INTT
template <typename E, typename S, uint32_t MAX_GROUP_SIZE = 80>
static __global__ void reorder_digits_inplace_and_normalize_kernel(
E* arr, uint32_t log_size, bool dit, eRevType rev_type, bool is_normalize, S inverse_N)
E* arr, uint32_t log_size, bool dit, bool fast_tw, eRevType rev_type, bool is_normalize, S inverse_N)
{
// launch N threads (per batch element)
// each thread starts from one index and calculates the corresponding group
@@ -74,7 +74,7 @@ namespace ntt {
uint32_t i = 1;
for (; i < MAX_GROUP_SIZE;) {
next_element = generalized_rev(next_element, log_size, dit, rev_type);
next_element = generalized_rev(next_element, log_size, dit, fast_tw, rev_type);
if (next_element < idx) return; // not handling this group
if (next_element == idx) break; // calculated whole group
group[i++] = next_element + size * batch_idx;
@@ -91,12 +91,19 @@ namespace ntt {
template <typename E, typename S>
__launch_bounds__(64) __global__ void reorder_digits_and_normalize_kernel(
E* arr, E* arr_reordered, uint32_t log_size, bool dit, eRevType rev_type, bool is_normalize, S inverse_N)
E* arr,
E* arr_reordered,
uint32_t log_size,
bool dit,
bool fast_tw,
eRevType rev_type,
bool is_normalize,
S inverse_N)
{
uint32_t tid = blockDim.x * blockIdx.x + threadIdx.x;
uint32_t rd = tid;
uint32_t wr =
((tid >> log_size) << log_size) + generalized_rev(tid & ((1 << log_size) - 1), log_size, dit, rev_type);
((tid >> log_size) << log_size) + generalized_rev(tid & ((1 << log_size) - 1), log_size, dit, fast_tw, rev_type);
arr_reordered[wr] = is_normalize ? arr[rd] * inverse_N : arr[rd];
}
@@ -116,7 +123,7 @@ namespace ntt {
int tid = blockDim.x * blockIdx.x + threadIdx.x;
if (tid >= n_elements * batch_size) return;
int64_t scalar_id = tid % n_elements;
if (rev_type != eRevType::None) scalar_id = generalized_rev(tid, logn, dit, rev_type);
if (rev_type != eRevType::None) scalar_id = generalized_rev(tid, logn, dit, false, rev_type);
out_vec[tid] = *(scalar_vec + ((scalar_id * step) % n_scalars)) * in_vec[tid];
}
@@ -136,7 +143,8 @@ namespace ntt {
bool strided,
uint32_t stage_num,
bool inv,
bool dit)
bool dit,
bool fast_tw)
{
NTTEngine<E, S> engine;
stage_metadata s_meta;
@@ -150,14 +158,23 @@ namespace ntt {
if (s_meta.ntt_block_id >= nof_ntt_blocks) return;
engine.loadBasicTwiddles(basic_twiddles, inv);
if (fast_tw)
engine.loadBasicTwiddles(basic_twiddles);
else
engine.loadBasicTwiddlesGeneric(basic_twiddles, inv);
engine.loadGlobalData(in, data_stride, log_data_stride, log_size, strided, s_meta);
if (twiddle_stride && dit) {
engine.loadExternalTwiddlesGeneric64(
external_twiddles, twiddle_stride, log_data_stride, s_meta, tw_log_size, inv);
if (fast_tw)
engine.loadExternalTwiddles64(external_twiddles, twiddle_stride, log_data_stride, strided, s_meta);
else
engine.loadExternalTwiddlesGeneric64(
external_twiddles, twiddle_stride, log_data_stride, s_meta, tw_log_size, inv);
engine.twiddlesExternal();
}
engine.loadInternalTwiddles64(internal_twiddles, strided, inv);
if (fast_tw)
engine.loadInternalTwiddles64(internal_twiddles, strided);
else
engine.loadInternalTwiddlesGeneric64(internal_twiddles, strided, inv);
#pragma unroll 1
for (uint32_t phase = 0; phase < 2; phase++) {
@@ -171,8 +188,11 @@ namespace ntt {
}
if (twiddle_stride && !dit) {
engine.loadExternalTwiddlesGeneric64(
external_twiddles, twiddle_stride, log_data_stride, s_meta, tw_log_size, inv);
if (fast_tw)
engine.loadExternalTwiddles64(external_twiddles, twiddle_stride, log_data_stride, strided, s_meta);
else
engine.loadExternalTwiddlesGeneric64(
external_twiddles, twiddle_stride, log_data_stride, s_meta, tw_log_size, inv);
engine.twiddlesExternal();
}
engine.storeGlobalData(out, data_stride, log_data_stride, log_size, strided, s_meta);
@@ -194,7 +214,8 @@ namespace ntt {
bool strided,
uint32_t stage_num,
bool inv,
bool dit)
bool dit,
bool fast_tw)
{
NTTEngine<E, S> engine;
stage_metadata s_meta;
@@ -209,9 +230,15 @@ namespace ntt {
if (s_meta.ntt_block_id >= nof_ntt_blocks) return;
engine.loadBasicTwiddles(basic_twiddles, inv);
if (fast_tw)
engine.loadBasicTwiddles(basic_twiddles);
else
engine.loadBasicTwiddlesGeneric(basic_twiddles, inv);
engine.loadGlobalData(in, data_stride, log_data_stride, log_size, strided, s_meta);
engine.loadInternalTwiddles32(internal_twiddles, strided, inv);
if (fast_tw)
engine.loadInternalTwiddles32(internal_twiddles, strided);
else
engine.loadInternalTwiddlesGeneric32(internal_twiddles, strided, inv);
engine.ntt8win();
engine.twiddlesInternal();
engine.SharedData32Columns8(shmem, true, false, strided); // store
@@ -219,8 +246,11 @@ namespace ntt {
engine.SharedData32Rows4_2(shmem, false, false, strided); // load
engine.ntt4_2();
if (twiddle_stride) {
engine.loadExternalTwiddlesGeneric32(
external_twiddles, twiddle_stride, log_data_stride, s_meta, tw_log_size, inv);
if (fast_tw)
engine.loadExternalTwiddles32(external_twiddles, twiddle_stride, log_data_stride, strided, s_meta);
else
engine.loadExternalTwiddlesGeneric32(
external_twiddles, twiddle_stride, log_data_stride, s_meta, tw_log_size, inv);
engine.twiddlesExternal();
}
engine.storeGlobalData32(out, data_stride, log_data_stride, log_size, strided, s_meta);
@@ -242,7 +272,8 @@ namespace ntt {
bool strided,
uint32_t stage_num,
bool inv,
bool dit)
bool dit,
bool fast_tw)
{
NTTEngine<E, S> engine;
stage_metadata s_meta;
@@ -257,14 +288,23 @@ namespace ntt {
if (s_meta.ntt_block_id >= nof_ntt_blocks) return;
engine.loadBasicTwiddles(basic_twiddles, inv);
if (fast_tw)
engine.loadBasicTwiddles(basic_twiddles);
else
engine.loadBasicTwiddlesGeneric(basic_twiddles, inv);
engine.loadGlobalData32(in, data_stride, log_data_stride, log_size, strided, s_meta);
if (twiddle_stride) {
engine.loadExternalTwiddlesGeneric32(
external_twiddles, twiddle_stride, log_data_stride, s_meta, tw_log_size, inv);
if (fast_tw)
engine.loadExternalTwiddles32(external_twiddles, twiddle_stride, log_data_stride, strided, s_meta);
else
engine.loadExternalTwiddlesGeneric32(
external_twiddles, twiddle_stride, log_data_stride, s_meta, tw_log_size, inv);
engine.twiddlesExternal();
}
engine.loadInternalTwiddles32(internal_twiddles, strided, inv);
if (fast_tw)
engine.loadInternalTwiddles32(internal_twiddles, strided);
else
engine.loadInternalTwiddlesGeneric32(internal_twiddles, strided, inv);
engine.ntt4_2();
engine.SharedData32Columns4_2(shmem, true, false, strided); // store
__syncthreads();
@@ -290,7 +330,8 @@ namespace ntt {
bool strided,
uint32_t stage_num,
bool inv,
bool dit)
bool dit,
bool fast_tw)
{
NTTEngine<E, S> engine;
stage_metadata s_meta;
@@ -305,9 +346,15 @@ namespace ntt {
if (s_meta.ntt_block_id >= nof_ntt_blocks) return;
engine.loadBasicTwiddles(basic_twiddles, inv);
if (fast_tw)
engine.loadBasicTwiddles(basic_twiddles);
else
engine.loadBasicTwiddlesGeneric(basic_twiddles, inv);
engine.loadGlobalData(in, data_stride, log_data_stride, log_size, strided, s_meta);
engine.loadInternalTwiddles16(internal_twiddles, strided, inv);
if (fast_tw)
engine.loadInternalTwiddles16(internal_twiddles, strided);
else
engine.loadInternalTwiddlesGeneric16(internal_twiddles, strided, inv);
engine.ntt8win();
engine.twiddlesInternal();
engine.SharedData16Columns8(shmem, true, false, strided); // store
@@ -315,8 +362,11 @@ namespace ntt {
engine.SharedData16Rows2_4(shmem, false, false, strided); // load
engine.ntt2_4();
if (twiddle_stride) {
engine.loadExternalTwiddlesGeneric16(
external_twiddles, twiddle_stride, log_data_stride, s_meta, tw_log_size, inv);
if (fast_tw)
engine.loadExternalTwiddles16(external_twiddles, twiddle_stride, log_data_stride, strided, s_meta);
else
engine.loadExternalTwiddlesGeneric16(
external_twiddles, twiddle_stride, log_data_stride, s_meta, tw_log_size, inv);
engine.twiddlesExternal();
}
engine.storeGlobalData16(out, data_stride, log_data_stride, log_size, strided, s_meta);
@@ -338,7 +388,8 @@ namespace ntt {
bool strided,
uint32_t stage_num,
bool inv,
bool dit)
bool dit,
bool fast_tw)
{
NTTEngine<E, S> engine;
stage_metadata s_meta;
@@ -353,14 +404,23 @@ namespace ntt {
if (s_meta.ntt_block_id >= nof_ntt_blocks) return;
engine.loadBasicTwiddles(basic_twiddles, inv);
if (fast_tw)
engine.loadBasicTwiddles(basic_twiddles);
else
engine.loadBasicTwiddlesGeneric(basic_twiddles, inv);
engine.loadGlobalData16(in, data_stride, log_data_stride, log_size, strided, s_meta);
if (twiddle_stride) {
engine.loadExternalTwiddlesGeneric16(
external_twiddles, twiddle_stride, log_data_stride, s_meta, tw_log_size, inv);
if (fast_tw)
engine.loadExternalTwiddles16(external_twiddles, twiddle_stride, log_data_stride, strided, s_meta);
else
engine.loadExternalTwiddlesGeneric16(
external_twiddles, twiddle_stride, log_data_stride, s_meta, tw_log_size, inv);
engine.twiddlesExternal();
}
engine.loadInternalTwiddles16(internal_twiddles, strided, inv);
if (fast_tw)
engine.loadInternalTwiddles16(internal_twiddles, strided);
else
engine.loadInternalTwiddlesGeneric16(internal_twiddles, strided, inv);
engine.ntt2_4();
engine.SharedData16Columns2_4(shmem, true, false, strided); // store
__syncthreads();
@@ -388,8 +448,9 @@ namespace ntt {
}
}
// Generic twiddles: 1N twiddles for forward and inverse NTT
template <typename S>
__global__ void generate_basic_twiddles(S basic_root, S* w6_table, S* basic_twiddles)
__global__ void generate_basic_twiddles_generic(S basic_root, S* w6_table, S* basic_twiddles)
{
S w0 = basic_root * basic_root;
S w1 = (basic_root + w0 * basic_root) * S::inv_log_size(1);
@@ -484,7 +545,7 @@ namespace ntt {
if (log_size > 2)
for (int i = 0; i < 3 - (log_size > 6 ? 0 : 6 - log_size); i++)
temp_root = temp_root * temp_root;
generate_basic_twiddles<<<1, 1, 0, stream>>>(temp_root, w6_table, basic_twiddles);
generate_basic_twiddles_generic<<<1, 1, 0, stream>>>(temp_root, w6_table, basic_twiddles);
const int NOF_BLOCKS = (log_size >= 8) ? (1 << (log_size - 8)) : 1;
const int NOF_THREADS = (log_size >= 8) ? 256 : (1 << log_size);
@@ -501,6 +562,100 @@ namespace ntt {
return CHK_LAST();
}
// Fast-twiddles: 2N twiddles for forward, 2N for inverse
template <typename S>
__global__ void generate_basic_twiddles_fast_twiddles_mode(S basic_root, S* basic_twiddles)
{
S w0 = basic_root * basic_root;
S w1 = (basic_root + w0 * basic_root) * S::inv_log_size(1);
S w2 = (basic_root - w0 * basic_root) * S::inv_log_size(1);
basic_twiddles[0] = w0;
basic_twiddles[1] = w1;
basic_twiddles[2] = w2;
}
template <typename S>
__global__ void generate_twiddle_combinations_fast_twiddles_mode(
S* w6_table,
S* w12_table,
S* w18_table,
S* w24_table,
S* w30_table,
S* external_twiddles,
uint32_t log_size,
uint32_t prev_log_size)
{
uint32_t tid = blockIdx.x * blockDim.x + threadIdx.x;
uint32_t exp = ((tid & ((1 << prev_log_size) - 1)) * (tid >> prev_log_size)) << (30 - log_size);
S w6, w12, w18, w24, w30;
w6 = w6_table[exp >> 24];
w12 = w12_table[((exp >> 18) & 0x3f)];
w18 = w18_table[((exp >> 12) & 0x3f)];
w24 = w24_table[((exp >> 6) & 0x3f)];
w30 = w30_table[(exp & 0x3f)];
S t = w6 * w12 * w18 * w24 * w30;
external_twiddles[tid + (1 << log_size) - 1] = t;
}
template <typename S>
cudaError_t generate_external_twiddles_fast_twiddles_mode(
const S& basic_root,
S* external_twiddles,
S*& internal_twiddles,
S*& basic_twiddles,
uint32_t log_size,
cudaStream_t& stream)
{
CHK_INIT_IF_RETURN();
S* w6_table;
S* w12_table;
S* w18_table;
S* w24_table;
S* w30_table;
CHK_IF_RETURN(cudaMallocAsync(&w6_table, sizeof(S) * 64, stream));
CHK_IF_RETURN(cudaMallocAsync(&w12_table, sizeof(S) * 64, stream));
CHK_IF_RETURN(cudaMallocAsync(&w18_table, sizeof(S) * 64, stream));
CHK_IF_RETURN(cudaMallocAsync(&w24_table, sizeof(S) * 64, stream));
CHK_IF_RETURN(cudaMallocAsync(&w30_table, sizeof(S) * 64, stream));
CHK_IF_RETURN(cudaMallocAsync(&basic_twiddles, 3 * sizeof(S), stream));
S temp_root = basic_root;
generate_base_table<<<1, 1, 0, stream>>>(basic_root, w30_table, 1 << (30 - log_size));
if (log_size > 24)
for (int i = 0; i < 6 - (30 - log_size); i++)
temp_root = temp_root * temp_root;
generate_base_table<<<1, 1, 0, stream>>>(temp_root, w24_table, 1 << (log_size > 24 ? 0 : 24 - log_size));
if (log_size > 18)
for (int i = 0; i < 6 - (log_size > 24 ? 0 : 24 - log_size); i++)
temp_root = temp_root * temp_root;
generate_base_table<<<1, 1, 0, stream>>>(temp_root, w18_table, 1 << (log_size > 18 ? 0 : 18 - log_size));
if (log_size > 12)
for (int i = 0; i < 6 - (log_size > 18 ? 0 : 18 - log_size); i++)
temp_root = temp_root * temp_root;
generate_base_table<<<1, 1, 0, stream>>>(temp_root, w12_table, 1 << (log_size > 12 ? 0 : 12 - log_size));
if (log_size > 6)
for (int i = 0; i < 6 - (log_size > 12 ? 0 : 12 - log_size); i++)
temp_root = temp_root * temp_root;
generate_base_table<<<1, 1, 0, stream>>>(temp_root, w6_table, 1 << (log_size > 6 ? 0 : 6 - log_size));
for (int i = 0; i < 3 - (log_size > 6 ? 0 : 6 - log_size); i++)
temp_root = temp_root * temp_root;
generate_basic_twiddles_fast_twiddles_mode<<<1, 1, 0, stream>>>(temp_root, basic_twiddles);
for (int i = 8; i < log_size + 1; i++) {
generate_twiddle_combinations_fast_twiddles_mode<<<1 << (i - 8), 256, 0, stream>>>(
w6_table, w12_table, w18_table, w24_table, w30_table, external_twiddles, i, STAGE_PREV_SIZES[i]);
}
internal_twiddles = w6_table;
CHK_IF_RETURN(cudaFreeAsync(w12_table, stream));
CHK_IF_RETURN(cudaFreeAsync(w18_table, stream));
CHK_IF_RETURN(cudaFreeAsync(w24_table, stream));
CHK_IF_RETURN(cudaFreeAsync(w30_table, stream));
return CHK_LAST();
}
template <typename E, typename S>
cudaError_t large_ntt(
E* in,
@@ -514,6 +669,7 @@ namespace ntt {
bool inv,
bool normalize,
bool dit,
bool fast_tw,
cudaStream_t cuda_stream)
{
CHK_INIT_IF_RETURN();
@@ -529,11 +685,11 @@ namespace ntt {
if (dit) {
ntt16dit<<<NOF_BLOCKS, NOF_THREADS, 8 * 64 * sizeof(E), cuda_stream>>>(
in, out, external_twiddles, internal_twiddles, basic_twiddles, log_size, tw_log_size, batch_size, 1, 0, 0,
false, 0, inv, dit);
false, 0, inv, dit, fast_tw);
} else { // dif
ntt16<<<NOF_BLOCKS, NOF_THREADS, 8 * 64 * sizeof(E), cuda_stream>>>(
in, out, external_twiddles, internal_twiddles, basic_twiddles, log_size, tw_log_size, batch_size, 1, 0, 0,
false, 0, inv, dit);
false, 0, inv, dit, fast_tw);
}
if (normalize) normalize_kernel<<<batch_size, 16, 0, cuda_stream>>>(out, S::inv_log_size(4));
return CHK_LAST();
@@ -545,11 +701,11 @@ namespace ntt {
if (dit) {
ntt32dit<<<NOF_BLOCKS, NOF_THREADS, 8 * 64 * sizeof(E), cuda_stream>>>(
in, out, external_twiddles, internal_twiddles, basic_twiddles, log_size, tw_log_size, batch_size, 1, 0, 0,
false, 0, inv, dit);
false, 0, inv, dit, fast_tw);
} else { // dif
ntt32<<<NOF_BLOCKS, NOF_THREADS, 8 * 64 * sizeof(E), cuda_stream>>>(
in, out, external_twiddles, internal_twiddles, basic_twiddles, log_size, tw_log_size, batch_size, 1, 0, 0,
false, 0, inv, dit);
false, 0, inv, dit, fast_tw);
}
if (normalize) normalize_kernel<<<batch_size, 32, 0, cuda_stream>>>(out, S::inv_log_size(5));
return CHK_LAST();
@@ -560,7 +716,7 @@ namespace ntt {
const int NOF_BLOCKS = (8 * batch_size + NOF_THREADS - 1) / NOF_THREADS;
ntt64<<<NOF_BLOCKS, NOF_THREADS, 8 * 64 * sizeof(E), cuda_stream>>>(
in, out, external_twiddles, internal_twiddles, basic_twiddles, log_size, tw_log_size, batch_size, 1, 0, 0,
false, 0, inv, dit);
false, 0, inv, dit, fast_tw);
if (normalize) normalize_kernel<<<batch_size, 64, 0, cuda_stream>>>(out, S::inv_log_size(6));
return CHK_LAST();
}
@@ -571,17 +727,17 @@ namespace ntt {
if (dit) {
ntt16dit<<<NOF_BLOCKS, NOF_THREADS, 8 * 64 * sizeof(E), cuda_stream>>>(
in, out, external_twiddles, internal_twiddles, basic_twiddles, log_size, tw_log_size,
(1 << log_size - 4) * batch_size, 1, 0, 0, false, 0, inv, dit);
(1 << log_size - 4) * batch_size, 1, 0, 0, false, 0, inv, dit, fast_tw);
ntt16dit<<<NOF_BLOCKS, NOF_THREADS, 8 * 64 * sizeof(E), cuda_stream>>>(
out, out, external_twiddles, internal_twiddles, basic_twiddles, log_size, tw_log_size,
(1 << log_size - 4) * batch_size, 16, 4, 16, true, 1, inv, dit);
(1 << log_size - 4) * batch_size, 16, 4, 16, true, 1, inv, dit, fast_tw);
} else { // dif
ntt16<<<NOF_BLOCKS, NOF_THREADS, 8 * 64 * sizeof(E), cuda_stream>>>(
in, out, external_twiddles, internal_twiddles, basic_twiddles, log_size, tw_log_size,
(1 << log_size - 4) * batch_size, 16, 4, 16, true, 1, inv, dit);
(1 << log_size - 4) * batch_size, 16, 4, 16, true, 1, inv, dit, fast_tw);
ntt16<<<NOF_BLOCKS, NOF_THREADS, 8 * 64 * sizeof(E), cuda_stream>>>(
out, out, external_twiddles, internal_twiddles, basic_twiddles, log_size, tw_log_size,
(1 << log_size - 4) * batch_size, 1, 0, 0, false, 0, inv, dit);
(1 << log_size - 4) * batch_size, 1, 0, 0, false, 0, inv, dit, fast_tw);
}
if (normalize) normalize_kernel<<<batch_size, 256, 0, cuda_stream>>>(out, S::inv_log_size(8));
return CHK_LAST();
@@ -591,43 +747,49 @@ namespace ntt {
uint32_t nof_blocks = (1 << (log_size - 9)) * batch_size;
if (dit) {
for (int i = 0; i < 5; i++) {
uint32_t stage_size = STAGE_SIZES_HOST[log_size][i];
uint32_t stage_size = fast_tw ? STAGE_SIZES_HOST_FT[log_size][i] : STAGE_SIZES_HOST[log_size][i];
uint32_t stride_log = 0;
for (int j = 0; j < i; j++)
stride_log += STAGE_SIZES_HOST[log_size][j];
stride_log += fast_tw ? STAGE_SIZES_HOST_FT[log_size][j] : STAGE_SIZES_HOST[log_size][j];
if (stage_size == 6)
ntt64<<<nof_blocks, 64, 8 * 64 * sizeof(E), cuda_stream>>>(
i ? out : in, out, external_twiddles, internal_twiddles, basic_twiddles, log_size, tw_log_size,
(1 << log_size - 6) * batch_size, 1 << stride_log, stride_log, i ? (1 << stride_log) : 0, i, i, inv, dit);
(1 << log_size - 6) * batch_size, 1 << stride_log, stride_log, i ? (1 << stride_log) : 0, i, i, inv, dit,
fast_tw);
else if (stage_size == 5)
ntt32dit<<<nof_blocks, 64, 8 * 64 * sizeof(E), cuda_stream>>>(
i ? out : in, out, external_twiddles, internal_twiddles, basic_twiddles, log_size, tw_log_size,
(1 << log_size - 5) * batch_size, 1 << stride_log, stride_log, i ? (1 << stride_log) : 0, i, i, inv, dit);
(1 << log_size - 5) * batch_size, 1 << stride_log, stride_log, i ? (1 << stride_log) : 0, i, i, inv, dit,
fast_tw);
else if (stage_size == 4)
ntt16dit<<<nof_blocks, 64, 8 * 64 * sizeof(E), cuda_stream>>>(
i ? out : in, out, external_twiddles, internal_twiddles, basic_twiddles, log_size, tw_log_size,
(1 << log_size - 4) * batch_size, 1 << stride_log, stride_log, i ? (1 << stride_log) : 0, i, i, inv, dit);
(1 << log_size - 4) * batch_size, 1 << stride_log, stride_log, i ? (1 << stride_log) : 0, i, i, inv, dit,
fast_tw);
}
} else { // dif
bool first_run = false, prev_stage = false;
for (int i = 4; i >= 0; i--) {
uint32_t stage_size = STAGE_SIZES_HOST[log_size][i];
uint32_t stage_size = fast_tw ? STAGE_SIZES_HOST_FT[log_size][i] : STAGE_SIZES_HOST[log_size][i];
uint32_t stride_log = 0;
for (int j = 0; j < i; j++)
stride_log += STAGE_SIZES_HOST[log_size][j];
stride_log += fast_tw ? STAGE_SIZES_HOST_FT[log_size][j] : STAGE_SIZES_HOST[log_size][j];
first_run = stage_size && !prev_stage;
if (stage_size == 6)
ntt64<<<nof_blocks, 64, 8 * 64 * sizeof(E), cuda_stream>>>(
first_run ? in : out, out, external_twiddles, internal_twiddles, basic_twiddles, log_size, tw_log_size,
(1 << log_size - 6) * batch_size, 1 << stride_log, stride_log, i ? (1 << stride_log) : 0, i, i, inv, dit);
(1 << log_size - 6) * batch_size, 1 << stride_log, stride_log, i ? (1 << stride_log) : 0, i, i, inv, dit,
fast_tw);
else if (stage_size == 5)
ntt32<<<nof_blocks, 64, 8 * 64 * sizeof(E), cuda_stream>>>(
first_run ? in : out, out, external_twiddles, internal_twiddles, basic_twiddles, log_size, tw_log_size,
(1 << log_size - 5) * batch_size, 1 << stride_log, stride_log, i ? (1 << stride_log) : 0, i, i, inv, dit);
(1 << log_size - 5) * batch_size, 1 << stride_log, stride_log, i ? (1 << stride_log) : 0, i, i, inv, dit,
fast_tw);
else if (stage_size == 4)
ntt16<<<nof_blocks, 64, 8 * 64 * sizeof(E), cuda_stream>>>(
first_run ? in : out, out, external_twiddles, internal_twiddles, basic_twiddles, log_size, tw_log_size,
(1 << log_size - 4) * batch_size, 1 << stride_log, stride_log, i ? (1 << stride_log) : 0, i, i, inv, dit);
(1 << log_size - 4) * batch_size, 1 << stride_log, stride_log, i ? (1 << stride_log) : 0, i, i, inv, dit,
fast_tw);
prev_stage = stage_size;
}
}
@@ -648,6 +810,7 @@ namespace ntt {
int max_logn,
int batch_size,
bool is_inverse,
bool fast_tw,
Ordering ordering,
S* arbitrary_coset,
int coset_gen_index,
@@ -706,10 +869,10 @@ namespace ntt {
const bool is_reverse_in_place = (d_input == d_output);
if (is_reverse_in_place) {
reorder_digits_inplace_and_normalize_kernel<<<NOF_BLOCKS, NOF_THREADS, 0, cuda_stream>>>(
d_output, logn, dit, reverse_input, is_normalize, S::inv_log_size(logn));
d_output, logn, dit, fast_tw, reverse_input, is_normalize, S::inv_log_size(logn));
} else {
reorder_digits_and_normalize_kernel<<<NOF_BLOCKS, NOF_THREADS, 0, cuda_stream>>>(
d_input, d_output, logn, dit, reverse_input, is_normalize, S::inv_log_size(logn));
d_input, d_output, logn, dit, fast_tw, reverse_input, is_normalize, S::inv_log_size(logn));
}
is_normalize = false;
d_input = d_output;
@@ -718,11 +881,11 @@ namespace ntt {
// inplace ntt
CHK_IF_RETURN(large_ntt(
d_input, d_output, external_twiddles, internal_twiddles, basic_twiddles, logn, max_logn, batch_size, is_inverse,
(is_normalize && reverse_output == eRevType::None), dit, cuda_stream));
(is_normalize && reverse_output == eRevType::None), dit, fast_tw, cuda_stream));
if (reverse_output != eRevType::None) {
reorder_digits_inplace_and_normalize_kernel<<<NOF_BLOCKS, NOF_THREADS, 0, cuda_stream>>>(
d_output, logn, dit, reverse_output, is_normalize, S::inv_log_size(logn));
d_output, logn, dit, fast_tw, reverse_output, is_normalize, S::inv_log_size(logn));
}
if (is_on_coset && is_inverse) {
@@ -743,6 +906,14 @@ namespace ntt {
uint32_t log_size,
cudaStream_t& stream);
template cudaError_t generate_external_twiddles_fast_twiddles_mode(
const curve_config::scalar_t& basic_root,
curve_config::scalar_t* external_twiddles,
curve_config::scalar_t*& internal_twiddles,
curve_config::scalar_t*& basic_twiddles,
uint32_t log_size,
cudaStream_t& stream);
template cudaError_t mixed_radix_ntt<curve_config::scalar_t, curve_config::scalar_t>(
curve_config::scalar_t* d_input,
curve_config::scalar_t* d_output,
@@ -753,6 +924,7 @@ namespace ntt {
int max_logn,
int batch_size,
bool is_inverse,
bool fast_tw,
Ordering ordering,
curve_config::scalar_t* arbitrary_coset,
int coset_gen_index,

View File

@@ -370,14 +370,23 @@ namespace ntt {
int max_size = 0;
int max_log_size = 0;
S* twiddles = nullptr;
bool initialized = false; // protection for multi-threaded case
std::unordered_map<S, int> coset_index = {};
S* internal_twiddles = nullptr; // required by mixed-radix NTT
S* basic_twiddles = nullptr; // required by mixed-radix NTT
// mixed-radix NTT supports a fast-twiddle option at the cost of additional 4N memory (where N is max NTT size)
S* fast_external_twiddles = nullptr; // required by mixed-radix NTT (fast-twiddles mode)
S* fast_internal_twiddles = nullptr; // required by mixed-radix NTT (fast-twiddles mode)
S* fast_basic_twiddles = nullptr; // required by mixed-radix NTT (fast-twiddles mode)
S* fast_external_twiddles_inv = nullptr; // required by mixed-radix NTT (fast-twiddles mode)
S* fast_internal_twiddles_inv = nullptr; // required by mixed-radix NTT (fast-twiddles mode)
S* fast_basic_twiddles_inv = nullptr; // required by mixed-radix NTT (fast-twiddles mode)
public:
template <typename U>
friend cudaError_t InitDomain<U>(U primitive_root, device_context::DeviceContext& ctx);
friend cudaError_t InitDomain<U>(U primitive_root, device_context::DeviceContext& ctx, bool fast_tw);
cudaError_t ReleaseDomain(device_context::DeviceContext& ctx);
@@ -389,7 +398,7 @@ namespace ntt {
static inline Domain<S> domains_for_devices[device_context::MAX_DEVICES] = {};
template <typename S>
cudaError_t InitDomain(S primitive_root, device_context::DeviceContext& ctx)
cudaError_t InitDomain(S primitive_root, device_context::DeviceContext& ctx, bool fast_twiddles_mode)
{
CHK_INIT_IF_RETURN();
@@ -399,11 +408,11 @@ namespace ntt {
// please note that this offers just basic thread-safety,
// it's assumed a singleton (non-enforced) that is supposed
// to be initialized once per device per program lifetime
if (!domain.twiddles) {
if (!domain.initialized) {
// Mutex is automatically released when lock goes out of scope, even in case of exceptions
std::lock_guard<std::mutex> lock(Domain<S>::device_domain_mutex);
// double check locking
if (domain.twiddles) return CHK_LAST(); // another thread is already initializing the domain
if (domain.initialized) return CHK_LAST(); // another thread is already initializing the domain
bool found_logn = false;
S omega = primitive_root;
@@ -430,6 +439,25 @@ namespace ntt {
CHK_IF_RETURN(generate_external_twiddles_generic(
primitive_root, domain.twiddles, domain.internal_twiddles, domain.basic_twiddles, domain.max_log_size,
ctx.stream));
if (fast_twiddles_mode) {
// generating fast-twiddles (note that this cost 4N additional memory)
CHK_IF_RETURN(cudaMallocAsync(&domain.fast_external_twiddles, domain.max_size * sizeof(S) * 2, ctx.stream));
CHK_IF_RETURN(cudaMallocAsync(&domain.fast_external_twiddles_inv, domain.max_size * sizeof(S) * 2, ctx.stream));
// fast-twiddles forward NTT
CHK_IF_RETURN(generate_external_twiddles_fast_twiddles_mode(
primitive_root, domain.fast_external_twiddles, domain.fast_internal_twiddles, domain.fast_basic_twiddles,
domain.max_log_size, ctx.stream));
// fast-twiddles inverse NTT
S primitive_root_inv;
CHK_IF_RETURN(cudaMemcpyAsync(
&primitive_root_inv, &domain.twiddles[domain.max_size - 1], sizeof(S), cudaMemcpyDeviceToHost, ctx.stream));
CHK_IF_RETURN(generate_external_twiddles_fast_twiddles_mode(
primitive_root_inv, domain.fast_external_twiddles_inv, domain.fast_internal_twiddles_inv,
domain.fast_basic_twiddles_inv, domain.max_log_size, ctx.stream));
}
CHK_IF_RETURN(cudaStreamSynchronize(ctx.stream));
const bool is_map_only_powers_of_primitive_root = true;
@@ -447,6 +475,7 @@ namespace ntt {
domain.coset_index[domain.twiddles[i]] = i;
}
}
domain.initialized = true;
}
return CHK_LAST();
@@ -467,6 +496,19 @@ namespace ntt {
basic_twiddles = nullptr;
coset_index.clear();
cudaFreeAsync(fast_external_twiddles, ctx.stream);
fast_external_twiddles = nullptr;
cudaFreeAsync(fast_internal_twiddles, ctx.stream);
fast_internal_twiddles = nullptr;
cudaFreeAsync(fast_basic_twiddles, ctx.stream);
fast_basic_twiddles = nullptr;
cudaFreeAsync(fast_external_twiddles_inv, ctx.stream);
fast_external_twiddles_inv = nullptr;
cudaFreeAsync(fast_internal_twiddles_inv, ctx.stream);
fast_internal_twiddles_inv = nullptr;
cudaFreeAsync(fast_basic_twiddles_inv, ctx.stream);
fast_basic_twiddles_inv = nullptr;
return CHK_LAST();
}
@@ -607,9 +649,21 @@ namespace ntt {
d_input, d_output, domain.twiddles, size, domain.max_size, batch_size, is_inverse, config.ordering, coset,
coset_index, stream));
} else {
const bool is_on_coset = (coset_index != 0) || coset;
const bool is_fast_twiddles_enabled = (domain.fast_external_twiddles != nullptr) && !is_on_coset;
S* twiddles = is_fast_twiddles_enabled
? (is_inverse ? domain.fast_external_twiddles_inv : domain.fast_external_twiddles)
: domain.twiddles;
S* internal_twiddles = is_fast_twiddles_enabled
? (is_inverse ? domain.fast_internal_twiddles_inv : domain.fast_internal_twiddles)
: domain.internal_twiddles;
S* basic_twiddles = is_fast_twiddles_enabled
? (is_inverse ? domain.fast_basic_twiddles_inv : domain.fast_basic_twiddles)
: domain.basic_twiddles;
CHK_IF_RETURN(ntt::mixed_radix_ntt(
d_input, d_output, domain.twiddles, domain.internal_twiddles, domain.basic_twiddles, size, domain.max_log_size,
batch_size, is_inverse, config.ordering, coset, coset_index, stream));
d_input, d_output, twiddles, internal_twiddles, basic_twiddles, size, domain.max_log_size, batch_size,
is_inverse, is_fast_twiddles_enabled, config.ordering, coset, coset_index, stream));
}
if (!are_outputs_on_device)
@@ -645,10 +699,10 @@ namespace ntt {
* value of template parameter (where the curve is given by `-DCURVE` env variable during build):
* - `S` is the [scalar field](@ref scalar_t) of the curve;
*/
extern "C" cudaError_t
CONCAT_EXPAND(CURVE, InitializeDomain)(curve_config::scalar_t primitive_root, device_context::DeviceContext& ctx)
extern "C" cudaError_t CONCAT_EXPAND(CURVE, InitializeDomain)(
curve_config::scalar_t primitive_root, device_context::DeviceContext& ctx, bool fast_twiddles_mode)
{
return InitDomain(primitive_root, ctx);
return InitDomain(primitive_root, ctx, fast_twiddles_mode);
}
/**

View File

@@ -32,10 +32,13 @@ namespace ntt {
* @param primitive_root Primitive root in field `S` of order \f$ 2^s \f$. This should be the smallest power-of-2
* order that's large enough to support any NTT you might want to perform.
* @param ctx Details related to the device such as its id and stream id.
* @param fast_twiddles_mode A mode where more memory is allocated for twiddle factors in exchange for faster compute.
* In this mode need additional 4N memory when N is the largest NTT size to be supported (which is derived by the
* primitive_root).
* @return `cudaSuccess` if the execution was successful and an error code otherwise.
*/
template <typename S>
cudaError_t InitDomain(S primitive_root, device_context::DeviceContext& ctx);
cudaError_t InitDomain(S primitive_root, device_context::DeviceContext& ctx, bool fast_twiddles_mode = false);
/**
* @enum NTTDir

View File

@@ -16,6 +16,15 @@ namespace ntt {
uint32_t log_size,
cudaStream_t& stream);
template <typename S>
cudaError_t generate_external_twiddles_fast_twiddles_mode(
const S& basic_root,
S* external_twiddles,
S*& internal_twiddles,
S*& basic_twiddles,
uint32_t log_size,
cudaStream_t& stream);
template <typename E, typename S>
cudaError_t mixed_radix_ntt(
E* d_input,
@@ -27,6 +36,7 @@ namespace ntt {
int max_logn,
int batch_size,
bool is_inverse,
bool fast_tw,
Ordering ordering,
S* arbitrary_coset,
int coset_gen_index,

View File

@@ -34,13 +34,14 @@ int main(int argc, char** argv)
cudaEvent_t icicle_start, icicle_stop, new_start, new_stop;
float icicle_time, new_time;
int NTT_LOG_SIZE = (argc > 1) ? atoi(argv[1]) : 4; // assuming second input is the log-size
int NTT_LOG_SIZE = (argc > 1) ? atoi(argv[1]) : 19;
int NTT_SIZE = 1 << NTT_LOG_SIZE;
bool INPLACE = (argc > 2) ? atoi(argv[2]) : false;
int INV = (argc > 3) ? atoi(argv[3]) : false;
int INV = (argc > 3) ? atoi(argv[3]) : true;
int BATCH_SIZE = (argc > 4) ? atoi(argv[4]) : 1;
int COSET_IDX = (argc > 5) ? atoi(argv[5]) : 0;
const ntt::Ordering ordering = (argc > 6) ? ntt::Ordering(atoi(argv[6])) : ntt::Ordering::kNN;
bool FAST_TW = (argc > 7) ? atoi(argv[7]) : true;
// Note: NM, MN are not expected to be equal when comparing mixed-radix and radix-2 NTTs
const char* ordering_str = ordering == ntt::Ordering::kNN ? "NN"
@@ -51,8 +52,8 @@ int main(int argc, char** argv)
: "MN";
printf(
"running ntt 2^%d, inplace=%d, inverse=%d, batch_size=%d, coset-idx=%d, ordering=%s\n", NTT_LOG_SIZE, INPLACE, INV,
BATCH_SIZE, COSET_IDX, ordering_str);
"running ntt 2^%d, inplace=%d, inverse=%d, batch_size=%d, coset-idx=%d, ordering=%s, fast_tw=%d\n", NTT_LOG_SIZE,
INPLACE, INV, BATCH_SIZE, COSET_IDX, ordering_str, FAST_TW);
CHK_IF_RETURN(cudaFree(nullptr)); // init GPU context (warmup)
@@ -70,7 +71,7 @@ int main(int argc, char** argv)
auto start = std::chrono::high_resolution_clock::now();
const test_scalar basic_root = test_scalar::omega(NTT_LOG_SIZE);
ntt::InitDomain(basic_root, ntt_config.ctx);
ntt::InitDomain(basic_root, ntt_config.ctx, FAST_TW);
auto stop = std::chrono::high_resolution_clock::now();
auto duration = std::chrono::duration_cast<std::chrono::microseconds>(stop - start).count();
std::cout << "initDomain took: " << duration / 1000 << " MS" << std::endl;

View File

@@ -13,19 +13,33 @@ struct stage_metadata {
uint32_t ntt_inp_id;
};
uint32_t constexpr STAGE_SIZES_HOST[31][5] = {
{0, 0, 0, 0, 0}, {0, 0, 0, 0, 0}, {0, 0, 0, 0, 0}, {0, 0, 0, 0, 0}, {4, 0, 0, 0, 0}, {5, 0, 0, 0, 0}, {6, 0, 0, 0, 0},
{0, 0, 0, 0, 0}, {4, 4, 0, 0, 0}, {5, 4, 0, 0, 0}, {5, 5, 0, 0, 0}, {6, 5, 0, 0, 0}, {6, 6, 0, 0, 0}, {4, 5, 4, 0, 0},
{4, 6, 4, 0, 0}, {5, 5, 5, 0, 0}, {6, 4, 6, 0, 0}, {6, 5, 6, 0, 0}, {6, 6, 6, 0, 0}, {6, 5, 4, 4, 0}, {5, 5, 5, 5, 0},
{6, 5, 5, 5, 0}, {6, 5, 5, 6, 0}, {6, 6, 6, 5, 0}, {6, 6, 6, 6, 0}, {5, 5, 5, 5, 5}, {6, 5, 4, 5, 6}, {6, 5, 5, 5, 6},
{6, 5, 6, 5, 6}, {6, 6, 5, 6, 6}, {6, 6, 6, 6, 6}};
#define STAGE_SIZES_DATA \
{ \
{0, 0, 0, 0, 0}, {0, 0, 0, 0, 0}, {0, 0, 0, 0, 0}, {0, 0, 0, 0, 0}, {4, 0, 0, 0, 0}, {5, 0, 0, 0, 0}, \
{6, 0, 0, 0, 0}, {0, 0, 0, 0, 0}, {4, 4, 0, 0, 0}, {5, 4, 0, 0, 0}, {5, 5, 0, 0, 0}, {6, 5, 0, 0, 0}, \
{6, 6, 0, 0, 0}, {4, 5, 4, 0, 0}, {4, 6, 4, 0, 0}, {5, 5, 5, 0, 0}, {6, 4, 6, 0, 0}, {6, 5, 6, 0, 0}, \
{6, 6, 6, 0, 0}, {6, 5, 4, 4, 0}, {5, 5, 5, 5, 0}, {6, 5, 5, 5, 0}, {6, 5, 5, 6, 0}, {6, 6, 6, 5, 0}, \
{6, 6, 6, 6, 0}, {5, 5, 5, 5, 5}, {6, 5, 4, 5, 6}, {6, 5, 5, 5, 6}, {6, 5, 6, 5, 6}, {6, 6, 5, 6, 6}, \
{6, 6, 6, 6, 6}, \
}
uint32_t constexpr STAGE_SIZES_HOST[31][5] = STAGE_SIZES_DATA;
__device__ constexpr uint32_t STAGE_SIZES_DEVICE[31][5] = STAGE_SIZES_DATA;
__device__ constexpr uint32_t STAGE_SIZES_DEVICE[31][5] = {
{0, 0, 0, 0, 0}, {0, 0, 0, 0, 0}, {0, 0, 0, 0, 0}, {0, 0, 0, 0, 0}, {4, 0, 0, 0, 0}, {5, 0, 0, 0, 0}, {6, 0, 0, 0, 0},
{0, 0, 0, 0, 0}, {4, 4, 0, 0, 0}, {5, 4, 0, 0, 0}, {5, 5, 0, 0, 0}, {6, 5, 0, 0, 0}, {6, 6, 0, 0, 0}, {4, 5, 4, 0, 0},
{4, 6, 4, 0, 0}, {5, 5, 5, 0, 0}, {6, 4, 6, 0, 0}, {6, 5, 6, 0, 0}, {6, 6, 6, 0, 0}, {6, 5, 4, 4, 0}, {5, 5, 5, 5, 0},
{6, 5, 5, 5, 0}, {6, 5, 5, 6, 0}, {6, 6, 6, 5, 0}, {6, 6, 6, 6, 0}, {5, 5, 5, 5, 5}, {6, 5, 4, 5, 6}, {6, 5, 5, 5, 6},
{6, 5, 6, 5, 6}, {6, 6, 5, 6, 6}, {6, 6, 6, 6, 6}};
// construction for fast-twiddles
uint32_t constexpr STAGE_PREV_SIZES[31] = {0, 0, 0, 0, 0, 0, 0, 0, 4, 5, 5, 6, 6, 9, 9, 10,
11, 11, 12, 15, 15, 16, 16, 18, 18, 20, 21, 21, 22, 23, 24};
#define STAGE_SIZES_DATA_FAST_TW \
{ \
{0, 0, 0, 0, 0}, {0, 0, 0, 0, 0}, {0, 0, 0, 0, 0}, {0, 0, 0, 0, 0}, {4, 0, 0, 0, 0}, {5, 0, 0, 0, 0}, \
{6, 0, 0, 0, 0}, {0, 0, 0, 0, 0}, {4, 4, 0, 0, 0}, {5, 4, 0, 0, 0}, {5, 5, 0, 0, 0}, {6, 5, 0, 0, 0}, \
{6, 6, 0, 0, 0}, {5, 4, 4, 0, 0}, {5, 4, 5, 0, 0}, {5, 5, 5, 0, 0}, {6, 5, 5, 0, 0}, {6, 5, 6, 0, 0}, \
{6, 6, 6, 0, 0}, {5, 5, 5, 4, 0}, {5, 5, 5, 5, 0}, {6, 5, 5, 5, 0}, {6, 5, 5, 6, 0}, {6, 6, 6, 5, 0}, \
{6, 6, 6, 6, 0}, {5, 5, 5, 5, 5}, {6, 5, 5, 5, 5}, {6, 5, 5, 5, 6}, {6, 5, 5, 6, 6}, {6, 6, 6, 5, 6}, \
{6, 6, 6, 6, 6}, \
}
uint32_t constexpr STAGE_SIZES_HOST_FT[31][5] = STAGE_SIZES_DATA_FAST_TW;
__device__ uint32_t constexpr STAGE_SIZES_DEVICE_FT[31][5] = STAGE_SIZES_DATA_FAST_TW;
template <typename E, typename S>
class NTTEngine
@@ -36,7 +50,15 @@ public:
S WI[7];
S WE[8];
__device__ __forceinline__ void loadBasicTwiddles(S* basic_twiddles, bool inv)
__device__ __forceinline__ void loadBasicTwiddles(S* basic_twiddles)
{
#pragma unroll
for (int i = 0; i < 3; i++) {
WB[i] = basic_twiddles[i];
}
}
__device__ __forceinline__ void loadBasicTwiddlesGeneric(S* basic_twiddles, bool inv)
{
#pragma unroll
for (int i = 0; i < 3; i++) {
@@ -44,7 +66,31 @@ public:
}
}
__device__ __forceinline__ void loadInternalTwiddles64(S* data, bool stride, bool inv)
__device__ __forceinline__ void loadInternalTwiddles64(S* data, bool stride)
{
#pragma unroll
for (int i = 0; i < 7; i++) {
WI[i] = data[((stride ? (threadIdx.x >> 3) : (threadIdx.x)) & 0x7) * (i + 1)];
}
}
__device__ __forceinline__ void loadInternalTwiddles32(S* data, bool stride)
{
#pragma unroll
for (int i = 0; i < 7; i++) {
WI[i] = data[2 * ((stride ? (threadIdx.x >> 4) : (threadIdx.x)) & 0x3) * (i + 1)];
}
}
__device__ __forceinline__ void loadInternalTwiddles16(S* data, bool stride)
{
#pragma unroll
for (int i = 0; i < 7; i++) {
WI[i] = data[4 * ((stride ? (threadIdx.x >> 5) : (threadIdx.x)) & 0x1) * (i + 1)];
}
}
__device__ __forceinline__ void loadInternalTwiddlesGeneric64(S* data, bool stride, bool inv)
{
#pragma unroll
for (int i = 0; i < 7; i++) {
@@ -53,7 +99,7 @@ public:
}
}
__device__ __forceinline__ void loadInternalTwiddles32(S* data, bool stride, bool inv)
__device__ __forceinline__ void loadInternalTwiddlesGeneric32(S* data, bool stride, bool inv)
{
#pragma unroll
for (int i = 0; i < 7; i++) {
@@ -62,7 +108,7 @@ public:
}
}
__device__ __forceinline__ void loadInternalTwiddles16(S* data, bool stride, bool inv)
__device__ __forceinline__ void loadInternalTwiddlesGeneric16(S* data, bool stride, bool inv)
{
#pragma unroll
for (int i = 0; i < 7; i++) {
@@ -71,8 +117,47 @@ public:
}
}
__device__ __forceinline__ void
loadExternalTwiddles64(S* data, uint32_t tw_order, uint32_t tw_log_order, bool strided, stage_metadata s_meta)
{
data += tw_order * s_meta.ntt_inp_id + (s_meta.ntt_block_id & (tw_order - 1));
#pragma unroll
for (uint32_t i = 0; i < 8; i++) {
WE[i] = data[8 * i * tw_order + (1 << tw_log_order + 6) - 1];
}
}
__device__ __forceinline__ void
loadExternalTwiddles32(S* data, uint32_t tw_order, uint32_t tw_log_order, bool strided, stage_metadata s_meta)
{
data += tw_order * s_meta.ntt_inp_id * 2 + (s_meta.ntt_block_id & (tw_order - 1));
#pragma unroll
for (uint32_t j = 0; j < 2; j++) {
#pragma unroll
for (uint32_t i = 0; i < 4; i++) {
WE[4 * j + i] = data[(8 * i + j) * tw_order + (1 << tw_log_order + 5) - 1];
}
}
}
__device__ __forceinline__ void
loadExternalTwiddles16(S* data, uint32_t tw_order, uint32_t tw_log_order, bool strided, stage_metadata s_meta)
{
data += tw_order * s_meta.ntt_inp_id * 4 + (s_meta.ntt_block_id & (tw_order - 1));
#pragma unroll
for (uint32_t j = 0; j < 4; j++) {
#pragma unroll
for (uint32_t i = 0; i < 2; i++) {
WE[2 * j + i] = data[(8 * i + j) * tw_order + (1 << tw_log_order + 4) - 1];
}
}
}
__device__ __forceinline__ void loadExternalTwiddlesGeneric64(
E* data, uint32_t tw_order, uint32_t tw_log_order, stage_metadata s_meta, uint32_t tw_log_size, bool inv)
S* data, uint32_t tw_order, uint32_t tw_log_order, stage_metadata s_meta, uint32_t tw_log_size, bool inv)
{
#pragma unroll
for (uint32_t i = 0; i < 8; i++) {
@@ -83,7 +168,7 @@ public:
}
__device__ __forceinline__ void loadExternalTwiddlesGeneric32(
E* data, uint32_t tw_order, uint32_t tw_log_order, stage_metadata s_meta, uint32_t tw_log_size, bool inv)
S* data, uint32_t tw_order, uint32_t tw_log_order, stage_metadata s_meta, uint32_t tw_log_size, bool inv)
{
#pragma unroll
for (uint32_t j = 0; j < 2; j++) {
@@ -97,7 +182,7 @@ public:
}
__device__ __forceinline__ void loadExternalTwiddlesGeneric16(
E* data, uint32_t tw_order, uint32_t tw_log_order, stage_metadata s_meta, uint32_t tw_log_size, bool inv)
S* data, uint32_t tw_order, uint32_t tw_log_order, stage_metadata s_meta, uint32_t tw_log_size, bool inv)
{
#pragma unroll
for (uint32_t j = 0; j < 4; j++) {

View File

@@ -119,6 +119,7 @@ pub trait NTT<F: FieldImpl> {
output: &mut HostOrDeviceSlice<F>,
) -> IcicleResult<()>;
fn initialize_domain(primitive_root: F, ctx: &DeviceContext) -> IcicleResult<()>;
fn initialize_domain_fast_twiddles_mode(primitive_root: F, ctx: &DeviceContext) -> IcicleResult<()>;
}
/// Computes the NTT, or a batch of several NTTs.
@@ -172,6 +173,13 @@ where
{
<<F as FieldImpl>::Config as NTT<F>>::initialize_domain(primitive_root, ctx)
}
pub fn initialize_domain_fast_twiddles_mode<F>(primitive_root: F, ctx: &DeviceContext) -> IcicleResult<()>
where
F: FieldImpl,
<F as FieldImpl>::Config: NTT<F>,
{
<<F as FieldImpl>::Config as NTT<F>>::initialize_domain_fast_twiddles_mode(primitive_root, ctx)
}
#[macro_export]
macro_rules! impl_ntt {
@@ -195,7 +203,11 @@ macro_rules! impl_ntt {
) -> CudaError;
#[link_name = concat!($field_prefix, "InitializeDomain")]
pub(crate) fn initialize_ntt_domain(primitive_root: $field, ctx: &DeviceContext) -> CudaError;
pub(crate) fn initialize_ntt_domain(
primitive_root: $field,
ctx: &DeviceContext,
fast_twiddles_mode: bool,
) -> CudaError;
}
}
@@ -219,7 +231,10 @@ macro_rules! impl_ntt {
}
fn initialize_domain(primitive_root: $field, ctx: &DeviceContext) -> IcicleResult<()> {
unsafe { $field_prefix_ident::initialize_ntt_domain(primitive_root, ctx).wrap() }
unsafe { $field_prefix_ident::initialize_ntt_domain(primitive_root, ctx, false).wrap() }
}
fn initialize_domain_fast_twiddles_mode(primitive_root: $field, ctx: &DeviceContext) -> IcicleResult<()> {
unsafe { $field_prefix_ident::initialize_ntt_domain(primitive_root, ctx, true).wrap() }
}
}
};
@@ -232,28 +247,29 @@ macro_rules! impl_ntt_tests {
) => {
const MAX_SIZE: u64 = 1 << 17;
static INIT: OnceLock<()> = OnceLock::new();
const FAST_TWIDDLES_MODE: bool = false;
#[test]
fn test_ntt() {
INIT.get_or_init(move || init_domain::<$field>(MAX_SIZE, DEFAULT_DEVICE_ID));
INIT.get_or_init(move || init_domain::<$field>(MAX_SIZE, DEFAULT_DEVICE_ID, FAST_TWIDDLES_MODE));
check_ntt::<$field>()
}
#[test]
fn test_ntt_coset_from_subgroup() {
INIT.get_or_init(move || init_domain::<$field>(MAX_SIZE, DEFAULT_DEVICE_ID));
INIT.get_or_init(move || init_domain::<$field>(MAX_SIZE, DEFAULT_DEVICE_ID, FAST_TWIDDLES_MODE));
check_ntt_coset_from_subgroup::<$field>()
}
#[test]
fn test_ntt_arbitrary_coset() {
INIT.get_or_init(move || init_domain::<$field>(MAX_SIZE, DEFAULT_DEVICE_ID));
INIT.get_or_init(move || init_domain::<$field>(MAX_SIZE, DEFAULT_DEVICE_ID, FAST_TWIDDLES_MODE));
check_ntt_arbitrary_coset::<$field>()
}
#[test]
fn test_ntt_batch() {
INIT.get_or_init(move || init_domain::<$field>(MAX_SIZE, DEFAULT_DEVICE_ID));
INIT.get_or_init(move || init_domain::<$field>(MAX_SIZE, DEFAULT_DEVICE_ID, FAST_TWIDDLES_MODE));
check_ntt_batch::<$field>()
}

View File

@@ -9,21 +9,25 @@ use rayon::iter::IntoParallelIterator;
use rayon::iter::ParallelIterator;
use crate::{
ntt::{initialize_domain, ntt, NTTDir, NttAlgorithm, Ordering},
ntt::{initialize_domain, initialize_domain_fast_twiddles_mode, ntt, NTTDir, NttAlgorithm, Ordering},
traits::{ArkConvertible, FieldImpl, GenerateRandom},
};
use super::NTTConfig;
use super::NTT;
pub fn init_domain<F: FieldImpl + ArkConvertible>(max_size: u64, device_id: usize)
pub fn init_domain<F: FieldImpl + ArkConvertible>(max_size: u64, device_id: usize, fast_twiddles_mode: bool)
where
F::ArkEquivalent: FftField,
<F as FieldImpl>::Config: NTT<F>,
{
let ctx = DeviceContext::default_for_device(device_id);
let ark_rou = F::ArkEquivalent::get_root_of_unity(max_size).unwrap();
initialize_domain(F::from_ark(ark_rou), &ctx).unwrap();
if fast_twiddles_mode {
initialize_domain_fast_twiddles_mode(F::from_ark(ark_rou), &ctx).unwrap();
} else {
initialize_domain(F::from_ark(ark_rou), &ctx).unwrap();
}
}
pub fn reverse_bit_order(n: u32, order: u32) -> u32 {
@@ -289,7 +293,8 @@ where
.into_par_iter()
.for_each(move |device_id| {
set_device(device_id).unwrap();
init_domain::<F>(1 << 16, device_id); // init domain per device
// if have more than one device, it will use fast-twiddles-mode (note that domain is reused per device if not released)
init_domain::<F>(1 << 16, device_id, true /*=fast twiddles mode*/); // init domain per device
let test_sizes = [1 << 4, 1 << 12];
let batch_sizes = [1, 1 << 4, 100];
for test_size in test_sizes {