NTT columns batch (#424)

This PR adds the columns batch feature - enabling batch NTT computation
to be performed directly on the columns of a matrix without having to
transpose it beforehand, as requested in issue #264.

Also some small fixes to the reordering kernels were added and some
unnecessary parameters were removes from functions interfaces.

---------

Co-authored-by: DmytroTym <dmytrotym1@gmail.com>
This commit is contained in:
HadarIngonyama
2024-03-13 18:46:47 +02:00
committed by GitHub
parent 89082fb561
commit 287f53ff16
10 changed files with 385 additions and 127 deletions

View File

@@ -56,7 +56,15 @@ namespace ntt {
// Note: the following reorder kernels are fused with normalization for INTT
template <typename E, typename S, uint32_t MAX_GROUP_SIZE = 80>
static __global__ void reorder_digits_inplace_and_normalize_kernel(
E* arr, uint32_t log_size, bool dit, bool fast_tw, eRevType rev_type, bool is_normalize, S inverse_N)
E* arr,
uint32_t log_size,
bool columns_batch,
uint32_t batch_size,
bool dit,
bool fast_tw,
eRevType rev_type,
bool is_normalize,
S inverse_N)
{
// launch N threads (per batch element)
// each thread starts from one index and calculates the corresponding group
@@ -65,19 +73,20 @@ namespace ntt {
const uint32_t size = 1 << log_size;
const uint32_t tid = blockDim.x * blockIdx.x + threadIdx.x;
const uint32_t idx = tid % size;
const uint32_t batch_idx = tid / size;
const uint32_t idx = columns_batch ? tid / batch_size : tid % size;
const uint32_t batch_idx = columns_batch ? tid % batch_size : tid / size;
if (tid >= size * batch_size) return;
uint32_t next_element = idx;
uint32_t group[MAX_GROUP_SIZE];
group[0] = next_element + size * batch_idx;
group[0] = columns_batch ? next_element * batch_size + batch_idx : next_element + size * batch_idx;
uint32_t i = 1;
for (; i < MAX_GROUP_SIZE;) {
next_element = generalized_rev(next_element, log_size, dit, fast_tw, rev_type);
if (next_element < idx) return; // not handling this group
if (next_element == idx) break; // calculated whole group
group[i++] = next_element + size * batch_idx;
group[i++] = columns_batch ? next_element * batch_size + batch_idx : next_element + size * batch_idx;
}
--i;
@@ -94,6 +103,9 @@ namespace ntt {
E* arr,
E* arr_reordered,
uint32_t log_size,
bool columns_batch,
uint32_t batch_size,
uint32_t columns_batch_size,
bool dit,
bool fast_tw,
eRevType rev_type,
@@ -101,29 +113,33 @@ namespace ntt {
S inverse_N)
{
uint32_t tid = blockDim.x * blockIdx.x + threadIdx.x;
if (tid >= (1 << log_size) * batch_size) return;
uint32_t rd = tid;
uint32_t wr =
((tid >> log_size) << log_size) + generalized_rev(tid & ((1 << log_size) - 1), log_size, dit, fast_tw, rev_type);
arr_reordered[wr] = is_normalize ? arr[rd] * inverse_N : arr[rd];
uint32_t wr = (columns_batch ? 0 : ((tid >> log_size) << log_size)) +
generalized_rev((tid / columns_batch_size) & ((1 << log_size) - 1), log_size, dit, fast_tw, rev_type);
arr_reordered[wr * columns_batch_size + (tid % columns_batch_size)] = is_normalize ? arr[rd] * inverse_N : arr[rd];
}
template <typename E, typename S>
static __global__ void batch_elementwise_mul_with_reorder(
static __global__ void batch_elementwise_mul_with_reorder_kernel(
E* in_vec,
int n_elements,
int batch_size,
uint32_t size,
bool columns_batch,
uint32_t batch_size,
uint32_t columns_batch_size,
S* scalar_vec,
int step,
int n_scalars,
int logn,
uint32_t log_size,
eRevType rev_type,
bool dit,
E* out_vec)
{
int tid = blockDim.x * blockIdx.x + threadIdx.x;
if (tid >= n_elements * batch_size) return;
int64_t scalar_id = tid % n_elements;
if (rev_type != eRevType::None) scalar_id = generalized_rev(tid, logn, dit, false, rev_type);
if (tid >= size * batch_size) return;
int64_t scalar_id = (tid / columns_batch_size) % size;
if (rev_type != eRevType::None)
scalar_id = generalized_rev((tid / columns_batch_size) & ((1 << log_size) - 1), log_size, dit, false, rev_type);
out_vec[tid] = *(scalar_vec + ((scalar_id * step) % n_scalars)) * in_vec[tid];
}
@@ -136,6 +152,7 @@ namespace ntt {
S* basic_twiddles,
uint32_t log_size,
uint32_t tw_log_size,
uint32_t columns_batch_size,
uint32_t nof_ntt_blocks,
uint32_t data_stride,
uint32_t log_data_stride,
@@ -153,19 +170,27 @@ namespace ntt {
s_meta.th_stride = 8;
s_meta.ntt_block_size = 64;
s_meta.ntt_block_id = (blockIdx.x << 3) + (strided ? (threadIdx.x & 0x7) : (threadIdx.x >> 3));
s_meta.ntt_block_id = columns_batch_size ? blockIdx.x / ((columns_batch_size + 7) / 8)
: (blockIdx.x << 3) + (strided ? (threadIdx.x & 0x7) : (threadIdx.x >> 3));
s_meta.ntt_inp_id = strided ? (threadIdx.x >> 3) : (threadIdx.x & 0x7);
if (s_meta.ntt_block_id >= nof_ntt_blocks) return;
s_meta.batch_id =
columns_batch_size ? (threadIdx.x & 0x7) + ((blockIdx.x % ((columns_batch_size + 7) / 8)) << 3) : 0;
if (s_meta.ntt_block_id >= nof_ntt_blocks || (columns_batch_size > 0 && s_meta.batch_id >= columns_batch_size))
return;
if (fast_tw)
engine.loadBasicTwiddles(basic_twiddles);
else
engine.loadBasicTwiddlesGeneric(basic_twiddles, inv);
engine.loadGlobalData(in, data_stride, log_data_stride, log_size, strided, s_meta);
if (columns_batch_size)
engine.loadGlobalDataColumnBatch(in, data_stride, log_data_stride, s_meta, columns_batch_size);
else
engine.loadGlobalData(in, data_stride, log_data_stride, strided, s_meta);
if (twiddle_stride && dit) {
if (fast_tw)
engine.loadExternalTwiddles64(external_twiddles, twiddle_stride, log_data_stride, strided, s_meta);
engine.loadExternalTwiddles64(external_twiddles, twiddle_stride, log_data_stride, s_meta);
else
engine.loadExternalTwiddlesGeneric64(
external_twiddles, twiddle_stride, log_data_stride, s_meta, tw_log_size, inv);
@@ -189,13 +214,16 @@ namespace ntt {
if (twiddle_stride && !dit) {
if (fast_tw)
engine.loadExternalTwiddles64(external_twiddles, twiddle_stride, log_data_stride, strided, s_meta);
engine.loadExternalTwiddles64(external_twiddles, twiddle_stride, log_data_stride, s_meta);
else
engine.loadExternalTwiddlesGeneric64(
external_twiddles, twiddle_stride, log_data_stride, s_meta, tw_log_size, inv);
engine.twiddlesExternal();
}
engine.storeGlobalData(out, data_stride, log_data_stride, log_size, strided, s_meta);
if (columns_batch_size)
engine.storeGlobalDataColumnBatch(out, data_stride, log_data_stride, s_meta, columns_batch_size);
else
engine.storeGlobalData(out, data_stride, log_data_stride, strided, s_meta);
}
template <typename E, typename S>
@@ -207,6 +235,7 @@ namespace ntt {
S* basic_twiddles,
uint32_t log_size,
uint32_t tw_log_size,
uint32_t columns_batch_size,
uint32_t nof_ntt_blocks,
uint32_t data_stride,
uint32_t log_data_stride,
@@ -225,16 +254,25 @@ namespace ntt {
s_meta.th_stride = 4;
s_meta.ntt_block_size = 32;
s_meta.ntt_block_id = (blockIdx.x << 4) + (strided ? (threadIdx.x & 0xf) : (threadIdx.x >> 2));
s_meta.ntt_block_id = columns_batch_size ? blockIdx.x / ((columns_batch_size + 15) / 16)
: (blockIdx.x << 4) + (strided ? (threadIdx.x & 0xf) : (threadIdx.x >> 2));
s_meta.ntt_inp_id = strided ? (threadIdx.x >> 4) : (threadIdx.x & 0x3);
if (s_meta.ntt_block_id >= nof_ntt_blocks) return;
s_meta.batch_id =
columns_batch_size ? (threadIdx.x & 0xf) + ((blockIdx.x % ((columns_batch_size + 15) / 16)) << 4) : 0;
if (s_meta.ntt_block_id >= nof_ntt_blocks || (columns_batch_size > 0 && s_meta.batch_id >= columns_batch_size))
return;
if (fast_tw)
engine.loadBasicTwiddles(basic_twiddles);
else
engine.loadBasicTwiddlesGeneric(basic_twiddles, inv);
engine.loadGlobalData(in, data_stride, log_data_stride, log_size, strided, s_meta);
if (columns_batch_size)
engine.loadGlobalDataColumnBatch(in, data_stride, log_data_stride, s_meta, columns_batch_size);
else
engine.loadGlobalData(in, data_stride, log_data_stride, strided, s_meta);
if (fast_tw)
engine.loadInternalTwiddles32(internal_twiddles, strided);
else
@@ -247,13 +285,16 @@ namespace ntt {
engine.ntt4_2();
if (twiddle_stride) {
if (fast_tw)
engine.loadExternalTwiddles32(external_twiddles, twiddle_stride, log_data_stride, strided, s_meta);
engine.loadExternalTwiddles32(external_twiddles, twiddle_stride, log_data_stride, s_meta);
else
engine.loadExternalTwiddlesGeneric32(
external_twiddles, twiddle_stride, log_data_stride, s_meta, tw_log_size, inv);
engine.twiddlesExternal();
}
engine.storeGlobalData32(out, data_stride, log_data_stride, log_size, strided, s_meta);
if (columns_batch_size)
engine.storeGlobalData32ColumnBatch(out, data_stride, log_data_stride, s_meta, columns_batch_size);
else
engine.storeGlobalData32(out, data_stride, log_data_stride, strided, s_meta);
}
template <typename E, typename S>
@@ -265,6 +306,7 @@ namespace ntt {
S* basic_twiddles,
uint32_t log_size,
uint32_t tw_log_size,
uint32_t columns_batch_size,
uint32_t nof_ntt_blocks,
uint32_t data_stride,
uint32_t log_data_stride,
@@ -283,19 +325,27 @@ namespace ntt {
s_meta.th_stride = 4;
s_meta.ntt_block_size = 32;
s_meta.ntt_block_id = (blockIdx.x << 4) + (strided ? (threadIdx.x & 0xf) : (threadIdx.x >> 2));
s_meta.ntt_block_id = columns_batch_size ? blockIdx.x / ((columns_batch_size + 15) / 16)
: (blockIdx.x << 4) + (strided ? (threadIdx.x & 0xf) : (threadIdx.x >> 2));
s_meta.ntt_inp_id = strided ? (threadIdx.x >> 4) : (threadIdx.x & 0x3);
if (s_meta.ntt_block_id >= nof_ntt_blocks) return;
s_meta.batch_id =
columns_batch_size ? (threadIdx.x & 0xf) + ((blockIdx.x % ((columns_batch_size + 15) / 16)) << 4) : 0;
if (s_meta.ntt_block_id >= nof_ntt_blocks || (columns_batch_size > 0 && s_meta.batch_id >= columns_batch_size))
return;
if (fast_tw)
engine.loadBasicTwiddles(basic_twiddles);
else
engine.loadBasicTwiddlesGeneric(basic_twiddles, inv);
engine.loadGlobalData32(in, data_stride, log_data_stride, log_size, strided, s_meta);
if (columns_batch_size)
engine.loadGlobalData32ColumnBatch(in, data_stride, log_data_stride, s_meta, columns_batch_size);
else
engine.loadGlobalData32(in, data_stride, log_data_stride, strided, s_meta);
if (twiddle_stride) {
if (fast_tw)
engine.loadExternalTwiddles32(external_twiddles, twiddle_stride, log_data_stride, strided, s_meta);
engine.loadExternalTwiddles32(external_twiddles, twiddle_stride, log_data_stride, s_meta);
else
engine.loadExternalTwiddlesGeneric32(
external_twiddles, twiddle_stride, log_data_stride, s_meta, tw_log_size, inv);
@@ -311,7 +361,10 @@ namespace ntt {
engine.SharedData32Rows8(shmem, false, false, strided); // load
engine.twiddlesInternal();
engine.ntt8win();
engine.storeGlobalData(out, data_stride, log_data_stride, log_size, strided, s_meta);
if (columns_batch_size)
engine.storeGlobalDataColumnBatch(out, data_stride, log_data_stride, s_meta, columns_batch_size);
else
engine.storeGlobalData(out, data_stride, log_data_stride, strided, s_meta);
}
template <typename E, typename S>
@@ -323,6 +376,7 @@ namespace ntt {
S* basic_twiddles,
uint32_t log_size,
uint32_t tw_log_size,
uint32_t columns_batch_size,
uint32_t nof_ntt_blocks,
uint32_t data_stride,
uint32_t log_data_stride,
@@ -341,16 +395,26 @@ namespace ntt {
s_meta.th_stride = 2;
s_meta.ntt_block_size = 16;
s_meta.ntt_block_id = (blockIdx.x << 5) + (strided ? (threadIdx.x & 0x1f) : (threadIdx.x >> 1));
s_meta.ntt_block_id = columns_batch_size
? blockIdx.x / ((columns_batch_size + 31) / 32)
: (blockIdx.x << 5) + (strided ? (threadIdx.x & 0x1f) : (threadIdx.x >> 1));
s_meta.ntt_inp_id = strided ? (threadIdx.x >> 5) : (threadIdx.x & 0x1);
if (s_meta.ntt_block_id >= nof_ntt_blocks) return;
s_meta.batch_id =
columns_batch_size ? (threadIdx.x & 0x1f) + ((blockIdx.x % ((columns_batch_size + 31) / 32)) << 5) : 0;
if (s_meta.ntt_block_id >= nof_ntt_blocks || (columns_batch_size > 0 && s_meta.batch_id >= columns_batch_size))
return;
if (fast_tw)
engine.loadBasicTwiddles(basic_twiddles);
else
engine.loadBasicTwiddlesGeneric(basic_twiddles, inv);
engine.loadGlobalData(in, data_stride, log_data_stride, log_size, strided, s_meta);
if (columns_batch_size)
engine.loadGlobalDataColumnBatch(in, data_stride, log_data_stride, s_meta, columns_batch_size);
else
engine.loadGlobalData(in, data_stride, log_data_stride, strided, s_meta);
if (fast_tw)
engine.loadInternalTwiddles16(internal_twiddles, strided);
else
@@ -363,13 +427,16 @@ namespace ntt {
engine.ntt2_4();
if (twiddle_stride) {
if (fast_tw)
engine.loadExternalTwiddles16(external_twiddles, twiddle_stride, log_data_stride, strided, s_meta);
engine.loadExternalTwiddles16(external_twiddles, twiddle_stride, log_data_stride, s_meta);
else
engine.loadExternalTwiddlesGeneric16(
external_twiddles, twiddle_stride, log_data_stride, s_meta, tw_log_size, inv);
engine.twiddlesExternal();
}
engine.storeGlobalData16(out, data_stride, log_data_stride, log_size, strided, s_meta);
if (columns_batch_size)
engine.storeGlobalData16ColumnBatch(out, data_stride, log_data_stride, s_meta, columns_batch_size);
else
engine.storeGlobalData16(out, data_stride, log_data_stride, strided, s_meta);
}
template <typename E, typename S>
@@ -381,6 +448,7 @@ namespace ntt {
S* basic_twiddles,
uint32_t log_size,
uint32_t tw_log_size,
uint32_t columns_batch_size,
uint32_t nof_ntt_blocks,
uint32_t data_stride,
uint32_t log_data_stride,
@@ -399,19 +467,29 @@ namespace ntt {
s_meta.th_stride = 2;
s_meta.ntt_block_size = 16;
s_meta.ntt_block_id = (blockIdx.x << 5) + (strided ? (threadIdx.x & 0x1f) : (threadIdx.x >> 1));
s_meta.ntt_block_id = columns_batch_size
? blockIdx.x / ((columns_batch_size + 31) / 32)
: (blockIdx.x << 5) + (strided ? (threadIdx.x & 0x1f) : (threadIdx.x >> 1));
s_meta.ntt_inp_id = strided ? (threadIdx.x >> 5) : (threadIdx.x & 0x1);
if (s_meta.ntt_block_id >= nof_ntt_blocks) return;
s_meta.batch_id =
columns_batch_size ? (threadIdx.x & 0x1f) + ((blockIdx.x % ((columns_batch_size + 31) / 32)) << 5) : 0;
if (s_meta.ntt_block_id >= nof_ntt_blocks || (columns_batch_size > 0 && s_meta.batch_id >= columns_batch_size))
return;
if (fast_tw)
engine.loadBasicTwiddles(basic_twiddles);
else
engine.loadBasicTwiddlesGeneric(basic_twiddles, inv);
engine.loadGlobalData16(in, data_stride, log_data_stride, log_size, strided, s_meta);
if (columns_batch_size)
engine.loadGlobalData16ColumnBatch(in, data_stride, log_data_stride, s_meta, columns_batch_size);
else
engine.loadGlobalData16(in, data_stride, log_data_stride, strided, s_meta);
if (twiddle_stride) {
if (fast_tw)
engine.loadExternalTwiddles16(external_twiddles, twiddle_stride, log_data_stride, strided, s_meta);
engine.loadExternalTwiddles16(external_twiddles, twiddle_stride, log_data_stride, s_meta);
else
engine.loadExternalTwiddlesGeneric16(
external_twiddles, twiddle_stride, log_data_stride, s_meta, tw_log_size, inv);
@@ -427,13 +505,17 @@ namespace ntt {
engine.SharedData16Rows8(shmem, false, false, strided); // load
engine.twiddlesInternal();
engine.ntt8win();
engine.storeGlobalData(out, data_stride, log_data_stride, log_size, strided, s_meta);
if (columns_batch_size)
engine.storeGlobalDataColumnBatch(out, data_stride, log_data_stride, s_meta, columns_batch_size);
else
engine.storeGlobalData(out, data_stride, log_data_stride, strided, s_meta);
}
template <typename E, typename S>
__global__ void normalize_kernel(E* data, S norm_factor)
__global__ void normalize_kernel(E* data, S norm_factor, uint32_t size)
{
uint32_t tid = blockIdx.x * blockDim.x + threadIdx.x;
if (tid >= size) return;
data[tid] = data[tid] * norm_factor;
}
@@ -666,6 +748,7 @@ namespace ntt {
uint32_t log_size,
uint32_t tw_log_size,
uint32_t batch_size,
bool columns_batch,
bool inv,
bool normalize,
bool dit,
@@ -679,72 +762,83 @@ namespace ntt {
}
if (log_size == 4) {
const int NOF_THREADS = min(64, 2 * batch_size);
const int NOF_BLOCKS = (2 * batch_size + NOF_THREADS - 1) / NOF_THREADS;
const int NOF_THREADS = columns_batch ? 64 : min(64, 2 * batch_size);
const int NOF_BLOCKS =
columns_batch ? ((batch_size + 31) / 32) : (2 * batch_size + NOF_THREADS - 1) / NOF_THREADS;
if (dit) {
ntt16dit<<<NOF_BLOCKS, NOF_THREADS, 8 * 64 * sizeof(E), cuda_stream>>>(
in, out, external_twiddles, internal_twiddles, basic_twiddles, log_size, tw_log_size, batch_size, 1, 0, 0,
false, 0, inv, dit, fast_tw);
in, out, external_twiddles, internal_twiddles, basic_twiddles, log_size, tw_log_size,
columns_batch ? batch_size : 0, columns_batch ? 1 : batch_size, 1, 0, 0, columns_batch, 0, inv, dit, fast_tw);
} else { // dif
ntt16<<<NOF_BLOCKS, NOF_THREADS, 8 * 64 * sizeof(E), cuda_stream>>>(
in, out, external_twiddles, internal_twiddles, basic_twiddles, log_size, tw_log_size, batch_size, 1, 0, 0,
false, 0, inv, dit, fast_tw);
in, out, external_twiddles, internal_twiddles, basic_twiddles, log_size, tw_log_size,
columns_batch ? batch_size : 0, columns_batch ? 1 : batch_size, 1, 0, 0, columns_batch, 0, inv, dit, fast_tw);
}
if (normalize) normalize_kernel<<<batch_size, 16, 0, cuda_stream>>>(out, S::inv_log_size(4));
if (normalize)
normalize_kernel<<<batch_size, 16, 0, cuda_stream>>>(out, S::inv_log_size(4), (1 << log_size) * batch_size);
return CHK_LAST();
}
if (log_size == 5) {
const int NOF_THREADS = min(64, 4 * batch_size);
const int NOF_BLOCKS = (4 * batch_size + NOF_THREADS - 1) / NOF_THREADS;
const int NOF_THREADS = columns_batch ? 64 : min(64, 4 * batch_size);
const int NOF_BLOCKS =
columns_batch ? ((batch_size + 15) / 16) : (4 * batch_size + NOF_THREADS - 1) / NOF_THREADS;
if (dit) {
ntt32dit<<<NOF_BLOCKS, NOF_THREADS, 8 * 64 * sizeof(E), cuda_stream>>>(
in, out, external_twiddles, internal_twiddles, basic_twiddles, log_size, tw_log_size, batch_size, 1, 0, 0,
false, 0, inv, dit, fast_tw);
in, out, external_twiddles, internal_twiddles, basic_twiddles, log_size, tw_log_size,
columns_batch ? batch_size : 0, columns_batch ? 1 : batch_size, 1, 0, 0, columns_batch, 0, inv, dit, fast_tw);
} else { // dif
ntt32<<<NOF_BLOCKS, NOF_THREADS, 8 * 64 * sizeof(E), cuda_stream>>>(
in, out, external_twiddles, internal_twiddles, basic_twiddles, log_size, tw_log_size, batch_size, 1, 0, 0,
false, 0, inv, dit, fast_tw);
in, out, external_twiddles, internal_twiddles, basic_twiddles, log_size, tw_log_size,
columns_batch ? batch_size : 0, columns_batch ? 1 : batch_size, 1, 0, 0, columns_batch, 0, inv, dit, fast_tw);
}
if (normalize) normalize_kernel<<<batch_size, 32, 0, cuda_stream>>>(out, S::inv_log_size(5));
if (normalize)
normalize_kernel<<<batch_size, 32, 0, cuda_stream>>>(out, S::inv_log_size(5), (1 << log_size) * batch_size);
return CHK_LAST();
}
if (log_size == 6) {
const int NOF_THREADS = min(64, 8 * batch_size);
const int NOF_BLOCKS = (8 * batch_size + NOF_THREADS - 1) / NOF_THREADS;
const int NOF_THREADS = columns_batch ? 64 : min(64, 8 * batch_size);
const int NOF_BLOCKS =
columns_batch ? ((batch_size + 7) / 8) : ((8 * batch_size + NOF_THREADS - 1) / NOF_THREADS);
ntt64<<<NOF_BLOCKS, NOF_THREADS, 8 * 64 * sizeof(E), cuda_stream>>>(
in, out, external_twiddles, internal_twiddles, basic_twiddles, log_size, tw_log_size, batch_size, 1, 0, 0,
false, 0, inv, dit, fast_tw);
if (normalize) normalize_kernel<<<batch_size, 64, 0, cuda_stream>>>(out, S::inv_log_size(6));
in, out, external_twiddles, internal_twiddles, basic_twiddles, log_size, tw_log_size,
columns_batch ? batch_size : 0, columns_batch ? 1 : batch_size, 1, 0, 0, columns_batch, 0, inv, dit, fast_tw);
if (normalize)
normalize_kernel<<<batch_size, 64, 0, cuda_stream>>>(out, S::inv_log_size(6), (1 << log_size) * batch_size);
return CHK_LAST();
}
if (log_size == 8) {
const int NOF_THREADS = 64;
const int NOF_BLOCKS = (32 * batch_size + NOF_THREADS - 1) / NOF_THREADS;
const int NOF_BLOCKS =
columns_batch ? ((batch_size + 31) / 32 * 16) : ((32 * batch_size + NOF_THREADS - 1) / NOF_THREADS);
if (dit) {
ntt16dit<<<NOF_BLOCKS, NOF_THREADS, 8 * 64 * sizeof(E), cuda_stream>>>(
in, out, external_twiddles, internal_twiddles, basic_twiddles, log_size, tw_log_size,
(1 << log_size - 4) * batch_size, 1, 0, 0, false, 0, inv, dit, fast_tw);
columns_batch ? batch_size : 0, (1 << log_size - 4) * (columns_batch ? 1 : batch_size), 1, 0, 0,
columns_batch, 0, inv, dit, fast_tw);
ntt16dit<<<NOF_BLOCKS, NOF_THREADS, 8 * 64 * sizeof(E), cuda_stream>>>(
out, out, external_twiddles, internal_twiddles, basic_twiddles, log_size, tw_log_size,
(1 << log_size - 4) * batch_size, 16, 4, 16, true, 1, inv, dit, fast_tw);
columns_batch ? batch_size : 0, (1 << log_size - 4) * (columns_batch ? 1 : batch_size), 16, 4, 16, true, 1,
inv, dit, fast_tw);
} else { // dif
ntt16<<<NOF_BLOCKS, NOF_THREADS, 8 * 64 * sizeof(E), cuda_stream>>>(
in, out, external_twiddles, internal_twiddles, basic_twiddles, log_size, tw_log_size,
(1 << log_size - 4) * batch_size, 16, 4, 16, true, 1, inv, dit, fast_tw);
columns_batch ? batch_size : 0, (1 << log_size - 4) * (columns_batch ? 1 : batch_size), 16, 4, 16, true, 1,
inv, dit, fast_tw);
ntt16<<<NOF_BLOCKS, NOF_THREADS, 8 * 64 * sizeof(E), cuda_stream>>>(
out, out, external_twiddles, internal_twiddles, basic_twiddles, log_size, tw_log_size,
(1 << log_size - 4) * batch_size, 1, 0, 0, false, 0, inv, dit, fast_tw);
columns_batch ? batch_size : 0, (1 << log_size - 4) * (columns_batch ? 1 : batch_size), 1, 0, 0,
columns_batch, 0, inv, dit, fast_tw);
}
if (normalize) normalize_kernel<<<batch_size, 256, 0, cuda_stream>>>(out, S::inv_log_size(8));
if (normalize)
normalize_kernel<<<batch_size, 256, 0, cuda_stream>>>(out, S::inv_log_size(8), (1 << log_size) * batch_size);
return CHK_LAST();
}
// general case:
uint32_t nof_blocks = (1 << (log_size - 9)) * batch_size;
uint32_t nof_blocks = (1 << (log_size - 9)) * (columns_batch ? ((batch_size + 31) / 32) * 32 : batch_size);
if (dit) {
for (int i = 0; i < 5; i++) {
uint32_t stage_size = fast_tw ? STAGE_SIZES_HOST_FT[log_size][i] : STAGE_SIZES_HOST[log_size][i];
@@ -754,18 +848,18 @@ namespace ntt {
if (stage_size == 6)
ntt64<<<nof_blocks, 64, 8 * 64 * sizeof(E), cuda_stream>>>(
i ? out : in, out, external_twiddles, internal_twiddles, basic_twiddles, log_size, tw_log_size,
(1 << log_size - 6) * batch_size, 1 << stride_log, stride_log, i ? (1 << stride_log) : 0, i, i, inv, dit,
fast_tw);
columns_batch ? batch_size : 0, (1 << log_size - 6) * (columns_batch ? 1 : batch_size), 1 << stride_log,
stride_log, i ? (1 << stride_log) : 0, i || columns_batch, i, inv, dit, fast_tw);
else if (stage_size == 5)
ntt32dit<<<nof_blocks, 64, 8 * 64 * sizeof(E), cuda_stream>>>(
i ? out : in, out, external_twiddles, internal_twiddles, basic_twiddles, log_size, tw_log_size,
(1 << log_size - 5) * batch_size, 1 << stride_log, stride_log, i ? (1 << stride_log) : 0, i, i, inv, dit,
fast_tw);
columns_batch ? batch_size : 0, (1 << log_size - 5) * (columns_batch ? 1 : batch_size), 1 << stride_log,
stride_log, i ? (1 << stride_log) : 0, i || columns_batch, i, inv, dit, fast_tw);
else if (stage_size == 4)
ntt16dit<<<nof_blocks, 64, 8 * 64 * sizeof(E), cuda_stream>>>(
i ? out : in, out, external_twiddles, internal_twiddles, basic_twiddles, log_size, tw_log_size,
(1 << log_size - 4) * batch_size, 1 << stride_log, stride_log, i ? (1 << stride_log) : 0, i, i, inv, dit,
fast_tw);
columns_batch ? batch_size : 0, (1 << log_size - 4) * (columns_batch ? 1 : batch_size), 1 << stride_log,
stride_log, i ? (1 << stride_log) : 0, i || columns_batch, i, inv, dit, fast_tw);
}
} else { // dif
bool first_run = false, prev_stage = false;
@@ -778,23 +872,24 @@ namespace ntt {
if (stage_size == 6)
ntt64<<<nof_blocks, 64, 8 * 64 * sizeof(E), cuda_stream>>>(
first_run ? in : out, out, external_twiddles, internal_twiddles, basic_twiddles, log_size, tw_log_size,
(1 << log_size - 6) * batch_size, 1 << stride_log, stride_log, i ? (1 << stride_log) : 0, i, i, inv, dit,
fast_tw);
columns_batch ? batch_size : 0, (1 << log_size - 6) * (columns_batch ? 1 : batch_size), 1 << stride_log,
stride_log, i ? (1 << stride_log) : 0, i || columns_batch, i, inv, dit, fast_tw);
else if (stage_size == 5)
ntt32<<<nof_blocks, 64, 8 * 64 * sizeof(E), cuda_stream>>>(
first_run ? in : out, out, external_twiddles, internal_twiddles, basic_twiddles, log_size, tw_log_size,
(1 << log_size - 5) * batch_size, 1 << stride_log, stride_log, i ? (1 << stride_log) : 0, i, i, inv, dit,
fast_tw);
columns_batch ? batch_size : 0, (1 << log_size - 5) * (columns_batch ? 1 : batch_size), 1 << stride_log,
stride_log, i ? (1 << stride_log) : 0, i || columns_batch, i, inv, dit, fast_tw);
else if (stage_size == 4)
ntt16<<<nof_blocks, 64, 8 * 64 * sizeof(E), cuda_stream>>>(
first_run ? in : out, out, external_twiddles, internal_twiddles, basic_twiddles, log_size, tw_log_size,
(1 << log_size - 4) * batch_size, 1 << stride_log, stride_log, i ? (1 << stride_log) : 0, i, i, inv, dit,
fast_tw);
columns_batch ? batch_size : 0, (1 << log_size - 4) * (columns_batch ? 1 : batch_size), 1 << stride_log,
stride_log, i ? (1 << stride_log) : 0, i || columns_batch, i, inv, dit, fast_tw);
prev_stage = stage_size;
}
}
if (normalize)
normalize_kernel<<<(1 << (log_size - 8)) * batch_size, 256, 0, cuda_stream>>>(out, S::inv_log_size(log_size));
normalize_kernel<<<(1 << (log_size - 8)) * batch_size, 256, 0, cuda_stream>>>(
out, S::inv_log_size(log_size), (1 << log_size) * batch_size);
return CHK_LAST();
}
@@ -809,6 +904,7 @@ namespace ntt {
int ntt_size,
int max_logn,
int batch_size,
bool columns_batch,
bool is_inverse,
bool fast_tw,
Ordering ordering,
@@ -858,9 +954,10 @@ namespace ntt {
}
if (is_on_coset && !is_inverse) {
batch_elementwise_mul_with_reorder<<<NOF_BLOCKS, NOF_THREADS, 0, cuda_stream>>>(
d_input, ntt_size, batch_size, arbitrary_coset ? arbitrary_coset : external_twiddles,
arbitrary_coset ? 1 : coset_gen_index, n_twiddles, logn, reverse_coset, dit, d_output);
batch_elementwise_mul_with_reorder_kernel<<<NOF_BLOCKS, NOF_THREADS, 0, cuda_stream>>>(
d_input, ntt_size, columns_batch, batch_size, columns_batch ? batch_size : 1,
arbitrary_coset ? arbitrary_coset : external_twiddles, arbitrary_coset ? 1 : coset_gen_index, n_twiddles, logn,
reverse_coset, dit, d_output);
d_input = d_output;
}
@@ -869,10 +966,11 @@ namespace ntt {
const bool is_reverse_in_place = (d_input == d_output);
if (is_reverse_in_place) {
reorder_digits_inplace_and_normalize_kernel<<<NOF_BLOCKS, NOF_THREADS, 0, cuda_stream>>>(
d_output, logn, dit, fast_tw, reverse_input, is_normalize, S::inv_log_size(logn));
d_output, logn, columns_batch, batch_size, dit, fast_tw, reverse_input, is_normalize, S::inv_log_size(logn));
} else {
reorder_digits_and_normalize_kernel<<<NOF_BLOCKS, NOF_THREADS, 0, cuda_stream>>>(
d_input, d_output, logn, dit, fast_tw, reverse_input, is_normalize, S::inv_log_size(logn));
d_input, d_output, logn, columns_batch, batch_size, columns_batch ? batch_size : 1, dit, fast_tw,
reverse_input, is_normalize, S::inv_log_size(logn));
}
is_normalize = false;
d_input = d_output;
@@ -880,18 +978,19 @@ namespace ntt {
// inplace ntt
CHK_IF_RETURN(large_ntt(
d_input, d_output, external_twiddles, internal_twiddles, basic_twiddles, logn, max_logn, batch_size, is_inverse,
(is_normalize && reverse_output == eRevType::None), dit, fast_tw, cuda_stream));
d_input, d_output, external_twiddles, internal_twiddles, basic_twiddles, logn, max_logn, batch_size,
columns_batch, is_inverse, (is_normalize && reverse_output == eRevType::None), dit, fast_tw, cuda_stream));
if (reverse_output != eRevType::None) {
reorder_digits_inplace_and_normalize_kernel<<<NOF_BLOCKS, NOF_THREADS, 0, cuda_stream>>>(
d_output, logn, dit, fast_tw, reverse_output, is_normalize, S::inv_log_size(logn));
d_output, logn, columns_batch, batch_size, dit, fast_tw, reverse_output, is_normalize, S::inv_log_size(logn));
}
if (is_on_coset && is_inverse) {
batch_elementwise_mul_with_reorder<<<NOF_BLOCKS, NOF_THREADS, 0, cuda_stream>>>(
d_output, ntt_size, batch_size, arbitrary_coset ? arbitrary_coset : external_twiddles + n_twiddles,
arbitrary_coset ? 1 : -coset_gen_index, n_twiddles, logn, reverse_coset, dit, d_output);
batch_elementwise_mul_with_reorder_kernel<<<NOF_BLOCKS, NOF_THREADS, 0, cuda_stream>>>(
d_output, ntt_size, columns_batch, batch_size, columns_batch ? batch_size : 1,
arbitrary_coset ? arbitrary_coset : external_twiddles + n_twiddles, arbitrary_coset ? 1 : -coset_gen_index,
n_twiddles, logn, reverse_coset, dit, d_output);
}
return CHK_LAST();
@@ -923,6 +1022,7 @@ namespace ntt {
int ntt_size,
int max_logn,
int batch_size,
bool columns_batch,
bool is_inverse,
bool fast_tw,
Ordering ordering,

View File

@@ -516,12 +516,15 @@ namespace ntt {
static bool is_choose_radix2_algorithm(int logn, int batch_size, const NTTConfig<S>& config)
{
const bool is_mixed_radix_alg_supported = (logn > 3 && logn != 7);
if (!is_mixed_radix_alg_supported && config.columns_batch)
throw IcicleError(IcicleError_t::InvalidArgument, "columns batch is not supported for given NTT size");
const bool is_user_selected_radix2_alg = config.ntt_algorithm == NttAlgorithm::Radix2;
const bool is_force_radix2 = !is_mixed_radix_alg_supported || is_user_selected_radix2_alg;
if (is_force_radix2) return true;
const bool is_user_selected_mixed_radix_alg = config.ntt_algorithm == NttAlgorithm::MixedRadix;
if (is_user_selected_mixed_radix_alg) return false;
if (config.columns_batch) return false; // radix2 does not currently support columns batch mode.
// Heuristic to automatically select an algorithm
// Note that generally the decision depends on {logn, batch, ordering, inverse, coset, in-place, coeff-field} and
@@ -663,7 +666,7 @@ namespace ntt {
CHK_IF_RETURN(ntt::mixed_radix_ntt(
d_input, d_output, twiddles, internal_twiddles, basic_twiddles, size, domain.max_log_size, batch_size,
is_inverse, is_fast_twiddles_enabled, config.ordering, coset, coset_index, stream));
config.columns_batch, is_inverse, is_fast_twiddles_enabled, config.ordering, coset, coset_index, stream));
}
if (!are_outputs_on_device)
@@ -685,6 +688,7 @@ namespace ntt {
ctx, // ctx
S::one(), // coset_gen
1, // batch_size
false, // columns_batch
Ordering::kNN, // ordering
false, // are_inputs_on_device
false, // are_outputs_on_device

View File

@@ -95,6 +95,8 @@ namespace ntt {
S coset_gen; /**< Coset generator. Used to perform coset (i)NTTs. Default value: `S::one()`
* (corresponding to no coset being used). */
int batch_size; /**< The number of NTTs to compute. Default value: 1. */
bool columns_batch; /**< True if the batches are the columns of an input matrix
(they are strided in memory with a stride of ntt size) Default value: false. */
Ordering ordering; /**< Ordering of inputs and outputs. See [Ordering](@ref Ordering). Default value:
* `Ordering::kNN`. */
bool are_inputs_on_device; /**< True if inputs are on device and false if they're on host. Default value: false. */

View File

@@ -35,6 +35,7 @@ namespace ntt {
int ntt_size,
int max_logn,
int batch_size,
bool columns_batch,
bool is_inverse,
bool fast_tw,
Ordering ordering,

View File

@@ -29,6 +29,13 @@ void incremental_values(test_scalar* res, uint32_t count)
}
}
__global__ void transpose_batch(test_scalar* in, test_scalar* out, int row_size, int column_size)
{
int tid = blockDim.x * blockIdx.x + threadIdx.x;
if (tid >= row_size * column_size) return;
out[(tid % row_size) * column_size + (tid / row_size)] = in[tid];
}
int main(int argc, char** argv)
{
cudaEvent_t icicle_start, icicle_stop, new_start, new_stop;
@@ -37,11 +44,12 @@ int main(int argc, char** argv)
int NTT_LOG_SIZE = (argc > 1) ? atoi(argv[1]) : 19;
int NTT_SIZE = 1 << NTT_LOG_SIZE;
bool INPLACE = (argc > 2) ? atoi(argv[2]) : false;
int INV = (argc > 3) ? atoi(argv[3]) : true;
int BATCH_SIZE = (argc > 4) ? atoi(argv[4]) : 1;
int COSET_IDX = (argc > 5) ? atoi(argv[5]) : 0;
const ntt::Ordering ordering = (argc > 6) ? ntt::Ordering(atoi(argv[6])) : ntt::Ordering::kNN;
bool FAST_TW = (argc > 7) ? atoi(argv[7]) : true;
int INV = (argc > 3) ? atoi(argv[3]) : false;
int BATCH_SIZE = (argc > 4) ? atoi(argv[4]) : 150;
bool COLUMNS_BATCH = (argc > 5) ? atoi(argv[5]) : false;
int COSET_IDX = (argc > 6) ? atoi(argv[6]) : 2;
const ntt::Ordering ordering = (argc > 7) ? ntt::Ordering(atoi(argv[7])) : ntt::Ordering::kNN;
bool FAST_TW = (argc > 8) ? atoi(argv[8]) : true;
// Note: NM, MN are not expected to be equal when comparing mixed-radix and radix-2 NTTs
const char* ordering_str = ordering == ntt::Ordering::kNN ? "NN"
@@ -52,8 +60,8 @@ int main(int argc, char** argv)
: "MN";
printf(
"running ntt 2^%d, inplace=%d, inverse=%d, batch_size=%d, coset-idx=%d, ordering=%s, fast_tw=%d\n", NTT_LOG_SIZE,
INPLACE, INV, BATCH_SIZE, COSET_IDX, ordering_str, FAST_TW);
"running ntt 2^%d, inplace=%d, inverse=%d, batch_size=%d, columns_batch=%d coset-idx=%d, ordering=%s, fast_tw=%d\n",
NTT_LOG_SIZE, INPLACE, INV, BATCH_SIZE, COLUMNS_BATCH, COSET_IDX, ordering_str, FAST_TW);
CHK_IF_RETURN(cudaFree(nullptr)); // init GPU context (warmup)
@@ -63,6 +71,7 @@ int main(int argc, char** argv)
ntt_config.are_inputs_on_device = true;
ntt_config.are_outputs_on_device = true;
ntt_config.batch_size = BATCH_SIZE;
ntt_config.columns_batch = COLUMNS_BATCH;
CHK_IF_RETURN(cudaEventCreate(&icicle_start));
CHK_IF_RETURN(cudaEventCreate(&icicle_stop));
@@ -83,7 +92,9 @@ int main(int argc, char** argv)
// gpu allocation
test_data *GpuScalars, *GpuOutputOld, *GpuOutputNew;
test_data* GpuScalarsTransposed;
CHK_IF_RETURN(cudaMalloc(&GpuScalars, sizeof(test_data) * NTT_SIZE * BATCH_SIZE));
CHK_IF_RETURN(cudaMalloc(&GpuScalarsTransposed, sizeof(test_data) * NTT_SIZE * BATCH_SIZE));
CHK_IF_RETURN(cudaMalloc(&GpuOutputOld, sizeof(test_data) * NTT_SIZE * BATCH_SIZE));
CHK_IF_RETURN(cudaMalloc(&GpuOutputNew, sizeof(test_data) * NTT_SIZE * BATCH_SIZE));
@@ -93,10 +104,16 @@ int main(int argc, char** argv)
CHK_IF_RETURN(
cudaMemcpy(GpuScalars, CpuScalars.get(), NTT_SIZE * BATCH_SIZE * sizeof(test_data), cudaMemcpyHostToDevice));
if (COLUMNS_BATCH) {
transpose_batch<<<(NTT_SIZE * BATCH_SIZE + 256 - 1) / 256, 256>>>(
GpuScalars, GpuScalarsTransposed, NTT_SIZE, BATCH_SIZE);
}
// inplace
if (INPLACE) {
CHK_IF_RETURN(
cudaMemcpy(GpuOutputNew, GpuScalars, NTT_SIZE * BATCH_SIZE * sizeof(test_data), cudaMemcpyDeviceToDevice));
CHK_IF_RETURN(cudaMemcpy(
GpuOutputNew, COLUMNS_BATCH ? GpuScalarsTransposed : GpuScalars, NTT_SIZE * BATCH_SIZE * sizeof(test_data),
cudaMemcpyDeviceToDevice));
}
for (int coset_idx = 0; coset_idx < COSET_IDX; ++coset_idx) {
@@ -109,13 +126,14 @@ int main(int argc, char** argv)
ntt_config.ntt_algorithm = ntt::NttAlgorithm::MixedRadix;
for (size_t i = 0; i < iterations; i++) {
CHK_IF_RETURN(ntt::NTT(
INPLACE ? GpuOutputNew : GpuScalars, NTT_SIZE, INV ? ntt::NTTDir::kInverse : ntt::NTTDir::kForward, ntt_config,
GpuOutputNew));
INPLACE ? GpuOutputNew
: COLUMNS_BATCH ? GpuScalarsTransposed
: GpuScalars,
NTT_SIZE, INV ? ntt::NTTDir::kInverse : ntt::NTTDir::kForward, ntt_config, GpuOutputNew));
}
CHK_IF_RETURN(cudaEventRecord(new_stop, ntt_config.ctx.stream));
CHK_IF_RETURN(cudaStreamSynchronize(ntt_config.ctx.stream));
CHK_IF_RETURN(cudaEventElapsedTime(&new_time, new_start, new_stop));
if (is_print) { fprintf(stderr, "cuda err %d\n", cudaGetLastError()); }
// OLD
CHK_IF_RETURN(cudaEventRecord(icicle_start, ntt_config.ctx.stream));
@@ -127,7 +145,6 @@ int main(int argc, char** argv)
CHK_IF_RETURN(cudaEventRecord(icicle_stop, ntt_config.ctx.stream));
CHK_IF_RETURN(cudaStreamSynchronize(ntt_config.ctx.stream));
CHK_IF_RETURN(cudaEventElapsedTime(&icicle_time, icicle_start, icicle_stop));
if (is_print) { fprintf(stderr, "cuda err %d\n", cudaGetLastError()); }
if (is_print) {
printf("Old Runtime=%0.3f MS\n", icicle_time / iterations);
@@ -140,11 +157,19 @@ int main(int argc, char** argv)
CHK_IF_RETURN(benchmark(false /*=print*/, 1)); // warmup
int count = INPLACE ? 1 : 10;
if (INPLACE) {
CHK_IF_RETURN(
cudaMemcpy(GpuOutputNew, GpuScalars, NTT_SIZE * BATCH_SIZE * sizeof(test_data), cudaMemcpyDeviceToDevice));
CHK_IF_RETURN(cudaMemcpy(
GpuOutputNew, COLUMNS_BATCH ? GpuScalarsTransposed : GpuScalars, NTT_SIZE * BATCH_SIZE * sizeof(test_data),
cudaMemcpyDeviceToDevice));
}
CHK_IF_RETURN(benchmark(true /*=print*/, count));
if (COLUMNS_BATCH) {
transpose_batch<<<(NTT_SIZE * BATCH_SIZE + 256 - 1) / 256, 256>>>(
GpuOutputNew, GpuScalarsTransposed, BATCH_SIZE, NTT_SIZE);
CHK_IF_RETURN(cudaMemcpy(
GpuOutputNew, GpuScalarsTransposed, NTT_SIZE * BATCH_SIZE * sizeof(test_data), cudaMemcpyDeviceToDevice));
}
// verify
CHK_IF_RETURN(
cudaMemcpy(CpuOutputNew.get(), GpuOutputNew, NTT_SIZE * BATCH_SIZE * sizeof(test_data), cudaMemcpyDeviceToHost));
@@ -153,10 +178,11 @@ int main(int argc, char** argv)
bool success = true;
for (int i = 0; i < NTT_SIZE * BATCH_SIZE; i++) {
// if (i%64==0) printf("\n");
if (CpuOutputNew[i] != CpuOutputOld[i]) {
success = false;
// std::cout << i << " ref " << CpuOutputOld[i] << " != " << CpuOutputNew[i] << std::endl;
break;
// break;
} else {
// std::cout << i << " ref " << CpuOutputOld[i] << " == " << CpuOutputNew[i] << std::endl;
// break;

View File

@@ -9,6 +9,7 @@
struct stage_metadata {
uint32_t th_stride;
uint32_t ntt_block_size;
uint32_t batch_id;
uint32_t ntt_block_id;
uint32_t ntt_inp_id;
};
@@ -118,7 +119,7 @@ public:
}
__device__ __forceinline__ void
loadExternalTwiddles64(S* data, uint32_t tw_order, uint32_t tw_log_order, bool strided, stage_metadata s_meta)
loadExternalTwiddles64(S* data, uint32_t tw_order, uint32_t tw_log_order, stage_metadata s_meta)
{
data += tw_order * s_meta.ntt_inp_id + (s_meta.ntt_block_id & (tw_order - 1));
@@ -129,7 +130,7 @@ public:
}
__device__ __forceinline__ void
loadExternalTwiddles32(S* data, uint32_t tw_order, uint32_t tw_log_order, bool strided, stage_metadata s_meta)
loadExternalTwiddles32(S* data, uint32_t tw_order, uint32_t tw_log_order, stage_metadata s_meta)
{
data += tw_order * s_meta.ntt_inp_id * 2 + (s_meta.ntt_block_id & (tw_order - 1));
@@ -143,7 +144,7 @@ public:
}
__device__ __forceinline__ void
loadExternalTwiddles16(S* data, uint32_t tw_order, uint32_t tw_log_order, bool strided, stage_metadata s_meta)
loadExternalTwiddles16(S* data, uint32_t tw_order, uint32_t tw_log_order, stage_metadata s_meta)
{
data += tw_order * s_meta.ntt_inp_id * 4 + (s_meta.ntt_block_id & (tw_order - 1));
@@ -195,8 +196,8 @@ public:
}
}
__device__ __forceinline__ void loadGlobalData(
E* data, uint32_t data_stride, uint32_t log_data_stride, uint32_t log_size, bool strided, stage_metadata s_meta)
__device__ __forceinline__ void
loadGlobalData(E* data, uint32_t data_stride, uint32_t log_data_stride, bool strided, stage_metadata s_meta)
{
if (strided) {
data += (s_meta.ntt_block_id & (data_stride - 1)) + data_stride * s_meta.ntt_inp_id +
@@ -211,8 +212,22 @@ public:
}
}
__device__ __forceinline__ void storeGlobalData(
E* data, uint32_t data_stride, uint32_t log_data_stride, uint32_t log_size, bool strided, stage_metadata s_meta)
__device__ __forceinline__ void loadGlobalDataColumnBatch(
E* data, uint32_t data_stride, uint32_t log_data_stride, stage_metadata s_meta, uint32_t batch_size)
{
data += ((s_meta.ntt_block_id & (data_stride - 1)) + data_stride * s_meta.ntt_inp_id +
(s_meta.ntt_block_id >> log_data_stride) * data_stride * s_meta.ntt_block_size) *
batch_size +
s_meta.batch_id;
#pragma unroll
for (uint32_t i = 0; i < 8; i++) {
X[i] = data[s_meta.th_stride * i * data_stride * batch_size];
}
}
__device__ __forceinline__ void
storeGlobalData(E* data, uint32_t data_stride, uint32_t log_data_stride, bool strided, stage_metadata s_meta)
{
if (strided) {
data += (s_meta.ntt_block_id & (data_stride - 1)) + data_stride * s_meta.ntt_inp_id +
@@ -227,8 +242,22 @@ public:
}
}
__device__ __forceinline__ void loadGlobalData32(
E* data, uint32_t data_stride, uint32_t log_data_stride, uint32_t log_size, bool strided, stage_metadata s_meta)
__device__ __forceinline__ void storeGlobalDataColumnBatch(
E* data, uint32_t data_stride, uint32_t log_data_stride, stage_metadata s_meta, uint32_t batch_size)
{
data += ((s_meta.ntt_block_id & (data_stride - 1)) + data_stride * s_meta.ntt_inp_id +
(s_meta.ntt_block_id >> log_data_stride) * data_stride * s_meta.ntt_block_size) *
batch_size +
s_meta.batch_id;
#pragma unroll
for (uint32_t i = 0; i < 8; i++) {
data[s_meta.th_stride * i * data_stride * batch_size] = X[i];
}
}
__device__ __forceinline__ void
loadGlobalData32(E* data, uint32_t data_stride, uint32_t log_data_stride, bool strided, stage_metadata s_meta)
{
if (strided) {
data += (s_meta.ntt_block_id & (data_stride - 1)) + data_stride * s_meta.ntt_inp_id * 2 +
@@ -246,8 +275,25 @@ public:
}
}
__device__ __forceinline__ void storeGlobalData32(
E* data, uint32_t data_stride, uint32_t log_data_stride, uint32_t log_size, bool strided, stage_metadata s_meta)
__device__ __forceinline__ void loadGlobalData32ColumnBatch(
E* data, uint32_t data_stride, uint32_t log_data_stride, stage_metadata s_meta, uint32_t batch_size)
{
data += ((s_meta.ntt_block_id & (data_stride - 1)) + data_stride * s_meta.ntt_inp_id * 2 +
(s_meta.ntt_block_id >> log_data_stride) * data_stride * s_meta.ntt_block_size) *
batch_size +
s_meta.batch_id;
#pragma unroll
for (uint32_t j = 0; j < 2; j++) {
#pragma unroll
for (uint32_t i = 0; i < 4; i++) {
X[4 * j + i] = data[(8 * i + j) * data_stride * batch_size];
}
}
}
__device__ __forceinline__ void
storeGlobalData32(E* data, uint32_t data_stride, uint32_t log_data_stride, bool strided, stage_metadata s_meta)
{
if (strided) {
data += (s_meta.ntt_block_id & (data_stride - 1)) + data_stride * s_meta.ntt_inp_id * 2 +
@@ -265,8 +311,25 @@ public:
}
}
__device__ __forceinline__ void loadGlobalData16(
E* data, uint32_t data_stride, uint32_t log_data_stride, uint32_t log_size, bool strided, stage_metadata s_meta)
__device__ __forceinline__ void storeGlobalData32ColumnBatch(
E* data, uint32_t data_stride, uint32_t log_data_stride, stage_metadata s_meta, uint32_t batch_size)
{
data += ((s_meta.ntt_block_id & (data_stride - 1)) + data_stride * s_meta.ntt_inp_id * 2 +
(s_meta.ntt_block_id >> log_data_stride) * data_stride * s_meta.ntt_block_size) *
batch_size +
s_meta.batch_id;
#pragma unroll
for (uint32_t j = 0; j < 2; j++) {
#pragma unroll
for (uint32_t i = 0; i < 4; i++) {
data[(8 * i + j) * data_stride * batch_size] = X[4 * j + i];
}
}
}
__device__ __forceinline__ void
loadGlobalData16(E* data, uint32_t data_stride, uint32_t log_data_stride, bool strided, stage_metadata s_meta)
{
if (strided) {
data += (s_meta.ntt_block_id & (data_stride - 1)) + data_stride * s_meta.ntt_inp_id * 4 +
@@ -284,8 +347,25 @@ public:
}
}
__device__ __forceinline__ void storeGlobalData16(
E* data, uint32_t data_stride, uint32_t log_data_stride, uint32_t log_size, bool strided, stage_metadata s_meta)
__device__ __forceinline__ void loadGlobalData16ColumnBatch(
E* data, uint32_t data_stride, uint32_t log_data_stride, stage_metadata s_meta, uint32_t batch_size)
{
data += ((s_meta.ntt_block_id & (data_stride - 1)) + data_stride * s_meta.ntt_inp_id * 4 +
(s_meta.ntt_block_id >> log_data_stride) * data_stride * s_meta.ntt_block_size) *
batch_size +
s_meta.batch_id;
#pragma unroll
for (uint32_t j = 0; j < 4; j++) {
#pragma unroll
for (uint32_t i = 0; i < 2; i++) {
X[2 * j + i] = data[(8 * i + j) * data_stride * batch_size];
}
}
}
__device__ __forceinline__ void
storeGlobalData16(E* data, uint32_t data_stride, uint32_t log_data_stride, bool strided, stage_metadata s_meta)
{
if (strided) {
data += (s_meta.ntt_block_id & (data_stride - 1)) + data_stride * s_meta.ntt_inp_id * 4 +
@@ -303,6 +383,23 @@ public:
}
}
__device__ __forceinline__ void storeGlobalData16ColumnBatch(
E* data, uint32_t data_stride, uint32_t log_data_stride, stage_metadata s_meta, uint32_t batch_size)
{
data += ((s_meta.ntt_block_id & (data_stride - 1)) + data_stride * s_meta.ntt_inp_id * 4 +
(s_meta.ntt_block_id >> log_data_stride) * data_stride * s_meta.ntt_block_size) *
batch_size +
s_meta.batch_id;
#pragma unroll
for (uint32_t j = 0; j < 4; j++) {
#pragma unroll
for (uint32_t i = 0; i < 2; i++) {
data[(8 * i + j) * data_stride * batch_size] = X[2 * j + i];
}
}
}
__device__ __forceinline__ void ntt4_2()
{
#pragma unroll

View File

@@ -31,6 +31,8 @@ type NTTConfig[T any] struct {
CosetGen T
/// The number of NTTs to compute. Default value: 1.
BatchSize int32
/// If true the function will compute the NTTs over the columns of the input matrix and not over the rows.
ColumnsBatch bool
/// Ordering of inputs and outputs. See [Ordering](@ref Ordering). Default value: `Ordering::kNN`.
Ordering Ordering
areInputsOnDevice bool
@@ -46,6 +48,7 @@ func GetDefaultNTTConfig[T any](cosetGen T) NTTConfig[T] {
ctx, // Ctx
cosetGen, // CosetGen
1, // BatchSize
false, // ColumnsBatch
KNN, // Ordering
false, // areInputsOnDevice
false, // areOutputsOnDevice

View File

@@ -19,6 +19,7 @@ func TestNTTDefaultConfig(t *testing.T) {
ctx, // Ctx
cosetGen, // CosetGen
1, // BatchSize
false, // ColumnsBatch
KNN, // Ordering
false, // areInputsOnDevice
false, // areOutputsOnDevice

View File

@@ -77,6 +77,8 @@ pub struct NTTConfig<'a, S> {
pub coset_gen: S,
/// The number of NTTs to compute. Default value: 1.
pub batch_size: i32,
/// If true the function will compute the NTTs over the columns of the input matrix and not over the rows.
pub columns_batch: bool,
/// Ordering of inputs and outputs. See [Ordering](@ref Ordering). Default value: `Ordering::kNN`.
pub ordering: Ordering,
are_inputs_on_device: bool,
@@ -101,6 +103,7 @@ impl<'a, S: FieldImpl> NTTConfig<'a, S> {
ctx: DeviceContext::default_for_device(device_id),
coset_gen: S::one(),
batch_size: 1,
columns_batch: false,
ordering: Ordering::kNN,
are_inputs_on_device: false,
are_outputs_on_device: false,

View File

@@ -44,6 +44,14 @@ pub fn reverse_bit_order(n: u32, order: u32) -> u32 {
u32::from_str_radix(&reversed, 2).unwrap()
}
pub fn transpose_flattened_matrix<T: Copy>(m: &[T], nrows: usize) -> Vec<T> {
let ncols = m.len() / nrows;
assert!(nrows * ncols == m.len());
(0..m.len())
.map(|i| m[(i % nrows) * ncols + i / nrows])
.collect()
}
pub fn list_to_reverse_bit_order<T: Copy>(l: &[T]) -> Vec<T> {
l.iter()
.enumerate()
@@ -253,11 +261,10 @@ where
] {
config.coset_gen = coset_gen;
config.ordering = ordering;
let mut batch_ntt_result = HostOrDeviceSlice::on_host(vec![F::zero(); batch_size * test_size]);
for alg in [NttAlgorithm::Radix2, NttAlgorithm::MixedRadix] {
config.batch_size = batch_size as i32;
config.ntt_algorithm = alg;
let mut batch_ntt_result =
HostOrDeviceSlice::on_host(vec![F::zero(); batch_size * test_size]);
ntt(&scalars, is_inverse, &config, &mut batch_ntt_result).unwrap();
config.batch_size = 1;
let mut one_ntt_result = HostOrDeviceSlice::on_host(vec![F::one(); test_size]);
@@ -275,6 +282,20 @@ where
);
}
}
// for now, columns batching only works with MixedRadix NTT
config.batch_size = batch_size as i32;
config.columns_batch = true;
let transposed_input =
HostOrDeviceSlice::on_host(transpose_flattened_matrix(&scalars[..], batch_size));
let mut col_batch_ntt_result =
HostOrDeviceSlice::on_host(vec![F::zero(); batch_size * test_size]);
ntt(&transposed_input, is_inverse, &config, &mut col_batch_ntt_result).unwrap();
assert_eq!(
batch_ntt_result[..],
transpose_flattened_matrix(&col_batch_ntt_result[..], test_size)
);
config.columns_batch = false;
}
}
}