feat(cuda): add a new fft algorithm.

- FFT can work for any polynomial size, as long as twiddles are provided.
 - All the twiddles fit in the constant memory.
 - Bit reverse is not used anymore, no more sw1 and sw2 arrays in constant memory.
 - Real to complex compression algorithm is changed.
 - Twiddle initialization functions are removed.
This commit is contained in:
Beka Barbakadze
2023-01-27 14:20:07 +04:00
committed by bbarbakadze
parent bd9cbbc7af
commit 3cd48f0de2
11 changed files with 4150 additions and 8905 deletions

View File

@@ -5,9 +5,6 @@
extern "C" {
void cuda_initialize_twiddles(uint32_t polynomial_size, void *v_stream,
uint32_t gpu_index);
void cuda_convert_lwe_bootstrap_key_32(void *dest, void *src, void *v_stream,
uint32_t gpu_index,
uint32_t input_lwe_dim,

View File

@@ -165,8 +165,6 @@ __global__ void device_bootstrap_amortized(
NSMFFT_direct<HalfDegree<params>>(accumulator_fft);
synchronize_threads_in_block();
correction_direct_fft_inplace<params>(accumulator_fft);
// Get the bootstrapping key piece necessary for the multiplication
// It is already in the Fourier domain
auto bsk_mask_slice =
@@ -194,8 +192,6 @@ __global__ void device_bootstrap_amortized(
NSMFFT_direct<HalfDegree<params>>(accumulator_fft);
synchronize_threads_in_block();
correction_direct_fft_inplace<params>(accumulator_fft);
auto bsk_mask_slice_2 =
get_ith_mask_kth_block(bootstrapping_key, iteration, 1, level,
polynomial_size, 1, level_count);
@@ -215,10 +211,6 @@ __global__ void device_bootstrap_amortized(
if constexpr (SMD == FULLSM || SMD == NOSM) {
synchronize_threads_in_block();
correction_inverse_fft_inplace<params>(mask_res_fft);
correction_inverse_fft_inplace<params>(body_res_fft);
synchronize_threads_in_block();
NSMFFT_inverse<HalfDegree<params>>(mask_res_fft);
NSMFFT_inverse<HalfDegree<params>>(body_res_fft);
@@ -227,6 +219,7 @@ __global__ void device_bootstrap_amortized(
add_to_torus<Torus, params>(mask_res_fft, accumulator_mask);
add_to_torus<Torus, params>(body_res_fft, accumulator_body);
synchronize_threads_in_block();
} else {
int tid = threadIdx.x;
#pragma unroll
@@ -236,9 +229,6 @@ __global__ void device_bootstrap_amortized(
}
synchronize_threads_in_block();
correction_inverse_fft_inplace<params>(accumulator_fft);
synchronize_threads_in_block();
NSMFFT_inverse<HalfDegree<params>>(accumulator_fft);
synchronize_threads_in_block();
@@ -253,9 +243,6 @@ __global__ void device_bootstrap_amortized(
}
synchronize_threads_in_block();
correction_inverse_fft_inplace<params>(accumulator_fft);
synchronize_threads_in_block();
NSMFFT_inverse<HalfDegree<params>>(accumulator_fft);
synchronize_threads_in_block();

View File

@@ -38,9 +38,6 @@ mul_ggsw_glwe(Torus *accumulator, double2 *fft, double2 *mask_join_buffer,
NSMFFT_direct<HalfDegree<params>>(fft);
synchronize_threads_in_block();
correction_direct_fft_inplace<params>(fft);
synchronize_threads_in_block();
// Get the pieces of the bootstrapping key that will be needed for the
// external product; blockIdx.x is the ID of the block that's executing
// this function, so we end up getting the lines of the bootstrapping key
@@ -113,9 +110,6 @@ mul_ggsw_glwe(Torus *accumulator, double2 *fft, double2 *mask_join_buffer,
synchronize_threads_in_block();
correction_inverse_fft_inplace<params>(fft);
synchronize_threads_in_block();
// Perform the inverse FFT on the result of the GGSW x GLWE and add to the
// accumulator
NSMFFT_inverse<HalfDegree<params>>(fft);

View File

@@ -38,39 +38,6 @@ __device__ T *get_ith_body_kth_block(T *ptr, int i, int k, int level,
polynomial_size / 2];
}
void cuda_initialize_twiddles(uint32_t polynomial_size, void *v_stream,
uint32_t gpu_index) {
cudaSetDevice(gpu_index);
int sw_size = polynomial_size / 2;
short *sw1_h, *sw2_h;
sw1_h = (short *)malloc(sizeof(short) * sw_size);
sw2_h = (short *)malloc(sizeof(short) * sw_size);
memset(sw1_h, 0, sw_size * sizeof(short));
memset(sw2_h, 0, sw_size * sizeof(short));
int cnt = 0;
for (int i = 1, j = 0; i < polynomial_size / 2; i++) {
int bit = (polynomial_size / 2) >> 1;
for (; j & bit; bit >>= 1)
j ^= bit;
j ^= bit;
if (i < j) {
sw1_h[cnt] = i;
sw2_h[cnt] = j;
cnt++;
}
}
auto stream = static_cast<cudaStream_t *>(v_stream);
cudaMemcpyToSymbolAsync(SW1, sw1_h, sw_size * sizeof(short), 0,
cudaMemcpyHostToDevice, *stream);
cudaMemcpyToSymbolAsync(SW2, sw2_h, sw_size * sizeof(short), 0,
cudaMemcpyHostToDevice, *stream);
free(sw1_h);
free(sw2_h);
}
template <typename T, typename ST>
void cuda_convert_lwe_bootstrap_key(double2 *dest, ST *src, void *v_stream,
uint32_t gpu_index, uint32_t input_lwe_dim,
@@ -101,10 +68,9 @@ void cuda_convert_lwe_bootstrap_key(double2 *dest, ST *src, void *v_stream,
int complex_current_poly_idx = i * polynomial_size / 2;
int torus_current_poly_idx = i * polynomial_size;
for (int j = 0; j < polynomial_size / 2; j++) {
h_bsk[complex_current_poly_idx + j].x =
src[torus_current_poly_idx + 2 * j];
h_bsk[complex_current_poly_idx + j].x = src[torus_current_poly_idx + j];
h_bsk[complex_current_poly_idx + j].y =
src[torus_current_poly_idx + 2 * j + 1];
src[torus_current_poly_idx + j + polynomial_size / 2];
h_bsk[complex_current_poly_idx + j].x /=
(double)std::numeric_limits<T>::max();
h_bsk[complex_current_poly_idx + j].y /=

View File

@@ -34,16 +34,16 @@ public:
int tid = threadIdx.x;
current_level -= 1;
for (int i = 0; i < params::opt / 2; i++) {
T res_re = state[tid * 2] & mask_mod_b;
T res_im = state[tid * 2 + 1] & mask_mod_b;
state[tid * 2] >>= base_log;
state[tid * 2 + 1] >>= base_log;
T carry_re = ((res_re - 1ll) | state[tid * 2]) & res_re;
T carry_im = ((res_im - 1ll) | state[tid * 2 + 1]) & res_im;
T res_re = state[tid] & mask_mod_b;
T res_im = state[tid + params::degree / 2] & mask_mod_b;
state[tid] >>= base_log;
state[tid + params::degree / 2] >>= base_log;
T carry_re = ((res_re - 1ll) | state[tid]) & res_re;
T carry_im = ((res_im - 1ll) | state[tid + params::degree / 2]) & res_im;
carry_re >>= (base_log - 1);
carry_im >>= (base_log - 1);
state[tid * 2] += carry_re;
state[tid * 2 + 1] += carry_im;
state[tid] += carry_re;
state[tid + params::degree / 2] += carry_im;
res_re -= carry_re << base_log;
res_im -= carry_im << base_log;

View File

@@ -23,8 +23,8 @@ __global__ void device_batch_fft_ggsw_vector(double2 *dest, T *src,
#pragma unroll
for (int i = 0; i < log_2_opt; i++) {
ST x = src[(2 * tid) + params::opt * offset];
ST y = src[(2 * tid + 1) + params::opt * offset];
ST x = src[(tid) + params::opt * offset];
ST y = src[(tid + params::degree / 2) + params::opt * offset];
selected_memory[tid].x = x / (double)std::numeric_limits<T>::max();
selected_memory[tid].y = y / (double)std::numeric_limits<T>::max();
tid += params::degree / params::opt;
@@ -35,9 +35,6 @@ __global__ void device_batch_fft_ggsw_vector(double2 *dest, T *src,
NSMFFT_direct<HalfDegree<params>>(selected_memory);
synchronize_threads_in_block();
correction_direct_fft_inplace<params>(selected_memory);
synchronize_threads_in_block();
// Write the output to global memory
tid = threadIdx.x;
#pragma unroll

View File

@@ -6,55 +6,6 @@
#include "polynomial/parameters.cuh"
#include "twiddles.cuh"
/*
* bit reverse
* coefficient bits are reversed based on precalculated indexes
* SW1 and SW2
*/
template <class params> __device__ void bit_reverse_inplace(double2 *A) {
int tid = threadIdx.x;
#pragma unroll
for (int i = 0; i < params::opt / 2; i++) {
short sw1 = SW1[tid];
short sw2 = SW2[tid];
double2 tmp = A[sw1];
A[sw1] = A[sw2];
A[sw2] = tmp;
tid += params::degree / params::opt;
}
}
/*
* negacyclic twiddle
* returns negacyclic twiddle based on degree and index
* twiddles are precalculated inside negTwids{3..13} arrays
*/
template <int degree> __device__ double2 negacyclic_twiddle(int j) {
double2 twid;
switch (degree) {
case 512:
twid = negTwids9[j];
break;
case 1024:
twid = negTwids10[j];
break;
case 2048:
twid = negTwids11[j];
break;
case 4096:
twid = negTwids12[j];
break;
case 8192:
twid = negTwids13[j];
break;
default:
twid.x = 0;
twid.y = 0;
break;
}
return twid;
}
/*
* Direct negacyclic FFT:
* - before the FFT the N real coefficients are stored into a
@@ -75,641 +26,120 @@ template <int degree> __device__ double2 negacyclic_twiddle(int j) {
* forward_negacyclic_fft_inplace function of bootstrap.cuh
*/
template <class params> __device__ void NSMFFT_direct(double2 *A) {
/* First, reverse the bits of the input complex
* The bit reversal for half-size FFT has been stored into the
* SW1 and SW2 arrays beforehand
/* We don't make bit reverse here, since twiddles are already reversed
* Each thread is always in charge of "opt/2" pairs of coefficients,
* which is why we always loop through N/2 by N/opt strides
* The pragma unroll instruction tells the compiler to unroll the
* full loop, which should increase performance
*/
bit_reverse_inplace<params>(A);
__syncthreads();
// Now we go through all the levels of the FFT one by one
// (instead of recursively)
// first iteration: k=1, zeta=i for all coefficients
int tid = threadIdx.x;
int i1, i2;
double2 u, v;
size_t tid = threadIdx.x;
size_t twid_id;
size_t t = params::degree / 2;
size_t m = 1;
size_t i1, i2;
double2 u, v, w;
// level 1
// we don't make actual complex multiplication on level1 since we have only
// one twiddle, it's real and image parts are equal, so we can multiply
// it with simpler operations
#pragma unroll
for (int i = 0; i < params::opt / 2; i++) {
// the butterfly pattern is applied to each pair
// of coefficients, with a stride of 2
i1 = tid << 1;
i2 = i1 + 1;
for (size_t i = 0; i < params::opt / 2; ++i) {
i1 = tid;
i2 = tid + t;
u = A[i1];
// v = i*A[i2]
v.y = A[i2].x;
v.x = -A[i2].y;
// A[i1] <- A[i1] + i*A[i2]
// A[i2] <- A[i1] - i*A[i2]
A[i1] += v;
A[i2] = u - v;
v.x = (A[i2].x - A[i2].y) * 0.707106781186547461715008466854;
v.y = (A[i2].x + A[i2].y) * 0.707106781186547461715008466854;
A[i1].x += v.x;
A[i1].y += v.y;
A[i2].x = u.x - v.x;
A[i2].y = u.y - v.y;
tid += params::degree / params::opt;
}
__syncthreads();
// second iteration: apply the butterfly pattern
// between groups of 4 coefficients
// k=2, \zeta=exp(i pi/4) for even coefficients and
// exp(3 i pi / 4) for odd coefficients
tid = threadIdx.x;
// odd = 0 for even coefficients, 1 for odd coefficients
int odd = tid & 1;
#pragma unroll
for (int i = 0; i < params::opt / 2; i++) {
// the butterfly pattern is applied to each pair
// of coefficients, with a stride of 2
// i1=2*tid if tid is even and 2*tid-1 if it is odd
i1 = (tid << 1) - odd;
i2 = i1 + 2;
double a = A[i2].x;
double b = A[i2].y;
u = A[i1];
// \zeta_j,2 = exp(-i pi (2j-1)/4) -> j=0: exp(i pi/4) or j=1: exp(-i pi/4)
// \zeta_even = sqrt(2)/2 + i * sqrt(2)/2 = sqrt(2)/2*(1+i)
// \zeta_odd = sqrt(2)/2 - i * sqrt(2)/2 = sqrt(2)/2*(1-i)
// v_j = \zeta_j * (a+i*b)
// v_even = sqrt(2)/2*((a-b)+i*(a+b))
// v_odd = sqrt(2)/2*(a+b+i*(b-a))
v.x =
(odd) ? (-0.707106781186548) * (a + b) : (0.707106781186548) * (a - b);
v.y = (odd) ? (0.707106781186548) * (a - b) : (0.707106781186548) * (a + b);
// v.x = (0.707106781186548 * odd) * (a + b) + (0.707106781186548 * (!odd))
// * (a - b); v.y = (0.707106781186548 * odd) * (b - a) + (0.707106781186548
// * (!odd)) * (a + b);
A[i1] = u + v;
A[i2] = u - v;
tid += params::degree / params::opt;
}
__syncthreads();
// third iteration
// from k=3 on, we have to do the full complex multiplication
tid = threadIdx.x;
#pragma unroll
for (int i = 0; i < params::opt / 2; i++) {
// the butterfly pattern is applied to each pair
// of coefficients, with a stride of 4
// rem is the remainder of tid/4. tid takes values:
// 0, 1, 2, 3, 4, 5, 6, 7, ... N/4
// then rem takes values:
// 0, 1, 2, 3, 0, 1, 2, 3, ... N/4
// and striding by 4 will allow us to cover all
// the coefficients correctly
int rem = tid & 3;
i1 = (tid << 1) - rem;
i2 = i1 + 4;
double2 w = negTwids3[rem];
u = A[i1], v = A[i2] * w;
A[i1] = u + v;
A[i2] = u - v;
tid += params::degree / params::opt;
}
__syncthreads();
// 4_th iteration
tid = threadIdx.x;
#pragma unroll
for (int i = 0; i < params::opt / 2; i++) {
// the butterfly pattern is applied to each pair
// of coefficients, with a stride of 8
// rem is the remainder of tid/8. tid takes values:
// 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, ... N/4
// then rem takes values:
// 0, 1, 2, 3, 4, 5, 6, 7, 0, 1, 2, ... N/4
// and striding by 8 will allow us to cover all
// the coefficients correctly
int rem = tid & 7;
i1 = (tid << 1) - rem;
i2 = i1 + 8;
double2 w = negTwids4[rem];
u = A[i1], v = A[i2] * w;
A[i1] = u + v;
A[i2] = u - v;
tid += params::degree / params::opt;
}
__syncthreads();
// 5_th iteration
tid = threadIdx.x;
#pragma unroll
for (int i = 0; i < params::opt / 2; i++) {
// the butterfly pattern is applied to each pair
// of coefficients, with a stride of 16
// rem is the remainder of tid/16
// and the same logic as for previous iterations applies
int rem = tid & 15;
i1 = (tid << 1) - rem;
i2 = i1 + 16;
double2 w = negTwids5[rem];
u = A[i1], v = A[i2] * w;
A[i1] = u + v;
A[i2] = u - v;
tid += params::degree / params::opt;
}
__syncthreads();
// 6_th iteration
tid = threadIdx.x;
#pragma unroll
for (int i = 0; i < params::opt / 2; i++) {
// the butterfly pattern is applied to each pair
// of coefficients, with a stride of 32
// rem is the remainder of tid/32
// and the same logic as for previous iterations applies
int rem = tid & 31;
i1 = (tid << 1) - rem;
i2 = i1 + 32;
double2 w = negTwids6[rem];
u = A[i1], v = A[i2] * w;
A[i1] = u + v;
A[i2] = u - v;
tid += params::degree / params::opt;
}
__syncthreads();
// 7_th iteration
tid = threadIdx.x;
#pragma unroll
for (int i = 0; i < params::opt / 2; i++) {
// the butterfly pattern is applied to each pair
// of coefficients, with a stride of 64
// rem is the remainder of tid/64
// and the same logic as for previous iterations applies
int rem = tid & 63;
i1 = (tid << 1) - rem;
i2 = i1 + 64;
double2 w = negTwids7[rem];
u = A[i1], v = A[i2] * w;
A[i1] = u + v;
A[i2] = u - v;
tid += params::degree / params::opt;
}
__syncthreads();
// 8_th iteration
tid = threadIdx.x;
#pragma unroll
for (int i = 0; i < params::opt / 2; i++) {
// the butterfly pattern is applied to each pair
// of coefficients, with a stride of 128
// rem is the remainder of tid/128
// and the same logic as for previous iterations applies
int rem = tid & 127;
i1 = (tid << 1) - rem;
i2 = i1 + 128;
double2 w = negTwids8[rem];
u = A[i1], v = A[i2] * w;
A[i1] = u + v;
A[i2] = u - v;
tid += params::degree / params::opt;
}
__syncthreads();
if constexpr (params::log2_degree > 8) {
// 9_th iteration
size_t iter = 1;
// for levels more than 1
// from here none of the twiddles have equal real and imag part, so
// complete complex multiplication has to be done
// here we have more than one twiddles
while (t > 1) {
iter++;
tid = threadIdx.x;
t >>= 1;
m <<= 1;
#pragma unroll
for (int i = 0; i < params::opt / 2; i++) {
// the butterfly pattern is applied to each pair
// of coefficients, with a stride of 256
// rem is the remainder of tid/256
// and the same logic as for previous iterations applies
int rem = tid & 255;
i1 = (tid << 1) - rem;
i2 = i1 + 256;
double2 w = negTwids9[rem];
u = A[i1], v = A[i2] * w;
A[i1] = u + v;
A[i2] = u - v;
for (size_t i = 0; i < params::opt / 2; ++i) {
twid_id = tid / t;
i1 = 2 * t * twid_id + (tid & (t - 1));
i2 = i1 + t;
w = negtwiddles[twid_id + m];
u = A[i1];
v.x = A[i2].x * w.x - A[i2].y * w.y;
v.y = A[i2].y * w.x + A[i2].x * w.y;
A[i1].x += v.x;
A[i1].y += v.y;
A[i2].x = u.x - v.x;
A[i2].y = u.y - v.y;
tid += params::degree / params::opt;
}
__syncthreads();
}
if constexpr (params::log2_degree > 9) {
// 10_th iteration
tid = threadIdx.x;
//#pragma unroll
for (int i = 0; i < params::opt / 2; i++) {
// the butterfly pattern is applied to each pair
// of coefficients, with a stride of 512
// rem is the remainder of tid/512
// and the same logic as for previous iterations applies
int rem = tid & 511;
i1 = (tid << 1) - rem;
i2 = i1 + 512;
double2 w = negTwids10[rem];
u = A[i1], v = A[i2] * w;
A[i1] = u + v;
A[i2] = u - v;
tid += params::degree / params::opt;
}
__syncthreads();
}
if constexpr (params::log2_degree > 10) {
// 11_th iteration
tid = threadIdx.x;
//#pragma unroll
for (int i = 0; i < params::opt / 2; i++) {
// the butterfly pattern is applied to each pair
// of coefficients, with a stride of 1024
// rem is the remainder of tid/1024
// and the same logic as for previous iterations applies
int rem = tid & 1023;
i1 = (tid << 1) - rem;
i2 = i1 + 1024;
double2 w = negTwids11[rem];
u = A[i1], v = A[i2] * w;
A[i1] = u + v;
A[i2] = u - v;
tid += params::degree / params::opt;
}
__syncthreads();
}
if constexpr (params::log2_degree > 11) {
// 12_th iteration
tid = threadIdx.x;
//#pragma unroll
for (int i = 0; i < params::opt / 2; i++) {
// the butterfly pattern is applied to each pair
// of coefficients, with a stride of 2048
// rem is the remainder of tid/2048
// and the same logic as for previous iterations applies
int rem = tid & 2047;
i1 = (tid << 1) - rem;
i2 = i1 + 2048;
double2 w = negTwids12[rem];
u = A[i1], v = A[i2] * w;
A[i1] = u + v;
A[i2] = u - v;
tid += params::degree / params::opt;
}
__syncthreads();
}
// Real polynomials handled should not exceed a degree of 8192
}
/*
* negacyclic inverse fft
*/
template <class params> __device__ void NSMFFT_inverse(double2 *A) {
/* First, reverse the bits of the input complex
* The bit reversal for half-size FFT has been stored into the
* SW1 and SW2 arrays beforehand
/* We don't make bit reverse here, since twiddles are already reversed
* Each thread is always in charge of "opt/2" pairs of coefficients,
* which is why we always loop through N/2 by N/opt strides
* The pragma unroll instruction tells the compiler to unroll the
* full loop, which should increase performance
*/
int tid;
int i1, i2;
double2 u, v;
if constexpr (params::log2_degree > 11) {
// 12_th iteration
size_t tid = threadIdx.x;
size_t twid_id;
size_t m = params::degree;
size_t t = 1;
size_t i1, i2;
double2 u, w;
tid = threadIdx.x;
for (size_t i = 0; i < params::opt; ++i) {
A[tid].x *= 1. / params::degree;
A[tid].y *= 1. / params::degree;
tid += params::degree / params::opt;
}
__syncthreads();
// none of the twiddles have equal real and imag part, so
// complete complex multiplication has to be done
// here we have more than one twiddles
while (m > 1) {
tid = threadIdx.x;
//#pragma unroll
for (int i = 0; i < params::opt / 2; i++) {
// the butterfly pattern is applied to each pair
// of coefficients, with a stride of 2048
// rem is the remainder of tid/2048
// and the same logic as for previous iterations applies
int rem = tid & 2047;
i1 = (tid << 1) - rem;
i2 = i1 + 2048;
double2 w = conjugate(negTwids12[rem]);
u = A[i1], v = A[i2];
A[i1] = (u + v) * 0.5;
A[i2] = (u - v) * w * 0.5;
m >>= 1;
#pragma unroll
for (size_t i = 0; i < params::opt / 2; ++i) {
twid_id = tid / t;
i1 = 2 * t * twid_id + (tid & (t - 1));
i2 = i1 + t;
w = negtwiddles[twid_id + m];
u.x = A[i1].x - A[i2].x;
u.y = A[i1].y - A[i2].y;
A[i1].x += A[i2].x;
A[i1].y += A[i2].y;
A[i2].x = u.x * w.x + u.y * w.y;
A[i2].y = u.y * w.x - u.x * w.y;
tid += params::degree / params::opt;
}
t <<= 1;
__syncthreads();
}
if constexpr (params::log2_degree > 10) {
// 11_th iteration
tid = threadIdx.x;
//#pragma unroll
for (int i = 0; i < params::opt / 2; i++) {
// the butterfly pattern is applied to each pair
// of coefficients, with a stride of 1024
// rem is the remainder of tid/1024
// and the same logic as for previous iterations applies
int rem = tid & 1023;
i1 = (tid << 1) - rem;
i2 = i1 + 1024;
double2 w = conjugate(negTwids11[rem]);
u = A[i1], v = A[i2];
A[i1] = (u + v) * 0.5;
A[i2] = (u - v) * w * 0.5;
tid += params::degree / params::opt;
}
__syncthreads();
}
if constexpr (params::log2_degree > 9) {
// 10_th iteration
tid = threadIdx.x;
//#pragma unroll
for (int i = 0; i < params::opt / 2; i++) {
// the butterfly pattern is applied to each pair
// of coefficients, with a stride of 512
// rem is the remainder of tid/512
// and the same logic as for previous iterations applies
int rem = tid & 511;
i1 = (tid << 1) - rem;
i2 = i1 + 512;
double2 w = conjugate(negTwids10[rem]);
u = A[i1], v = A[i2];
A[i1] = (u + v) * 0.5;
A[i2] = (u - v) * w * 0.5;
tid += params::degree / params::opt;
}
__syncthreads();
}
if constexpr (params::log2_degree > 8) {
// 9_th iteration
tid = threadIdx.x;
#pragma unroll
for (int i = 0; i < params::opt / 2; i++) {
// the butterfly pattern is applied to each pair
// of coefficients, with a stride of 256
// rem is the remainder of tid/256
// and the same logic as for previous iterations applies
int rem = tid & 255;
i1 = (tid << 1) - rem;
i2 = i1 + 256;
double2 w = conjugate(negTwids9[rem]);
u = A[i1], v = A[i2];
A[i1] = (u + v) * 0.5;
A[i2] = (u - v) * w * 0.5;
tid += params::degree / params::opt;
}
__syncthreads();
}
// 8_th iteration
tid = threadIdx.x;
#pragma unroll
for (int i = 0; i < params::opt / 2; i++) {
// the butterfly pattern is applied to each pair
// of coefficients, with a stride of 128
// rem is the remainder of tid/128
// and the same logic as for previous iterations applies
int rem = tid & 127;
i1 = (tid << 1) - rem;
i2 = i1 + 128;
double2 w = conjugate(negTwids8[rem]);
u = A[i1], v = A[i2];
A[i1] = (u + v) * 0.5;
A[i2] = (u - v) * w * 0.5;
tid += params::degree / params::opt;
}
__syncthreads();
// 7_th iteration
tid = threadIdx.x;
#pragma unroll
for (int i = 0; i < params::opt / 2; i++) {
// the butterfly pattern is applied to each pair
// of coefficients, with a stride of 64
// rem is the remainder of tid/64
// and the same logic as for previous iterations applies
int rem = tid & 63;
i1 = (tid << 1) - rem;
i2 = i1 + 64;
double2 w = conjugate(negTwids7[rem]);
u = A[i1], v = A[i2];
A[i1] = (u + v) * 0.5;
A[i2] = (u - v) * w * 0.5;
tid += params::degree / params::opt;
}
__syncthreads();
// 6_th iteration
tid = threadIdx.x;
#pragma unroll
for (int i = 0; i < params::opt / 2; i++) {
// the butterfly pattern is applied to each pair
// of coefficients, with a stride of 32
// rem is the remainder of tid/32
// and the same logic as for previous iterations applies
int rem = tid & 31;
i1 = (tid << 1) - rem;
i2 = i1 + 32;
double2 w = conjugate(negTwids6[rem]);
u = A[i1], v = A[i2];
A[i1] = (u + v) * 0.5;
A[i2] = (u - v) * w * 0.5;
tid += params::degree / params::opt;
}
__syncthreads();
// 5_th iteration
tid = threadIdx.x;
#pragma unroll
for (int i = 0; i < params::opt / 2; i++) {
// the butterfly pattern is applied to each pair
// of coefficients, with a stride of 16
// rem is the remainder of tid/16
// and the same logic as for previous iterations applies
int rem = tid & 15;
i1 = (tid << 1) - rem;
i2 = i1 + 16;
double2 w = conjugate(negTwids5[rem]);
u = A[i1], v = A[i2];
A[i1] = (u + v) * 0.5;
A[i2] = (u - v) * w * 0.5;
tid += params::degree / params::opt;
}
__syncthreads();
// 4_th iteration
tid = threadIdx.x;
#pragma unroll
for (int i = 0; i < params::opt / 2; i++) {
// the butterfly pattern is applied to each pair
// of coefficients, with a stride of 8
// rem is the remainder of tid/8. tid takes values:
// 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, ... N/4
// then rem takes values:
// 0, 1, 2, 3, 4, 5, 6, 7, 0, 1, 2, ... N/4
// and striding by 8 will allow us to cover all
// the coefficients correctly
int rem = tid & 7;
i1 = (tid << 1) - rem;
i2 = i1 + 8;
double2 w = conjugate(negTwids4[rem]);
u = A[i1], v = A[i2];
A[i1] = (u + v) * 0.5;
A[i2] = (u - v) * w * 0.5;
tid += params::degree / params::opt;
}
__syncthreads();
// third iteration
// from k=3 on, we have to do the full complex multiplication
tid = threadIdx.x;
#pragma unroll
for (int i = 0; i < params::opt / 2; i++) {
// the butterfly pattern is applied to each pair
// of coefficients, with a stride of 4
// rem is the remainder of tid/4. tid takes values:
// 0, 1, 2, 3, 4, 5, 6, 7, ... N/4
// then rem takes values:
// 0, 1, 2, 3, 0, 1, 2, 3, ... N/4
// and striding by 4 will allow us to cover all
// the coefficients correctly
int rem = tid & 3;
i1 = (tid << 1) - rem;
i2 = i1 + 4;
double2 w = conjugate(negTwids3[rem]);
u = A[i1], v = A[i2];
A[i1] = (u + v) * 0.5;
A[i2] = (u - v) * w * 0.5;
tid += params::degree / params::opt;
}
__syncthreads();
// second iteration: apply the butterfly pattern
// between groups of 4 coefficients
// k=2, \zeta=exp(i pi/4) for even coefficients and
// exp(3 i pi / 4) for odd coefficients
tid = threadIdx.x;
// odd = 0 for even coefficients, 1 for odd coefficients
int odd = tid & 1;
#pragma unroll
for (int i = 0; i < params::opt / 2; i++) {
// the butterfly pattern is applied to each pair
// of coefficients, with a stride of 2
// i1=2*tid if tid is even and 2*tid-1 if it is odd
i1 = (tid << 1) - odd;
i2 = i1 + 2;
double2 w;
if (odd) {
w.x = -0.707106781186547461715008466854;
w.y = -0.707106781186547572737310929369;
} else {
w.x = 0.707106781186547461715008466854;
w.y = -0.707106781186547572737310929369;
}
u = A[i1], v = A[i2];
A[i1] = (u + v) * 0.5;
A[i2] = (u - v) * w * 0.5;
tid += params::degree / params::opt;
}
__syncthreads();
// Now we go through all the levels of the FFT one by one
// (instead of recursively)
// first iteration: k=1, zeta=i for all coefficients
tid = threadIdx.x;
#pragma unroll
for (int i = 0; i < params::opt / 2; i++) {
// the butterfly pattern is applied to each pair
// of coefficients, with a stride of 2
i1 = tid << 1;
i2 = i1 + 1;
double2 w = {0, -1};
u = A[i1], v = A[i2];
A[i1] = (u + v) * 0.5;
A[i2] = (u - v) * w * 0.5;
tid += params::degree / params::opt;
}
__syncthreads();
bit_reverse_inplace<params>(A);
__syncthreads();
// Real polynomials handled should not exceed a degree of 8192
}
/*
* correction after direct fft
* does not use extra shared memory for recovering
* correction is done using registers.
* based on Pascal's paper
*/
template <class params>
__device__ void correction_direct_fft_inplace(double2 *x) {
constexpr int threads = params::degree / params::opt;
int tid = threadIdx.x;
double2 left[params::opt / 4];
double2 right[params::opt / 4];
#pragma unroll
for (int i = 0; i < params::opt / 4; i++) {
left[i] = x[tid + i * threads];
}
#pragma unroll
for (int i = 0; i < params::opt / 4; i++) {
right[i] = x[params::degree / 2 - (tid + i * threads + 1)];
}
#pragma unroll
for (int i = 0; i < params::opt / 4; i++) {
double2 tw = negacyclic_twiddle<params::degree>(tid + i * threads);
double add_RE = left[i].x + right[i].x;
double sub_RE = left[i].x - right[i].x;
double add_IM = left[i].y + right[i].y;
double sub_IM = left[i].y - right[i].y;
double tmp1 = add_IM * tw.x + sub_RE * tw.y;
double tmp2 = -sub_RE * tw.x + add_IM * tw.y;
x[tid + i * threads].x = (add_RE + tmp1) * 0.5;
x[tid + i * threads].y = (sub_IM + tmp2) * 0.5;
x[params::degree / 2 - (tid + i * threads + 1)].x = (add_RE - tmp1) * 0.5;
x[params::degree / 2 - (tid + i * threads + 1)].y = (-sub_IM + tmp2) * 0.5;
}
}
/*
* correction before inverse fft
* does not use extra shared memory for recovering
* correction is done using registers.
* based on Pascal's paper
*/
template <class params>
__device__ void correction_inverse_fft_inplace(double2 *x) {
constexpr int threads = params::degree / params::opt;
int tid = threadIdx.x;
double2 left[params::opt / 4];
double2 right[params::opt / 4];
#pragma unroll
for (int i = 0; i < params::opt / 4; i++) {
left[i] = x[tid + i * threads];
}
#pragma unroll
for (int i = 0; i < params::opt / 4; i++) {
right[i] = x[params::degree / 2 - (tid + i * threads + 1)];
}
#pragma unroll
for (int i = 0; i < params::opt / 4; i++) {
double2 tw = negacyclic_twiddle<params::degree>(tid + i * threads);
double add_RE = left[i].x + right[i].x;
double sub_RE = left[i].x - right[i].x;
double add_IM = left[i].y + right[i].y;
double sub_IM = left[i].y - right[i].y;
double tmp1 = add_IM * tw.x - sub_RE * tw.y;
double tmp2 = sub_RE * tw.x + add_IM * tw.y;
x[tid + i * threads].x = (add_RE - tmp1) * 0.5;
x[tid + i * threads].y = (sub_IM + tmp2) * 0.5;
x[params::degree / 2 - (tid + i * threads + 1)].x = (add_RE + tmp1) * 0.5;
x[params::degree / 2 - (tid + i * threads + 1)].y = (-sub_IM + tmp2) * 0.5;
}
}
/*
@@ -735,8 +165,6 @@ __global__ void batch_NSMFFT(double2 *d_input, double2 *d_output,
__syncthreads();
NSMFFT_direct<HalfDegree<params>>(fft);
__syncthreads();
correction_direct_fft_inplace<params>(fft);
__syncthreads();
tid = threadIdx.x;
#pragma unroll

File diff suppressed because it is too large Load Diff

View File

@@ -2,19 +2,6 @@
#ifndef GPU_BOOTSTRAP_TWIDDLES_CUH
#define GPU_BOOTSTRAP_TWIDDLES_CUH
extern __constant__ short SW1[4096];
extern __constant__ short SW2[4096];
extern __constant__ double2 negTwids3[4];
extern __constant__ double2 negTwids4[8];
extern __constant__ double2 negTwids5[16];
extern __constant__ double2 negTwids6[32];
extern __constant__ double2 negTwids7[64];
extern __constant__ double2 negTwids8[128];
extern __constant__ double2 negTwids9[256];
extern __constant__ double2 negTwids10[512];
extern __constant__ double2 negTwids11[1024];
extern __device__ double2 negTwids12[2048];
extern __device__ double2 negTwids13[4096];
extern __constant__ double2 negtwiddles[4096];
#endif

View File

@@ -170,8 +170,8 @@ __device__ void add_to_torus(double2 *m_values, Torus *result) {
Torus V2 = 0;
typecast_double_to_torus<Torus>(frac, V2);
result[tid * 2] += V1;
result[tid * 2 + 1] += V2;
result[tid] += V1;
result[tid + params::degree / 2] += V2;
tid = tid + params::degree / params::opt;
}
}

View File

@@ -21,17 +21,11 @@ template <class params> __device__ void fft(double2 *output) {
// Switch to the FFT space
NSMFFT_direct<HalfDegree<params>>(output);
synchronize_threads_in_block();
correction_direct_fft_inplace<params>(output);
synchronize_threads_in_block();
}
template <class params> __device__ void ifft_inplace(double2 *data) {
synchronize_threads_in_block();
correction_inverse_fft_inplace<params>(data);
synchronize_threads_in_block();
NSMFFT_inverse<HalfDegree<params>>(data);
synchronize_threads_in_block();
}