mirror of
https://github.com/zama-ai/tfhe-rs.git
synced 2026-01-10 07:08:03 -05:00
feat(gpu): use warp level optimizations for fft
This commit is contained in:
@@ -63,7 +63,7 @@ template <class params> __device__ void NSMFFT_direct(double2 *A) {
|
||||
}
|
||||
|
||||
Index twiddle_shift = 1;
|
||||
for (Index l = LOG2_DEGREE - 1; l >= 1; --l) {
|
||||
for (Index l = LOG2_DEGREE - 1; l >= 5; --l) {
|
||||
Index lane_mask = 1 << (l - 1);
|
||||
Index thread_mask = (1 << l) - 1;
|
||||
twiddle_shift <<= 1;
|
||||
@@ -96,8 +96,43 @@ template <class params> __device__ void NSMFFT_direct(double2 *A) {
|
||||
tid = tid + STRIDE;
|
||||
}
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
for (Index l = 4; l >= 1; --l) {
|
||||
Index lane_mask = 1 << (l - 1);
|
||||
Index thread_mask = (1 << l) - 1;
|
||||
twiddle_shift <<= 1;
|
||||
|
||||
tid = threadIdx.x;
|
||||
__syncwarp();
|
||||
double2 reg_A[BUTTERFLY_DEPTH];
|
||||
#pragma unroll
|
||||
for (Index i = 0; i < BUTTERFLY_DEPTH; i++) {
|
||||
Index rank = tid & thread_mask;
|
||||
bool u_stays_in_register = rank < lane_mask;
|
||||
reg_A[i] = (u_stays_in_register) ? v[i] : u[i];
|
||||
tid = tid + STRIDE;
|
||||
}
|
||||
__syncwarp();
|
||||
|
||||
tid = threadIdx.x;
|
||||
#pragma unroll
|
||||
for (Index i = 0; i < BUTTERFLY_DEPTH; i++) {
|
||||
Index rank = tid & thread_mask;
|
||||
bool u_stays_in_register = rank < lane_mask;
|
||||
w = shfl_xor_double2(reg_A[i], 1 << (l - 1), 0xFFFFFFFF);
|
||||
u[i] = (u_stays_in_register) ? u[i] : w;
|
||||
v[i] = (u_stays_in_register) ? w : v[i];
|
||||
w = negtwiddles[tid / lane_mask + twiddle_shift];
|
||||
|
||||
w *= v[i];
|
||||
|
||||
v[i] = u[i] - w;
|
||||
u[i] = u[i] + w;
|
||||
tid = tid + STRIDE;
|
||||
}
|
||||
}
|
||||
|
||||
__syncthreads();
|
||||
// store registers in SM
|
||||
tid = threadIdx.x;
|
||||
#pragma unroll
|
||||
@@ -155,7 +190,7 @@ __device__ void NSMFFT_direct_2_2_params(double2 *A, double2 *shared_twiddles) {
|
||||
}
|
||||
|
||||
Index twiddle_shift = 1;
|
||||
for (Index l = LOG2_DEGREE - 1; l >= 1; --l) {
|
||||
for (Index l = LOG2_DEGREE - 1; l >= 5; --l) {
|
||||
Index lane_mask = 1 << (l - 1);
|
||||
Index thread_mask = (1 << l) - 1;
|
||||
twiddle_shift <<= 1;
|
||||
@@ -188,8 +223,43 @@ __device__ void NSMFFT_direct_2_2_params(double2 *A, double2 *shared_twiddles) {
|
||||
tid = tid + STRIDE;
|
||||
}
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
for (Index l = 4; l >= 1; --l) {
|
||||
Index lane_mask = 1 << (l - 1);
|
||||
Index thread_mask = (1 << l) - 1;
|
||||
twiddle_shift <<= 1;
|
||||
|
||||
tid = threadIdx.x;
|
||||
double2 reg_A[BUTTERFLY_DEPTH];
|
||||
__syncwarp();
|
||||
#pragma unroll
|
||||
for (Index i = 0; i < BUTTERFLY_DEPTH; i++) {
|
||||
Index rank = tid & thread_mask;
|
||||
bool u_stays_in_register = rank < lane_mask;
|
||||
reg_A[i] = (u_stays_in_register) ? v[i] : u[i];
|
||||
tid = tid + STRIDE;
|
||||
}
|
||||
__syncwarp();
|
||||
|
||||
tid = threadIdx.x;
|
||||
#pragma unroll
|
||||
for (Index i = 0; i < BUTTERFLY_DEPTH; i++) {
|
||||
Index rank = tid & thread_mask;
|
||||
bool u_stays_in_register = rank < lane_mask;
|
||||
w = shfl_xor_double2(reg_A[i], 1 << (l - 1), 0xFFFFFFFF);
|
||||
u[i] = (u_stays_in_register) ? u[i] : w;
|
||||
v[i] = (u_stays_in_register) ? w : v[i];
|
||||
w = shared_twiddles[tid / lane_mask + twiddle_shift];
|
||||
|
||||
w *= v[i];
|
||||
|
||||
v[i] = u[i] - w;
|
||||
u[i] = u[i] + w;
|
||||
tid = tid + STRIDE;
|
||||
}
|
||||
}
|
||||
|
||||
__syncthreads();
|
||||
// store registers in SM
|
||||
tid = threadIdx.x;
|
||||
#pragma unroll
|
||||
@@ -236,7 +306,46 @@ template <class params> __device__ void NSMFFT_inverse(double2 *A) {
|
||||
}
|
||||
|
||||
Index twiddle_shift = DEGREE;
|
||||
for (Index l = 1; l <= LOG2_DEGREE - 1; ++l) {
|
||||
for (Index l = 1; l <= 4; ++l) {
|
||||
Index lane_mask = 1 << (l - 1);
|
||||
Index thread_mask = (1 << l) - 1;
|
||||
tid = threadIdx.x;
|
||||
twiddle_shift >>= 1;
|
||||
|
||||
// at this point registers are ready for the butterfly
|
||||
tid = threadIdx.x;
|
||||
__syncwarp();
|
||||
double2 reg_A[BUTTERFLY_DEPTH];
|
||||
#pragma unroll
|
||||
for (Index i = 0; i < BUTTERFLY_DEPTH; ++i) {
|
||||
w = (u[i] - v[i]);
|
||||
u[i] += v[i];
|
||||
v[i] = w * conjugate(negtwiddles[tid / lane_mask + twiddle_shift]);
|
||||
|
||||
// keep one of the register for next iteration and store another one in sm
|
||||
Index rank = tid & thread_mask;
|
||||
bool u_stays_in_register = rank < lane_mask;
|
||||
reg_A[i] = (u_stays_in_register) ? v[i] : u[i];
|
||||
|
||||
tid = tid + STRIDE;
|
||||
}
|
||||
__syncwarp();
|
||||
|
||||
// prepare registers for next butterfly iteration
|
||||
tid = threadIdx.x;
|
||||
#pragma unroll
|
||||
for (Index i = 0; i < BUTTERFLY_DEPTH; ++i) {
|
||||
Index rank = tid & thread_mask;
|
||||
bool u_stays_in_register = rank < lane_mask;
|
||||
w = shfl_xor_double2(reg_A[i], 1 << (l - 1), 0xFFFFFFFF);
|
||||
u[i] = (u_stays_in_register) ? u[i] : w;
|
||||
v[i] = (u_stays_in_register) ? w : v[i];
|
||||
|
||||
tid = tid + STRIDE;
|
||||
}
|
||||
}
|
||||
|
||||
for (Index l = 5; l <= LOG2_DEGREE - 1; ++l) {
|
||||
Index lane_mask = 1 << (l - 1);
|
||||
Index thread_mask = (1 << l) - 1;
|
||||
tid = threadIdx.x;
|
||||
@@ -332,7 +441,46 @@ __device__ void NSMFFT_inverse_2_2_params(double2 *A,
|
||||
}
|
||||
|
||||
Index twiddle_shift = DEGREE;
|
||||
for (Index l = 1; l <= LOG2_DEGREE - 1; ++l) {
|
||||
for (Index l = 1; l <= 4; ++l) {
|
||||
Index lane_mask = 1 << (l - 1);
|
||||
Index thread_mask = (1 << l) - 1;
|
||||
tid = threadIdx.x;
|
||||
twiddle_shift >>= 1;
|
||||
|
||||
// at this point registers are ready for the butterfly
|
||||
tid = threadIdx.x;
|
||||
__syncwarp();
|
||||
double2 reg_A[BUTTERFLY_DEPTH];
|
||||
#pragma unroll
|
||||
for (Index i = 0; i < BUTTERFLY_DEPTH; ++i) {
|
||||
w = (u[i] - v[i]);
|
||||
u[i] += v[i];
|
||||
v[i] = w * conjugate(shared_twiddles[tid / lane_mask + twiddle_shift]);
|
||||
|
||||
// keep one of the register for next iteration and store another one in sm
|
||||
Index rank = tid & thread_mask;
|
||||
bool u_stays_in_register = rank < lane_mask;
|
||||
reg_A[i] = (u_stays_in_register) ? v[i] : u[i];
|
||||
|
||||
tid = tid + STRIDE;
|
||||
}
|
||||
__syncwarp();
|
||||
|
||||
// prepare registers for next butterfly iteration
|
||||
tid = threadIdx.x;
|
||||
#pragma unroll
|
||||
for (Index i = 0; i < BUTTERFLY_DEPTH; ++i) {
|
||||
Index rank = tid & thread_mask;
|
||||
bool u_stays_in_register = rank < lane_mask;
|
||||
w = shfl_xor_double2(reg_A[i], 1 << (l - 1), 0xFFFFFFFF);
|
||||
u[i] = (u_stays_in_register) ? u[i] : w;
|
||||
v[i] = (u_stays_in_register) ? w : v[i];
|
||||
|
||||
tid = tid + STRIDE;
|
||||
}
|
||||
}
|
||||
|
||||
for (Index l = 5; l <= LOG2_DEGREE - 1; ++l) {
|
||||
Index lane_mask = 1 << (l - 1);
|
||||
Index thread_mask = (1 << l) - 1;
|
||||
tid = threadIdx.x;
|
||||
|
||||
@@ -415,15 +415,15 @@ uint64_t scratch_cuda_multi_bit_programmable_bootstrap_64(
|
||||
input_lwe_ciphertext_count, glwe_dimension, polynomial_size,
|
||||
level_count, cuda_get_max_shared_memory(gpu_index));
|
||||
|
||||
if (supports_tbc &&
|
||||
!(input_lwe_ciphertext_count > num_sms / 2 && supports_cg))
|
||||
return scratch_cuda_tbc_multi_bit_programmable_bootstrap<uint64_t>(
|
||||
stream, gpu_index, (pbs_buffer<uint64_t, MULTI_BIT> **)buffer,
|
||||
glwe_dimension, polynomial_size, level_count,
|
||||
input_lwe_ciphertext_count, allocate_gpu_memory);
|
||||
else
|
||||
// if (supports_tbc &&
|
||||
// !(input_lwe_ciphertext_count > num_sms / 2 && supports_cg))
|
||||
return scratch_cuda_tbc_multi_bit_programmable_bootstrap<uint64_t>(
|
||||
stream, gpu_index, (pbs_buffer<uint64_t, MULTI_BIT> **)buffer,
|
||||
glwe_dimension, polynomial_size, level_count, input_lwe_ciphertext_count,
|
||||
allocate_gpu_memory);
|
||||
// else
|
||||
#endif
|
||||
if (supports_cg)
|
||||
if (supports_cg)
|
||||
return scratch_cuda_cg_multi_bit_programmable_bootstrap<uint64_t>(
|
||||
stream, gpu_index, (pbs_buffer<uint64_t, MULTI_BIT> **)buffer,
|
||||
glwe_dimension, polynomial_size, level_count,
|
||||
|
||||
@@ -177,6 +177,7 @@ __global__ void device_multi_bit_programmable_bootstrap_keybundle_2_2_params(
|
||||
2 * params::degree] = -1;
|
||||
}
|
||||
|
||||
double2 *shared_fft = (double2 *)(precalc_coefs + polynomial_size * 3);
|
||||
// Ids
|
||||
constexpr uint32_t level_id = 0;
|
||||
uint32_t glwe_id = blockIdx.y / (glwe_dimension + 1);
|
||||
@@ -239,12 +240,10 @@ __global__ void device_multi_bit_programmable_bootstrap_keybundle_2_2_params(
|
||||
polynomial_accumulate_monic_monomial_mul_on_regs_precalc<Torus, params>(
|
||||
reg_acc, bsk_poly, precalc_coefs + jump, monomial_degree);
|
||||
}
|
||||
__syncthreads(); // needed because we are going to reuse the
|
||||
// shared memory for the fft
|
||||
|
||||
// Move from local memory back to shared memory but as complex
|
||||
int tid = threadIdx.x;
|
||||
double2 *fft = (double2 *)selected_memory;
|
||||
double2 *fft = shared_fft;
|
||||
#pragma unroll
|
||||
for (int i = 0; i < params::opt / 2; i++) {
|
||||
fft[tid] =
|
||||
@@ -708,10 +707,10 @@ __host__ void execute_compute_keybundle(
|
||||
cudaFuncCachePreferShared);
|
||||
check_cuda_error(cudaGetLastError());
|
||||
device_multi_bit_programmable_bootstrap_keybundle_2_2_params<
|
||||
Torus, Degree<2048>, FULLSM>
|
||||
<<<grid_keybundle, thds_new_keybundle, 3 * full_sm_keybundle, stream>>>(
|
||||
lwe_array_in, lwe_input_indexes, keybundle_fft, bootstrapping_key,
|
||||
lwe_offset, chunk_size, keybundle_size_per_input);
|
||||
Torus, Degree<2048>, FULLSM><<<grid_keybundle, thds_new_keybundle,
|
||||
3 * full_sm_keybundle, stream>>>(
|
||||
lwe_array_in, lwe_input_indexes, keybundle_fft, bootstrapping_key,
|
||||
lwe_offset, chunk_size, keybundle_size_per_input);
|
||||
} else {
|
||||
device_multi_bit_programmable_bootstrap_keybundle<Torus, params, FULLSM>
|
||||
<<<grid_keybundle, thds, full_sm_keybundle, stream>>>(
|
||||
|
||||
@@ -354,6 +354,9 @@ device_multi_bit_programmable_bootstrap_tbc_accumulate_2_2_params(
|
||||
copy_polynomial_from_regs<Torus, params::opt, params::degree / params::opt>(
|
||||
reg_acc_rotated, global_accumulator_slice);
|
||||
}
|
||||
// Before exiting the kernel we need to sync the cluster to ensure that
|
||||
// that other blocks can still access the dsm in the ping pong buffer
|
||||
cluster.sync();
|
||||
}
|
||||
|
||||
template <typename Torus>
|
||||
|
||||
@@ -74,4 +74,11 @@ __device__ inline double2 operator*(double a, double2 b) {
|
||||
return {__dmul_rn(b.x, a), __dmul_rn(b.y, a)};
|
||||
}
|
||||
|
||||
__device__ inline double2 shfl_xor_double2(double2 val, int laneMask,
|
||||
unsigned mask = 0xFFFFFFFF) {
|
||||
double lo = __shfl_xor_sync(mask, val.x, laneMask);
|
||||
double hi = __shfl_xor_sync(mask, val.y, laneMask);
|
||||
|
||||
return make_double2(lo, hi);
|
||||
}
|
||||
#endif
|
||||
|
||||
Reference in New Issue
Block a user