feat(gpu): use warp level optimizations for fft

This commit is contained in:
Guillermo Oyarzun
2025-07-15 19:01:12 +02:00
parent d741e55218
commit 79d5db66d4
5 changed files with 178 additions and 21 deletions

View File

@@ -63,7 +63,7 @@ template <class params> __device__ void NSMFFT_direct(double2 *A) {
}
Index twiddle_shift = 1;
for (Index l = LOG2_DEGREE - 1; l >= 1; --l) {
for (Index l = LOG2_DEGREE - 1; l >= 5; --l) {
Index lane_mask = 1 << (l - 1);
Index thread_mask = (1 << l) - 1;
twiddle_shift <<= 1;
@@ -96,8 +96,43 @@ template <class params> __device__ void NSMFFT_direct(double2 *A) {
tid = tid + STRIDE;
}
}
__syncthreads();
for (Index l = 4; l >= 1; --l) {
Index lane_mask = 1 << (l - 1);
Index thread_mask = (1 << l) - 1;
twiddle_shift <<= 1;
tid = threadIdx.x;
__syncwarp();
double2 reg_A[BUTTERFLY_DEPTH];
#pragma unroll
for (Index i = 0; i < BUTTERFLY_DEPTH; i++) {
Index rank = tid & thread_mask;
bool u_stays_in_register = rank < lane_mask;
reg_A[i] = (u_stays_in_register) ? v[i] : u[i];
tid = tid + STRIDE;
}
__syncwarp();
tid = threadIdx.x;
#pragma unroll
for (Index i = 0; i < BUTTERFLY_DEPTH; i++) {
Index rank = tid & thread_mask;
bool u_stays_in_register = rank < lane_mask;
w = shfl_xor_double2(reg_A[i], 1 << (l - 1), 0xFFFFFFFF);
u[i] = (u_stays_in_register) ? u[i] : w;
v[i] = (u_stays_in_register) ? w : v[i];
w = negtwiddles[tid / lane_mask + twiddle_shift];
w *= v[i];
v[i] = u[i] - w;
u[i] = u[i] + w;
tid = tid + STRIDE;
}
}
__syncthreads();
// store registers in SM
tid = threadIdx.x;
#pragma unroll
@@ -155,7 +190,7 @@ __device__ void NSMFFT_direct_2_2_params(double2 *A, double2 *shared_twiddles) {
}
Index twiddle_shift = 1;
for (Index l = LOG2_DEGREE - 1; l >= 1; --l) {
for (Index l = LOG2_DEGREE - 1; l >= 5; --l) {
Index lane_mask = 1 << (l - 1);
Index thread_mask = (1 << l) - 1;
twiddle_shift <<= 1;
@@ -188,8 +223,43 @@ __device__ void NSMFFT_direct_2_2_params(double2 *A, double2 *shared_twiddles) {
tid = tid + STRIDE;
}
}
__syncthreads();
for (Index l = 4; l >= 1; --l) {
Index lane_mask = 1 << (l - 1);
Index thread_mask = (1 << l) - 1;
twiddle_shift <<= 1;
tid = threadIdx.x;
double2 reg_A[BUTTERFLY_DEPTH];
__syncwarp();
#pragma unroll
for (Index i = 0; i < BUTTERFLY_DEPTH; i++) {
Index rank = tid & thread_mask;
bool u_stays_in_register = rank < lane_mask;
reg_A[i] = (u_stays_in_register) ? v[i] : u[i];
tid = tid + STRIDE;
}
__syncwarp();
tid = threadIdx.x;
#pragma unroll
for (Index i = 0; i < BUTTERFLY_DEPTH; i++) {
Index rank = tid & thread_mask;
bool u_stays_in_register = rank < lane_mask;
w = shfl_xor_double2(reg_A[i], 1 << (l - 1), 0xFFFFFFFF);
u[i] = (u_stays_in_register) ? u[i] : w;
v[i] = (u_stays_in_register) ? w : v[i];
w = shared_twiddles[tid / lane_mask + twiddle_shift];
w *= v[i];
v[i] = u[i] - w;
u[i] = u[i] + w;
tid = tid + STRIDE;
}
}
__syncthreads();
// store registers in SM
tid = threadIdx.x;
#pragma unroll
@@ -236,7 +306,46 @@ template <class params> __device__ void NSMFFT_inverse(double2 *A) {
}
Index twiddle_shift = DEGREE;
for (Index l = 1; l <= LOG2_DEGREE - 1; ++l) {
for (Index l = 1; l <= 4; ++l) {
Index lane_mask = 1 << (l - 1);
Index thread_mask = (1 << l) - 1;
tid = threadIdx.x;
twiddle_shift >>= 1;
// at this point registers are ready for the butterfly
tid = threadIdx.x;
__syncwarp();
double2 reg_A[BUTTERFLY_DEPTH];
#pragma unroll
for (Index i = 0; i < BUTTERFLY_DEPTH; ++i) {
w = (u[i] - v[i]);
u[i] += v[i];
v[i] = w * conjugate(negtwiddles[tid / lane_mask + twiddle_shift]);
// keep one of the register for next iteration and store another one in sm
Index rank = tid & thread_mask;
bool u_stays_in_register = rank < lane_mask;
reg_A[i] = (u_stays_in_register) ? v[i] : u[i];
tid = tid + STRIDE;
}
__syncwarp();
// prepare registers for next butterfly iteration
tid = threadIdx.x;
#pragma unroll
for (Index i = 0; i < BUTTERFLY_DEPTH; ++i) {
Index rank = tid & thread_mask;
bool u_stays_in_register = rank < lane_mask;
w = shfl_xor_double2(reg_A[i], 1 << (l - 1), 0xFFFFFFFF);
u[i] = (u_stays_in_register) ? u[i] : w;
v[i] = (u_stays_in_register) ? w : v[i];
tid = tid + STRIDE;
}
}
for (Index l = 5; l <= LOG2_DEGREE - 1; ++l) {
Index lane_mask = 1 << (l - 1);
Index thread_mask = (1 << l) - 1;
tid = threadIdx.x;
@@ -332,7 +441,46 @@ __device__ void NSMFFT_inverse_2_2_params(double2 *A,
}
Index twiddle_shift = DEGREE;
for (Index l = 1; l <= LOG2_DEGREE - 1; ++l) {
for (Index l = 1; l <= 4; ++l) {
Index lane_mask = 1 << (l - 1);
Index thread_mask = (1 << l) - 1;
tid = threadIdx.x;
twiddle_shift >>= 1;
// at this point registers are ready for the butterfly
tid = threadIdx.x;
__syncwarp();
double2 reg_A[BUTTERFLY_DEPTH];
#pragma unroll
for (Index i = 0; i < BUTTERFLY_DEPTH; ++i) {
w = (u[i] - v[i]);
u[i] += v[i];
v[i] = w * conjugate(shared_twiddles[tid / lane_mask + twiddle_shift]);
// keep one of the register for next iteration and store another one in sm
Index rank = tid & thread_mask;
bool u_stays_in_register = rank < lane_mask;
reg_A[i] = (u_stays_in_register) ? v[i] : u[i];
tid = tid + STRIDE;
}
__syncwarp();
// prepare registers for next butterfly iteration
tid = threadIdx.x;
#pragma unroll
for (Index i = 0; i < BUTTERFLY_DEPTH; ++i) {
Index rank = tid & thread_mask;
bool u_stays_in_register = rank < lane_mask;
w = shfl_xor_double2(reg_A[i], 1 << (l - 1), 0xFFFFFFFF);
u[i] = (u_stays_in_register) ? u[i] : w;
v[i] = (u_stays_in_register) ? w : v[i];
tid = tid + STRIDE;
}
}
for (Index l = 5; l <= LOG2_DEGREE - 1; ++l) {
Index lane_mask = 1 << (l - 1);
Index thread_mask = (1 << l) - 1;
tid = threadIdx.x;

View File

@@ -415,15 +415,15 @@ uint64_t scratch_cuda_multi_bit_programmable_bootstrap_64(
input_lwe_ciphertext_count, glwe_dimension, polynomial_size,
level_count, cuda_get_max_shared_memory(gpu_index));
if (supports_tbc &&
!(input_lwe_ciphertext_count > num_sms / 2 && supports_cg))
return scratch_cuda_tbc_multi_bit_programmable_bootstrap<uint64_t>(
stream, gpu_index, (pbs_buffer<uint64_t, MULTI_BIT> **)buffer,
glwe_dimension, polynomial_size, level_count,
input_lwe_ciphertext_count, allocate_gpu_memory);
else
// if (supports_tbc &&
// !(input_lwe_ciphertext_count > num_sms / 2 && supports_cg))
return scratch_cuda_tbc_multi_bit_programmable_bootstrap<uint64_t>(
stream, gpu_index, (pbs_buffer<uint64_t, MULTI_BIT> **)buffer,
glwe_dimension, polynomial_size, level_count, input_lwe_ciphertext_count,
allocate_gpu_memory);
// else
#endif
if (supports_cg)
if (supports_cg)
return scratch_cuda_cg_multi_bit_programmable_bootstrap<uint64_t>(
stream, gpu_index, (pbs_buffer<uint64_t, MULTI_BIT> **)buffer,
glwe_dimension, polynomial_size, level_count,

View File

@@ -177,6 +177,7 @@ __global__ void device_multi_bit_programmable_bootstrap_keybundle_2_2_params(
2 * params::degree] = -1;
}
double2 *shared_fft = (double2 *)(precalc_coefs + polynomial_size * 3);
// Ids
constexpr uint32_t level_id = 0;
uint32_t glwe_id = blockIdx.y / (glwe_dimension + 1);
@@ -239,12 +240,10 @@ __global__ void device_multi_bit_programmable_bootstrap_keybundle_2_2_params(
polynomial_accumulate_monic_monomial_mul_on_regs_precalc<Torus, params>(
reg_acc, bsk_poly, precalc_coefs + jump, monomial_degree);
}
__syncthreads(); // needed because we are going to reuse the
// shared memory for the fft
// Move from local memory back to shared memory but as complex
int tid = threadIdx.x;
double2 *fft = (double2 *)selected_memory;
double2 *fft = shared_fft;
#pragma unroll
for (int i = 0; i < params::opt / 2; i++) {
fft[tid] =
@@ -708,10 +707,10 @@ __host__ void execute_compute_keybundle(
cudaFuncCachePreferShared);
check_cuda_error(cudaGetLastError());
device_multi_bit_programmable_bootstrap_keybundle_2_2_params<
Torus, Degree<2048>, FULLSM>
<<<grid_keybundle, thds_new_keybundle, 3 * full_sm_keybundle, stream>>>(
lwe_array_in, lwe_input_indexes, keybundle_fft, bootstrapping_key,
lwe_offset, chunk_size, keybundle_size_per_input);
Torus, Degree<2048>, FULLSM><<<grid_keybundle, thds_new_keybundle,
3 * full_sm_keybundle, stream>>>(
lwe_array_in, lwe_input_indexes, keybundle_fft, bootstrapping_key,
lwe_offset, chunk_size, keybundle_size_per_input);
} else {
device_multi_bit_programmable_bootstrap_keybundle<Torus, params, FULLSM>
<<<grid_keybundle, thds, full_sm_keybundle, stream>>>(

View File

@@ -354,6 +354,9 @@ device_multi_bit_programmable_bootstrap_tbc_accumulate_2_2_params(
copy_polynomial_from_regs<Torus, params::opt, params::degree / params::opt>(
reg_acc_rotated, global_accumulator_slice);
}
// Before exiting the kernel we need to sync the cluster to ensure that
// that other blocks can still access the dsm in the ping pong buffer
cluster.sync();
}
template <typename Torus>

View File

@@ -74,4 +74,11 @@ __device__ inline double2 operator*(double a, double2 b) {
return {__dmul_rn(b.x, a), __dmul_rn(b.y, a)};
}
__device__ inline double2 shfl_xor_double2(double2 val, int laneMask,
unsigned mask = 0xFFFFFFFF) {
double lo = __shfl_xor_sync(mask, val.x, laneMask);
double hi = __shfl_xor_sync(mask, val.y, laneMask);
return make_double2(lo, hi);
}
#endif