mirror of
https://github.com/zama-ai/concrete.git
synced 2026-02-08 19:44:57 -05:00
feat(cuda): add a new fft algorithm.
- FFT can work for any polynomial size, as long as twiddles are provided. - All the twiddles fit in the constant memory. - Bit reverse is not used anymore, no more sw1 and sw2 arrays in constant memory. - Real to complex compression algorithm is changed. - Twiddle initialization functions are removed.
This commit is contained in:
committed by
bbarbakadze
parent
bd9cbbc7af
commit
3cd48f0de2
@@ -5,9 +5,6 @@
|
||||
|
||||
extern "C" {
|
||||
|
||||
void cuda_initialize_twiddles(uint32_t polynomial_size, void *v_stream,
|
||||
uint32_t gpu_index);
|
||||
|
||||
void cuda_convert_lwe_bootstrap_key_32(void *dest, void *src, void *v_stream,
|
||||
uint32_t gpu_index,
|
||||
uint32_t input_lwe_dim,
|
||||
|
||||
@@ -165,8 +165,6 @@ __global__ void device_bootstrap_amortized(
|
||||
NSMFFT_direct<HalfDegree<params>>(accumulator_fft);
|
||||
synchronize_threads_in_block();
|
||||
|
||||
correction_direct_fft_inplace<params>(accumulator_fft);
|
||||
|
||||
// Get the bootstrapping key piece necessary for the multiplication
|
||||
// It is already in the Fourier domain
|
||||
auto bsk_mask_slice =
|
||||
@@ -194,8 +192,6 @@ __global__ void device_bootstrap_amortized(
|
||||
NSMFFT_direct<HalfDegree<params>>(accumulator_fft);
|
||||
synchronize_threads_in_block();
|
||||
|
||||
correction_direct_fft_inplace<params>(accumulator_fft);
|
||||
|
||||
auto bsk_mask_slice_2 =
|
||||
get_ith_mask_kth_block(bootstrapping_key, iteration, 1, level,
|
||||
polynomial_size, 1, level_count);
|
||||
@@ -215,10 +211,6 @@ __global__ void device_bootstrap_amortized(
|
||||
if constexpr (SMD == FULLSM || SMD == NOSM) {
|
||||
synchronize_threads_in_block();
|
||||
|
||||
correction_inverse_fft_inplace<params>(mask_res_fft);
|
||||
correction_inverse_fft_inplace<params>(body_res_fft);
|
||||
synchronize_threads_in_block();
|
||||
|
||||
NSMFFT_inverse<HalfDegree<params>>(mask_res_fft);
|
||||
NSMFFT_inverse<HalfDegree<params>>(body_res_fft);
|
||||
|
||||
@@ -227,6 +219,7 @@ __global__ void device_bootstrap_amortized(
|
||||
add_to_torus<Torus, params>(mask_res_fft, accumulator_mask);
|
||||
add_to_torus<Torus, params>(body_res_fft, accumulator_body);
|
||||
synchronize_threads_in_block();
|
||||
|
||||
} else {
|
||||
int tid = threadIdx.x;
|
||||
#pragma unroll
|
||||
@@ -236,9 +229,6 @@ __global__ void device_bootstrap_amortized(
|
||||
}
|
||||
synchronize_threads_in_block();
|
||||
|
||||
correction_inverse_fft_inplace<params>(accumulator_fft);
|
||||
synchronize_threads_in_block();
|
||||
|
||||
NSMFFT_inverse<HalfDegree<params>>(accumulator_fft);
|
||||
synchronize_threads_in_block();
|
||||
|
||||
@@ -253,9 +243,6 @@ __global__ void device_bootstrap_amortized(
|
||||
}
|
||||
synchronize_threads_in_block();
|
||||
|
||||
correction_inverse_fft_inplace<params>(accumulator_fft);
|
||||
synchronize_threads_in_block();
|
||||
|
||||
NSMFFT_inverse<HalfDegree<params>>(accumulator_fft);
|
||||
synchronize_threads_in_block();
|
||||
|
||||
|
||||
@@ -38,9 +38,6 @@ mul_ggsw_glwe(Torus *accumulator, double2 *fft, double2 *mask_join_buffer,
|
||||
NSMFFT_direct<HalfDegree<params>>(fft);
|
||||
synchronize_threads_in_block();
|
||||
|
||||
correction_direct_fft_inplace<params>(fft);
|
||||
synchronize_threads_in_block();
|
||||
|
||||
// Get the pieces of the bootstrapping key that will be needed for the
|
||||
// external product; blockIdx.x is the ID of the block that's executing
|
||||
// this function, so we end up getting the lines of the bootstrapping key
|
||||
@@ -113,9 +110,6 @@ mul_ggsw_glwe(Torus *accumulator, double2 *fft, double2 *mask_join_buffer,
|
||||
|
||||
synchronize_threads_in_block();
|
||||
|
||||
correction_inverse_fft_inplace<params>(fft);
|
||||
synchronize_threads_in_block();
|
||||
|
||||
// Perform the inverse FFT on the result of the GGSW x GLWE and add to the
|
||||
// accumulator
|
||||
NSMFFT_inverse<HalfDegree<params>>(fft);
|
||||
|
||||
@@ -38,39 +38,6 @@ __device__ T *get_ith_body_kth_block(T *ptr, int i, int k, int level,
|
||||
polynomial_size / 2];
|
||||
}
|
||||
|
||||
void cuda_initialize_twiddles(uint32_t polynomial_size, void *v_stream,
|
||||
uint32_t gpu_index) {
|
||||
cudaSetDevice(gpu_index);
|
||||
int sw_size = polynomial_size / 2;
|
||||
short *sw1_h, *sw2_h;
|
||||
|
||||
sw1_h = (short *)malloc(sizeof(short) * sw_size);
|
||||
sw2_h = (short *)malloc(sizeof(short) * sw_size);
|
||||
|
||||
memset(sw1_h, 0, sw_size * sizeof(short));
|
||||
memset(sw2_h, 0, sw_size * sizeof(short));
|
||||
int cnt = 0;
|
||||
for (int i = 1, j = 0; i < polynomial_size / 2; i++) {
|
||||
int bit = (polynomial_size / 2) >> 1;
|
||||
for (; j & bit; bit >>= 1)
|
||||
j ^= bit;
|
||||
j ^= bit;
|
||||
|
||||
if (i < j) {
|
||||
sw1_h[cnt] = i;
|
||||
sw2_h[cnt] = j;
|
||||
cnt++;
|
||||
}
|
||||
}
|
||||
auto stream = static_cast<cudaStream_t *>(v_stream);
|
||||
cudaMemcpyToSymbolAsync(SW1, sw1_h, sw_size * sizeof(short), 0,
|
||||
cudaMemcpyHostToDevice, *stream);
|
||||
cudaMemcpyToSymbolAsync(SW2, sw2_h, sw_size * sizeof(short), 0,
|
||||
cudaMemcpyHostToDevice, *stream);
|
||||
free(sw1_h);
|
||||
free(sw2_h);
|
||||
}
|
||||
|
||||
template <typename T, typename ST>
|
||||
void cuda_convert_lwe_bootstrap_key(double2 *dest, ST *src, void *v_stream,
|
||||
uint32_t gpu_index, uint32_t input_lwe_dim,
|
||||
@@ -101,10 +68,9 @@ void cuda_convert_lwe_bootstrap_key(double2 *dest, ST *src, void *v_stream,
|
||||
int complex_current_poly_idx = i * polynomial_size / 2;
|
||||
int torus_current_poly_idx = i * polynomial_size;
|
||||
for (int j = 0; j < polynomial_size / 2; j++) {
|
||||
h_bsk[complex_current_poly_idx + j].x =
|
||||
src[torus_current_poly_idx + 2 * j];
|
||||
h_bsk[complex_current_poly_idx + j].x = src[torus_current_poly_idx + j];
|
||||
h_bsk[complex_current_poly_idx + j].y =
|
||||
src[torus_current_poly_idx + 2 * j + 1];
|
||||
src[torus_current_poly_idx + j + polynomial_size / 2];
|
||||
h_bsk[complex_current_poly_idx + j].x /=
|
||||
(double)std::numeric_limits<T>::max();
|
||||
h_bsk[complex_current_poly_idx + j].y /=
|
||||
|
||||
@@ -34,16 +34,16 @@ public:
|
||||
int tid = threadIdx.x;
|
||||
current_level -= 1;
|
||||
for (int i = 0; i < params::opt / 2; i++) {
|
||||
T res_re = state[tid * 2] & mask_mod_b;
|
||||
T res_im = state[tid * 2 + 1] & mask_mod_b;
|
||||
state[tid * 2] >>= base_log;
|
||||
state[tid * 2 + 1] >>= base_log;
|
||||
T carry_re = ((res_re - 1ll) | state[tid * 2]) & res_re;
|
||||
T carry_im = ((res_im - 1ll) | state[tid * 2 + 1]) & res_im;
|
||||
T res_re = state[tid] & mask_mod_b;
|
||||
T res_im = state[tid + params::degree / 2] & mask_mod_b;
|
||||
state[tid] >>= base_log;
|
||||
state[tid + params::degree / 2] >>= base_log;
|
||||
T carry_re = ((res_re - 1ll) | state[tid]) & res_re;
|
||||
T carry_im = ((res_im - 1ll) | state[tid + params::degree / 2]) & res_im;
|
||||
carry_re >>= (base_log - 1);
|
||||
carry_im >>= (base_log - 1);
|
||||
state[tid * 2] += carry_re;
|
||||
state[tid * 2 + 1] += carry_im;
|
||||
state[tid] += carry_re;
|
||||
state[tid + params::degree / 2] += carry_im;
|
||||
res_re -= carry_re << base_log;
|
||||
res_im -= carry_im << base_log;
|
||||
|
||||
|
||||
@@ -23,8 +23,8 @@ __global__ void device_batch_fft_ggsw_vector(double2 *dest, T *src,
|
||||
|
||||
#pragma unroll
|
||||
for (int i = 0; i < log_2_opt; i++) {
|
||||
ST x = src[(2 * tid) + params::opt * offset];
|
||||
ST y = src[(2 * tid + 1) + params::opt * offset];
|
||||
ST x = src[(tid) + params::opt * offset];
|
||||
ST y = src[(tid + params::degree / 2) + params::opt * offset];
|
||||
selected_memory[tid].x = x / (double)std::numeric_limits<T>::max();
|
||||
selected_memory[tid].y = y / (double)std::numeric_limits<T>::max();
|
||||
tid += params::degree / params::opt;
|
||||
@@ -35,9 +35,6 @@ __global__ void device_batch_fft_ggsw_vector(double2 *dest, T *src,
|
||||
NSMFFT_direct<HalfDegree<params>>(selected_memory);
|
||||
synchronize_threads_in_block();
|
||||
|
||||
correction_direct_fft_inplace<params>(selected_memory);
|
||||
synchronize_threads_in_block();
|
||||
|
||||
// Write the output to global memory
|
||||
tid = threadIdx.x;
|
||||
#pragma unroll
|
||||
|
||||
@@ -6,55 +6,6 @@
|
||||
#include "polynomial/parameters.cuh"
|
||||
#include "twiddles.cuh"
|
||||
|
||||
/*
|
||||
* bit reverse
|
||||
* coefficient bits are reversed based on precalculated indexes
|
||||
* SW1 and SW2
|
||||
*/
|
||||
template <class params> __device__ void bit_reverse_inplace(double2 *A) {
|
||||
int tid = threadIdx.x;
|
||||
#pragma unroll
|
||||
for (int i = 0; i < params::opt / 2; i++) {
|
||||
short sw1 = SW1[tid];
|
||||
short sw2 = SW2[tid];
|
||||
double2 tmp = A[sw1];
|
||||
A[sw1] = A[sw2];
|
||||
A[sw2] = tmp;
|
||||
tid += params::degree / params::opt;
|
||||
}
|
||||
}
|
||||
|
||||
/*
|
||||
* negacyclic twiddle
|
||||
* returns negacyclic twiddle based on degree and index
|
||||
* twiddles are precalculated inside negTwids{3..13} arrays
|
||||
*/
|
||||
template <int degree> __device__ double2 negacyclic_twiddle(int j) {
|
||||
double2 twid;
|
||||
switch (degree) {
|
||||
case 512:
|
||||
twid = negTwids9[j];
|
||||
break;
|
||||
case 1024:
|
||||
twid = negTwids10[j];
|
||||
break;
|
||||
case 2048:
|
||||
twid = negTwids11[j];
|
||||
break;
|
||||
case 4096:
|
||||
twid = negTwids12[j];
|
||||
break;
|
||||
case 8192:
|
||||
twid = negTwids13[j];
|
||||
break;
|
||||
default:
|
||||
twid.x = 0;
|
||||
twid.y = 0;
|
||||
break;
|
||||
}
|
||||
return twid;
|
||||
}
|
||||
|
||||
/*
|
||||
* Direct negacyclic FFT:
|
||||
* - before the FFT the N real coefficients are stored into a
|
||||
@@ -75,641 +26,120 @@ template <int degree> __device__ double2 negacyclic_twiddle(int j) {
|
||||
* forward_negacyclic_fft_inplace function of bootstrap.cuh
|
||||
*/
|
||||
template <class params> __device__ void NSMFFT_direct(double2 *A) {
|
||||
/* First, reverse the bits of the input complex
|
||||
* The bit reversal for half-size FFT has been stored into the
|
||||
* SW1 and SW2 arrays beforehand
|
||||
|
||||
/* We don't make bit reverse here, since twiddles are already reversed
|
||||
* Each thread is always in charge of "opt/2" pairs of coefficients,
|
||||
* which is why we always loop through N/2 by N/opt strides
|
||||
* The pragma unroll instruction tells the compiler to unroll the
|
||||
* full loop, which should increase performance
|
||||
*/
|
||||
bit_reverse_inplace<params>(A);
|
||||
__syncthreads();
|
||||
|
||||
// Now we go through all the levels of the FFT one by one
|
||||
// (instead of recursively)
|
||||
// first iteration: k=1, zeta=i for all coefficients
|
||||
int tid = threadIdx.x;
|
||||
int i1, i2;
|
||||
double2 u, v;
|
||||
size_t tid = threadIdx.x;
|
||||
size_t twid_id;
|
||||
size_t t = params::degree / 2;
|
||||
size_t m = 1;
|
||||
size_t i1, i2;
|
||||
double2 u, v, w;
|
||||
// level 1
|
||||
// we don't make actual complex multiplication on level1 since we have only
|
||||
// one twiddle, it's real and image parts are equal, so we can multiply
|
||||
// it with simpler operations
|
||||
#pragma unroll
|
||||
for (int i = 0; i < params::opt / 2; i++) {
|
||||
// the butterfly pattern is applied to each pair
|
||||
// of coefficients, with a stride of 2
|
||||
i1 = tid << 1;
|
||||
i2 = i1 + 1;
|
||||
for (size_t i = 0; i < params::opt / 2; ++i) {
|
||||
i1 = tid;
|
||||
i2 = tid + t;
|
||||
u = A[i1];
|
||||
// v = i*A[i2]
|
||||
v.y = A[i2].x;
|
||||
v.x = -A[i2].y;
|
||||
// A[i1] <- A[i1] + i*A[i2]
|
||||
// A[i2] <- A[i1] - i*A[i2]
|
||||
A[i1] += v;
|
||||
A[i2] = u - v;
|
||||
v.x = (A[i2].x - A[i2].y) * 0.707106781186547461715008466854;
|
||||
v.y = (A[i2].x + A[i2].y) * 0.707106781186547461715008466854;
|
||||
A[i1].x += v.x;
|
||||
A[i1].y += v.y;
|
||||
|
||||
A[i2].x = u.x - v.x;
|
||||
A[i2].y = u.y - v.y;
|
||||
tid += params::degree / params::opt;
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
// second iteration: apply the butterfly pattern
|
||||
// between groups of 4 coefficients
|
||||
// k=2, \zeta=exp(i pi/4) for even coefficients and
|
||||
// exp(3 i pi / 4) for odd coefficients
|
||||
tid = threadIdx.x;
|
||||
// odd = 0 for even coefficients, 1 for odd coefficients
|
||||
int odd = tid & 1;
|
||||
#pragma unroll
|
||||
for (int i = 0; i < params::opt / 2; i++) {
|
||||
// the butterfly pattern is applied to each pair
|
||||
// of coefficients, with a stride of 2
|
||||
// i1=2*tid if tid is even and 2*tid-1 if it is odd
|
||||
i1 = (tid << 1) - odd;
|
||||
i2 = i1 + 2;
|
||||
|
||||
double a = A[i2].x;
|
||||
double b = A[i2].y;
|
||||
u = A[i1];
|
||||
|
||||
// \zeta_j,2 = exp(-i pi (2j-1)/4) -> j=0: exp(i pi/4) or j=1: exp(-i pi/4)
|
||||
// \zeta_even = sqrt(2)/2 + i * sqrt(2)/2 = sqrt(2)/2*(1+i)
|
||||
// \zeta_odd = sqrt(2)/2 - i * sqrt(2)/2 = sqrt(2)/2*(1-i)
|
||||
|
||||
// v_j = \zeta_j * (a+i*b)
|
||||
// v_even = sqrt(2)/2*((a-b)+i*(a+b))
|
||||
// v_odd = sqrt(2)/2*(a+b+i*(b-a))
|
||||
v.x =
|
||||
(odd) ? (-0.707106781186548) * (a + b) : (0.707106781186548) * (a - b);
|
||||
v.y = (odd) ? (0.707106781186548) * (a - b) : (0.707106781186548) * (a + b);
|
||||
|
||||
// v.x = (0.707106781186548 * odd) * (a + b) + (0.707106781186548 * (!odd))
|
||||
// * (a - b); v.y = (0.707106781186548 * odd) * (b - a) + (0.707106781186548
|
||||
// * (!odd)) * (a + b);
|
||||
|
||||
A[i1] = u + v;
|
||||
A[i2] = u - v;
|
||||
tid += params::degree / params::opt;
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
// third iteration
|
||||
// from k=3 on, we have to do the full complex multiplication
|
||||
tid = threadIdx.x;
|
||||
#pragma unroll
|
||||
for (int i = 0; i < params::opt / 2; i++) {
|
||||
// the butterfly pattern is applied to each pair
|
||||
// of coefficients, with a stride of 4
|
||||
// rem is the remainder of tid/4. tid takes values:
|
||||
// 0, 1, 2, 3, 4, 5, 6, 7, ... N/4
|
||||
// then rem takes values:
|
||||
// 0, 1, 2, 3, 0, 1, 2, 3, ... N/4
|
||||
// and striding by 4 will allow us to cover all
|
||||
// the coefficients correctly
|
||||
int rem = tid & 3;
|
||||
i1 = (tid << 1) - rem;
|
||||
i2 = i1 + 4;
|
||||
|
||||
double2 w = negTwids3[rem];
|
||||
u = A[i1], v = A[i2] * w;
|
||||
|
||||
A[i1] = u + v;
|
||||
A[i2] = u - v;
|
||||
tid += params::degree / params::opt;
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
// 4_th iteration
|
||||
tid = threadIdx.x;
|
||||
#pragma unroll
|
||||
for (int i = 0; i < params::opt / 2; i++) {
|
||||
// the butterfly pattern is applied to each pair
|
||||
// of coefficients, with a stride of 8
|
||||
// rem is the remainder of tid/8. tid takes values:
|
||||
// 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, ... N/4
|
||||
// then rem takes values:
|
||||
// 0, 1, 2, 3, 4, 5, 6, 7, 0, 1, 2, ... N/4
|
||||
// and striding by 8 will allow us to cover all
|
||||
// the coefficients correctly
|
||||
int rem = tid & 7;
|
||||
i1 = (tid << 1) - rem;
|
||||
i2 = i1 + 8;
|
||||
|
||||
double2 w = negTwids4[rem];
|
||||
u = A[i1], v = A[i2] * w;
|
||||
A[i1] = u + v;
|
||||
A[i2] = u - v;
|
||||
tid += params::degree / params::opt;
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
// 5_th iteration
|
||||
tid = threadIdx.x;
|
||||
#pragma unroll
|
||||
for (int i = 0; i < params::opt / 2; i++) {
|
||||
// the butterfly pattern is applied to each pair
|
||||
// of coefficients, with a stride of 16
|
||||
// rem is the remainder of tid/16
|
||||
// and the same logic as for previous iterations applies
|
||||
int rem = tid & 15;
|
||||
i1 = (tid << 1) - rem;
|
||||
i2 = i1 + 16;
|
||||
double2 w = negTwids5[rem];
|
||||
u = A[i1], v = A[i2] * w;
|
||||
A[i1] = u + v;
|
||||
A[i2] = u - v;
|
||||
tid += params::degree / params::opt;
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
// 6_th iteration
|
||||
tid = threadIdx.x;
|
||||
#pragma unroll
|
||||
for (int i = 0; i < params::opt / 2; i++) {
|
||||
// the butterfly pattern is applied to each pair
|
||||
// of coefficients, with a stride of 32
|
||||
// rem is the remainder of tid/32
|
||||
// and the same logic as for previous iterations applies
|
||||
int rem = tid & 31;
|
||||
i1 = (tid << 1) - rem;
|
||||
i2 = i1 + 32;
|
||||
double2 w = negTwids6[rem];
|
||||
u = A[i1], v = A[i2] * w;
|
||||
A[i1] = u + v;
|
||||
A[i2] = u - v;
|
||||
tid += params::degree / params::opt;
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
// 7_th iteration
|
||||
tid = threadIdx.x;
|
||||
#pragma unroll
|
||||
for (int i = 0; i < params::opt / 2; i++) {
|
||||
// the butterfly pattern is applied to each pair
|
||||
// of coefficients, with a stride of 64
|
||||
// rem is the remainder of tid/64
|
||||
// and the same logic as for previous iterations applies
|
||||
int rem = tid & 63;
|
||||
i1 = (tid << 1) - rem;
|
||||
i2 = i1 + 64;
|
||||
double2 w = negTwids7[rem];
|
||||
u = A[i1], v = A[i2] * w;
|
||||
A[i1] = u + v;
|
||||
A[i2] = u - v;
|
||||
tid += params::degree / params::opt;
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
// 8_th iteration
|
||||
tid = threadIdx.x;
|
||||
#pragma unroll
|
||||
for (int i = 0; i < params::opt / 2; i++) {
|
||||
// the butterfly pattern is applied to each pair
|
||||
// of coefficients, with a stride of 128
|
||||
// rem is the remainder of tid/128
|
||||
// and the same logic as for previous iterations applies
|
||||
int rem = tid & 127;
|
||||
i1 = (tid << 1) - rem;
|
||||
i2 = i1 + 128;
|
||||
double2 w = negTwids8[rem];
|
||||
u = A[i1], v = A[i2] * w;
|
||||
A[i1] = u + v;
|
||||
A[i2] = u - v;
|
||||
tid += params::degree / params::opt;
|
||||
}
|
||||
__syncthreads();
|
||||
if constexpr (params::log2_degree > 8) {
|
||||
// 9_th iteration
|
||||
size_t iter = 1;
|
||||
// for levels more than 1
|
||||
// from here none of the twiddles have equal real and imag part, so
|
||||
// complete complex multiplication has to be done
|
||||
// here we have more than one twiddles
|
||||
while (t > 1) {
|
||||
iter++;
|
||||
tid = threadIdx.x;
|
||||
t >>= 1;
|
||||
m <<= 1;
|
||||
#pragma unroll
|
||||
for (int i = 0; i < params::opt / 2; i++) {
|
||||
// the butterfly pattern is applied to each pair
|
||||
// of coefficients, with a stride of 256
|
||||
// rem is the remainder of tid/256
|
||||
// and the same logic as for previous iterations applies
|
||||
int rem = tid & 255;
|
||||
i1 = (tid << 1) - rem;
|
||||
i2 = i1 + 256;
|
||||
double2 w = negTwids9[rem];
|
||||
u = A[i1], v = A[i2] * w;
|
||||
A[i1] = u + v;
|
||||
A[i2] = u - v;
|
||||
for (size_t i = 0; i < params::opt / 2; ++i) {
|
||||
twid_id = tid / t;
|
||||
i1 = 2 * t * twid_id + (tid & (t - 1));
|
||||
i2 = i1 + t;
|
||||
w = negtwiddles[twid_id + m];
|
||||
u = A[i1];
|
||||
v.x = A[i2].x * w.x - A[i2].y * w.y;
|
||||
v.y = A[i2].y * w.x + A[i2].x * w.y;
|
||||
A[i1].x += v.x;
|
||||
A[i1].y += v.y;
|
||||
A[i2].x = u.x - v.x;
|
||||
A[i2].y = u.y - v.y;
|
||||
tid += params::degree / params::opt;
|
||||
}
|
||||
__syncthreads();
|
||||
}
|
||||
if constexpr (params::log2_degree > 9) {
|
||||
// 10_th iteration
|
||||
tid = threadIdx.x;
|
||||
//#pragma unroll
|
||||
for (int i = 0; i < params::opt / 2; i++) {
|
||||
// the butterfly pattern is applied to each pair
|
||||
// of coefficients, with a stride of 512
|
||||
// rem is the remainder of tid/512
|
||||
// and the same logic as for previous iterations applies
|
||||
int rem = tid & 511;
|
||||
i1 = (tid << 1) - rem;
|
||||
i2 = i1 + 512;
|
||||
double2 w = negTwids10[rem];
|
||||
u = A[i1], v = A[i2] * w;
|
||||
A[i1] = u + v;
|
||||
A[i2] = u - v;
|
||||
tid += params::degree / params::opt;
|
||||
}
|
||||
__syncthreads();
|
||||
}
|
||||
|
||||
if constexpr (params::log2_degree > 10) {
|
||||
// 11_th iteration
|
||||
tid = threadIdx.x;
|
||||
//#pragma unroll
|
||||
for (int i = 0; i < params::opt / 2; i++) {
|
||||
// the butterfly pattern is applied to each pair
|
||||
// of coefficients, with a stride of 1024
|
||||
// rem is the remainder of tid/1024
|
||||
// and the same logic as for previous iterations applies
|
||||
int rem = tid & 1023;
|
||||
i1 = (tid << 1) - rem;
|
||||
i2 = i1 + 1024;
|
||||
double2 w = negTwids11[rem];
|
||||
u = A[i1], v = A[i2] * w;
|
||||
A[i1] = u + v;
|
||||
A[i2] = u - v;
|
||||
tid += params::degree / params::opt;
|
||||
}
|
||||
__syncthreads();
|
||||
}
|
||||
|
||||
if constexpr (params::log2_degree > 11) {
|
||||
// 12_th iteration
|
||||
tid = threadIdx.x;
|
||||
//#pragma unroll
|
||||
for (int i = 0; i < params::opt / 2; i++) {
|
||||
// the butterfly pattern is applied to each pair
|
||||
// of coefficients, with a stride of 2048
|
||||
// rem is the remainder of tid/2048
|
||||
// and the same logic as for previous iterations applies
|
||||
int rem = tid & 2047;
|
||||
i1 = (tid << 1) - rem;
|
||||
i2 = i1 + 2048;
|
||||
double2 w = negTwids12[rem];
|
||||
u = A[i1], v = A[i2] * w;
|
||||
A[i1] = u + v;
|
||||
A[i2] = u - v;
|
||||
tid += params::degree / params::opt;
|
||||
}
|
||||
__syncthreads();
|
||||
}
|
||||
// Real polynomials handled should not exceed a degree of 8192
|
||||
}
|
||||
|
||||
/*
|
||||
* negacyclic inverse fft
|
||||
*/
|
||||
template <class params> __device__ void NSMFFT_inverse(double2 *A) {
|
||||
/* First, reverse the bits of the input complex
|
||||
* The bit reversal for half-size FFT has been stored into the
|
||||
* SW1 and SW2 arrays beforehand
|
||||
|
||||
/* We don't make bit reverse here, since twiddles are already reversed
|
||||
* Each thread is always in charge of "opt/2" pairs of coefficients,
|
||||
* which is why we always loop through N/2 by N/opt strides
|
||||
* The pragma unroll instruction tells the compiler to unroll the
|
||||
* full loop, which should increase performance
|
||||
*/
|
||||
int tid;
|
||||
int i1, i2;
|
||||
double2 u, v;
|
||||
if constexpr (params::log2_degree > 11) {
|
||||
// 12_th iteration
|
||||
|
||||
size_t tid = threadIdx.x;
|
||||
size_t twid_id;
|
||||
size_t m = params::degree;
|
||||
size_t t = 1;
|
||||
size_t i1, i2;
|
||||
double2 u, w;
|
||||
|
||||
tid = threadIdx.x;
|
||||
for (size_t i = 0; i < params::opt; ++i) {
|
||||
A[tid].x *= 1. / params::degree;
|
||||
A[tid].y *= 1. / params::degree;
|
||||
tid += params::degree / params::opt;
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
// none of the twiddles have equal real and imag part, so
|
||||
// complete complex multiplication has to be done
|
||||
// here we have more than one twiddles
|
||||
while (m > 1) {
|
||||
tid = threadIdx.x;
|
||||
//#pragma unroll
|
||||
for (int i = 0; i < params::opt / 2; i++) {
|
||||
// the butterfly pattern is applied to each pair
|
||||
// of coefficients, with a stride of 2048
|
||||
// rem is the remainder of tid/2048
|
||||
// and the same logic as for previous iterations applies
|
||||
int rem = tid & 2047;
|
||||
i1 = (tid << 1) - rem;
|
||||
i2 = i1 + 2048;
|
||||
double2 w = conjugate(negTwids12[rem]);
|
||||
u = A[i1], v = A[i2];
|
||||
A[i1] = (u + v) * 0.5;
|
||||
A[i2] = (u - v) * w * 0.5;
|
||||
m >>= 1;
|
||||
#pragma unroll
|
||||
for (size_t i = 0; i < params::opt / 2; ++i) {
|
||||
twid_id = tid / t;
|
||||
i1 = 2 * t * twid_id + (tid & (t - 1));
|
||||
i2 = i1 + t;
|
||||
w = negtwiddles[twid_id + m];
|
||||
u.x = A[i1].x - A[i2].x;
|
||||
u.y = A[i1].y - A[i2].y;
|
||||
A[i1].x += A[i2].x;
|
||||
A[i1].y += A[i2].y;
|
||||
|
||||
A[i2].x = u.x * w.x + u.y * w.y;
|
||||
A[i2].y = u.y * w.x - u.x * w.y;
|
||||
tid += params::degree / params::opt;
|
||||
}
|
||||
t <<= 1;
|
||||
__syncthreads();
|
||||
}
|
||||
|
||||
if constexpr (params::log2_degree > 10) {
|
||||
// 11_th iteration
|
||||
tid = threadIdx.x;
|
||||
//#pragma unroll
|
||||
for (int i = 0; i < params::opt / 2; i++) {
|
||||
// the butterfly pattern is applied to each pair
|
||||
// of coefficients, with a stride of 1024
|
||||
// rem is the remainder of tid/1024
|
||||
// and the same logic as for previous iterations applies
|
||||
int rem = tid & 1023;
|
||||
i1 = (tid << 1) - rem;
|
||||
i2 = i1 + 1024;
|
||||
double2 w = conjugate(negTwids11[rem]);
|
||||
u = A[i1], v = A[i2];
|
||||
A[i1] = (u + v) * 0.5;
|
||||
A[i2] = (u - v) * w * 0.5;
|
||||
tid += params::degree / params::opt;
|
||||
}
|
||||
__syncthreads();
|
||||
}
|
||||
|
||||
if constexpr (params::log2_degree > 9) {
|
||||
// 10_th iteration
|
||||
tid = threadIdx.x;
|
||||
//#pragma unroll
|
||||
for (int i = 0; i < params::opt / 2; i++) {
|
||||
// the butterfly pattern is applied to each pair
|
||||
// of coefficients, with a stride of 512
|
||||
// rem is the remainder of tid/512
|
||||
// and the same logic as for previous iterations applies
|
||||
int rem = tid & 511;
|
||||
i1 = (tid << 1) - rem;
|
||||
i2 = i1 + 512;
|
||||
double2 w = conjugate(negTwids10[rem]);
|
||||
u = A[i1], v = A[i2];
|
||||
A[i1] = (u + v) * 0.5;
|
||||
A[i2] = (u - v) * w * 0.5;
|
||||
tid += params::degree / params::opt;
|
||||
}
|
||||
__syncthreads();
|
||||
}
|
||||
|
||||
if constexpr (params::log2_degree > 8) {
|
||||
// 9_th iteration
|
||||
tid = threadIdx.x;
|
||||
#pragma unroll
|
||||
for (int i = 0; i < params::opt / 2; i++) {
|
||||
// the butterfly pattern is applied to each pair
|
||||
// of coefficients, with a stride of 256
|
||||
// rem is the remainder of tid/256
|
||||
// and the same logic as for previous iterations applies
|
||||
int rem = tid & 255;
|
||||
i1 = (tid << 1) - rem;
|
||||
i2 = i1 + 256;
|
||||
double2 w = conjugate(negTwids9[rem]);
|
||||
u = A[i1], v = A[i2];
|
||||
A[i1] = (u + v) * 0.5;
|
||||
A[i2] = (u - v) * w * 0.5;
|
||||
tid += params::degree / params::opt;
|
||||
}
|
||||
__syncthreads();
|
||||
}
|
||||
|
||||
// 8_th iteration
|
||||
tid = threadIdx.x;
|
||||
#pragma unroll
|
||||
for (int i = 0; i < params::opt / 2; i++) {
|
||||
// the butterfly pattern is applied to each pair
|
||||
// of coefficients, with a stride of 128
|
||||
// rem is the remainder of tid/128
|
||||
// and the same logic as for previous iterations applies
|
||||
int rem = tid & 127;
|
||||
i1 = (tid << 1) - rem;
|
||||
i2 = i1 + 128;
|
||||
double2 w = conjugate(negTwids8[rem]);
|
||||
u = A[i1], v = A[i2];
|
||||
A[i1] = (u + v) * 0.5;
|
||||
A[i2] = (u - v) * w * 0.5;
|
||||
tid += params::degree / params::opt;
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
// 7_th iteration
|
||||
tid = threadIdx.x;
|
||||
#pragma unroll
|
||||
for (int i = 0; i < params::opt / 2; i++) {
|
||||
// the butterfly pattern is applied to each pair
|
||||
// of coefficients, with a stride of 64
|
||||
// rem is the remainder of tid/64
|
||||
// and the same logic as for previous iterations applies
|
||||
int rem = tid & 63;
|
||||
i1 = (tid << 1) - rem;
|
||||
i2 = i1 + 64;
|
||||
double2 w = conjugate(negTwids7[rem]);
|
||||
u = A[i1], v = A[i2];
|
||||
A[i1] = (u + v) * 0.5;
|
||||
A[i2] = (u - v) * w * 0.5;
|
||||
tid += params::degree / params::opt;
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
// 6_th iteration
|
||||
tid = threadIdx.x;
|
||||
#pragma unroll
|
||||
for (int i = 0; i < params::opt / 2; i++) {
|
||||
// the butterfly pattern is applied to each pair
|
||||
// of coefficients, with a stride of 32
|
||||
// rem is the remainder of tid/32
|
||||
// and the same logic as for previous iterations applies
|
||||
int rem = tid & 31;
|
||||
i1 = (tid << 1) - rem;
|
||||
i2 = i1 + 32;
|
||||
double2 w = conjugate(negTwids6[rem]);
|
||||
u = A[i1], v = A[i2];
|
||||
A[i1] = (u + v) * 0.5;
|
||||
A[i2] = (u - v) * w * 0.5;
|
||||
tid += params::degree / params::opt;
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
// 5_th iteration
|
||||
tid = threadIdx.x;
|
||||
#pragma unroll
|
||||
for (int i = 0; i < params::opt / 2; i++) {
|
||||
// the butterfly pattern is applied to each pair
|
||||
// of coefficients, with a stride of 16
|
||||
// rem is the remainder of tid/16
|
||||
// and the same logic as for previous iterations applies
|
||||
int rem = tid & 15;
|
||||
i1 = (tid << 1) - rem;
|
||||
i2 = i1 + 16;
|
||||
double2 w = conjugate(negTwids5[rem]);
|
||||
u = A[i1], v = A[i2];
|
||||
A[i1] = (u + v) * 0.5;
|
||||
A[i2] = (u - v) * w * 0.5;
|
||||
tid += params::degree / params::opt;
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
// 4_th iteration
|
||||
tid = threadIdx.x;
|
||||
#pragma unroll
|
||||
for (int i = 0; i < params::opt / 2; i++) {
|
||||
// the butterfly pattern is applied to each pair
|
||||
// of coefficients, with a stride of 8
|
||||
// rem is the remainder of tid/8. tid takes values:
|
||||
// 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, ... N/4
|
||||
// then rem takes values:
|
||||
// 0, 1, 2, 3, 4, 5, 6, 7, 0, 1, 2, ... N/4
|
||||
// and striding by 8 will allow us to cover all
|
||||
// the coefficients correctly
|
||||
int rem = tid & 7;
|
||||
i1 = (tid << 1) - rem;
|
||||
i2 = i1 + 8;
|
||||
|
||||
double2 w = conjugate(negTwids4[rem]);
|
||||
u = A[i1], v = A[i2];
|
||||
A[i1] = (u + v) * 0.5;
|
||||
A[i2] = (u - v) * w * 0.5;
|
||||
tid += params::degree / params::opt;
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
// third iteration
|
||||
// from k=3 on, we have to do the full complex multiplication
|
||||
tid = threadIdx.x;
|
||||
#pragma unroll
|
||||
for (int i = 0; i < params::opt / 2; i++) {
|
||||
// the butterfly pattern is applied to each pair
|
||||
// of coefficients, with a stride of 4
|
||||
// rem is the remainder of tid/4. tid takes values:
|
||||
// 0, 1, 2, 3, 4, 5, 6, 7, ... N/4
|
||||
// then rem takes values:
|
||||
// 0, 1, 2, 3, 0, 1, 2, 3, ... N/4
|
||||
// and striding by 4 will allow us to cover all
|
||||
// the coefficients correctly
|
||||
int rem = tid & 3;
|
||||
i1 = (tid << 1) - rem;
|
||||
i2 = i1 + 4;
|
||||
|
||||
double2 w = conjugate(negTwids3[rem]);
|
||||
u = A[i1], v = A[i2];
|
||||
A[i1] = (u + v) * 0.5;
|
||||
A[i2] = (u - v) * w * 0.5;
|
||||
tid += params::degree / params::opt;
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
// second iteration: apply the butterfly pattern
|
||||
// between groups of 4 coefficients
|
||||
// k=2, \zeta=exp(i pi/4) for even coefficients and
|
||||
// exp(3 i pi / 4) for odd coefficients
|
||||
tid = threadIdx.x;
|
||||
// odd = 0 for even coefficients, 1 for odd coefficients
|
||||
int odd = tid & 1;
|
||||
#pragma unroll
|
||||
for (int i = 0; i < params::opt / 2; i++) {
|
||||
// the butterfly pattern is applied to each pair
|
||||
// of coefficients, with a stride of 2
|
||||
// i1=2*tid if tid is even and 2*tid-1 if it is odd
|
||||
i1 = (tid << 1) - odd;
|
||||
i2 = i1 + 2;
|
||||
|
||||
double2 w;
|
||||
if (odd) {
|
||||
w.x = -0.707106781186547461715008466854;
|
||||
w.y = -0.707106781186547572737310929369;
|
||||
} else {
|
||||
w.x = 0.707106781186547461715008466854;
|
||||
w.y = -0.707106781186547572737310929369;
|
||||
}
|
||||
|
||||
u = A[i1], v = A[i2];
|
||||
A[i1] = (u + v) * 0.5;
|
||||
A[i2] = (u - v) * w * 0.5;
|
||||
tid += params::degree / params::opt;
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
// Now we go through all the levels of the FFT one by one
|
||||
// (instead of recursively)
|
||||
// first iteration: k=1, zeta=i for all coefficients
|
||||
tid = threadIdx.x;
|
||||
#pragma unroll
|
||||
for (int i = 0; i < params::opt / 2; i++) {
|
||||
// the butterfly pattern is applied to each pair
|
||||
// of coefficients, with a stride of 2
|
||||
i1 = tid << 1;
|
||||
i2 = i1 + 1;
|
||||
double2 w = {0, -1};
|
||||
u = A[i1], v = A[i2];
|
||||
A[i1] = (u + v) * 0.5;
|
||||
A[i2] = (u - v) * w * 0.5;
|
||||
tid += params::degree / params::opt;
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
bit_reverse_inplace<params>(A);
|
||||
__syncthreads();
|
||||
// Real polynomials handled should not exceed a degree of 8192
|
||||
}
|
||||
|
||||
/*
|
||||
* correction after direct fft
|
||||
* does not use extra shared memory for recovering
|
||||
* correction is done using registers.
|
||||
* based on Pascal's paper
|
||||
*/
|
||||
template <class params>
|
||||
__device__ void correction_direct_fft_inplace(double2 *x) {
|
||||
constexpr int threads = params::degree / params::opt;
|
||||
int tid = threadIdx.x;
|
||||
double2 left[params::opt / 4];
|
||||
double2 right[params::opt / 4];
|
||||
#pragma unroll
|
||||
for (int i = 0; i < params::opt / 4; i++) {
|
||||
left[i] = x[tid + i * threads];
|
||||
}
|
||||
#pragma unroll
|
||||
for (int i = 0; i < params::opt / 4; i++) {
|
||||
right[i] = x[params::degree / 2 - (tid + i * threads + 1)];
|
||||
}
|
||||
#pragma unroll
|
||||
for (int i = 0; i < params::opt / 4; i++) {
|
||||
double2 tw = negacyclic_twiddle<params::degree>(tid + i * threads);
|
||||
double add_RE = left[i].x + right[i].x;
|
||||
double sub_RE = left[i].x - right[i].x;
|
||||
double add_IM = left[i].y + right[i].y;
|
||||
double sub_IM = left[i].y - right[i].y;
|
||||
|
||||
double tmp1 = add_IM * tw.x + sub_RE * tw.y;
|
||||
double tmp2 = -sub_RE * tw.x + add_IM * tw.y;
|
||||
x[tid + i * threads].x = (add_RE + tmp1) * 0.5;
|
||||
x[tid + i * threads].y = (sub_IM + tmp2) * 0.5;
|
||||
x[params::degree / 2 - (tid + i * threads + 1)].x = (add_RE - tmp1) * 0.5;
|
||||
x[params::degree / 2 - (tid + i * threads + 1)].y = (-sub_IM + tmp2) * 0.5;
|
||||
}
|
||||
}
|
||||
|
||||
/*
|
||||
* correction before inverse fft
|
||||
* does not use extra shared memory for recovering
|
||||
* correction is done using registers.
|
||||
* based on Pascal's paper
|
||||
*/
|
||||
template <class params>
|
||||
__device__ void correction_inverse_fft_inplace(double2 *x) {
|
||||
constexpr int threads = params::degree / params::opt;
|
||||
int tid = threadIdx.x;
|
||||
double2 left[params::opt / 4];
|
||||
double2 right[params::opt / 4];
|
||||
#pragma unroll
|
||||
for (int i = 0; i < params::opt / 4; i++) {
|
||||
left[i] = x[tid + i * threads];
|
||||
}
|
||||
#pragma unroll
|
||||
for (int i = 0; i < params::opt / 4; i++) {
|
||||
right[i] = x[params::degree / 2 - (tid + i * threads + 1)];
|
||||
}
|
||||
#pragma unroll
|
||||
for (int i = 0; i < params::opt / 4; i++) {
|
||||
double2 tw = negacyclic_twiddle<params::degree>(tid + i * threads);
|
||||
double add_RE = left[i].x + right[i].x;
|
||||
double sub_RE = left[i].x - right[i].x;
|
||||
double add_IM = left[i].y + right[i].y;
|
||||
double sub_IM = left[i].y - right[i].y;
|
||||
|
||||
double tmp1 = add_IM * tw.x - sub_RE * tw.y;
|
||||
double tmp2 = sub_RE * tw.x + add_IM * tw.y;
|
||||
x[tid + i * threads].x = (add_RE - tmp1) * 0.5;
|
||||
x[tid + i * threads].y = (sub_IM + tmp2) * 0.5;
|
||||
x[params::degree / 2 - (tid + i * threads + 1)].x = (add_RE + tmp1) * 0.5;
|
||||
x[params::degree / 2 - (tid + i * threads + 1)].y = (-sub_IM + tmp2) * 0.5;
|
||||
}
|
||||
}
|
||||
|
||||
/*
|
||||
@@ -735,8 +165,6 @@ __global__ void batch_NSMFFT(double2 *d_input, double2 *d_output,
|
||||
__syncthreads();
|
||||
NSMFFT_direct<HalfDegree<params>>(fft);
|
||||
__syncthreads();
|
||||
correction_direct_fft_inplace<params>(fft);
|
||||
__syncthreads();
|
||||
|
||||
tid = threadIdx.x;
|
||||
#pragma unroll
|
||||
|
||||
12213
src/fft/twiddles.cu
12213
src/fft/twiddles.cu
File diff suppressed because it is too large
Load Diff
@@ -2,19 +2,6 @@
|
||||
#ifndef GPU_BOOTSTRAP_TWIDDLES_CUH
|
||||
#define GPU_BOOTSTRAP_TWIDDLES_CUH
|
||||
|
||||
extern __constant__ short SW1[4096];
|
||||
extern __constant__ short SW2[4096];
|
||||
|
||||
extern __constant__ double2 negTwids3[4];
|
||||
extern __constant__ double2 negTwids4[8];
|
||||
extern __constant__ double2 negTwids5[16];
|
||||
extern __constant__ double2 negTwids6[32];
|
||||
extern __constant__ double2 negTwids7[64];
|
||||
extern __constant__ double2 negTwids8[128];
|
||||
extern __constant__ double2 negTwids9[256];
|
||||
extern __constant__ double2 negTwids10[512];
|
||||
extern __constant__ double2 negTwids11[1024];
|
||||
extern __device__ double2 negTwids12[2048];
|
||||
extern __device__ double2 negTwids13[4096];
|
||||
extern __constant__ double2 negtwiddles[4096];
|
||||
|
||||
#endif
|
||||
|
||||
@@ -170,8 +170,8 @@ __device__ void add_to_torus(double2 *m_values, Torus *result) {
|
||||
Torus V2 = 0;
|
||||
typecast_double_to_torus<Torus>(frac, V2);
|
||||
|
||||
result[tid * 2] += V1;
|
||||
result[tid * 2 + 1] += V2;
|
||||
result[tid] += V1;
|
||||
result[tid + params::degree / 2] += V2;
|
||||
tid = tid + params::degree / params::opt;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -21,17 +21,11 @@ template <class params> __device__ void fft(double2 *output) {
|
||||
// Switch to the FFT space
|
||||
NSMFFT_direct<HalfDegree<params>>(output);
|
||||
synchronize_threads_in_block();
|
||||
|
||||
correction_direct_fft_inplace<params>(output);
|
||||
synchronize_threads_in_block();
|
||||
}
|
||||
|
||||
template <class params> __device__ void ifft_inplace(double2 *data) {
|
||||
synchronize_threads_in_block();
|
||||
|
||||
correction_inverse_fft_inplace<params>(data);
|
||||
synchronize_threads_in_block();
|
||||
|
||||
NSMFFT_inverse<HalfDegree<params>>(data);
|
||||
synchronize_threads_in_block();
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user