mirror of
https://github.com/pseXperiments/icicle.git
synced 2026-01-08 23:17:54 -05:00
feature: mixed-radix NTT fast twiddles mode (#382)
- this mode is allocating additional 4N twiddle-factors to achieve faster computation - enabled by flag for initDomain(). Defaults to false. Co-authored-by: hadaringonyama <hadar@ingonyama.com>
This commit is contained in:
@@ -56,7 +56,7 @@ int main(int argc, char** argv)
|
||||
CHK_IF_RETURN(cudaEventCreate(&stop));
|
||||
|
||||
const test_scalar basic_root = test_scalar::omega(NTT_LOG_SIZE);
|
||||
ntt::InitDomain(basic_root, ntt_config.ctx);
|
||||
ntt::InitDomain(basic_root, ntt_config.ctx, true /*=fast_twidddles_mode*/);
|
||||
|
||||
// (1) cpu allocation
|
||||
auto CpuA = std::make_unique<test_data[]>(NTT_SIZE);
|
||||
|
||||
@@ -6,12 +6,12 @@
|
||||
|
||||
namespace ntt {
|
||||
|
||||
static inline __device__ uint32_t dig_rev(uint32_t num, uint32_t log_size, bool dit)
|
||||
static inline __device__ uint32_t dig_rev(uint32_t num, uint32_t log_size, bool dit, bool fast_tw)
|
||||
{
|
||||
uint32_t rev_num = 0, temp, dig_len;
|
||||
if (dit) {
|
||||
for (int i = 4; i >= 0; i--) {
|
||||
dig_len = STAGE_SIZES_DEVICE[log_size][i];
|
||||
dig_len = fast_tw ? STAGE_SIZES_DEVICE_FT[log_size][i] : STAGE_SIZES_DEVICE[log_size][i];
|
||||
temp = num & ((1 << dig_len) - 1);
|
||||
num = num >> dig_len;
|
||||
rev_num = rev_num << dig_len;
|
||||
@@ -19,7 +19,7 @@ namespace ntt {
|
||||
}
|
||||
} else {
|
||||
for (int i = 0; i < 5; i++) {
|
||||
dig_len = STAGE_SIZES_DEVICE[log_size][i];
|
||||
dig_len = fast_tw ? STAGE_SIZES_DEVICE_FT[log_size][i] : STAGE_SIZES_DEVICE[log_size][i];
|
||||
temp = num & ((1 << dig_len) - 1);
|
||||
num = num >> dig_len;
|
||||
rev_num = rev_num << dig_len;
|
||||
@@ -33,18 +33,18 @@ namespace ntt {
|
||||
|
||||
enum eRevType { None, RevToMixedRev, MixedRevToRev, NaturalToMixedRev, NaturalToRev, MixedRevToNatural };
|
||||
|
||||
static __device__ uint32_t generalized_rev(uint32_t num, uint32_t log_size, bool dit, eRevType rev_type)
|
||||
static __device__ uint32_t generalized_rev(uint32_t num, uint32_t log_size, bool dit, bool fast_tw, eRevType rev_type)
|
||||
{
|
||||
switch (rev_type) {
|
||||
case eRevType::RevToMixedRev:
|
||||
// R -> N -> MR
|
||||
return dig_rev(bit_rev(num, log_size), log_size, dit);
|
||||
return dig_rev(bit_rev(num, log_size), log_size, dit, fast_tw);
|
||||
case eRevType::MixedRevToRev:
|
||||
// MR -> N -> R
|
||||
return bit_rev(dig_rev(num, log_size, dit), log_size);
|
||||
return bit_rev(dig_rev(num, log_size, dit, fast_tw), log_size);
|
||||
case eRevType::NaturalToMixedRev:
|
||||
case eRevType::MixedRevToNatural:
|
||||
return dig_rev(num, log_size, dit);
|
||||
return dig_rev(num, log_size, dit, fast_tw);
|
||||
case eRevType::NaturalToRev:
|
||||
return bit_rev(num, log_size);
|
||||
default:
|
||||
@@ -56,7 +56,7 @@ namespace ntt {
|
||||
// Note: the following reorder kernels are fused with normalization for INTT
|
||||
template <typename E, typename S, uint32_t MAX_GROUP_SIZE = 80>
|
||||
static __global__ void reorder_digits_inplace_and_normalize_kernel(
|
||||
E* arr, uint32_t log_size, bool dit, eRevType rev_type, bool is_normalize, S inverse_N)
|
||||
E* arr, uint32_t log_size, bool dit, bool fast_tw, eRevType rev_type, bool is_normalize, S inverse_N)
|
||||
{
|
||||
// launch N threads (per batch element)
|
||||
// each thread starts from one index and calculates the corresponding group
|
||||
@@ -74,7 +74,7 @@ namespace ntt {
|
||||
|
||||
uint32_t i = 1;
|
||||
for (; i < MAX_GROUP_SIZE;) {
|
||||
next_element = generalized_rev(next_element, log_size, dit, rev_type);
|
||||
next_element = generalized_rev(next_element, log_size, dit, fast_tw, rev_type);
|
||||
if (next_element < idx) return; // not handling this group
|
||||
if (next_element == idx) break; // calculated whole group
|
||||
group[i++] = next_element + size * batch_idx;
|
||||
@@ -91,12 +91,19 @@ namespace ntt {
|
||||
|
||||
template <typename E, typename S>
|
||||
__launch_bounds__(64) __global__ void reorder_digits_and_normalize_kernel(
|
||||
E* arr, E* arr_reordered, uint32_t log_size, bool dit, eRevType rev_type, bool is_normalize, S inverse_N)
|
||||
E* arr,
|
||||
E* arr_reordered,
|
||||
uint32_t log_size,
|
||||
bool dit,
|
||||
bool fast_tw,
|
||||
eRevType rev_type,
|
||||
bool is_normalize,
|
||||
S inverse_N)
|
||||
{
|
||||
uint32_t tid = blockDim.x * blockIdx.x + threadIdx.x;
|
||||
uint32_t rd = tid;
|
||||
uint32_t wr =
|
||||
((tid >> log_size) << log_size) + generalized_rev(tid & ((1 << log_size) - 1), log_size, dit, rev_type);
|
||||
((tid >> log_size) << log_size) + generalized_rev(tid & ((1 << log_size) - 1), log_size, dit, fast_tw, rev_type);
|
||||
arr_reordered[wr] = is_normalize ? arr[rd] * inverse_N : arr[rd];
|
||||
}
|
||||
|
||||
@@ -116,7 +123,7 @@ namespace ntt {
|
||||
int tid = blockDim.x * blockIdx.x + threadIdx.x;
|
||||
if (tid >= n_elements * batch_size) return;
|
||||
int64_t scalar_id = tid % n_elements;
|
||||
if (rev_type != eRevType::None) scalar_id = generalized_rev(tid, logn, dit, rev_type);
|
||||
if (rev_type != eRevType::None) scalar_id = generalized_rev(tid, logn, dit, false, rev_type);
|
||||
out_vec[tid] = *(scalar_vec + ((scalar_id * step) % n_scalars)) * in_vec[tid];
|
||||
}
|
||||
|
||||
@@ -136,7 +143,8 @@ namespace ntt {
|
||||
bool strided,
|
||||
uint32_t stage_num,
|
||||
bool inv,
|
||||
bool dit)
|
||||
bool dit,
|
||||
bool fast_tw)
|
||||
{
|
||||
NTTEngine<E, S> engine;
|
||||
stage_metadata s_meta;
|
||||
@@ -150,14 +158,23 @@ namespace ntt {
|
||||
|
||||
if (s_meta.ntt_block_id >= nof_ntt_blocks) return;
|
||||
|
||||
engine.loadBasicTwiddles(basic_twiddles, inv);
|
||||
if (fast_tw)
|
||||
engine.loadBasicTwiddles(basic_twiddles);
|
||||
else
|
||||
engine.loadBasicTwiddlesGeneric(basic_twiddles, inv);
|
||||
engine.loadGlobalData(in, data_stride, log_data_stride, log_size, strided, s_meta);
|
||||
if (twiddle_stride && dit) {
|
||||
engine.loadExternalTwiddlesGeneric64(
|
||||
external_twiddles, twiddle_stride, log_data_stride, s_meta, tw_log_size, inv);
|
||||
if (fast_tw)
|
||||
engine.loadExternalTwiddles64(external_twiddles, twiddle_stride, log_data_stride, strided, s_meta);
|
||||
else
|
||||
engine.loadExternalTwiddlesGeneric64(
|
||||
external_twiddles, twiddle_stride, log_data_stride, s_meta, tw_log_size, inv);
|
||||
engine.twiddlesExternal();
|
||||
}
|
||||
engine.loadInternalTwiddles64(internal_twiddles, strided, inv);
|
||||
if (fast_tw)
|
||||
engine.loadInternalTwiddles64(internal_twiddles, strided);
|
||||
else
|
||||
engine.loadInternalTwiddlesGeneric64(internal_twiddles, strided, inv);
|
||||
|
||||
#pragma unroll 1
|
||||
for (uint32_t phase = 0; phase < 2; phase++) {
|
||||
@@ -171,8 +188,11 @@ namespace ntt {
|
||||
}
|
||||
|
||||
if (twiddle_stride && !dit) {
|
||||
engine.loadExternalTwiddlesGeneric64(
|
||||
external_twiddles, twiddle_stride, log_data_stride, s_meta, tw_log_size, inv);
|
||||
if (fast_tw)
|
||||
engine.loadExternalTwiddles64(external_twiddles, twiddle_stride, log_data_stride, strided, s_meta);
|
||||
else
|
||||
engine.loadExternalTwiddlesGeneric64(
|
||||
external_twiddles, twiddle_stride, log_data_stride, s_meta, tw_log_size, inv);
|
||||
engine.twiddlesExternal();
|
||||
}
|
||||
engine.storeGlobalData(out, data_stride, log_data_stride, log_size, strided, s_meta);
|
||||
@@ -194,7 +214,8 @@ namespace ntt {
|
||||
bool strided,
|
||||
uint32_t stage_num,
|
||||
bool inv,
|
||||
bool dit)
|
||||
bool dit,
|
||||
bool fast_tw)
|
||||
{
|
||||
NTTEngine<E, S> engine;
|
||||
stage_metadata s_meta;
|
||||
@@ -209,9 +230,15 @@ namespace ntt {
|
||||
|
||||
if (s_meta.ntt_block_id >= nof_ntt_blocks) return;
|
||||
|
||||
engine.loadBasicTwiddles(basic_twiddles, inv);
|
||||
if (fast_tw)
|
||||
engine.loadBasicTwiddles(basic_twiddles);
|
||||
else
|
||||
engine.loadBasicTwiddlesGeneric(basic_twiddles, inv);
|
||||
engine.loadGlobalData(in, data_stride, log_data_stride, log_size, strided, s_meta);
|
||||
engine.loadInternalTwiddles32(internal_twiddles, strided, inv);
|
||||
if (fast_tw)
|
||||
engine.loadInternalTwiddles32(internal_twiddles, strided);
|
||||
else
|
||||
engine.loadInternalTwiddlesGeneric32(internal_twiddles, strided, inv);
|
||||
engine.ntt8win();
|
||||
engine.twiddlesInternal();
|
||||
engine.SharedData32Columns8(shmem, true, false, strided); // store
|
||||
@@ -219,8 +246,11 @@ namespace ntt {
|
||||
engine.SharedData32Rows4_2(shmem, false, false, strided); // load
|
||||
engine.ntt4_2();
|
||||
if (twiddle_stride) {
|
||||
engine.loadExternalTwiddlesGeneric32(
|
||||
external_twiddles, twiddle_stride, log_data_stride, s_meta, tw_log_size, inv);
|
||||
if (fast_tw)
|
||||
engine.loadExternalTwiddles32(external_twiddles, twiddle_stride, log_data_stride, strided, s_meta);
|
||||
else
|
||||
engine.loadExternalTwiddlesGeneric32(
|
||||
external_twiddles, twiddle_stride, log_data_stride, s_meta, tw_log_size, inv);
|
||||
engine.twiddlesExternal();
|
||||
}
|
||||
engine.storeGlobalData32(out, data_stride, log_data_stride, log_size, strided, s_meta);
|
||||
@@ -242,7 +272,8 @@ namespace ntt {
|
||||
bool strided,
|
||||
uint32_t stage_num,
|
||||
bool inv,
|
||||
bool dit)
|
||||
bool dit,
|
||||
bool fast_tw)
|
||||
{
|
||||
NTTEngine<E, S> engine;
|
||||
stage_metadata s_meta;
|
||||
@@ -257,14 +288,23 @@ namespace ntt {
|
||||
|
||||
if (s_meta.ntt_block_id >= nof_ntt_blocks) return;
|
||||
|
||||
engine.loadBasicTwiddles(basic_twiddles, inv);
|
||||
if (fast_tw)
|
||||
engine.loadBasicTwiddles(basic_twiddles);
|
||||
else
|
||||
engine.loadBasicTwiddlesGeneric(basic_twiddles, inv);
|
||||
engine.loadGlobalData32(in, data_stride, log_data_stride, log_size, strided, s_meta);
|
||||
if (twiddle_stride) {
|
||||
engine.loadExternalTwiddlesGeneric32(
|
||||
external_twiddles, twiddle_stride, log_data_stride, s_meta, tw_log_size, inv);
|
||||
if (fast_tw)
|
||||
engine.loadExternalTwiddles32(external_twiddles, twiddle_stride, log_data_stride, strided, s_meta);
|
||||
else
|
||||
engine.loadExternalTwiddlesGeneric32(
|
||||
external_twiddles, twiddle_stride, log_data_stride, s_meta, tw_log_size, inv);
|
||||
engine.twiddlesExternal();
|
||||
}
|
||||
engine.loadInternalTwiddles32(internal_twiddles, strided, inv);
|
||||
if (fast_tw)
|
||||
engine.loadInternalTwiddles32(internal_twiddles, strided);
|
||||
else
|
||||
engine.loadInternalTwiddlesGeneric32(internal_twiddles, strided, inv);
|
||||
engine.ntt4_2();
|
||||
engine.SharedData32Columns4_2(shmem, true, false, strided); // store
|
||||
__syncthreads();
|
||||
@@ -290,7 +330,8 @@ namespace ntt {
|
||||
bool strided,
|
||||
uint32_t stage_num,
|
||||
bool inv,
|
||||
bool dit)
|
||||
bool dit,
|
||||
bool fast_tw)
|
||||
{
|
||||
NTTEngine<E, S> engine;
|
||||
stage_metadata s_meta;
|
||||
@@ -305,9 +346,15 @@ namespace ntt {
|
||||
|
||||
if (s_meta.ntt_block_id >= nof_ntt_blocks) return;
|
||||
|
||||
engine.loadBasicTwiddles(basic_twiddles, inv);
|
||||
if (fast_tw)
|
||||
engine.loadBasicTwiddles(basic_twiddles);
|
||||
else
|
||||
engine.loadBasicTwiddlesGeneric(basic_twiddles, inv);
|
||||
engine.loadGlobalData(in, data_stride, log_data_stride, log_size, strided, s_meta);
|
||||
engine.loadInternalTwiddles16(internal_twiddles, strided, inv);
|
||||
if (fast_tw)
|
||||
engine.loadInternalTwiddles16(internal_twiddles, strided);
|
||||
else
|
||||
engine.loadInternalTwiddlesGeneric16(internal_twiddles, strided, inv);
|
||||
engine.ntt8win();
|
||||
engine.twiddlesInternal();
|
||||
engine.SharedData16Columns8(shmem, true, false, strided); // store
|
||||
@@ -315,8 +362,11 @@ namespace ntt {
|
||||
engine.SharedData16Rows2_4(shmem, false, false, strided); // load
|
||||
engine.ntt2_4();
|
||||
if (twiddle_stride) {
|
||||
engine.loadExternalTwiddlesGeneric16(
|
||||
external_twiddles, twiddle_stride, log_data_stride, s_meta, tw_log_size, inv);
|
||||
if (fast_tw)
|
||||
engine.loadExternalTwiddles16(external_twiddles, twiddle_stride, log_data_stride, strided, s_meta);
|
||||
else
|
||||
engine.loadExternalTwiddlesGeneric16(
|
||||
external_twiddles, twiddle_stride, log_data_stride, s_meta, tw_log_size, inv);
|
||||
engine.twiddlesExternal();
|
||||
}
|
||||
engine.storeGlobalData16(out, data_stride, log_data_stride, log_size, strided, s_meta);
|
||||
@@ -338,7 +388,8 @@ namespace ntt {
|
||||
bool strided,
|
||||
uint32_t stage_num,
|
||||
bool inv,
|
||||
bool dit)
|
||||
bool dit,
|
||||
bool fast_tw)
|
||||
{
|
||||
NTTEngine<E, S> engine;
|
||||
stage_metadata s_meta;
|
||||
@@ -353,14 +404,23 @@ namespace ntt {
|
||||
|
||||
if (s_meta.ntt_block_id >= nof_ntt_blocks) return;
|
||||
|
||||
engine.loadBasicTwiddles(basic_twiddles, inv);
|
||||
if (fast_tw)
|
||||
engine.loadBasicTwiddles(basic_twiddles);
|
||||
else
|
||||
engine.loadBasicTwiddlesGeneric(basic_twiddles, inv);
|
||||
engine.loadGlobalData16(in, data_stride, log_data_stride, log_size, strided, s_meta);
|
||||
if (twiddle_stride) {
|
||||
engine.loadExternalTwiddlesGeneric16(
|
||||
external_twiddles, twiddle_stride, log_data_stride, s_meta, tw_log_size, inv);
|
||||
if (fast_tw)
|
||||
engine.loadExternalTwiddles16(external_twiddles, twiddle_stride, log_data_stride, strided, s_meta);
|
||||
else
|
||||
engine.loadExternalTwiddlesGeneric16(
|
||||
external_twiddles, twiddle_stride, log_data_stride, s_meta, tw_log_size, inv);
|
||||
engine.twiddlesExternal();
|
||||
}
|
||||
engine.loadInternalTwiddles16(internal_twiddles, strided, inv);
|
||||
if (fast_tw)
|
||||
engine.loadInternalTwiddles16(internal_twiddles, strided);
|
||||
else
|
||||
engine.loadInternalTwiddlesGeneric16(internal_twiddles, strided, inv);
|
||||
engine.ntt2_4();
|
||||
engine.SharedData16Columns2_4(shmem, true, false, strided); // store
|
||||
__syncthreads();
|
||||
@@ -388,8 +448,9 @@ namespace ntt {
|
||||
}
|
||||
}
|
||||
|
||||
// Generic twiddles: 1N twiddles for forward and inverse NTT
|
||||
template <typename S>
|
||||
__global__ void generate_basic_twiddles(S basic_root, S* w6_table, S* basic_twiddles)
|
||||
__global__ void generate_basic_twiddles_generic(S basic_root, S* w6_table, S* basic_twiddles)
|
||||
{
|
||||
S w0 = basic_root * basic_root;
|
||||
S w1 = (basic_root + w0 * basic_root) * S::inv_log_size(1);
|
||||
@@ -484,7 +545,7 @@ namespace ntt {
|
||||
if (log_size > 2)
|
||||
for (int i = 0; i < 3 - (log_size > 6 ? 0 : 6 - log_size); i++)
|
||||
temp_root = temp_root * temp_root;
|
||||
generate_basic_twiddles<<<1, 1, 0, stream>>>(temp_root, w6_table, basic_twiddles);
|
||||
generate_basic_twiddles_generic<<<1, 1, 0, stream>>>(temp_root, w6_table, basic_twiddles);
|
||||
|
||||
const int NOF_BLOCKS = (log_size >= 8) ? (1 << (log_size - 8)) : 1;
|
||||
const int NOF_THREADS = (log_size >= 8) ? 256 : (1 << log_size);
|
||||
@@ -501,6 +562,100 @@ namespace ntt {
|
||||
return CHK_LAST();
|
||||
}
|
||||
|
||||
// Fast-twiddles: 2N twiddles for forward, 2N for inverse
|
||||
template <typename S>
|
||||
__global__ void generate_basic_twiddles_fast_twiddles_mode(S basic_root, S* basic_twiddles)
|
||||
{
|
||||
S w0 = basic_root * basic_root;
|
||||
S w1 = (basic_root + w0 * basic_root) * S::inv_log_size(1);
|
||||
S w2 = (basic_root - w0 * basic_root) * S::inv_log_size(1);
|
||||
basic_twiddles[0] = w0;
|
||||
basic_twiddles[1] = w1;
|
||||
basic_twiddles[2] = w2;
|
||||
}
|
||||
|
||||
template <typename S>
|
||||
__global__ void generate_twiddle_combinations_fast_twiddles_mode(
|
||||
S* w6_table,
|
||||
S* w12_table,
|
||||
S* w18_table,
|
||||
S* w24_table,
|
||||
S* w30_table,
|
||||
S* external_twiddles,
|
||||
uint32_t log_size,
|
||||
uint32_t prev_log_size)
|
||||
{
|
||||
uint32_t tid = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
uint32_t exp = ((tid & ((1 << prev_log_size) - 1)) * (tid >> prev_log_size)) << (30 - log_size);
|
||||
S w6, w12, w18, w24, w30;
|
||||
w6 = w6_table[exp >> 24];
|
||||
w12 = w12_table[((exp >> 18) & 0x3f)];
|
||||
w18 = w18_table[((exp >> 12) & 0x3f)];
|
||||
w24 = w24_table[((exp >> 6) & 0x3f)];
|
||||
w30 = w30_table[(exp & 0x3f)];
|
||||
S t = w6 * w12 * w18 * w24 * w30;
|
||||
external_twiddles[tid + (1 << log_size) - 1] = t;
|
||||
}
|
||||
|
||||
template <typename S>
|
||||
cudaError_t generate_external_twiddles_fast_twiddles_mode(
|
||||
const S& basic_root,
|
||||
S* external_twiddles,
|
||||
S*& internal_twiddles,
|
||||
S*& basic_twiddles,
|
||||
uint32_t log_size,
|
||||
cudaStream_t& stream)
|
||||
{
|
||||
CHK_INIT_IF_RETURN();
|
||||
|
||||
S* w6_table;
|
||||
S* w12_table;
|
||||
S* w18_table;
|
||||
S* w24_table;
|
||||
S* w30_table;
|
||||
CHK_IF_RETURN(cudaMallocAsync(&w6_table, sizeof(S) * 64, stream));
|
||||
CHK_IF_RETURN(cudaMallocAsync(&w12_table, sizeof(S) * 64, stream));
|
||||
CHK_IF_RETURN(cudaMallocAsync(&w18_table, sizeof(S) * 64, stream));
|
||||
CHK_IF_RETURN(cudaMallocAsync(&w24_table, sizeof(S) * 64, stream));
|
||||
CHK_IF_RETURN(cudaMallocAsync(&w30_table, sizeof(S) * 64, stream));
|
||||
CHK_IF_RETURN(cudaMallocAsync(&basic_twiddles, 3 * sizeof(S), stream));
|
||||
|
||||
S temp_root = basic_root;
|
||||
generate_base_table<<<1, 1, 0, stream>>>(basic_root, w30_table, 1 << (30 - log_size));
|
||||
if (log_size > 24)
|
||||
for (int i = 0; i < 6 - (30 - log_size); i++)
|
||||
temp_root = temp_root * temp_root;
|
||||
generate_base_table<<<1, 1, 0, stream>>>(temp_root, w24_table, 1 << (log_size > 24 ? 0 : 24 - log_size));
|
||||
if (log_size > 18)
|
||||
for (int i = 0; i < 6 - (log_size > 24 ? 0 : 24 - log_size); i++)
|
||||
temp_root = temp_root * temp_root;
|
||||
generate_base_table<<<1, 1, 0, stream>>>(temp_root, w18_table, 1 << (log_size > 18 ? 0 : 18 - log_size));
|
||||
if (log_size > 12)
|
||||
for (int i = 0; i < 6 - (log_size > 18 ? 0 : 18 - log_size); i++)
|
||||
temp_root = temp_root * temp_root;
|
||||
generate_base_table<<<1, 1, 0, stream>>>(temp_root, w12_table, 1 << (log_size > 12 ? 0 : 12 - log_size));
|
||||
if (log_size > 6)
|
||||
for (int i = 0; i < 6 - (log_size > 12 ? 0 : 12 - log_size); i++)
|
||||
temp_root = temp_root * temp_root;
|
||||
generate_base_table<<<1, 1, 0, stream>>>(temp_root, w6_table, 1 << (log_size > 6 ? 0 : 6 - log_size));
|
||||
for (int i = 0; i < 3 - (log_size > 6 ? 0 : 6 - log_size); i++)
|
||||
temp_root = temp_root * temp_root;
|
||||
generate_basic_twiddles_fast_twiddles_mode<<<1, 1, 0, stream>>>(temp_root, basic_twiddles);
|
||||
|
||||
for (int i = 8; i < log_size + 1; i++) {
|
||||
generate_twiddle_combinations_fast_twiddles_mode<<<1 << (i - 8), 256, 0, stream>>>(
|
||||
w6_table, w12_table, w18_table, w24_table, w30_table, external_twiddles, i, STAGE_PREV_SIZES[i]);
|
||||
}
|
||||
internal_twiddles = w6_table;
|
||||
|
||||
CHK_IF_RETURN(cudaFreeAsync(w12_table, stream));
|
||||
CHK_IF_RETURN(cudaFreeAsync(w18_table, stream));
|
||||
CHK_IF_RETURN(cudaFreeAsync(w24_table, stream));
|
||||
CHK_IF_RETURN(cudaFreeAsync(w30_table, stream));
|
||||
|
||||
return CHK_LAST();
|
||||
}
|
||||
|
||||
template <typename E, typename S>
|
||||
cudaError_t large_ntt(
|
||||
E* in,
|
||||
@@ -514,6 +669,7 @@ namespace ntt {
|
||||
bool inv,
|
||||
bool normalize,
|
||||
bool dit,
|
||||
bool fast_tw,
|
||||
cudaStream_t cuda_stream)
|
||||
{
|
||||
CHK_INIT_IF_RETURN();
|
||||
@@ -529,11 +685,11 @@ namespace ntt {
|
||||
if (dit) {
|
||||
ntt16dit<<<NOF_BLOCKS, NOF_THREADS, 8 * 64 * sizeof(E), cuda_stream>>>(
|
||||
in, out, external_twiddles, internal_twiddles, basic_twiddles, log_size, tw_log_size, batch_size, 1, 0, 0,
|
||||
false, 0, inv, dit);
|
||||
false, 0, inv, dit, fast_tw);
|
||||
} else { // dif
|
||||
ntt16<<<NOF_BLOCKS, NOF_THREADS, 8 * 64 * sizeof(E), cuda_stream>>>(
|
||||
in, out, external_twiddles, internal_twiddles, basic_twiddles, log_size, tw_log_size, batch_size, 1, 0, 0,
|
||||
false, 0, inv, dit);
|
||||
false, 0, inv, dit, fast_tw);
|
||||
}
|
||||
if (normalize) normalize_kernel<<<batch_size, 16, 0, cuda_stream>>>(out, S::inv_log_size(4));
|
||||
return CHK_LAST();
|
||||
@@ -545,11 +701,11 @@ namespace ntt {
|
||||
if (dit) {
|
||||
ntt32dit<<<NOF_BLOCKS, NOF_THREADS, 8 * 64 * sizeof(E), cuda_stream>>>(
|
||||
in, out, external_twiddles, internal_twiddles, basic_twiddles, log_size, tw_log_size, batch_size, 1, 0, 0,
|
||||
false, 0, inv, dit);
|
||||
false, 0, inv, dit, fast_tw);
|
||||
} else { // dif
|
||||
ntt32<<<NOF_BLOCKS, NOF_THREADS, 8 * 64 * sizeof(E), cuda_stream>>>(
|
||||
in, out, external_twiddles, internal_twiddles, basic_twiddles, log_size, tw_log_size, batch_size, 1, 0, 0,
|
||||
false, 0, inv, dit);
|
||||
false, 0, inv, dit, fast_tw);
|
||||
}
|
||||
if (normalize) normalize_kernel<<<batch_size, 32, 0, cuda_stream>>>(out, S::inv_log_size(5));
|
||||
return CHK_LAST();
|
||||
@@ -560,7 +716,7 @@ namespace ntt {
|
||||
const int NOF_BLOCKS = (8 * batch_size + NOF_THREADS - 1) / NOF_THREADS;
|
||||
ntt64<<<NOF_BLOCKS, NOF_THREADS, 8 * 64 * sizeof(E), cuda_stream>>>(
|
||||
in, out, external_twiddles, internal_twiddles, basic_twiddles, log_size, tw_log_size, batch_size, 1, 0, 0,
|
||||
false, 0, inv, dit);
|
||||
false, 0, inv, dit, fast_tw);
|
||||
if (normalize) normalize_kernel<<<batch_size, 64, 0, cuda_stream>>>(out, S::inv_log_size(6));
|
||||
return CHK_LAST();
|
||||
}
|
||||
@@ -571,17 +727,17 @@ namespace ntt {
|
||||
if (dit) {
|
||||
ntt16dit<<<NOF_BLOCKS, NOF_THREADS, 8 * 64 * sizeof(E), cuda_stream>>>(
|
||||
in, out, external_twiddles, internal_twiddles, basic_twiddles, log_size, tw_log_size,
|
||||
(1 << log_size - 4) * batch_size, 1, 0, 0, false, 0, inv, dit);
|
||||
(1 << log_size - 4) * batch_size, 1, 0, 0, false, 0, inv, dit, fast_tw);
|
||||
ntt16dit<<<NOF_BLOCKS, NOF_THREADS, 8 * 64 * sizeof(E), cuda_stream>>>(
|
||||
out, out, external_twiddles, internal_twiddles, basic_twiddles, log_size, tw_log_size,
|
||||
(1 << log_size - 4) * batch_size, 16, 4, 16, true, 1, inv, dit);
|
||||
(1 << log_size - 4) * batch_size, 16, 4, 16, true, 1, inv, dit, fast_tw);
|
||||
} else { // dif
|
||||
ntt16<<<NOF_BLOCKS, NOF_THREADS, 8 * 64 * sizeof(E), cuda_stream>>>(
|
||||
in, out, external_twiddles, internal_twiddles, basic_twiddles, log_size, tw_log_size,
|
||||
(1 << log_size - 4) * batch_size, 16, 4, 16, true, 1, inv, dit);
|
||||
(1 << log_size - 4) * batch_size, 16, 4, 16, true, 1, inv, dit, fast_tw);
|
||||
ntt16<<<NOF_BLOCKS, NOF_THREADS, 8 * 64 * sizeof(E), cuda_stream>>>(
|
||||
out, out, external_twiddles, internal_twiddles, basic_twiddles, log_size, tw_log_size,
|
||||
(1 << log_size - 4) * batch_size, 1, 0, 0, false, 0, inv, dit);
|
||||
(1 << log_size - 4) * batch_size, 1, 0, 0, false, 0, inv, dit, fast_tw);
|
||||
}
|
||||
if (normalize) normalize_kernel<<<batch_size, 256, 0, cuda_stream>>>(out, S::inv_log_size(8));
|
||||
return CHK_LAST();
|
||||
@@ -591,43 +747,49 @@ namespace ntt {
|
||||
uint32_t nof_blocks = (1 << (log_size - 9)) * batch_size;
|
||||
if (dit) {
|
||||
for (int i = 0; i < 5; i++) {
|
||||
uint32_t stage_size = STAGE_SIZES_HOST[log_size][i];
|
||||
uint32_t stage_size = fast_tw ? STAGE_SIZES_HOST_FT[log_size][i] : STAGE_SIZES_HOST[log_size][i];
|
||||
uint32_t stride_log = 0;
|
||||
for (int j = 0; j < i; j++)
|
||||
stride_log += STAGE_SIZES_HOST[log_size][j];
|
||||
stride_log += fast_tw ? STAGE_SIZES_HOST_FT[log_size][j] : STAGE_SIZES_HOST[log_size][j];
|
||||
if (stage_size == 6)
|
||||
ntt64<<<nof_blocks, 64, 8 * 64 * sizeof(E), cuda_stream>>>(
|
||||
i ? out : in, out, external_twiddles, internal_twiddles, basic_twiddles, log_size, tw_log_size,
|
||||
(1 << log_size - 6) * batch_size, 1 << stride_log, stride_log, i ? (1 << stride_log) : 0, i, i, inv, dit);
|
||||
(1 << log_size - 6) * batch_size, 1 << stride_log, stride_log, i ? (1 << stride_log) : 0, i, i, inv, dit,
|
||||
fast_tw);
|
||||
else if (stage_size == 5)
|
||||
ntt32dit<<<nof_blocks, 64, 8 * 64 * sizeof(E), cuda_stream>>>(
|
||||
i ? out : in, out, external_twiddles, internal_twiddles, basic_twiddles, log_size, tw_log_size,
|
||||
(1 << log_size - 5) * batch_size, 1 << stride_log, stride_log, i ? (1 << stride_log) : 0, i, i, inv, dit);
|
||||
(1 << log_size - 5) * batch_size, 1 << stride_log, stride_log, i ? (1 << stride_log) : 0, i, i, inv, dit,
|
||||
fast_tw);
|
||||
else if (stage_size == 4)
|
||||
ntt16dit<<<nof_blocks, 64, 8 * 64 * sizeof(E), cuda_stream>>>(
|
||||
i ? out : in, out, external_twiddles, internal_twiddles, basic_twiddles, log_size, tw_log_size,
|
||||
(1 << log_size - 4) * batch_size, 1 << stride_log, stride_log, i ? (1 << stride_log) : 0, i, i, inv, dit);
|
||||
(1 << log_size - 4) * batch_size, 1 << stride_log, stride_log, i ? (1 << stride_log) : 0, i, i, inv, dit,
|
||||
fast_tw);
|
||||
}
|
||||
} else { // dif
|
||||
bool first_run = false, prev_stage = false;
|
||||
for (int i = 4; i >= 0; i--) {
|
||||
uint32_t stage_size = STAGE_SIZES_HOST[log_size][i];
|
||||
uint32_t stage_size = fast_tw ? STAGE_SIZES_HOST_FT[log_size][i] : STAGE_SIZES_HOST[log_size][i];
|
||||
uint32_t stride_log = 0;
|
||||
for (int j = 0; j < i; j++)
|
||||
stride_log += STAGE_SIZES_HOST[log_size][j];
|
||||
stride_log += fast_tw ? STAGE_SIZES_HOST_FT[log_size][j] : STAGE_SIZES_HOST[log_size][j];
|
||||
first_run = stage_size && !prev_stage;
|
||||
if (stage_size == 6)
|
||||
ntt64<<<nof_blocks, 64, 8 * 64 * sizeof(E), cuda_stream>>>(
|
||||
first_run ? in : out, out, external_twiddles, internal_twiddles, basic_twiddles, log_size, tw_log_size,
|
||||
(1 << log_size - 6) * batch_size, 1 << stride_log, stride_log, i ? (1 << stride_log) : 0, i, i, inv, dit);
|
||||
(1 << log_size - 6) * batch_size, 1 << stride_log, stride_log, i ? (1 << stride_log) : 0, i, i, inv, dit,
|
||||
fast_tw);
|
||||
else if (stage_size == 5)
|
||||
ntt32<<<nof_blocks, 64, 8 * 64 * sizeof(E), cuda_stream>>>(
|
||||
first_run ? in : out, out, external_twiddles, internal_twiddles, basic_twiddles, log_size, tw_log_size,
|
||||
(1 << log_size - 5) * batch_size, 1 << stride_log, stride_log, i ? (1 << stride_log) : 0, i, i, inv, dit);
|
||||
(1 << log_size - 5) * batch_size, 1 << stride_log, stride_log, i ? (1 << stride_log) : 0, i, i, inv, dit,
|
||||
fast_tw);
|
||||
else if (stage_size == 4)
|
||||
ntt16<<<nof_blocks, 64, 8 * 64 * sizeof(E), cuda_stream>>>(
|
||||
first_run ? in : out, out, external_twiddles, internal_twiddles, basic_twiddles, log_size, tw_log_size,
|
||||
(1 << log_size - 4) * batch_size, 1 << stride_log, stride_log, i ? (1 << stride_log) : 0, i, i, inv, dit);
|
||||
(1 << log_size - 4) * batch_size, 1 << stride_log, stride_log, i ? (1 << stride_log) : 0, i, i, inv, dit,
|
||||
fast_tw);
|
||||
prev_stage = stage_size;
|
||||
}
|
||||
}
|
||||
@@ -648,6 +810,7 @@ namespace ntt {
|
||||
int max_logn,
|
||||
int batch_size,
|
||||
bool is_inverse,
|
||||
bool fast_tw,
|
||||
Ordering ordering,
|
||||
S* arbitrary_coset,
|
||||
int coset_gen_index,
|
||||
@@ -706,10 +869,10 @@ namespace ntt {
|
||||
const bool is_reverse_in_place = (d_input == d_output);
|
||||
if (is_reverse_in_place) {
|
||||
reorder_digits_inplace_and_normalize_kernel<<<NOF_BLOCKS, NOF_THREADS, 0, cuda_stream>>>(
|
||||
d_output, logn, dit, reverse_input, is_normalize, S::inv_log_size(logn));
|
||||
d_output, logn, dit, fast_tw, reverse_input, is_normalize, S::inv_log_size(logn));
|
||||
} else {
|
||||
reorder_digits_and_normalize_kernel<<<NOF_BLOCKS, NOF_THREADS, 0, cuda_stream>>>(
|
||||
d_input, d_output, logn, dit, reverse_input, is_normalize, S::inv_log_size(logn));
|
||||
d_input, d_output, logn, dit, fast_tw, reverse_input, is_normalize, S::inv_log_size(logn));
|
||||
}
|
||||
is_normalize = false;
|
||||
d_input = d_output;
|
||||
@@ -718,11 +881,11 @@ namespace ntt {
|
||||
// inplace ntt
|
||||
CHK_IF_RETURN(large_ntt(
|
||||
d_input, d_output, external_twiddles, internal_twiddles, basic_twiddles, logn, max_logn, batch_size, is_inverse,
|
||||
(is_normalize && reverse_output == eRevType::None), dit, cuda_stream));
|
||||
(is_normalize && reverse_output == eRevType::None), dit, fast_tw, cuda_stream));
|
||||
|
||||
if (reverse_output != eRevType::None) {
|
||||
reorder_digits_inplace_and_normalize_kernel<<<NOF_BLOCKS, NOF_THREADS, 0, cuda_stream>>>(
|
||||
d_output, logn, dit, reverse_output, is_normalize, S::inv_log_size(logn));
|
||||
d_output, logn, dit, fast_tw, reverse_output, is_normalize, S::inv_log_size(logn));
|
||||
}
|
||||
|
||||
if (is_on_coset && is_inverse) {
|
||||
@@ -743,6 +906,14 @@ namespace ntt {
|
||||
uint32_t log_size,
|
||||
cudaStream_t& stream);
|
||||
|
||||
template cudaError_t generate_external_twiddles_fast_twiddles_mode(
|
||||
const curve_config::scalar_t& basic_root,
|
||||
curve_config::scalar_t* external_twiddles,
|
||||
curve_config::scalar_t*& internal_twiddles,
|
||||
curve_config::scalar_t*& basic_twiddles,
|
||||
uint32_t log_size,
|
||||
cudaStream_t& stream);
|
||||
|
||||
template cudaError_t mixed_radix_ntt<curve_config::scalar_t, curve_config::scalar_t>(
|
||||
curve_config::scalar_t* d_input,
|
||||
curve_config::scalar_t* d_output,
|
||||
@@ -753,6 +924,7 @@ namespace ntt {
|
||||
int max_logn,
|
||||
int batch_size,
|
||||
bool is_inverse,
|
||||
bool fast_tw,
|
||||
Ordering ordering,
|
||||
curve_config::scalar_t* arbitrary_coset,
|
||||
int coset_gen_index,
|
||||
|
||||
@@ -370,14 +370,23 @@ namespace ntt {
|
||||
int max_size = 0;
|
||||
int max_log_size = 0;
|
||||
S* twiddles = nullptr;
|
||||
bool initialized = false; // protection for multi-threaded case
|
||||
std::unordered_map<S, int> coset_index = {};
|
||||
|
||||
S* internal_twiddles = nullptr; // required by mixed-radix NTT
|
||||
S* basic_twiddles = nullptr; // required by mixed-radix NTT
|
||||
|
||||
// mixed-radix NTT supports a fast-twiddle option at the cost of additional 4N memory (where N is max NTT size)
|
||||
S* fast_external_twiddles = nullptr; // required by mixed-radix NTT (fast-twiddles mode)
|
||||
S* fast_internal_twiddles = nullptr; // required by mixed-radix NTT (fast-twiddles mode)
|
||||
S* fast_basic_twiddles = nullptr; // required by mixed-radix NTT (fast-twiddles mode)
|
||||
S* fast_external_twiddles_inv = nullptr; // required by mixed-radix NTT (fast-twiddles mode)
|
||||
S* fast_internal_twiddles_inv = nullptr; // required by mixed-radix NTT (fast-twiddles mode)
|
||||
S* fast_basic_twiddles_inv = nullptr; // required by mixed-radix NTT (fast-twiddles mode)
|
||||
|
||||
public:
|
||||
template <typename U>
|
||||
friend cudaError_t InitDomain<U>(U primitive_root, device_context::DeviceContext& ctx);
|
||||
friend cudaError_t InitDomain<U>(U primitive_root, device_context::DeviceContext& ctx, bool fast_tw);
|
||||
|
||||
cudaError_t ReleaseDomain(device_context::DeviceContext& ctx);
|
||||
|
||||
@@ -389,7 +398,7 @@ namespace ntt {
|
||||
static inline Domain<S> domains_for_devices[device_context::MAX_DEVICES] = {};
|
||||
|
||||
template <typename S>
|
||||
cudaError_t InitDomain(S primitive_root, device_context::DeviceContext& ctx)
|
||||
cudaError_t InitDomain(S primitive_root, device_context::DeviceContext& ctx, bool fast_twiddles_mode)
|
||||
{
|
||||
CHK_INIT_IF_RETURN();
|
||||
|
||||
@@ -399,11 +408,11 @@ namespace ntt {
|
||||
// please note that this offers just basic thread-safety,
|
||||
// it's assumed a singleton (non-enforced) that is supposed
|
||||
// to be initialized once per device per program lifetime
|
||||
if (!domain.twiddles) {
|
||||
if (!domain.initialized) {
|
||||
// Mutex is automatically released when lock goes out of scope, even in case of exceptions
|
||||
std::lock_guard<std::mutex> lock(Domain<S>::device_domain_mutex);
|
||||
// double check locking
|
||||
if (domain.twiddles) return CHK_LAST(); // another thread is already initializing the domain
|
||||
if (domain.initialized) return CHK_LAST(); // another thread is already initializing the domain
|
||||
|
||||
bool found_logn = false;
|
||||
S omega = primitive_root;
|
||||
@@ -430,6 +439,25 @@ namespace ntt {
|
||||
CHK_IF_RETURN(generate_external_twiddles_generic(
|
||||
primitive_root, domain.twiddles, domain.internal_twiddles, domain.basic_twiddles, domain.max_log_size,
|
||||
ctx.stream));
|
||||
|
||||
if (fast_twiddles_mode) {
|
||||
// generating fast-twiddles (note that this cost 4N additional memory)
|
||||
CHK_IF_RETURN(cudaMallocAsync(&domain.fast_external_twiddles, domain.max_size * sizeof(S) * 2, ctx.stream));
|
||||
CHK_IF_RETURN(cudaMallocAsync(&domain.fast_external_twiddles_inv, domain.max_size * sizeof(S) * 2, ctx.stream));
|
||||
|
||||
// fast-twiddles forward NTT
|
||||
CHK_IF_RETURN(generate_external_twiddles_fast_twiddles_mode(
|
||||
primitive_root, domain.fast_external_twiddles, domain.fast_internal_twiddles, domain.fast_basic_twiddles,
|
||||
domain.max_log_size, ctx.stream));
|
||||
|
||||
// fast-twiddles inverse NTT
|
||||
S primitive_root_inv;
|
||||
CHK_IF_RETURN(cudaMemcpyAsync(
|
||||
&primitive_root_inv, &domain.twiddles[domain.max_size - 1], sizeof(S), cudaMemcpyDeviceToHost, ctx.stream));
|
||||
CHK_IF_RETURN(generate_external_twiddles_fast_twiddles_mode(
|
||||
primitive_root_inv, domain.fast_external_twiddles_inv, domain.fast_internal_twiddles_inv,
|
||||
domain.fast_basic_twiddles_inv, domain.max_log_size, ctx.stream));
|
||||
}
|
||||
CHK_IF_RETURN(cudaStreamSynchronize(ctx.stream));
|
||||
|
||||
const bool is_map_only_powers_of_primitive_root = true;
|
||||
@@ -447,6 +475,7 @@ namespace ntt {
|
||||
domain.coset_index[domain.twiddles[i]] = i;
|
||||
}
|
||||
}
|
||||
domain.initialized = true;
|
||||
}
|
||||
|
||||
return CHK_LAST();
|
||||
@@ -467,6 +496,19 @@ namespace ntt {
|
||||
basic_twiddles = nullptr;
|
||||
coset_index.clear();
|
||||
|
||||
cudaFreeAsync(fast_external_twiddles, ctx.stream);
|
||||
fast_external_twiddles = nullptr;
|
||||
cudaFreeAsync(fast_internal_twiddles, ctx.stream);
|
||||
fast_internal_twiddles = nullptr;
|
||||
cudaFreeAsync(fast_basic_twiddles, ctx.stream);
|
||||
fast_basic_twiddles = nullptr;
|
||||
cudaFreeAsync(fast_external_twiddles_inv, ctx.stream);
|
||||
fast_external_twiddles_inv = nullptr;
|
||||
cudaFreeAsync(fast_internal_twiddles_inv, ctx.stream);
|
||||
fast_internal_twiddles_inv = nullptr;
|
||||
cudaFreeAsync(fast_basic_twiddles_inv, ctx.stream);
|
||||
fast_basic_twiddles_inv = nullptr;
|
||||
|
||||
return CHK_LAST();
|
||||
}
|
||||
|
||||
@@ -607,9 +649,21 @@ namespace ntt {
|
||||
d_input, d_output, domain.twiddles, size, domain.max_size, batch_size, is_inverse, config.ordering, coset,
|
||||
coset_index, stream));
|
||||
} else {
|
||||
const bool is_on_coset = (coset_index != 0) || coset;
|
||||
const bool is_fast_twiddles_enabled = (domain.fast_external_twiddles != nullptr) && !is_on_coset;
|
||||
S* twiddles = is_fast_twiddles_enabled
|
||||
? (is_inverse ? domain.fast_external_twiddles_inv : domain.fast_external_twiddles)
|
||||
: domain.twiddles;
|
||||
S* internal_twiddles = is_fast_twiddles_enabled
|
||||
? (is_inverse ? domain.fast_internal_twiddles_inv : domain.fast_internal_twiddles)
|
||||
: domain.internal_twiddles;
|
||||
S* basic_twiddles = is_fast_twiddles_enabled
|
||||
? (is_inverse ? domain.fast_basic_twiddles_inv : domain.fast_basic_twiddles)
|
||||
: domain.basic_twiddles;
|
||||
|
||||
CHK_IF_RETURN(ntt::mixed_radix_ntt(
|
||||
d_input, d_output, domain.twiddles, domain.internal_twiddles, domain.basic_twiddles, size, domain.max_log_size,
|
||||
batch_size, is_inverse, config.ordering, coset, coset_index, stream));
|
||||
d_input, d_output, twiddles, internal_twiddles, basic_twiddles, size, domain.max_log_size, batch_size,
|
||||
is_inverse, is_fast_twiddles_enabled, config.ordering, coset, coset_index, stream));
|
||||
}
|
||||
|
||||
if (!are_outputs_on_device)
|
||||
@@ -645,10 +699,10 @@ namespace ntt {
|
||||
* value of template parameter (where the curve is given by `-DCURVE` env variable during build):
|
||||
* - `S` is the [scalar field](@ref scalar_t) of the curve;
|
||||
*/
|
||||
extern "C" cudaError_t
|
||||
CONCAT_EXPAND(CURVE, InitializeDomain)(curve_config::scalar_t primitive_root, device_context::DeviceContext& ctx)
|
||||
extern "C" cudaError_t CONCAT_EXPAND(CURVE, InitializeDomain)(
|
||||
curve_config::scalar_t primitive_root, device_context::DeviceContext& ctx, bool fast_twiddles_mode)
|
||||
{
|
||||
return InitDomain(primitive_root, ctx);
|
||||
return InitDomain(primitive_root, ctx, fast_twiddles_mode);
|
||||
}
|
||||
|
||||
/**
|
||||
|
||||
@@ -32,10 +32,13 @@ namespace ntt {
|
||||
* @param primitive_root Primitive root in field `S` of order \f$ 2^s \f$. This should be the smallest power-of-2
|
||||
* order that's large enough to support any NTT you might want to perform.
|
||||
* @param ctx Details related to the device such as its id and stream id.
|
||||
* @param fast_twiddles_mode A mode where more memory is allocated for twiddle factors in exchange for faster compute.
|
||||
* In this mode need additional 4N memory when N is the largest NTT size to be supported (which is derived by the
|
||||
* primitive_root).
|
||||
* @return `cudaSuccess` if the execution was successful and an error code otherwise.
|
||||
*/
|
||||
template <typename S>
|
||||
cudaError_t InitDomain(S primitive_root, device_context::DeviceContext& ctx);
|
||||
cudaError_t InitDomain(S primitive_root, device_context::DeviceContext& ctx, bool fast_twiddles_mode = false);
|
||||
|
||||
/**
|
||||
* @enum NTTDir
|
||||
|
||||
@@ -16,6 +16,15 @@ namespace ntt {
|
||||
uint32_t log_size,
|
||||
cudaStream_t& stream);
|
||||
|
||||
template <typename S>
|
||||
cudaError_t generate_external_twiddles_fast_twiddles_mode(
|
||||
const S& basic_root,
|
||||
S* external_twiddles,
|
||||
S*& internal_twiddles,
|
||||
S*& basic_twiddles,
|
||||
uint32_t log_size,
|
||||
cudaStream_t& stream);
|
||||
|
||||
template <typename E, typename S>
|
||||
cudaError_t mixed_radix_ntt(
|
||||
E* d_input,
|
||||
@@ -27,6 +36,7 @@ namespace ntt {
|
||||
int max_logn,
|
||||
int batch_size,
|
||||
bool is_inverse,
|
||||
bool fast_tw,
|
||||
Ordering ordering,
|
||||
S* arbitrary_coset,
|
||||
int coset_gen_index,
|
||||
|
||||
@@ -34,13 +34,14 @@ int main(int argc, char** argv)
|
||||
cudaEvent_t icicle_start, icicle_stop, new_start, new_stop;
|
||||
float icicle_time, new_time;
|
||||
|
||||
int NTT_LOG_SIZE = (argc > 1) ? atoi(argv[1]) : 4; // assuming second input is the log-size
|
||||
int NTT_LOG_SIZE = (argc > 1) ? atoi(argv[1]) : 19;
|
||||
int NTT_SIZE = 1 << NTT_LOG_SIZE;
|
||||
bool INPLACE = (argc > 2) ? atoi(argv[2]) : false;
|
||||
int INV = (argc > 3) ? atoi(argv[3]) : false;
|
||||
int INV = (argc > 3) ? atoi(argv[3]) : true;
|
||||
int BATCH_SIZE = (argc > 4) ? atoi(argv[4]) : 1;
|
||||
int COSET_IDX = (argc > 5) ? atoi(argv[5]) : 0;
|
||||
const ntt::Ordering ordering = (argc > 6) ? ntt::Ordering(atoi(argv[6])) : ntt::Ordering::kNN;
|
||||
bool FAST_TW = (argc > 7) ? atoi(argv[7]) : true;
|
||||
|
||||
// Note: NM, MN are not expected to be equal when comparing mixed-radix and radix-2 NTTs
|
||||
const char* ordering_str = ordering == ntt::Ordering::kNN ? "NN"
|
||||
@@ -51,8 +52,8 @@ int main(int argc, char** argv)
|
||||
: "MN";
|
||||
|
||||
printf(
|
||||
"running ntt 2^%d, inplace=%d, inverse=%d, batch_size=%d, coset-idx=%d, ordering=%s\n", NTT_LOG_SIZE, INPLACE, INV,
|
||||
BATCH_SIZE, COSET_IDX, ordering_str);
|
||||
"running ntt 2^%d, inplace=%d, inverse=%d, batch_size=%d, coset-idx=%d, ordering=%s, fast_tw=%d\n", NTT_LOG_SIZE,
|
||||
INPLACE, INV, BATCH_SIZE, COSET_IDX, ordering_str, FAST_TW);
|
||||
|
||||
CHK_IF_RETURN(cudaFree(nullptr)); // init GPU context (warmup)
|
||||
|
||||
@@ -70,7 +71,7 @@ int main(int argc, char** argv)
|
||||
|
||||
auto start = std::chrono::high_resolution_clock::now();
|
||||
const test_scalar basic_root = test_scalar::omega(NTT_LOG_SIZE);
|
||||
ntt::InitDomain(basic_root, ntt_config.ctx);
|
||||
ntt::InitDomain(basic_root, ntt_config.ctx, FAST_TW);
|
||||
auto stop = std::chrono::high_resolution_clock::now();
|
||||
auto duration = std::chrono::duration_cast<std::chrono::microseconds>(stop - start).count();
|
||||
std::cout << "initDomain took: " << duration / 1000 << " MS" << std::endl;
|
||||
|
||||
@@ -13,19 +13,33 @@ struct stage_metadata {
|
||||
uint32_t ntt_inp_id;
|
||||
};
|
||||
|
||||
uint32_t constexpr STAGE_SIZES_HOST[31][5] = {
|
||||
{0, 0, 0, 0, 0}, {0, 0, 0, 0, 0}, {0, 0, 0, 0, 0}, {0, 0, 0, 0, 0}, {4, 0, 0, 0, 0}, {5, 0, 0, 0, 0}, {6, 0, 0, 0, 0},
|
||||
{0, 0, 0, 0, 0}, {4, 4, 0, 0, 0}, {5, 4, 0, 0, 0}, {5, 5, 0, 0, 0}, {6, 5, 0, 0, 0}, {6, 6, 0, 0, 0}, {4, 5, 4, 0, 0},
|
||||
{4, 6, 4, 0, 0}, {5, 5, 5, 0, 0}, {6, 4, 6, 0, 0}, {6, 5, 6, 0, 0}, {6, 6, 6, 0, 0}, {6, 5, 4, 4, 0}, {5, 5, 5, 5, 0},
|
||||
{6, 5, 5, 5, 0}, {6, 5, 5, 6, 0}, {6, 6, 6, 5, 0}, {6, 6, 6, 6, 0}, {5, 5, 5, 5, 5}, {6, 5, 4, 5, 6}, {6, 5, 5, 5, 6},
|
||||
{6, 5, 6, 5, 6}, {6, 6, 5, 6, 6}, {6, 6, 6, 6, 6}};
|
||||
#define STAGE_SIZES_DATA \
|
||||
{ \
|
||||
{0, 0, 0, 0, 0}, {0, 0, 0, 0, 0}, {0, 0, 0, 0, 0}, {0, 0, 0, 0, 0}, {4, 0, 0, 0, 0}, {5, 0, 0, 0, 0}, \
|
||||
{6, 0, 0, 0, 0}, {0, 0, 0, 0, 0}, {4, 4, 0, 0, 0}, {5, 4, 0, 0, 0}, {5, 5, 0, 0, 0}, {6, 5, 0, 0, 0}, \
|
||||
{6, 6, 0, 0, 0}, {4, 5, 4, 0, 0}, {4, 6, 4, 0, 0}, {5, 5, 5, 0, 0}, {6, 4, 6, 0, 0}, {6, 5, 6, 0, 0}, \
|
||||
{6, 6, 6, 0, 0}, {6, 5, 4, 4, 0}, {5, 5, 5, 5, 0}, {6, 5, 5, 5, 0}, {6, 5, 5, 6, 0}, {6, 6, 6, 5, 0}, \
|
||||
{6, 6, 6, 6, 0}, {5, 5, 5, 5, 5}, {6, 5, 4, 5, 6}, {6, 5, 5, 5, 6}, {6, 5, 6, 5, 6}, {6, 6, 5, 6, 6}, \
|
||||
{6, 6, 6, 6, 6}, \
|
||||
}
|
||||
uint32_t constexpr STAGE_SIZES_HOST[31][5] = STAGE_SIZES_DATA;
|
||||
__device__ constexpr uint32_t STAGE_SIZES_DEVICE[31][5] = STAGE_SIZES_DATA;
|
||||
|
||||
__device__ constexpr uint32_t STAGE_SIZES_DEVICE[31][5] = {
|
||||
{0, 0, 0, 0, 0}, {0, 0, 0, 0, 0}, {0, 0, 0, 0, 0}, {0, 0, 0, 0, 0}, {4, 0, 0, 0, 0}, {5, 0, 0, 0, 0}, {6, 0, 0, 0, 0},
|
||||
{0, 0, 0, 0, 0}, {4, 4, 0, 0, 0}, {5, 4, 0, 0, 0}, {5, 5, 0, 0, 0}, {6, 5, 0, 0, 0}, {6, 6, 0, 0, 0}, {4, 5, 4, 0, 0},
|
||||
{4, 6, 4, 0, 0}, {5, 5, 5, 0, 0}, {6, 4, 6, 0, 0}, {6, 5, 6, 0, 0}, {6, 6, 6, 0, 0}, {6, 5, 4, 4, 0}, {5, 5, 5, 5, 0},
|
||||
{6, 5, 5, 5, 0}, {6, 5, 5, 6, 0}, {6, 6, 6, 5, 0}, {6, 6, 6, 6, 0}, {5, 5, 5, 5, 5}, {6, 5, 4, 5, 6}, {6, 5, 5, 5, 6},
|
||||
{6, 5, 6, 5, 6}, {6, 6, 5, 6, 6}, {6, 6, 6, 6, 6}};
|
||||
// construction for fast-twiddles
|
||||
uint32_t constexpr STAGE_PREV_SIZES[31] = {0, 0, 0, 0, 0, 0, 0, 0, 4, 5, 5, 6, 6, 9, 9, 10,
|
||||
11, 11, 12, 15, 15, 16, 16, 18, 18, 20, 21, 21, 22, 23, 24};
|
||||
|
||||
#define STAGE_SIZES_DATA_FAST_TW \
|
||||
{ \
|
||||
{0, 0, 0, 0, 0}, {0, 0, 0, 0, 0}, {0, 0, 0, 0, 0}, {0, 0, 0, 0, 0}, {4, 0, 0, 0, 0}, {5, 0, 0, 0, 0}, \
|
||||
{6, 0, 0, 0, 0}, {0, 0, 0, 0, 0}, {4, 4, 0, 0, 0}, {5, 4, 0, 0, 0}, {5, 5, 0, 0, 0}, {6, 5, 0, 0, 0}, \
|
||||
{6, 6, 0, 0, 0}, {5, 4, 4, 0, 0}, {5, 4, 5, 0, 0}, {5, 5, 5, 0, 0}, {6, 5, 5, 0, 0}, {6, 5, 6, 0, 0}, \
|
||||
{6, 6, 6, 0, 0}, {5, 5, 5, 4, 0}, {5, 5, 5, 5, 0}, {6, 5, 5, 5, 0}, {6, 5, 5, 6, 0}, {6, 6, 6, 5, 0}, \
|
||||
{6, 6, 6, 6, 0}, {5, 5, 5, 5, 5}, {6, 5, 5, 5, 5}, {6, 5, 5, 5, 6}, {6, 5, 5, 6, 6}, {6, 6, 6, 5, 6}, \
|
||||
{6, 6, 6, 6, 6}, \
|
||||
}
|
||||
uint32_t constexpr STAGE_SIZES_HOST_FT[31][5] = STAGE_SIZES_DATA_FAST_TW;
|
||||
__device__ uint32_t constexpr STAGE_SIZES_DEVICE_FT[31][5] = STAGE_SIZES_DATA_FAST_TW;
|
||||
|
||||
template <typename E, typename S>
|
||||
class NTTEngine
|
||||
@@ -36,7 +50,15 @@ public:
|
||||
S WI[7];
|
||||
S WE[8];
|
||||
|
||||
__device__ __forceinline__ void loadBasicTwiddles(S* basic_twiddles, bool inv)
|
||||
__device__ __forceinline__ void loadBasicTwiddles(S* basic_twiddles)
|
||||
{
|
||||
#pragma unroll
|
||||
for (int i = 0; i < 3; i++) {
|
||||
WB[i] = basic_twiddles[i];
|
||||
}
|
||||
}
|
||||
|
||||
__device__ __forceinline__ void loadBasicTwiddlesGeneric(S* basic_twiddles, bool inv)
|
||||
{
|
||||
#pragma unroll
|
||||
for (int i = 0; i < 3; i++) {
|
||||
@@ -44,7 +66,31 @@ public:
|
||||
}
|
||||
}
|
||||
|
||||
__device__ __forceinline__ void loadInternalTwiddles64(S* data, bool stride, bool inv)
|
||||
__device__ __forceinline__ void loadInternalTwiddles64(S* data, bool stride)
|
||||
{
|
||||
#pragma unroll
|
||||
for (int i = 0; i < 7; i++) {
|
||||
WI[i] = data[((stride ? (threadIdx.x >> 3) : (threadIdx.x)) & 0x7) * (i + 1)];
|
||||
}
|
||||
}
|
||||
|
||||
__device__ __forceinline__ void loadInternalTwiddles32(S* data, bool stride)
|
||||
{
|
||||
#pragma unroll
|
||||
for (int i = 0; i < 7; i++) {
|
||||
WI[i] = data[2 * ((stride ? (threadIdx.x >> 4) : (threadIdx.x)) & 0x3) * (i + 1)];
|
||||
}
|
||||
}
|
||||
|
||||
__device__ __forceinline__ void loadInternalTwiddles16(S* data, bool stride)
|
||||
{
|
||||
#pragma unroll
|
||||
for (int i = 0; i < 7; i++) {
|
||||
WI[i] = data[4 * ((stride ? (threadIdx.x >> 5) : (threadIdx.x)) & 0x1) * (i + 1)];
|
||||
}
|
||||
}
|
||||
|
||||
__device__ __forceinline__ void loadInternalTwiddlesGeneric64(S* data, bool stride, bool inv)
|
||||
{
|
||||
#pragma unroll
|
||||
for (int i = 0; i < 7; i++) {
|
||||
@@ -53,7 +99,7 @@ public:
|
||||
}
|
||||
}
|
||||
|
||||
__device__ __forceinline__ void loadInternalTwiddles32(S* data, bool stride, bool inv)
|
||||
__device__ __forceinline__ void loadInternalTwiddlesGeneric32(S* data, bool stride, bool inv)
|
||||
{
|
||||
#pragma unroll
|
||||
for (int i = 0; i < 7; i++) {
|
||||
@@ -62,7 +108,7 @@ public:
|
||||
}
|
||||
}
|
||||
|
||||
__device__ __forceinline__ void loadInternalTwiddles16(S* data, bool stride, bool inv)
|
||||
__device__ __forceinline__ void loadInternalTwiddlesGeneric16(S* data, bool stride, bool inv)
|
||||
{
|
||||
#pragma unroll
|
||||
for (int i = 0; i < 7; i++) {
|
||||
@@ -71,8 +117,47 @@ public:
|
||||
}
|
||||
}
|
||||
|
||||
__device__ __forceinline__ void
|
||||
loadExternalTwiddles64(S* data, uint32_t tw_order, uint32_t tw_log_order, bool strided, stage_metadata s_meta)
|
||||
{
|
||||
data += tw_order * s_meta.ntt_inp_id + (s_meta.ntt_block_id & (tw_order - 1));
|
||||
|
||||
#pragma unroll
|
||||
for (uint32_t i = 0; i < 8; i++) {
|
||||
WE[i] = data[8 * i * tw_order + (1 << tw_log_order + 6) - 1];
|
||||
}
|
||||
}
|
||||
|
||||
__device__ __forceinline__ void
|
||||
loadExternalTwiddles32(S* data, uint32_t tw_order, uint32_t tw_log_order, bool strided, stage_metadata s_meta)
|
||||
{
|
||||
data += tw_order * s_meta.ntt_inp_id * 2 + (s_meta.ntt_block_id & (tw_order - 1));
|
||||
|
||||
#pragma unroll
|
||||
for (uint32_t j = 0; j < 2; j++) {
|
||||
#pragma unroll
|
||||
for (uint32_t i = 0; i < 4; i++) {
|
||||
WE[4 * j + i] = data[(8 * i + j) * tw_order + (1 << tw_log_order + 5) - 1];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
__device__ __forceinline__ void
|
||||
loadExternalTwiddles16(S* data, uint32_t tw_order, uint32_t tw_log_order, bool strided, stage_metadata s_meta)
|
||||
{
|
||||
data += tw_order * s_meta.ntt_inp_id * 4 + (s_meta.ntt_block_id & (tw_order - 1));
|
||||
|
||||
#pragma unroll
|
||||
for (uint32_t j = 0; j < 4; j++) {
|
||||
#pragma unroll
|
||||
for (uint32_t i = 0; i < 2; i++) {
|
||||
WE[2 * j + i] = data[(8 * i + j) * tw_order + (1 << tw_log_order + 4) - 1];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
__device__ __forceinline__ void loadExternalTwiddlesGeneric64(
|
||||
E* data, uint32_t tw_order, uint32_t tw_log_order, stage_metadata s_meta, uint32_t tw_log_size, bool inv)
|
||||
S* data, uint32_t tw_order, uint32_t tw_log_order, stage_metadata s_meta, uint32_t tw_log_size, bool inv)
|
||||
{
|
||||
#pragma unroll
|
||||
for (uint32_t i = 0; i < 8; i++) {
|
||||
@@ -83,7 +168,7 @@ public:
|
||||
}
|
||||
|
||||
__device__ __forceinline__ void loadExternalTwiddlesGeneric32(
|
||||
E* data, uint32_t tw_order, uint32_t tw_log_order, stage_metadata s_meta, uint32_t tw_log_size, bool inv)
|
||||
S* data, uint32_t tw_order, uint32_t tw_log_order, stage_metadata s_meta, uint32_t tw_log_size, bool inv)
|
||||
{
|
||||
#pragma unroll
|
||||
for (uint32_t j = 0; j < 2; j++) {
|
||||
@@ -97,7 +182,7 @@ public:
|
||||
}
|
||||
|
||||
__device__ __forceinline__ void loadExternalTwiddlesGeneric16(
|
||||
E* data, uint32_t tw_order, uint32_t tw_log_order, stage_metadata s_meta, uint32_t tw_log_size, bool inv)
|
||||
S* data, uint32_t tw_order, uint32_t tw_log_order, stage_metadata s_meta, uint32_t tw_log_size, bool inv)
|
||||
{
|
||||
#pragma unroll
|
||||
for (uint32_t j = 0; j < 4; j++) {
|
||||
|
||||
@@ -119,6 +119,7 @@ pub trait NTT<F: FieldImpl> {
|
||||
output: &mut HostOrDeviceSlice<F>,
|
||||
) -> IcicleResult<()>;
|
||||
fn initialize_domain(primitive_root: F, ctx: &DeviceContext) -> IcicleResult<()>;
|
||||
fn initialize_domain_fast_twiddles_mode(primitive_root: F, ctx: &DeviceContext) -> IcicleResult<()>;
|
||||
}
|
||||
|
||||
/// Computes the NTT, or a batch of several NTTs.
|
||||
@@ -172,6 +173,13 @@ where
|
||||
{
|
||||
<<F as FieldImpl>::Config as NTT<F>>::initialize_domain(primitive_root, ctx)
|
||||
}
|
||||
pub fn initialize_domain_fast_twiddles_mode<F>(primitive_root: F, ctx: &DeviceContext) -> IcicleResult<()>
|
||||
where
|
||||
F: FieldImpl,
|
||||
<F as FieldImpl>::Config: NTT<F>,
|
||||
{
|
||||
<<F as FieldImpl>::Config as NTT<F>>::initialize_domain_fast_twiddles_mode(primitive_root, ctx)
|
||||
}
|
||||
|
||||
#[macro_export]
|
||||
macro_rules! impl_ntt {
|
||||
@@ -195,7 +203,11 @@ macro_rules! impl_ntt {
|
||||
) -> CudaError;
|
||||
|
||||
#[link_name = concat!($field_prefix, "InitializeDomain")]
|
||||
pub(crate) fn initialize_ntt_domain(primitive_root: $field, ctx: &DeviceContext) -> CudaError;
|
||||
pub(crate) fn initialize_ntt_domain(
|
||||
primitive_root: $field,
|
||||
ctx: &DeviceContext,
|
||||
fast_twiddles_mode: bool,
|
||||
) -> CudaError;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -219,7 +231,10 @@ macro_rules! impl_ntt {
|
||||
}
|
||||
|
||||
fn initialize_domain(primitive_root: $field, ctx: &DeviceContext) -> IcicleResult<()> {
|
||||
unsafe { $field_prefix_ident::initialize_ntt_domain(primitive_root, ctx).wrap() }
|
||||
unsafe { $field_prefix_ident::initialize_ntt_domain(primitive_root, ctx, false).wrap() }
|
||||
}
|
||||
fn initialize_domain_fast_twiddles_mode(primitive_root: $field, ctx: &DeviceContext) -> IcicleResult<()> {
|
||||
unsafe { $field_prefix_ident::initialize_ntt_domain(primitive_root, ctx, true).wrap() }
|
||||
}
|
||||
}
|
||||
};
|
||||
@@ -232,28 +247,29 @@ macro_rules! impl_ntt_tests {
|
||||
) => {
|
||||
const MAX_SIZE: u64 = 1 << 17;
|
||||
static INIT: OnceLock<()> = OnceLock::new();
|
||||
const FAST_TWIDDLES_MODE: bool = false;
|
||||
|
||||
#[test]
|
||||
fn test_ntt() {
|
||||
INIT.get_or_init(move || init_domain::<$field>(MAX_SIZE, DEFAULT_DEVICE_ID));
|
||||
INIT.get_or_init(move || init_domain::<$field>(MAX_SIZE, DEFAULT_DEVICE_ID, FAST_TWIDDLES_MODE));
|
||||
check_ntt::<$field>()
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_ntt_coset_from_subgroup() {
|
||||
INIT.get_or_init(move || init_domain::<$field>(MAX_SIZE, DEFAULT_DEVICE_ID));
|
||||
INIT.get_or_init(move || init_domain::<$field>(MAX_SIZE, DEFAULT_DEVICE_ID, FAST_TWIDDLES_MODE));
|
||||
check_ntt_coset_from_subgroup::<$field>()
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_ntt_arbitrary_coset() {
|
||||
INIT.get_or_init(move || init_domain::<$field>(MAX_SIZE, DEFAULT_DEVICE_ID));
|
||||
INIT.get_or_init(move || init_domain::<$field>(MAX_SIZE, DEFAULT_DEVICE_ID, FAST_TWIDDLES_MODE));
|
||||
check_ntt_arbitrary_coset::<$field>()
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_ntt_batch() {
|
||||
INIT.get_or_init(move || init_domain::<$field>(MAX_SIZE, DEFAULT_DEVICE_ID));
|
||||
INIT.get_or_init(move || init_domain::<$field>(MAX_SIZE, DEFAULT_DEVICE_ID, FAST_TWIDDLES_MODE));
|
||||
check_ntt_batch::<$field>()
|
||||
}
|
||||
|
||||
|
||||
@@ -9,21 +9,25 @@ use rayon::iter::IntoParallelIterator;
|
||||
use rayon::iter::ParallelIterator;
|
||||
|
||||
use crate::{
|
||||
ntt::{initialize_domain, ntt, NTTDir, NttAlgorithm, Ordering},
|
||||
ntt::{initialize_domain, initialize_domain_fast_twiddles_mode, ntt, NTTDir, NttAlgorithm, Ordering},
|
||||
traits::{ArkConvertible, FieldImpl, GenerateRandom},
|
||||
};
|
||||
|
||||
use super::NTTConfig;
|
||||
use super::NTT;
|
||||
|
||||
pub fn init_domain<F: FieldImpl + ArkConvertible>(max_size: u64, device_id: usize)
|
||||
pub fn init_domain<F: FieldImpl + ArkConvertible>(max_size: u64, device_id: usize, fast_twiddles_mode: bool)
|
||||
where
|
||||
F::ArkEquivalent: FftField,
|
||||
<F as FieldImpl>::Config: NTT<F>,
|
||||
{
|
||||
let ctx = DeviceContext::default_for_device(device_id);
|
||||
let ark_rou = F::ArkEquivalent::get_root_of_unity(max_size).unwrap();
|
||||
initialize_domain(F::from_ark(ark_rou), &ctx).unwrap();
|
||||
if fast_twiddles_mode {
|
||||
initialize_domain_fast_twiddles_mode(F::from_ark(ark_rou), &ctx).unwrap();
|
||||
} else {
|
||||
initialize_domain(F::from_ark(ark_rou), &ctx).unwrap();
|
||||
}
|
||||
}
|
||||
|
||||
pub fn reverse_bit_order(n: u32, order: u32) -> u32 {
|
||||
@@ -289,7 +293,8 @@ where
|
||||
.into_par_iter()
|
||||
.for_each(move |device_id| {
|
||||
set_device(device_id).unwrap();
|
||||
init_domain::<F>(1 << 16, device_id); // init domain per device
|
||||
// if have more than one device, it will use fast-twiddles-mode (note that domain is reused per device if not released)
|
||||
init_domain::<F>(1 << 16, device_id, true /*=fast twiddles mode*/); // init domain per device
|
||||
let test_sizes = [1 << 4, 1 << 12];
|
||||
let batch_sizes = [1, 1 << 4, 100];
|
||||
for test_size in test_sizes {
|
||||
|
||||
Reference in New Issue
Block a user