mirror of
https://github.com/pseXperiments/icicle.git
synced 2026-01-10 07:57:56 -05:00
NTT columns batch (#424)
This PR adds the columns batch feature - enabling batch NTT computation to be performed directly on the columns of a matrix without having to transpose it beforehand, as requested in issue #264. Also some small fixes to the reordering kernels were added and some unnecessary parameters were removes from functions interfaces. --------- Co-authored-by: DmytroTym <dmytrotym1@gmail.com>
This commit is contained in:
@@ -56,7 +56,15 @@ namespace ntt {
|
||||
// Note: the following reorder kernels are fused with normalization for INTT
|
||||
template <typename E, typename S, uint32_t MAX_GROUP_SIZE = 80>
|
||||
static __global__ void reorder_digits_inplace_and_normalize_kernel(
|
||||
E* arr, uint32_t log_size, bool dit, bool fast_tw, eRevType rev_type, bool is_normalize, S inverse_N)
|
||||
E* arr,
|
||||
uint32_t log_size,
|
||||
bool columns_batch,
|
||||
uint32_t batch_size,
|
||||
bool dit,
|
||||
bool fast_tw,
|
||||
eRevType rev_type,
|
||||
bool is_normalize,
|
||||
S inverse_N)
|
||||
{
|
||||
// launch N threads (per batch element)
|
||||
// each thread starts from one index and calculates the corresponding group
|
||||
@@ -65,19 +73,20 @@ namespace ntt {
|
||||
|
||||
const uint32_t size = 1 << log_size;
|
||||
const uint32_t tid = blockDim.x * blockIdx.x + threadIdx.x;
|
||||
const uint32_t idx = tid % size;
|
||||
const uint32_t batch_idx = tid / size;
|
||||
const uint32_t idx = columns_batch ? tid / batch_size : tid % size;
|
||||
const uint32_t batch_idx = columns_batch ? tid % batch_size : tid / size;
|
||||
if (tid >= size * batch_size) return;
|
||||
|
||||
uint32_t next_element = idx;
|
||||
uint32_t group[MAX_GROUP_SIZE];
|
||||
group[0] = next_element + size * batch_idx;
|
||||
group[0] = columns_batch ? next_element * batch_size + batch_idx : next_element + size * batch_idx;
|
||||
|
||||
uint32_t i = 1;
|
||||
for (; i < MAX_GROUP_SIZE;) {
|
||||
next_element = generalized_rev(next_element, log_size, dit, fast_tw, rev_type);
|
||||
if (next_element < idx) return; // not handling this group
|
||||
if (next_element == idx) break; // calculated whole group
|
||||
group[i++] = next_element + size * batch_idx;
|
||||
group[i++] = columns_batch ? next_element * batch_size + batch_idx : next_element + size * batch_idx;
|
||||
}
|
||||
|
||||
--i;
|
||||
@@ -94,6 +103,9 @@ namespace ntt {
|
||||
E* arr,
|
||||
E* arr_reordered,
|
||||
uint32_t log_size,
|
||||
bool columns_batch,
|
||||
uint32_t batch_size,
|
||||
uint32_t columns_batch_size,
|
||||
bool dit,
|
||||
bool fast_tw,
|
||||
eRevType rev_type,
|
||||
@@ -101,29 +113,33 @@ namespace ntt {
|
||||
S inverse_N)
|
||||
{
|
||||
uint32_t tid = blockDim.x * blockIdx.x + threadIdx.x;
|
||||
if (tid >= (1 << log_size) * batch_size) return;
|
||||
uint32_t rd = tid;
|
||||
uint32_t wr =
|
||||
((tid >> log_size) << log_size) + generalized_rev(tid & ((1 << log_size) - 1), log_size, dit, fast_tw, rev_type);
|
||||
arr_reordered[wr] = is_normalize ? arr[rd] * inverse_N : arr[rd];
|
||||
uint32_t wr = (columns_batch ? 0 : ((tid >> log_size) << log_size)) +
|
||||
generalized_rev((tid / columns_batch_size) & ((1 << log_size) - 1), log_size, dit, fast_tw, rev_type);
|
||||
arr_reordered[wr * columns_batch_size + (tid % columns_batch_size)] = is_normalize ? arr[rd] * inverse_N : arr[rd];
|
||||
}
|
||||
|
||||
template <typename E, typename S>
|
||||
static __global__ void batch_elementwise_mul_with_reorder(
|
||||
static __global__ void batch_elementwise_mul_with_reorder_kernel(
|
||||
E* in_vec,
|
||||
int n_elements,
|
||||
int batch_size,
|
||||
uint32_t size,
|
||||
bool columns_batch,
|
||||
uint32_t batch_size,
|
||||
uint32_t columns_batch_size,
|
||||
S* scalar_vec,
|
||||
int step,
|
||||
int n_scalars,
|
||||
int logn,
|
||||
uint32_t log_size,
|
||||
eRevType rev_type,
|
||||
bool dit,
|
||||
E* out_vec)
|
||||
{
|
||||
int tid = blockDim.x * blockIdx.x + threadIdx.x;
|
||||
if (tid >= n_elements * batch_size) return;
|
||||
int64_t scalar_id = tid % n_elements;
|
||||
if (rev_type != eRevType::None) scalar_id = generalized_rev(tid, logn, dit, false, rev_type);
|
||||
if (tid >= size * batch_size) return;
|
||||
int64_t scalar_id = (tid / columns_batch_size) % size;
|
||||
if (rev_type != eRevType::None)
|
||||
scalar_id = generalized_rev((tid / columns_batch_size) & ((1 << log_size) - 1), log_size, dit, false, rev_type);
|
||||
out_vec[tid] = *(scalar_vec + ((scalar_id * step) % n_scalars)) * in_vec[tid];
|
||||
}
|
||||
|
||||
@@ -136,6 +152,7 @@ namespace ntt {
|
||||
S* basic_twiddles,
|
||||
uint32_t log_size,
|
||||
uint32_t tw_log_size,
|
||||
uint32_t columns_batch_size,
|
||||
uint32_t nof_ntt_blocks,
|
||||
uint32_t data_stride,
|
||||
uint32_t log_data_stride,
|
||||
@@ -153,19 +170,27 @@ namespace ntt {
|
||||
|
||||
s_meta.th_stride = 8;
|
||||
s_meta.ntt_block_size = 64;
|
||||
s_meta.ntt_block_id = (blockIdx.x << 3) + (strided ? (threadIdx.x & 0x7) : (threadIdx.x >> 3));
|
||||
s_meta.ntt_block_id = columns_batch_size ? blockIdx.x / ((columns_batch_size + 7) / 8)
|
||||
: (blockIdx.x << 3) + (strided ? (threadIdx.x & 0x7) : (threadIdx.x >> 3));
|
||||
s_meta.ntt_inp_id = strided ? (threadIdx.x >> 3) : (threadIdx.x & 0x7);
|
||||
|
||||
if (s_meta.ntt_block_id >= nof_ntt_blocks) return;
|
||||
s_meta.batch_id =
|
||||
columns_batch_size ? (threadIdx.x & 0x7) + ((blockIdx.x % ((columns_batch_size + 7) / 8)) << 3) : 0;
|
||||
if (s_meta.ntt_block_id >= nof_ntt_blocks || (columns_batch_size > 0 && s_meta.batch_id >= columns_batch_size))
|
||||
return;
|
||||
|
||||
if (fast_tw)
|
||||
engine.loadBasicTwiddles(basic_twiddles);
|
||||
else
|
||||
engine.loadBasicTwiddlesGeneric(basic_twiddles, inv);
|
||||
engine.loadGlobalData(in, data_stride, log_data_stride, log_size, strided, s_meta);
|
||||
if (columns_batch_size)
|
||||
engine.loadGlobalDataColumnBatch(in, data_stride, log_data_stride, s_meta, columns_batch_size);
|
||||
else
|
||||
engine.loadGlobalData(in, data_stride, log_data_stride, strided, s_meta);
|
||||
|
||||
if (twiddle_stride && dit) {
|
||||
if (fast_tw)
|
||||
engine.loadExternalTwiddles64(external_twiddles, twiddle_stride, log_data_stride, strided, s_meta);
|
||||
engine.loadExternalTwiddles64(external_twiddles, twiddle_stride, log_data_stride, s_meta);
|
||||
else
|
||||
engine.loadExternalTwiddlesGeneric64(
|
||||
external_twiddles, twiddle_stride, log_data_stride, s_meta, tw_log_size, inv);
|
||||
@@ -189,13 +214,16 @@ namespace ntt {
|
||||
|
||||
if (twiddle_stride && !dit) {
|
||||
if (fast_tw)
|
||||
engine.loadExternalTwiddles64(external_twiddles, twiddle_stride, log_data_stride, strided, s_meta);
|
||||
engine.loadExternalTwiddles64(external_twiddles, twiddle_stride, log_data_stride, s_meta);
|
||||
else
|
||||
engine.loadExternalTwiddlesGeneric64(
|
||||
external_twiddles, twiddle_stride, log_data_stride, s_meta, tw_log_size, inv);
|
||||
engine.twiddlesExternal();
|
||||
}
|
||||
engine.storeGlobalData(out, data_stride, log_data_stride, log_size, strided, s_meta);
|
||||
if (columns_batch_size)
|
||||
engine.storeGlobalDataColumnBatch(out, data_stride, log_data_stride, s_meta, columns_batch_size);
|
||||
else
|
||||
engine.storeGlobalData(out, data_stride, log_data_stride, strided, s_meta);
|
||||
}
|
||||
|
||||
template <typename E, typename S>
|
||||
@@ -207,6 +235,7 @@ namespace ntt {
|
||||
S* basic_twiddles,
|
||||
uint32_t log_size,
|
||||
uint32_t tw_log_size,
|
||||
uint32_t columns_batch_size,
|
||||
uint32_t nof_ntt_blocks,
|
||||
uint32_t data_stride,
|
||||
uint32_t log_data_stride,
|
||||
@@ -225,16 +254,25 @@ namespace ntt {
|
||||
|
||||
s_meta.th_stride = 4;
|
||||
s_meta.ntt_block_size = 32;
|
||||
s_meta.ntt_block_id = (blockIdx.x << 4) + (strided ? (threadIdx.x & 0xf) : (threadIdx.x >> 2));
|
||||
s_meta.ntt_block_id = columns_batch_size ? blockIdx.x / ((columns_batch_size + 15) / 16)
|
||||
: (blockIdx.x << 4) + (strided ? (threadIdx.x & 0xf) : (threadIdx.x >> 2));
|
||||
s_meta.ntt_inp_id = strided ? (threadIdx.x >> 4) : (threadIdx.x & 0x3);
|
||||
|
||||
if (s_meta.ntt_block_id >= nof_ntt_blocks) return;
|
||||
s_meta.batch_id =
|
||||
columns_batch_size ? (threadIdx.x & 0xf) + ((blockIdx.x % ((columns_batch_size + 15) / 16)) << 4) : 0;
|
||||
if (s_meta.ntt_block_id >= nof_ntt_blocks || (columns_batch_size > 0 && s_meta.batch_id >= columns_batch_size))
|
||||
return;
|
||||
|
||||
if (fast_tw)
|
||||
engine.loadBasicTwiddles(basic_twiddles);
|
||||
else
|
||||
engine.loadBasicTwiddlesGeneric(basic_twiddles, inv);
|
||||
engine.loadGlobalData(in, data_stride, log_data_stride, log_size, strided, s_meta);
|
||||
|
||||
if (columns_batch_size)
|
||||
engine.loadGlobalDataColumnBatch(in, data_stride, log_data_stride, s_meta, columns_batch_size);
|
||||
else
|
||||
engine.loadGlobalData(in, data_stride, log_data_stride, strided, s_meta);
|
||||
|
||||
if (fast_tw)
|
||||
engine.loadInternalTwiddles32(internal_twiddles, strided);
|
||||
else
|
||||
@@ -247,13 +285,16 @@ namespace ntt {
|
||||
engine.ntt4_2();
|
||||
if (twiddle_stride) {
|
||||
if (fast_tw)
|
||||
engine.loadExternalTwiddles32(external_twiddles, twiddle_stride, log_data_stride, strided, s_meta);
|
||||
engine.loadExternalTwiddles32(external_twiddles, twiddle_stride, log_data_stride, s_meta);
|
||||
else
|
||||
engine.loadExternalTwiddlesGeneric32(
|
||||
external_twiddles, twiddle_stride, log_data_stride, s_meta, tw_log_size, inv);
|
||||
engine.twiddlesExternal();
|
||||
}
|
||||
engine.storeGlobalData32(out, data_stride, log_data_stride, log_size, strided, s_meta);
|
||||
if (columns_batch_size)
|
||||
engine.storeGlobalData32ColumnBatch(out, data_stride, log_data_stride, s_meta, columns_batch_size);
|
||||
else
|
||||
engine.storeGlobalData32(out, data_stride, log_data_stride, strided, s_meta);
|
||||
}
|
||||
|
||||
template <typename E, typename S>
|
||||
@@ -265,6 +306,7 @@ namespace ntt {
|
||||
S* basic_twiddles,
|
||||
uint32_t log_size,
|
||||
uint32_t tw_log_size,
|
||||
uint32_t columns_batch_size,
|
||||
uint32_t nof_ntt_blocks,
|
||||
uint32_t data_stride,
|
||||
uint32_t log_data_stride,
|
||||
@@ -283,19 +325,27 @@ namespace ntt {
|
||||
|
||||
s_meta.th_stride = 4;
|
||||
s_meta.ntt_block_size = 32;
|
||||
s_meta.ntt_block_id = (blockIdx.x << 4) + (strided ? (threadIdx.x & 0xf) : (threadIdx.x >> 2));
|
||||
s_meta.ntt_block_id = columns_batch_size ? blockIdx.x / ((columns_batch_size + 15) / 16)
|
||||
: (blockIdx.x << 4) + (strided ? (threadIdx.x & 0xf) : (threadIdx.x >> 2));
|
||||
s_meta.ntt_inp_id = strided ? (threadIdx.x >> 4) : (threadIdx.x & 0x3);
|
||||
|
||||
if (s_meta.ntt_block_id >= nof_ntt_blocks) return;
|
||||
s_meta.batch_id =
|
||||
columns_batch_size ? (threadIdx.x & 0xf) + ((blockIdx.x % ((columns_batch_size + 15) / 16)) << 4) : 0;
|
||||
if (s_meta.ntt_block_id >= nof_ntt_blocks || (columns_batch_size > 0 && s_meta.batch_id >= columns_batch_size))
|
||||
return;
|
||||
|
||||
if (fast_tw)
|
||||
engine.loadBasicTwiddles(basic_twiddles);
|
||||
else
|
||||
engine.loadBasicTwiddlesGeneric(basic_twiddles, inv);
|
||||
engine.loadGlobalData32(in, data_stride, log_data_stride, log_size, strided, s_meta);
|
||||
|
||||
if (columns_batch_size)
|
||||
engine.loadGlobalData32ColumnBatch(in, data_stride, log_data_stride, s_meta, columns_batch_size);
|
||||
else
|
||||
engine.loadGlobalData32(in, data_stride, log_data_stride, strided, s_meta);
|
||||
if (twiddle_stride) {
|
||||
if (fast_tw)
|
||||
engine.loadExternalTwiddles32(external_twiddles, twiddle_stride, log_data_stride, strided, s_meta);
|
||||
engine.loadExternalTwiddles32(external_twiddles, twiddle_stride, log_data_stride, s_meta);
|
||||
else
|
||||
engine.loadExternalTwiddlesGeneric32(
|
||||
external_twiddles, twiddle_stride, log_data_stride, s_meta, tw_log_size, inv);
|
||||
@@ -311,7 +361,10 @@ namespace ntt {
|
||||
engine.SharedData32Rows8(shmem, false, false, strided); // load
|
||||
engine.twiddlesInternal();
|
||||
engine.ntt8win();
|
||||
engine.storeGlobalData(out, data_stride, log_data_stride, log_size, strided, s_meta);
|
||||
if (columns_batch_size)
|
||||
engine.storeGlobalDataColumnBatch(out, data_stride, log_data_stride, s_meta, columns_batch_size);
|
||||
else
|
||||
engine.storeGlobalData(out, data_stride, log_data_stride, strided, s_meta);
|
||||
}
|
||||
|
||||
template <typename E, typename S>
|
||||
@@ -323,6 +376,7 @@ namespace ntt {
|
||||
S* basic_twiddles,
|
||||
uint32_t log_size,
|
||||
uint32_t tw_log_size,
|
||||
uint32_t columns_batch_size,
|
||||
uint32_t nof_ntt_blocks,
|
||||
uint32_t data_stride,
|
||||
uint32_t log_data_stride,
|
||||
@@ -341,16 +395,26 @@ namespace ntt {
|
||||
|
||||
s_meta.th_stride = 2;
|
||||
s_meta.ntt_block_size = 16;
|
||||
s_meta.ntt_block_id = (blockIdx.x << 5) + (strided ? (threadIdx.x & 0x1f) : (threadIdx.x >> 1));
|
||||
s_meta.ntt_block_id = columns_batch_size
|
||||
? blockIdx.x / ((columns_batch_size + 31) / 32)
|
||||
: (blockIdx.x << 5) + (strided ? (threadIdx.x & 0x1f) : (threadIdx.x >> 1));
|
||||
s_meta.ntt_inp_id = strided ? (threadIdx.x >> 5) : (threadIdx.x & 0x1);
|
||||
|
||||
if (s_meta.ntt_block_id >= nof_ntt_blocks) return;
|
||||
s_meta.batch_id =
|
||||
columns_batch_size ? (threadIdx.x & 0x1f) + ((blockIdx.x % ((columns_batch_size + 31) / 32)) << 5) : 0;
|
||||
if (s_meta.ntt_block_id >= nof_ntt_blocks || (columns_batch_size > 0 && s_meta.batch_id >= columns_batch_size))
|
||||
return;
|
||||
|
||||
if (fast_tw)
|
||||
engine.loadBasicTwiddles(basic_twiddles);
|
||||
else
|
||||
engine.loadBasicTwiddlesGeneric(basic_twiddles, inv);
|
||||
engine.loadGlobalData(in, data_stride, log_data_stride, log_size, strided, s_meta);
|
||||
|
||||
if (columns_batch_size)
|
||||
engine.loadGlobalDataColumnBatch(in, data_stride, log_data_stride, s_meta, columns_batch_size);
|
||||
else
|
||||
engine.loadGlobalData(in, data_stride, log_data_stride, strided, s_meta);
|
||||
|
||||
if (fast_tw)
|
||||
engine.loadInternalTwiddles16(internal_twiddles, strided);
|
||||
else
|
||||
@@ -363,13 +427,16 @@ namespace ntt {
|
||||
engine.ntt2_4();
|
||||
if (twiddle_stride) {
|
||||
if (fast_tw)
|
||||
engine.loadExternalTwiddles16(external_twiddles, twiddle_stride, log_data_stride, strided, s_meta);
|
||||
engine.loadExternalTwiddles16(external_twiddles, twiddle_stride, log_data_stride, s_meta);
|
||||
else
|
||||
engine.loadExternalTwiddlesGeneric16(
|
||||
external_twiddles, twiddle_stride, log_data_stride, s_meta, tw_log_size, inv);
|
||||
engine.twiddlesExternal();
|
||||
}
|
||||
engine.storeGlobalData16(out, data_stride, log_data_stride, log_size, strided, s_meta);
|
||||
if (columns_batch_size)
|
||||
engine.storeGlobalData16ColumnBatch(out, data_stride, log_data_stride, s_meta, columns_batch_size);
|
||||
else
|
||||
engine.storeGlobalData16(out, data_stride, log_data_stride, strided, s_meta);
|
||||
}
|
||||
|
||||
template <typename E, typename S>
|
||||
@@ -381,6 +448,7 @@ namespace ntt {
|
||||
S* basic_twiddles,
|
||||
uint32_t log_size,
|
||||
uint32_t tw_log_size,
|
||||
uint32_t columns_batch_size,
|
||||
uint32_t nof_ntt_blocks,
|
||||
uint32_t data_stride,
|
||||
uint32_t log_data_stride,
|
||||
@@ -399,19 +467,29 @@ namespace ntt {
|
||||
|
||||
s_meta.th_stride = 2;
|
||||
s_meta.ntt_block_size = 16;
|
||||
s_meta.ntt_block_id = (blockIdx.x << 5) + (strided ? (threadIdx.x & 0x1f) : (threadIdx.x >> 1));
|
||||
s_meta.ntt_block_id = columns_batch_size
|
||||
? blockIdx.x / ((columns_batch_size + 31) / 32)
|
||||
: (blockIdx.x << 5) + (strided ? (threadIdx.x & 0x1f) : (threadIdx.x >> 1));
|
||||
s_meta.ntt_inp_id = strided ? (threadIdx.x >> 5) : (threadIdx.x & 0x1);
|
||||
|
||||
if (s_meta.ntt_block_id >= nof_ntt_blocks) return;
|
||||
s_meta.batch_id =
|
||||
columns_batch_size ? (threadIdx.x & 0x1f) + ((blockIdx.x % ((columns_batch_size + 31) / 32)) << 5) : 0;
|
||||
if (s_meta.ntt_block_id >= nof_ntt_blocks || (columns_batch_size > 0 && s_meta.batch_id >= columns_batch_size))
|
||||
return;
|
||||
|
||||
if (fast_tw)
|
||||
engine.loadBasicTwiddles(basic_twiddles);
|
||||
else
|
||||
engine.loadBasicTwiddlesGeneric(basic_twiddles, inv);
|
||||
engine.loadGlobalData16(in, data_stride, log_data_stride, log_size, strided, s_meta);
|
||||
|
||||
if (columns_batch_size)
|
||||
engine.loadGlobalData16ColumnBatch(in, data_stride, log_data_stride, s_meta, columns_batch_size);
|
||||
else
|
||||
engine.loadGlobalData16(in, data_stride, log_data_stride, strided, s_meta);
|
||||
|
||||
if (twiddle_stride) {
|
||||
if (fast_tw)
|
||||
engine.loadExternalTwiddles16(external_twiddles, twiddle_stride, log_data_stride, strided, s_meta);
|
||||
engine.loadExternalTwiddles16(external_twiddles, twiddle_stride, log_data_stride, s_meta);
|
||||
else
|
||||
engine.loadExternalTwiddlesGeneric16(
|
||||
external_twiddles, twiddle_stride, log_data_stride, s_meta, tw_log_size, inv);
|
||||
@@ -427,13 +505,17 @@ namespace ntt {
|
||||
engine.SharedData16Rows8(shmem, false, false, strided); // load
|
||||
engine.twiddlesInternal();
|
||||
engine.ntt8win();
|
||||
engine.storeGlobalData(out, data_stride, log_data_stride, log_size, strided, s_meta);
|
||||
if (columns_batch_size)
|
||||
engine.storeGlobalDataColumnBatch(out, data_stride, log_data_stride, s_meta, columns_batch_size);
|
||||
else
|
||||
engine.storeGlobalData(out, data_stride, log_data_stride, strided, s_meta);
|
||||
}
|
||||
|
||||
template <typename E, typename S>
|
||||
__global__ void normalize_kernel(E* data, S norm_factor)
|
||||
__global__ void normalize_kernel(E* data, S norm_factor, uint32_t size)
|
||||
{
|
||||
uint32_t tid = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
if (tid >= size) return;
|
||||
data[tid] = data[tid] * norm_factor;
|
||||
}
|
||||
|
||||
@@ -666,6 +748,7 @@ namespace ntt {
|
||||
uint32_t log_size,
|
||||
uint32_t tw_log_size,
|
||||
uint32_t batch_size,
|
||||
bool columns_batch,
|
||||
bool inv,
|
||||
bool normalize,
|
||||
bool dit,
|
||||
@@ -679,72 +762,83 @@ namespace ntt {
|
||||
}
|
||||
|
||||
if (log_size == 4) {
|
||||
const int NOF_THREADS = min(64, 2 * batch_size);
|
||||
const int NOF_BLOCKS = (2 * batch_size + NOF_THREADS - 1) / NOF_THREADS;
|
||||
|
||||
const int NOF_THREADS = columns_batch ? 64 : min(64, 2 * batch_size);
|
||||
const int NOF_BLOCKS =
|
||||
columns_batch ? ((batch_size + 31) / 32) : (2 * batch_size + NOF_THREADS - 1) / NOF_THREADS;
|
||||
if (dit) {
|
||||
ntt16dit<<<NOF_BLOCKS, NOF_THREADS, 8 * 64 * sizeof(E), cuda_stream>>>(
|
||||
in, out, external_twiddles, internal_twiddles, basic_twiddles, log_size, tw_log_size, batch_size, 1, 0, 0,
|
||||
false, 0, inv, dit, fast_tw);
|
||||
in, out, external_twiddles, internal_twiddles, basic_twiddles, log_size, tw_log_size,
|
||||
columns_batch ? batch_size : 0, columns_batch ? 1 : batch_size, 1, 0, 0, columns_batch, 0, inv, dit, fast_tw);
|
||||
} else { // dif
|
||||
ntt16<<<NOF_BLOCKS, NOF_THREADS, 8 * 64 * sizeof(E), cuda_stream>>>(
|
||||
in, out, external_twiddles, internal_twiddles, basic_twiddles, log_size, tw_log_size, batch_size, 1, 0, 0,
|
||||
false, 0, inv, dit, fast_tw);
|
||||
in, out, external_twiddles, internal_twiddles, basic_twiddles, log_size, tw_log_size,
|
||||
columns_batch ? batch_size : 0, columns_batch ? 1 : batch_size, 1, 0, 0, columns_batch, 0, inv, dit, fast_tw);
|
||||
}
|
||||
if (normalize) normalize_kernel<<<batch_size, 16, 0, cuda_stream>>>(out, S::inv_log_size(4));
|
||||
if (normalize)
|
||||
normalize_kernel<<<batch_size, 16, 0, cuda_stream>>>(out, S::inv_log_size(4), (1 << log_size) * batch_size);
|
||||
return CHK_LAST();
|
||||
}
|
||||
|
||||
if (log_size == 5) {
|
||||
const int NOF_THREADS = min(64, 4 * batch_size);
|
||||
const int NOF_BLOCKS = (4 * batch_size + NOF_THREADS - 1) / NOF_THREADS;
|
||||
const int NOF_THREADS = columns_batch ? 64 : min(64, 4 * batch_size);
|
||||
const int NOF_BLOCKS =
|
||||
columns_batch ? ((batch_size + 15) / 16) : (4 * batch_size + NOF_THREADS - 1) / NOF_THREADS;
|
||||
if (dit) {
|
||||
ntt32dit<<<NOF_BLOCKS, NOF_THREADS, 8 * 64 * sizeof(E), cuda_stream>>>(
|
||||
in, out, external_twiddles, internal_twiddles, basic_twiddles, log_size, tw_log_size, batch_size, 1, 0, 0,
|
||||
false, 0, inv, dit, fast_tw);
|
||||
in, out, external_twiddles, internal_twiddles, basic_twiddles, log_size, tw_log_size,
|
||||
columns_batch ? batch_size : 0, columns_batch ? 1 : batch_size, 1, 0, 0, columns_batch, 0, inv, dit, fast_tw);
|
||||
} else { // dif
|
||||
ntt32<<<NOF_BLOCKS, NOF_THREADS, 8 * 64 * sizeof(E), cuda_stream>>>(
|
||||
in, out, external_twiddles, internal_twiddles, basic_twiddles, log_size, tw_log_size, batch_size, 1, 0, 0,
|
||||
false, 0, inv, dit, fast_tw);
|
||||
in, out, external_twiddles, internal_twiddles, basic_twiddles, log_size, tw_log_size,
|
||||
columns_batch ? batch_size : 0, columns_batch ? 1 : batch_size, 1, 0, 0, columns_batch, 0, inv, dit, fast_tw);
|
||||
}
|
||||
if (normalize) normalize_kernel<<<batch_size, 32, 0, cuda_stream>>>(out, S::inv_log_size(5));
|
||||
if (normalize)
|
||||
normalize_kernel<<<batch_size, 32, 0, cuda_stream>>>(out, S::inv_log_size(5), (1 << log_size) * batch_size);
|
||||
return CHK_LAST();
|
||||
}
|
||||
|
||||
if (log_size == 6) {
|
||||
const int NOF_THREADS = min(64, 8 * batch_size);
|
||||
const int NOF_BLOCKS = (8 * batch_size + NOF_THREADS - 1) / NOF_THREADS;
|
||||
const int NOF_THREADS = columns_batch ? 64 : min(64, 8 * batch_size);
|
||||
const int NOF_BLOCKS =
|
||||
columns_batch ? ((batch_size + 7) / 8) : ((8 * batch_size + NOF_THREADS - 1) / NOF_THREADS);
|
||||
ntt64<<<NOF_BLOCKS, NOF_THREADS, 8 * 64 * sizeof(E), cuda_stream>>>(
|
||||
in, out, external_twiddles, internal_twiddles, basic_twiddles, log_size, tw_log_size, batch_size, 1, 0, 0,
|
||||
false, 0, inv, dit, fast_tw);
|
||||
if (normalize) normalize_kernel<<<batch_size, 64, 0, cuda_stream>>>(out, S::inv_log_size(6));
|
||||
in, out, external_twiddles, internal_twiddles, basic_twiddles, log_size, tw_log_size,
|
||||
columns_batch ? batch_size : 0, columns_batch ? 1 : batch_size, 1, 0, 0, columns_batch, 0, inv, dit, fast_tw);
|
||||
if (normalize)
|
||||
normalize_kernel<<<batch_size, 64, 0, cuda_stream>>>(out, S::inv_log_size(6), (1 << log_size) * batch_size);
|
||||
return CHK_LAST();
|
||||
}
|
||||
|
||||
if (log_size == 8) {
|
||||
const int NOF_THREADS = 64;
|
||||
const int NOF_BLOCKS = (32 * batch_size + NOF_THREADS - 1) / NOF_THREADS;
|
||||
const int NOF_BLOCKS =
|
||||
columns_batch ? ((batch_size + 31) / 32 * 16) : ((32 * batch_size + NOF_THREADS - 1) / NOF_THREADS);
|
||||
if (dit) {
|
||||
ntt16dit<<<NOF_BLOCKS, NOF_THREADS, 8 * 64 * sizeof(E), cuda_stream>>>(
|
||||
in, out, external_twiddles, internal_twiddles, basic_twiddles, log_size, tw_log_size,
|
||||
(1 << log_size - 4) * batch_size, 1, 0, 0, false, 0, inv, dit, fast_tw);
|
||||
columns_batch ? batch_size : 0, (1 << log_size - 4) * (columns_batch ? 1 : batch_size), 1, 0, 0,
|
||||
columns_batch, 0, inv, dit, fast_tw);
|
||||
ntt16dit<<<NOF_BLOCKS, NOF_THREADS, 8 * 64 * sizeof(E), cuda_stream>>>(
|
||||
out, out, external_twiddles, internal_twiddles, basic_twiddles, log_size, tw_log_size,
|
||||
(1 << log_size - 4) * batch_size, 16, 4, 16, true, 1, inv, dit, fast_tw);
|
||||
columns_batch ? batch_size : 0, (1 << log_size - 4) * (columns_batch ? 1 : batch_size), 16, 4, 16, true, 1,
|
||||
inv, dit, fast_tw);
|
||||
} else { // dif
|
||||
ntt16<<<NOF_BLOCKS, NOF_THREADS, 8 * 64 * sizeof(E), cuda_stream>>>(
|
||||
in, out, external_twiddles, internal_twiddles, basic_twiddles, log_size, tw_log_size,
|
||||
(1 << log_size - 4) * batch_size, 16, 4, 16, true, 1, inv, dit, fast_tw);
|
||||
columns_batch ? batch_size : 0, (1 << log_size - 4) * (columns_batch ? 1 : batch_size), 16, 4, 16, true, 1,
|
||||
inv, dit, fast_tw);
|
||||
ntt16<<<NOF_BLOCKS, NOF_THREADS, 8 * 64 * sizeof(E), cuda_stream>>>(
|
||||
out, out, external_twiddles, internal_twiddles, basic_twiddles, log_size, tw_log_size,
|
||||
(1 << log_size - 4) * batch_size, 1, 0, 0, false, 0, inv, dit, fast_tw);
|
||||
columns_batch ? batch_size : 0, (1 << log_size - 4) * (columns_batch ? 1 : batch_size), 1, 0, 0,
|
||||
columns_batch, 0, inv, dit, fast_tw);
|
||||
}
|
||||
if (normalize) normalize_kernel<<<batch_size, 256, 0, cuda_stream>>>(out, S::inv_log_size(8));
|
||||
if (normalize)
|
||||
normalize_kernel<<<batch_size, 256, 0, cuda_stream>>>(out, S::inv_log_size(8), (1 << log_size) * batch_size);
|
||||
return CHK_LAST();
|
||||
}
|
||||
|
||||
// general case:
|
||||
uint32_t nof_blocks = (1 << (log_size - 9)) * batch_size;
|
||||
uint32_t nof_blocks = (1 << (log_size - 9)) * (columns_batch ? ((batch_size + 31) / 32) * 32 : batch_size);
|
||||
if (dit) {
|
||||
for (int i = 0; i < 5; i++) {
|
||||
uint32_t stage_size = fast_tw ? STAGE_SIZES_HOST_FT[log_size][i] : STAGE_SIZES_HOST[log_size][i];
|
||||
@@ -754,18 +848,18 @@ namespace ntt {
|
||||
if (stage_size == 6)
|
||||
ntt64<<<nof_blocks, 64, 8 * 64 * sizeof(E), cuda_stream>>>(
|
||||
i ? out : in, out, external_twiddles, internal_twiddles, basic_twiddles, log_size, tw_log_size,
|
||||
(1 << log_size - 6) * batch_size, 1 << stride_log, stride_log, i ? (1 << stride_log) : 0, i, i, inv, dit,
|
||||
fast_tw);
|
||||
columns_batch ? batch_size : 0, (1 << log_size - 6) * (columns_batch ? 1 : batch_size), 1 << stride_log,
|
||||
stride_log, i ? (1 << stride_log) : 0, i || columns_batch, i, inv, dit, fast_tw);
|
||||
else if (stage_size == 5)
|
||||
ntt32dit<<<nof_blocks, 64, 8 * 64 * sizeof(E), cuda_stream>>>(
|
||||
i ? out : in, out, external_twiddles, internal_twiddles, basic_twiddles, log_size, tw_log_size,
|
||||
(1 << log_size - 5) * batch_size, 1 << stride_log, stride_log, i ? (1 << stride_log) : 0, i, i, inv, dit,
|
||||
fast_tw);
|
||||
columns_batch ? batch_size : 0, (1 << log_size - 5) * (columns_batch ? 1 : batch_size), 1 << stride_log,
|
||||
stride_log, i ? (1 << stride_log) : 0, i || columns_batch, i, inv, dit, fast_tw);
|
||||
else if (stage_size == 4)
|
||||
ntt16dit<<<nof_blocks, 64, 8 * 64 * sizeof(E), cuda_stream>>>(
|
||||
i ? out : in, out, external_twiddles, internal_twiddles, basic_twiddles, log_size, tw_log_size,
|
||||
(1 << log_size - 4) * batch_size, 1 << stride_log, stride_log, i ? (1 << stride_log) : 0, i, i, inv, dit,
|
||||
fast_tw);
|
||||
columns_batch ? batch_size : 0, (1 << log_size - 4) * (columns_batch ? 1 : batch_size), 1 << stride_log,
|
||||
stride_log, i ? (1 << stride_log) : 0, i || columns_batch, i, inv, dit, fast_tw);
|
||||
}
|
||||
} else { // dif
|
||||
bool first_run = false, prev_stage = false;
|
||||
@@ -778,23 +872,24 @@ namespace ntt {
|
||||
if (stage_size == 6)
|
||||
ntt64<<<nof_blocks, 64, 8 * 64 * sizeof(E), cuda_stream>>>(
|
||||
first_run ? in : out, out, external_twiddles, internal_twiddles, basic_twiddles, log_size, tw_log_size,
|
||||
(1 << log_size - 6) * batch_size, 1 << stride_log, stride_log, i ? (1 << stride_log) : 0, i, i, inv, dit,
|
||||
fast_tw);
|
||||
columns_batch ? batch_size : 0, (1 << log_size - 6) * (columns_batch ? 1 : batch_size), 1 << stride_log,
|
||||
stride_log, i ? (1 << stride_log) : 0, i || columns_batch, i, inv, dit, fast_tw);
|
||||
else if (stage_size == 5)
|
||||
ntt32<<<nof_blocks, 64, 8 * 64 * sizeof(E), cuda_stream>>>(
|
||||
first_run ? in : out, out, external_twiddles, internal_twiddles, basic_twiddles, log_size, tw_log_size,
|
||||
(1 << log_size - 5) * batch_size, 1 << stride_log, stride_log, i ? (1 << stride_log) : 0, i, i, inv, dit,
|
||||
fast_tw);
|
||||
columns_batch ? batch_size : 0, (1 << log_size - 5) * (columns_batch ? 1 : batch_size), 1 << stride_log,
|
||||
stride_log, i ? (1 << stride_log) : 0, i || columns_batch, i, inv, dit, fast_tw);
|
||||
else if (stage_size == 4)
|
||||
ntt16<<<nof_blocks, 64, 8 * 64 * sizeof(E), cuda_stream>>>(
|
||||
first_run ? in : out, out, external_twiddles, internal_twiddles, basic_twiddles, log_size, tw_log_size,
|
||||
(1 << log_size - 4) * batch_size, 1 << stride_log, stride_log, i ? (1 << stride_log) : 0, i, i, inv, dit,
|
||||
fast_tw);
|
||||
columns_batch ? batch_size : 0, (1 << log_size - 4) * (columns_batch ? 1 : batch_size), 1 << stride_log,
|
||||
stride_log, i ? (1 << stride_log) : 0, i || columns_batch, i, inv, dit, fast_tw);
|
||||
prev_stage = stage_size;
|
||||
}
|
||||
}
|
||||
if (normalize)
|
||||
normalize_kernel<<<(1 << (log_size - 8)) * batch_size, 256, 0, cuda_stream>>>(out, S::inv_log_size(log_size));
|
||||
normalize_kernel<<<(1 << (log_size - 8)) * batch_size, 256, 0, cuda_stream>>>(
|
||||
out, S::inv_log_size(log_size), (1 << log_size) * batch_size);
|
||||
|
||||
return CHK_LAST();
|
||||
}
|
||||
@@ -809,6 +904,7 @@ namespace ntt {
|
||||
int ntt_size,
|
||||
int max_logn,
|
||||
int batch_size,
|
||||
bool columns_batch,
|
||||
bool is_inverse,
|
||||
bool fast_tw,
|
||||
Ordering ordering,
|
||||
@@ -858,9 +954,10 @@ namespace ntt {
|
||||
}
|
||||
|
||||
if (is_on_coset && !is_inverse) {
|
||||
batch_elementwise_mul_with_reorder<<<NOF_BLOCKS, NOF_THREADS, 0, cuda_stream>>>(
|
||||
d_input, ntt_size, batch_size, arbitrary_coset ? arbitrary_coset : external_twiddles,
|
||||
arbitrary_coset ? 1 : coset_gen_index, n_twiddles, logn, reverse_coset, dit, d_output);
|
||||
batch_elementwise_mul_with_reorder_kernel<<<NOF_BLOCKS, NOF_THREADS, 0, cuda_stream>>>(
|
||||
d_input, ntt_size, columns_batch, batch_size, columns_batch ? batch_size : 1,
|
||||
arbitrary_coset ? arbitrary_coset : external_twiddles, arbitrary_coset ? 1 : coset_gen_index, n_twiddles, logn,
|
||||
reverse_coset, dit, d_output);
|
||||
|
||||
d_input = d_output;
|
||||
}
|
||||
@@ -869,10 +966,11 @@ namespace ntt {
|
||||
const bool is_reverse_in_place = (d_input == d_output);
|
||||
if (is_reverse_in_place) {
|
||||
reorder_digits_inplace_and_normalize_kernel<<<NOF_BLOCKS, NOF_THREADS, 0, cuda_stream>>>(
|
||||
d_output, logn, dit, fast_tw, reverse_input, is_normalize, S::inv_log_size(logn));
|
||||
d_output, logn, columns_batch, batch_size, dit, fast_tw, reverse_input, is_normalize, S::inv_log_size(logn));
|
||||
} else {
|
||||
reorder_digits_and_normalize_kernel<<<NOF_BLOCKS, NOF_THREADS, 0, cuda_stream>>>(
|
||||
d_input, d_output, logn, dit, fast_tw, reverse_input, is_normalize, S::inv_log_size(logn));
|
||||
d_input, d_output, logn, columns_batch, batch_size, columns_batch ? batch_size : 1, dit, fast_tw,
|
||||
reverse_input, is_normalize, S::inv_log_size(logn));
|
||||
}
|
||||
is_normalize = false;
|
||||
d_input = d_output;
|
||||
@@ -880,18 +978,19 @@ namespace ntt {
|
||||
|
||||
// inplace ntt
|
||||
CHK_IF_RETURN(large_ntt(
|
||||
d_input, d_output, external_twiddles, internal_twiddles, basic_twiddles, logn, max_logn, batch_size, is_inverse,
|
||||
(is_normalize && reverse_output == eRevType::None), dit, fast_tw, cuda_stream));
|
||||
d_input, d_output, external_twiddles, internal_twiddles, basic_twiddles, logn, max_logn, batch_size,
|
||||
columns_batch, is_inverse, (is_normalize && reverse_output == eRevType::None), dit, fast_tw, cuda_stream));
|
||||
|
||||
if (reverse_output != eRevType::None) {
|
||||
reorder_digits_inplace_and_normalize_kernel<<<NOF_BLOCKS, NOF_THREADS, 0, cuda_stream>>>(
|
||||
d_output, logn, dit, fast_tw, reverse_output, is_normalize, S::inv_log_size(logn));
|
||||
d_output, logn, columns_batch, batch_size, dit, fast_tw, reverse_output, is_normalize, S::inv_log_size(logn));
|
||||
}
|
||||
|
||||
if (is_on_coset && is_inverse) {
|
||||
batch_elementwise_mul_with_reorder<<<NOF_BLOCKS, NOF_THREADS, 0, cuda_stream>>>(
|
||||
d_output, ntt_size, batch_size, arbitrary_coset ? arbitrary_coset : external_twiddles + n_twiddles,
|
||||
arbitrary_coset ? 1 : -coset_gen_index, n_twiddles, logn, reverse_coset, dit, d_output);
|
||||
batch_elementwise_mul_with_reorder_kernel<<<NOF_BLOCKS, NOF_THREADS, 0, cuda_stream>>>(
|
||||
d_output, ntt_size, columns_batch, batch_size, columns_batch ? batch_size : 1,
|
||||
arbitrary_coset ? arbitrary_coset : external_twiddles + n_twiddles, arbitrary_coset ? 1 : -coset_gen_index,
|
||||
n_twiddles, logn, reverse_coset, dit, d_output);
|
||||
}
|
||||
|
||||
return CHK_LAST();
|
||||
@@ -923,6 +1022,7 @@ namespace ntt {
|
||||
int ntt_size,
|
||||
int max_logn,
|
||||
int batch_size,
|
||||
bool columns_batch,
|
||||
bool is_inverse,
|
||||
bool fast_tw,
|
||||
Ordering ordering,
|
||||
|
||||
@@ -516,12 +516,15 @@ namespace ntt {
|
||||
static bool is_choose_radix2_algorithm(int logn, int batch_size, const NTTConfig<S>& config)
|
||||
{
|
||||
const bool is_mixed_radix_alg_supported = (logn > 3 && logn != 7);
|
||||
if (!is_mixed_radix_alg_supported && config.columns_batch)
|
||||
throw IcicleError(IcicleError_t::InvalidArgument, "columns batch is not supported for given NTT size");
|
||||
const bool is_user_selected_radix2_alg = config.ntt_algorithm == NttAlgorithm::Radix2;
|
||||
const bool is_force_radix2 = !is_mixed_radix_alg_supported || is_user_selected_radix2_alg;
|
||||
if (is_force_radix2) return true;
|
||||
|
||||
const bool is_user_selected_mixed_radix_alg = config.ntt_algorithm == NttAlgorithm::MixedRadix;
|
||||
if (is_user_selected_mixed_radix_alg) return false;
|
||||
if (config.columns_batch) return false; // radix2 does not currently support columns batch mode.
|
||||
|
||||
// Heuristic to automatically select an algorithm
|
||||
// Note that generally the decision depends on {logn, batch, ordering, inverse, coset, in-place, coeff-field} and
|
||||
@@ -663,7 +666,7 @@ namespace ntt {
|
||||
|
||||
CHK_IF_RETURN(ntt::mixed_radix_ntt(
|
||||
d_input, d_output, twiddles, internal_twiddles, basic_twiddles, size, domain.max_log_size, batch_size,
|
||||
is_inverse, is_fast_twiddles_enabled, config.ordering, coset, coset_index, stream));
|
||||
config.columns_batch, is_inverse, is_fast_twiddles_enabled, config.ordering, coset, coset_index, stream));
|
||||
}
|
||||
|
||||
if (!are_outputs_on_device)
|
||||
@@ -685,6 +688,7 @@ namespace ntt {
|
||||
ctx, // ctx
|
||||
S::one(), // coset_gen
|
||||
1, // batch_size
|
||||
false, // columns_batch
|
||||
Ordering::kNN, // ordering
|
||||
false, // are_inputs_on_device
|
||||
false, // are_outputs_on_device
|
||||
|
||||
@@ -95,6 +95,8 @@ namespace ntt {
|
||||
S coset_gen; /**< Coset generator. Used to perform coset (i)NTTs. Default value: `S::one()`
|
||||
* (corresponding to no coset being used). */
|
||||
int batch_size; /**< The number of NTTs to compute. Default value: 1. */
|
||||
bool columns_batch; /**< True if the batches are the columns of an input matrix
|
||||
(they are strided in memory with a stride of ntt size) Default value: false. */
|
||||
Ordering ordering; /**< Ordering of inputs and outputs. See [Ordering](@ref Ordering). Default value:
|
||||
* `Ordering::kNN`. */
|
||||
bool are_inputs_on_device; /**< True if inputs are on device and false if they're on host. Default value: false. */
|
||||
|
||||
@@ -35,6 +35,7 @@ namespace ntt {
|
||||
int ntt_size,
|
||||
int max_logn,
|
||||
int batch_size,
|
||||
bool columns_batch,
|
||||
bool is_inverse,
|
||||
bool fast_tw,
|
||||
Ordering ordering,
|
||||
|
||||
@@ -29,6 +29,13 @@ void incremental_values(test_scalar* res, uint32_t count)
|
||||
}
|
||||
}
|
||||
|
||||
__global__ void transpose_batch(test_scalar* in, test_scalar* out, int row_size, int column_size)
|
||||
{
|
||||
int tid = blockDim.x * blockIdx.x + threadIdx.x;
|
||||
if (tid >= row_size * column_size) return;
|
||||
out[(tid % row_size) * column_size + (tid / row_size)] = in[tid];
|
||||
}
|
||||
|
||||
int main(int argc, char** argv)
|
||||
{
|
||||
cudaEvent_t icicle_start, icicle_stop, new_start, new_stop;
|
||||
@@ -37,11 +44,12 @@ int main(int argc, char** argv)
|
||||
int NTT_LOG_SIZE = (argc > 1) ? atoi(argv[1]) : 19;
|
||||
int NTT_SIZE = 1 << NTT_LOG_SIZE;
|
||||
bool INPLACE = (argc > 2) ? atoi(argv[2]) : false;
|
||||
int INV = (argc > 3) ? atoi(argv[3]) : true;
|
||||
int BATCH_SIZE = (argc > 4) ? atoi(argv[4]) : 1;
|
||||
int COSET_IDX = (argc > 5) ? atoi(argv[5]) : 0;
|
||||
const ntt::Ordering ordering = (argc > 6) ? ntt::Ordering(atoi(argv[6])) : ntt::Ordering::kNN;
|
||||
bool FAST_TW = (argc > 7) ? atoi(argv[7]) : true;
|
||||
int INV = (argc > 3) ? atoi(argv[3]) : false;
|
||||
int BATCH_SIZE = (argc > 4) ? atoi(argv[4]) : 150;
|
||||
bool COLUMNS_BATCH = (argc > 5) ? atoi(argv[5]) : false;
|
||||
int COSET_IDX = (argc > 6) ? atoi(argv[6]) : 2;
|
||||
const ntt::Ordering ordering = (argc > 7) ? ntt::Ordering(atoi(argv[7])) : ntt::Ordering::kNN;
|
||||
bool FAST_TW = (argc > 8) ? atoi(argv[8]) : true;
|
||||
|
||||
// Note: NM, MN are not expected to be equal when comparing mixed-radix and radix-2 NTTs
|
||||
const char* ordering_str = ordering == ntt::Ordering::kNN ? "NN"
|
||||
@@ -52,8 +60,8 @@ int main(int argc, char** argv)
|
||||
: "MN";
|
||||
|
||||
printf(
|
||||
"running ntt 2^%d, inplace=%d, inverse=%d, batch_size=%d, coset-idx=%d, ordering=%s, fast_tw=%d\n", NTT_LOG_SIZE,
|
||||
INPLACE, INV, BATCH_SIZE, COSET_IDX, ordering_str, FAST_TW);
|
||||
"running ntt 2^%d, inplace=%d, inverse=%d, batch_size=%d, columns_batch=%d coset-idx=%d, ordering=%s, fast_tw=%d\n",
|
||||
NTT_LOG_SIZE, INPLACE, INV, BATCH_SIZE, COLUMNS_BATCH, COSET_IDX, ordering_str, FAST_TW);
|
||||
|
||||
CHK_IF_RETURN(cudaFree(nullptr)); // init GPU context (warmup)
|
||||
|
||||
@@ -63,6 +71,7 @@ int main(int argc, char** argv)
|
||||
ntt_config.are_inputs_on_device = true;
|
||||
ntt_config.are_outputs_on_device = true;
|
||||
ntt_config.batch_size = BATCH_SIZE;
|
||||
ntt_config.columns_batch = COLUMNS_BATCH;
|
||||
|
||||
CHK_IF_RETURN(cudaEventCreate(&icicle_start));
|
||||
CHK_IF_RETURN(cudaEventCreate(&icicle_stop));
|
||||
@@ -83,7 +92,9 @@ int main(int argc, char** argv)
|
||||
|
||||
// gpu allocation
|
||||
test_data *GpuScalars, *GpuOutputOld, *GpuOutputNew;
|
||||
test_data* GpuScalarsTransposed;
|
||||
CHK_IF_RETURN(cudaMalloc(&GpuScalars, sizeof(test_data) * NTT_SIZE * BATCH_SIZE));
|
||||
CHK_IF_RETURN(cudaMalloc(&GpuScalarsTransposed, sizeof(test_data) * NTT_SIZE * BATCH_SIZE));
|
||||
CHK_IF_RETURN(cudaMalloc(&GpuOutputOld, sizeof(test_data) * NTT_SIZE * BATCH_SIZE));
|
||||
CHK_IF_RETURN(cudaMalloc(&GpuOutputNew, sizeof(test_data) * NTT_SIZE * BATCH_SIZE));
|
||||
|
||||
@@ -93,10 +104,16 @@ int main(int argc, char** argv)
|
||||
CHK_IF_RETURN(
|
||||
cudaMemcpy(GpuScalars, CpuScalars.get(), NTT_SIZE * BATCH_SIZE * sizeof(test_data), cudaMemcpyHostToDevice));
|
||||
|
||||
if (COLUMNS_BATCH) {
|
||||
transpose_batch<<<(NTT_SIZE * BATCH_SIZE + 256 - 1) / 256, 256>>>(
|
||||
GpuScalars, GpuScalarsTransposed, NTT_SIZE, BATCH_SIZE);
|
||||
}
|
||||
|
||||
// inplace
|
||||
if (INPLACE) {
|
||||
CHK_IF_RETURN(
|
||||
cudaMemcpy(GpuOutputNew, GpuScalars, NTT_SIZE * BATCH_SIZE * sizeof(test_data), cudaMemcpyDeviceToDevice));
|
||||
CHK_IF_RETURN(cudaMemcpy(
|
||||
GpuOutputNew, COLUMNS_BATCH ? GpuScalarsTransposed : GpuScalars, NTT_SIZE * BATCH_SIZE * sizeof(test_data),
|
||||
cudaMemcpyDeviceToDevice));
|
||||
}
|
||||
|
||||
for (int coset_idx = 0; coset_idx < COSET_IDX; ++coset_idx) {
|
||||
@@ -109,13 +126,14 @@ int main(int argc, char** argv)
|
||||
ntt_config.ntt_algorithm = ntt::NttAlgorithm::MixedRadix;
|
||||
for (size_t i = 0; i < iterations; i++) {
|
||||
CHK_IF_RETURN(ntt::NTT(
|
||||
INPLACE ? GpuOutputNew : GpuScalars, NTT_SIZE, INV ? ntt::NTTDir::kInverse : ntt::NTTDir::kForward, ntt_config,
|
||||
GpuOutputNew));
|
||||
INPLACE ? GpuOutputNew
|
||||
: COLUMNS_BATCH ? GpuScalarsTransposed
|
||||
: GpuScalars,
|
||||
NTT_SIZE, INV ? ntt::NTTDir::kInverse : ntt::NTTDir::kForward, ntt_config, GpuOutputNew));
|
||||
}
|
||||
CHK_IF_RETURN(cudaEventRecord(new_stop, ntt_config.ctx.stream));
|
||||
CHK_IF_RETURN(cudaStreamSynchronize(ntt_config.ctx.stream));
|
||||
CHK_IF_RETURN(cudaEventElapsedTime(&new_time, new_start, new_stop));
|
||||
if (is_print) { fprintf(stderr, "cuda err %d\n", cudaGetLastError()); }
|
||||
|
||||
// OLD
|
||||
CHK_IF_RETURN(cudaEventRecord(icicle_start, ntt_config.ctx.stream));
|
||||
@@ -127,7 +145,6 @@ int main(int argc, char** argv)
|
||||
CHK_IF_RETURN(cudaEventRecord(icicle_stop, ntt_config.ctx.stream));
|
||||
CHK_IF_RETURN(cudaStreamSynchronize(ntt_config.ctx.stream));
|
||||
CHK_IF_RETURN(cudaEventElapsedTime(&icicle_time, icicle_start, icicle_stop));
|
||||
if (is_print) { fprintf(stderr, "cuda err %d\n", cudaGetLastError()); }
|
||||
|
||||
if (is_print) {
|
||||
printf("Old Runtime=%0.3f MS\n", icicle_time / iterations);
|
||||
@@ -140,11 +157,19 @@ int main(int argc, char** argv)
|
||||
CHK_IF_RETURN(benchmark(false /*=print*/, 1)); // warmup
|
||||
int count = INPLACE ? 1 : 10;
|
||||
if (INPLACE) {
|
||||
CHK_IF_RETURN(
|
||||
cudaMemcpy(GpuOutputNew, GpuScalars, NTT_SIZE * BATCH_SIZE * sizeof(test_data), cudaMemcpyDeviceToDevice));
|
||||
CHK_IF_RETURN(cudaMemcpy(
|
||||
GpuOutputNew, COLUMNS_BATCH ? GpuScalarsTransposed : GpuScalars, NTT_SIZE * BATCH_SIZE * sizeof(test_data),
|
||||
cudaMemcpyDeviceToDevice));
|
||||
}
|
||||
CHK_IF_RETURN(benchmark(true /*=print*/, count));
|
||||
|
||||
if (COLUMNS_BATCH) {
|
||||
transpose_batch<<<(NTT_SIZE * BATCH_SIZE + 256 - 1) / 256, 256>>>(
|
||||
GpuOutputNew, GpuScalarsTransposed, BATCH_SIZE, NTT_SIZE);
|
||||
CHK_IF_RETURN(cudaMemcpy(
|
||||
GpuOutputNew, GpuScalarsTransposed, NTT_SIZE * BATCH_SIZE * sizeof(test_data), cudaMemcpyDeviceToDevice));
|
||||
}
|
||||
|
||||
// verify
|
||||
CHK_IF_RETURN(
|
||||
cudaMemcpy(CpuOutputNew.get(), GpuOutputNew, NTT_SIZE * BATCH_SIZE * sizeof(test_data), cudaMemcpyDeviceToHost));
|
||||
@@ -153,10 +178,11 @@ int main(int argc, char** argv)
|
||||
|
||||
bool success = true;
|
||||
for (int i = 0; i < NTT_SIZE * BATCH_SIZE; i++) {
|
||||
// if (i%64==0) printf("\n");
|
||||
if (CpuOutputNew[i] != CpuOutputOld[i]) {
|
||||
success = false;
|
||||
// std::cout << i << " ref " << CpuOutputOld[i] << " != " << CpuOutputNew[i] << std::endl;
|
||||
break;
|
||||
// break;
|
||||
} else {
|
||||
// std::cout << i << " ref " << CpuOutputOld[i] << " == " << CpuOutputNew[i] << std::endl;
|
||||
// break;
|
||||
|
||||
@@ -9,6 +9,7 @@
|
||||
struct stage_metadata {
|
||||
uint32_t th_stride;
|
||||
uint32_t ntt_block_size;
|
||||
uint32_t batch_id;
|
||||
uint32_t ntt_block_id;
|
||||
uint32_t ntt_inp_id;
|
||||
};
|
||||
@@ -118,7 +119,7 @@ public:
|
||||
}
|
||||
|
||||
__device__ __forceinline__ void
|
||||
loadExternalTwiddles64(S* data, uint32_t tw_order, uint32_t tw_log_order, bool strided, stage_metadata s_meta)
|
||||
loadExternalTwiddles64(S* data, uint32_t tw_order, uint32_t tw_log_order, stage_metadata s_meta)
|
||||
{
|
||||
data += tw_order * s_meta.ntt_inp_id + (s_meta.ntt_block_id & (tw_order - 1));
|
||||
|
||||
@@ -129,7 +130,7 @@ public:
|
||||
}
|
||||
|
||||
__device__ __forceinline__ void
|
||||
loadExternalTwiddles32(S* data, uint32_t tw_order, uint32_t tw_log_order, bool strided, stage_metadata s_meta)
|
||||
loadExternalTwiddles32(S* data, uint32_t tw_order, uint32_t tw_log_order, stage_metadata s_meta)
|
||||
{
|
||||
data += tw_order * s_meta.ntt_inp_id * 2 + (s_meta.ntt_block_id & (tw_order - 1));
|
||||
|
||||
@@ -143,7 +144,7 @@ public:
|
||||
}
|
||||
|
||||
__device__ __forceinline__ void
|
||||
loadExternalTwiddles16(S* data, uint32_t tw_order, uint32_t tw_log_order, bool strided, stage_metadata s_meta)
|
||||
loadExternalTwiddles16(S* data, uint32_t tw_order, uint32_t tw_log_order, stage_metadata s_meta)
|
||||
{
|
||||
data += tw_order * s_meta.ntt_inp_id * 4 + (s_meta.ntt_block_id & (tw_order - 1));
|
||||
|
||||
@@ -195,8 +196,8 @@ public:
|
||||
}
|
||||
}
|
||||
|
||||
__device__ __forceinline__ void loadGlobalData(
|
||||
E* data, uint32_t data_stride, uint32_t log_data_stride, uint32_t log_size, bool strided, stage_metadata s_meta)
|
||||
__device__ __forceinline__ void
|
||||
loadGlobalData(E* data, uint32_t data_stride, uint32_t log_data_stride, bool strided, stage_metadata s_meta)
|
||||
{
|
||||
if (strided) {
|
||||
data += (s_meta.ntt_block_id & (data_stride - 1)) + data_stride * s_meta.ntt_inp_id +
|
||||
@@ -211,8 +212,22 @@ public:
|
||||
}
|
||||
}
|
||||
|
||||
__device__ __forceinline__ void storeGlobalData(
|
||||
E* data, uint32_t data_stride, uint32_t log_data_stride, uint32_t log_size, bool strided, stage_metadata s_meta)
|
||||
__device__ __forceinline__ void loadGlobalDataColumnBatch(
|
||||
E* data, uint32_t data_stride, uint32_t log_data_stride, stage_metadata s_meta, uint32_t batch_size)
|
||||
{
|
||||
data += ((s_meta.ntt_block_id & (data_stride - 1)) + data_stride * s_meta.ntt_inp_id +
|
||||
(s_meta.ntt_block_id >> log_data_stride) * data_stride * s_meta.ntt_block_size) *
|
||||
batch_size +
|
||||
s_meta.batch_id;
|
||||
|
||||
#pragma unroll
|
||||
for (uint32_t i = 0; i < 8; i++) {
|
||||
X[i] = data[s_meta.th_stride * i * data_stride * batch_size];
|
||||
}
|
||||
}
|
||||
|
||||
__device__ __forceinline__ void
|
||||
storeGlobalData(E* data, uint32_t data_stride, uint32_t log_data_stride, bool strided, stage_metadata s_meta)
|
||||
{
|
||||
if (strided) {
|
||||
data += (s_meta.ntt_block_id & (data_stride - 1)) + data_stride * s_meta.ntt_inp_id +
|
||||
@@ -227,8 +242,22 @@ public:
|
||||
}
|
||||
}
|
||||
|
||||
__device__ __forceinline__ void loadGlobalData32(
|
||||
E* data, uint32_t data_stride, uint32_t log_data_stride, uint32_t log_size, bool strided, stage_metadata s_meta)
|
||||
__device__ __forceinline__ void storeGlobalDataColumnBatch(
|
||||
E* data, uint32_t data_stride, uint32_t log_data_stride, stage_metadata s_meta, uint32_t batch_size)
|
||||
{
|
||||
data += ((s_meta.ntt_block_id & (data_stride - 1)) + data_stride * s_meta.ntt_inp_id +
|
||||
(s_meta.ntt_block_id >> log_data_stride) * data_stride * s_meta.ntt_block_size) *
|
||||
batch_size +
|
||||
s_meta.batch_id;
|
||||
|
||||
#pragma unroll
|
||||
for (uint32_t i = 0; i < 8; i++) {
|
||||
data[s_meta.th_stride * i * data_stride * batch_size] = X[i];
|
||||
}
|
||||
}
|
||||
|
||||
__device__ __forceinline__ void
|
||||
loadGlobalData32(E* data, uint32_t data_stride, uint32_t log_data_stride, bool strided, stage_metadata s_meta)
|
||||
{
|
||||
if (strided) {
|
||||
data += (s_meta.ntt_block_id & (data_stride - 1)) + data_stride * s_meta.ntt_inp_id * 2 +
|
||||
@@ -246,8 +275,25 @@ public:
|
||||
}
|
||||
}
|
||||
|
||||
__device__ __forceinline__ void storeGlobalData32(
|
||||
E* data, uint32_t data_stride, uint32_t log_data_stride, uint32_t log_size, bool strided, stage_metadata s_meta)
|
||||
__device__ __forceinline__ void loadGlobalData32ColumnBatch(
|
||||
E* data, uint32_t data_stride, uint32_t log_data_stride, stage_metadata s_meta, uint32_t batch_size)
|
||||
{
|
||||
data += ((s_meta.ntt_block_id & (data_stride - 1)) + data_stride * s_meta.ntt_inp_id * 2 +
|
||||
(s_meta.ntt_block_id >> log_data_stride) * data_stride * s_meta.ntt_block_size) *
|
||||
batch_size +
|
||||
s_meta.batch_id;
|
||||
|
||||
#pragma unroll
|
||||
for (uint32_t j = 0; j < 2; j++) {
|
||||
#pragma unroll
|
||||
for (uint32_t i = 0; i < 4; i++) {
|
||||
X[4 * j + i] = data[(8 * i + j) * data_stride * batch_size];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
__device__ __forceinline__ void
|
||||
storeGlobalData32(E* data, uint32_t data_stride, uint32_t log_data_stride, bool strided, stage_metadata s_meta)
|
||||
{
|
||||
if (strided) {
|
||||
data += (s_meta.ntt_block_id & (data_stride - 1)) + data_stride * s_meta.ntt_inp_id * 2 +
|
||||
@@ -265,8 +311,25 @@ public:
|
||||
}
|
||||
}
|
||||
|
||||
__device__ __forceinline__ void loadGlobalData16(
|
||||
E* data, uint32_t data_stride, uint32_t log_data_stride, uint32_t log_size, bool strided, stage_metadata s_meta)
|
||||
__device__ __forceinline__ void storeGlobalData32ColumnBatch(
|
||||
E* data, uint32_t data_stride, uint32_t log_data_stride, stage_metadata s_meta, uint32_t batch_size)
|
||||
{
|
||||
data += ((s_meta.ntt_block_id & (data_stride - 1)) + data_stride * s_meta.ntt_inp_id * 2 +
|
||||
(s_meta.ntt_block_id >> log_data_stride) * data_stride * s_meta.ntt_block_size) *
|
||||
batch_size +
|
||||
s_meta.batch_id;
|
||||
|
||||
#pragma unroll
|
||||
for (uint32_t j = 0; j < 2; j++) {
|
||||
#pragma unroll
|
||||
for (uint32_t i = 0; i < 4; i++) {
|
||||
data[(8 * i + j) * data_stride * batch_size] = X[4 * j + i];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
__device__ __forceinline__ void
|
||||
loadGlobalData16(E* data, uint32_t data_stride, uint32_t log_data_stride, bool strided, stage_metadata s_meta)
|
||||
{
|
||||
if (strided) {
|
||||
data += (s_meta.ntt_block_id & (data_stride - 1)) + data_stride * s_meta.ntt_inp_id * 4 +
|
||||
@@ -284,8 +347,25 @@ public:
|
||||
}
|
||||
}
|
||||
|
||||
__device__ __forceinline__ void storeGlobalData16(
|
||||
E* data, uint32_t data_stride, uint32_t log_data_stride, uint32_t log_size, bool strided, stage_metadata s_meta)
|
||||
__device__ __forceinline__ void loadGlobalData16ColumnBatch(
|
||||
E* data, uint32_t data_stride, uint32_t log_data_stride, stage_metadata s_meta, uint32_t batch_size)
|
||||
{
|
||||
data += ((s_meta.ntt_block_id & (data_stride - 1)) + data_stride * s_meta.ntt_inp_id * 4 +
|
||||
(s_meta.ntt_block_id >> log_data_stride) * data_stride * s_meta.ntt_block_size) *
|
||||
batch_size +
|
||||
s_meta.batch_id;
|
||||
|
||||
#pragma unroll
|
||||
for (uint32_t j = 0; j < 4; j++) {
|
||||
#pragma unroll
|
||||
for (uint32_t i = 0; i < 2; i++) {
|
||||
X[2 * j + i] = data[(8 * i + j) * data_stride * batch_size];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
__device__ __forceinline__ void
|
||||
storeGlobalData16(E* data, uint32_t data_stride, uint32_t log_data_stride, bool strided, stage_metadata s_meta)
|
||||
{
|
||||
if (strided) {
|
||||
data += (s_meta.ntt_block_id & (data_stride - 1)) + data_stride * s_meta.ntt_inp_id * 4 +
|
||||
@@ -303,6 +383,23 @@ public:
|
||||
}
|
||||
}
|
||||
|
||||
__device__ __forceinline__ void storeGlobalData16ColumnBatch(
|
||||
E* data, uint32_t data_stride, uint32_t log_data_stride, stage_metadata s_meta, uint32_t batch_size)
|
||||
{
|
||||
data += ((s_meta.ntt_block_id & (data_stride - 1)) + data_stride * s_meta.ntt_inp_id * 4 +
|
||||
(s_meta.ntt_block_id >> log_data_stride) * data_stride * s_meta.ntt_block_size) *
|
||||
batch_size +
|
||||
s_meta.batch_id;
|
||||
|
||||
#pragma unroll
|
||||
for (uint32_t j = 0; j < 4; j++) {
|
||||
#pragma unroll
|
||||
for (uint32_t i = 0; i < 2; i++) {
|
||||
data[(8 * i + j) * data_stride * batch_size] = X[2 * j + i];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
__device__ __forceinline__ void ntt4_2()
|
||||
{
|
||||
#pragma unroll
|
||||
|
||||
@@ -31,6 +31,8 @@ type NTTConfig[T any] struct {
|
||||
CosetGen T
|
||||
/// The number of NTTs to compute. Default value: 1.
|
||||
BatchSize int32
|
||||
/// If true the function will compute the NTTs over the columns of the input matrix and not over the rows.
|
||||
ColumnsBatch bool
|
||||
/// Ordering of inputs and outputs. See [Ordering](@ref Ordering). Default value: `Ordering::kNN`.
|
||||
Ordering Ordering
|
||||
areInputsOnDevice bool
|
||||
@@ -46,6 +48,7 @@ func GetDefaultNTTConfig[T any](cosetGen T) NTTConfig[T] {
|
||||
ctx, // Ctx
|
||||
cosetGen, // CosetGen
|
||||
1, // BatchSize
|
||||
false, // ColumnsBatch
|
||||
KNN, // Ordering
|
||||
false, // areInputsOnDevice
|
||||
false, // areOutputsOnDevice
|
||||
|
||||
@@ -19,6 +19,7 @@ func TestNTTDefaultConfig(t *testing.T) {
|
||||
ctx, // Ctx
|
||||
cosetGen, // CosetGen
|
||||
1, // BatchSize
|
||||
false, // ColumnsBatch
|
||||
KNN, // Ordering
|
||||
false, // areInputsOnDevice
|
||||
false, // areOutputsOnDevice
|
||||
|
||||
@@ -77,6 +77,8 @@ pub struct NTTConfig<'a, S> {
|
||||
pub coset_gen: S,
|
||||
/// The number of NTTs to compute. Default value: 1.
|
||||
pub batch_size: i32,
|
||||
/// If true the function will compute the NTTs over the columns of the input matrix and not over the rows.
|
||||
pub columns_batch: bool,
|
||||
/// Ordering of inputs and outputs. See [Ordering](@ref Ordering). Default value: `Ordering::kNN`.
|
||||
pub ordering: Ordering,
|
||||
are_inputs_on_device: bool,
|
||||
@@ -101,6 +103,7 @@ impl<'a, S: FieldImpl> NTTConfig<'a, S> {
|
||||
ctx: DeviceContext::default_for_device(device_id),
|
||||
coset_gen: S::one(),
|
||||
batch_size: 1,
|
||||
columns_batch: false,
|
||||
ordering: Ordering::kNN,
|
||||
are_inputs_on_device: false,
|
||||
are_outputs_on_device: false,
|
||||
|
||||
@@ -44,6 +44,14 @@ pub fn reverse_bit_order(n: u32, order: u32) -> u32 {
|
||||
u32::from_str_radix(&reversed, 2).unwrap()
|
||||
}
|
||||
|
||||
pub fn transpose_flattened_matrix<T: Copy>(m: &[T], nrows: usize) -> Vec<T> {
|
||||
let ncols = m.len() / nrows;
|
||||
assert!(nrows * ncols == m.len());
|
||||
(0..m.len())
|
||||
.map(|i| m[(i % nrows) * ncols + i / nrows])
|
||||
.collect()
|
||||
}
|
||||
|
||||
pub fn list_to_reverse_bit_order<T: Copy>(l: &[T]) -> Vec<T> {
|
||||
l.iter()
|
||||
.enumerate()
|
||||
@@ -253,11 +261,10 @@ where
|
||||
] {
|
||||
config.coset_gen = coset_gen;
|
||||
config.ordering = ordering;
|
||||
let mut batch_ntt_result = HostOrDeviceSlice::on_host(vec![F::zero(); batch_size * test_size]);
|
||||
for alg in [NttAlgorithm::Radix2, NttAlgorithm::MixedRadix] {
|
||||
config.batch_size = batch_size as i32;
|
||||
config.ntt_algorithm = alg;
|
||||
let mut batch_ntt_result =
|
||||
HostOrDeviceSlice::on_host(vec![F::zero(); batch_size * test_size]);
|
||||
ntt(&scalars, is_inverse, &config, &mut batch_ntt_result).unwrap();
|
||||
config.batch_size = 1;
|
||||
let mut one_ntt_result = HostOrDeviceSlice::on_host(vec![F::one(); test_size]);
|
||||
@@ -275,6 +282,20 @@ where
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
// for now, columns batching only works with MixedRadix NTT
|
||||
config.batch_size = batch_size as i32;
|
||||
config.columns_batch = true;
|
||||
let transposed_input =
|
||||
HostOrDeviceSlice::on_host(transpose_flattened_matrix(&scalars[..], batch_size));
|
||||
let mut col_batch_ntt_result =
|
||||
HostOrDeviceSlice::on_host(vec![F::zero(); batch_size * test_size]);
|
||||
ntt(&transposed_input, is_inverse, &config, &mut col_batch_ntt_result).unwrap();
|
||||
assert_eq!(
|
||||
batch_ntt_result[..],
|
||||
transpose_flattened_matrix(&col_batch_ntt_result[..], test_size)
|
||||
);
|
||||
config.columns_batch = false;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user