mirror of
https://github.com/zama-ai/tfhe-rs.git
synced 2026-01-08 22:28:01 -05:00
feat(gpu): extra optimizations for 2_2 params kernels and bugs fixes
This commit is contained in:
@@ -143,7 +143,6 @@ template <typename T, class params, uint32_t base_log>
|
||||
__device__ void decompose_and_compress_level_2_2_params(double2 *result,
|
||||
T *state) {
|
||||
constexpr T mask_mod_b = (1ll << base_log) - 1ll;
|
||||
uint32_t tid = threadIdx.x;
|
||||
for (int i = 0; i < params::opt / 2; i++) {
|
||||
auto input1 = state[i];
|
||||
auto input2 = state[i + params::opt / 2];
|
||||
@@ -158,20 +157,12 @@ __device__ void decompose_and_compress_level_2_2_params(double2 *result,
|
||||
carry_re >>= (base_log - 1);
|
||||
carry_im >>= (base_log - 1);
|
||||
|
||||
/* We don't need to update the state cause we know we won't use it anymore
|
||||
*in 2_2 params input1 += carry_re; // Update state input2 += carry_im; //
|
||||
*Update state
|
||||
*/
|
||||
|
||||
res_re -= carry_re << base_log;
|
||||
res_im -= carry_im << base_log;
|
||||
|
||||
typecast_torus_to_double(res_re, result[tid].x);
|
||||
typecast_torus_to_double(res_im, result[tid].y);
|
||||
|
||||
tid += params::degree / params::opt;
|
||||
typecast_torus_to_double(res_re, result[i].x);
|
||||
typecast_torus_to_double(res_im, result[i].y);
|
||||
}
|
||||
__syncthreads();
|
||||
}
|
||||
|
||||
template <typename Torus>
|
||||
|
||||
@@ -148,9 +148,11 @@ template <class params> __device__ void NSMFFT_direct(double2 *A) {
|
||||
* negacyclic fft optimized for 2_2 params
|
||||
it uses the twiddles from shared memory for extra performance
|
||||
this is possible cause we know for 2_2 params will have memory available
|
||||
the fft is returned in registers to avoid extra synchronizations
|
||||
*/
|
||||
template <class params>
|
||||
__device__ void NSMFFT_direct_2_2_params(double2 *A, double2 *shared_twiddles) {
|
||||
__device__ void NSMFFT_direct_2_2_params(double2 *A, double2 *fft_out,
|
||||
double2 *shared_twiddles) {
|
||||
|
||||
/* We don't make bit reverse here, since twiddles are already reversed
|
||||
* Each thread is always in charge of "opt/2" pairs of coefficients,
|
||||
@@ -159,7 +161,6 @@ __device__ void NSMFFT_direct_2_2_params(double2 *A, double2 *shared_twiddles) {
|
||||
* full loop, which should increase performance
|
||||
*/
|
||||
|
||||
__syncthreads();
|
||||
constexpr Index BUTTERFLY_DEPTH = params::opt >> 1;
|
||||
constexpr Index LOG2_DEGREE = params::log2_degree;
|
||||
constexpr Index HALF_DEGREE = params::degree >> 1;
|
||||
@@ -168,13 +169,10 @@ __device__ void NSMFFT_direct_2_2_params(double2 *A, double2 *shared_twiddles) {
|
||||
Index tid = threadIdx.x;
|
||||
double2 u[BUTTERFLY_DEPTH], v[BUTTERFLY_DEPTH], w;
|
||||
|
||||
// load into registers
|
||||
#pragma unroll
|
||||
// switch register order
|
||||
for (Index i = 0; i < BUTTERFLY_DEPTH; ++i) {
|
||||
u[i] = A[tid];
|
||||
v[i] = A[tid + HALF_DEGREE];
|
||||
|
||||
tid += STRIDE;
|
||||
u[i] = fft_out[i];
|
||||
v[i] = fft_out[i + params::opt / 2];
|
||||
}
|
||||
|
||||
// level 1
|
||||
@@ -231,21 +229,13 @@ __device__ void NSMFFT_direct_2_2_params(double2 *A, double2 *shared_twiddles) {
|
||||
|
||||
tid = threadIdx.x;
|
||||
double2 reg_A[BUTTERFLY_DEPTH];
|
||||
__syncwarp();
|
||||
#pragma unroll
|
||||
for (Index i = 0; i < BUTTERFLY_DEPTH; i++) {
|
||||
Index rank = tid & thread_mask;
|
||||
bool u_stays_in_register = rank < lane_mask;
|
||||
reg_A[i] = (u_stays_in_register) ? v[i] : u[i];
|
||||
tid = tid + STRIDE;
|
||||
}
|
||||
__syncwarp();
|
||||
|
||||
tid = threadIdx.x;
|
||||
#pragma unroll
|
||||
for (Index i = 0; i < BUTTERFLY_DEPTH; i++) {
|
||||
Index rank = tid & thread_mask;
|
||||
bool u_stays_in_register = rank < lane_mask;
|
||||
reg_A[i] = (u_stays_in_register) ? v[i] : u[i];
|
||||
w = shfl_xor_double2(reg_A[i], 1 << (l - 1), 0xFFFFFFFF);
|
||||
u[i] = (u_stays_in_register) ? u[i] : w;
|
||||
v[i] = (u_stays_in_register) ? w : v[i];
|
||||
@@ -259,16 +249,12 @@ __device__ void NSMFFT_direct_2_2_params(double2 *A, double2 *shared_twiddles) {
|
||||
}
|
||||
}
|
||||
|
||||
__syncthreads();
|
||||
// store registers in SM
|
||||
tid = threadIdx.x;
|
||||
#pragma unroll
|
||||
// Return result in registers, no need to synchronize here
|
||||
// only with we need to use the same shared memory afterwards
|
||||
for (Index i = 0; i < BUTTERFLY_DEPTH; i++) {
|
||||
A[tid * 2] = u[i];
|
||||
A[tid * 2 + 1] = v[i];
|
||||
tid = tid + STRIDE;
|
||||
fft_out[i] = u[i];
|
||||
fft_out[i + params::opt / 2] = v[i];
|
||||
}
|
||||
__syncthreads();
|
||||
}
|
||||
|
||||
/*
|
||||
@@ -406,9 +392,11 @@ template <class params> __device__ void NSMFFT_inverse(double2 *A) {
|
||||
* negacyclic inverse fft optimized for 2_2 params
|
||||
* it uses the twiddles from shared memory for extra performance
|
||||
* this is possible cause we know for 2_2 params will have memory available
|
||||
* the input comes from registers to avoid some synchronizations and shared mem
|
||||
* usage
|
||||
*/
|
||||
template <class params>
|
||||
__device__ void NSMFFT_inverse_2_2_params(double2 *A,
|
||||
__device__ void NSMFFT_inverse_2_2_params(double2 *A, double2 *buffer_regs,
|
||||
double2 *shared_twiddles) {
|
||||
|
||||
/* We don't make bit reverse here, since twiddles are already reversed
|
||||
@@ -418,7 +406,6 @@ __device__ void NSMFFT_inverse_2_2_params(double2 *A,
|
||||
* full loop, which should increase performance
|
||||
*/
|
||||
|
||||
__syncthreads();
|
||||
constexpr Index BUTTERFLY_DEPTH = params::opt >> 1;
|
||||
constexpr Index LOG2_DEGREE = params::log2_degree;
|
||||
constexpr Index DEGREE = params::degree;
|
||||
@@ -429,15 +416,12 @@ __device__ void NSMFFT_inverse_2_2_params(double2 *A,
|
||||
double2 u[BUTTERFLY_DEPTH], v[BUTTERFLY_DEPTH], w;
|
||||
|
||||
// load into registers and divide by compressed polynomial size
|
||||
#pragma unroll
|
||||
for (Index i = 0; i < BUTTERFLY_DEPTH; ++i) {
|
||||
u[i] = A[2 * tid];
|
||||
v[i] = A[2 * tid + 1];
|
||||
u[i] = buffer_regs[i];
|
||||
v[i] = buffer_regs[i + params::opt / 2];
|
||||
|
||||
u[i] /= DEGREE;
|
||||
v[i] /= DEGREE;
|
||||
|
||||
tid += STRIDE;
|
||||
}
|
||||
|
||||
Index twiddle_shift = DEGREE;
|
||||
@@ -449,7 +433,6 @@ __device__ void NSMFFT_inverse_2_2_params(double2 *A,
|
||||
|
||||
// at this point registers are ready for the butterfly
|
||||
tid = threadIdx.x;
|
||||
__syncwarp();
|
||||
double2 reg_A[BUTTERFLY_DEPTH];
|
||||
#pragma unroll
|
||||
for (Index i = 0; i < BUTTERFLY_DEPTH; ++i) {
|
||||
@@ -457,11 +440,6 @@ __device__ void NSMFFT_inverse_2_2_params(double2 *A,
|
||||
u[i] += v[i];
|
||||
v[i] = w * conjugate(shared_twiddles[tid / lane_mask + twiddle_shift]);
|
||||
|
||||
// keep one of the register for next iteration and store another one in sm
|
||||
Index rank = tid & thread_mask;
|
||||
bool u_stays_in_register = rank < lane_mask;
|
||||
reg_A[i] = (u_stays_in_register) ? v[i] : u[i];
|
||||
|
||||
tid = tid + STRIDE;
|
||||
}
|
||||
__syncwarp();
|
||||
@@ -472,6 +450,7 @@ __device__ void NSMFFT_inverse_2_2_params(double2 *A,
|
||||
for (Index i = 0; i < BUTTERFLY_DEPTH; ++i) {
|
||||
Index rank = tid & thread_mask;
|
||||
bool u_stays_in_register = rank < lane_mask;
|
||||
reg_A[i] = (u_stays_in_register) ? v[i] : u[i];
|
||||
w = shfl_xor_double2(reg_A[i], 1 << (l - 1), 0xFFFFFFFF);
|
||||
u[i] = (u_stays_in_register) ? u[i] : w;
|
||||
v[i] = (u_stays_in_register) ? w : v[i];
|
||||
@@ -518,23 +497,15 @@ __device__ void NSMFFT_inverse_2_2_params(double2 *A,
|
||||
}
|
||||
}
|
||||
|
||||
// last iteration
|
||||
// last iteration
|
||||
#pragma unroll
|
||||
for (Index i = 0; i < BUTTERFLY_DEPTH; ++i) {
|
||||
w = (u[i] - v[i]);
|
||||
u[i] = u[i] + v[i];
|
||||
v[i] = w * (double2){0.707106781186547461715008466854,
|
||||
-0.707106781186547461715008466854};
|
||||
buffer_regs[i] = u[i] + v[i];
|
||||
buffer_regs[i + params::opt / 2] =
|
||||
w * (double2){0.707106781186547461715008466854,
|
||||
-0.707106781186547461715008466854};
|
||||
}
|
||||
__syncthreads();
|
||||
// store registers in SM
|
||||
tid = threadIdx.x;
|
||||
#pragma unroll
|
||||
for (Index i = 0; i < BUTTERFLY_DEPTH; i++) {
|
||||
A[tid] = u[i];
|
||||
A[tid + HALF_DEGREE] = v[i];
|
||||
tid = tid + STRIDE;
|
||||
}
|
||||
__syncthreads();
|
||||
}
|
||||
|
||||
/*
|
||||
|
||||
@@ -101,19 +101,19 @@ mul_ggsw_glwe_in_fourier_domain(double2 *fft, double2 *join_buffer,
|
||||
* - Thread blocks at dimension z relates to the decomposition level.
|
||||
* - Thread blocks at dimension y relates to the glwe dimension.
|
||||
* - polynomial_size / params::opt threads are available per block
|
||||
* - local fft is read from registers
|
||||
* To avoid a cluster synchronization the accumulator output is different than
|
||||
* the input, and next iteration are switched to act as a ping pong buffer.
|
||||
*/
|
||||
template <typename G, class params, uint32_t polynomial_size,
|
||||
uint32_t glwe_dimension, uint32_t level_count>
|
||||
__device__ void mul_ggsw_glwe_in_fourier_domain_2_2_params(
|
||||
double2 *fft, double2 *accumulator_out,
|
||||
double2 *fft, double2 *fft_regs, double2 *buffer_regs,
|
||||
const double2 *__restrict__ bootstrapping_key, int iteration, G &group,
|
||||
int this_block_rank) {
|
||||
// Continues multiplying fft by every polynomial in that particular bsk level
|
||||
// Each y-block accumulates in a different polynomial at each iteration
|
||||
// We accumulate in registers to free shared memory
|
||||
double2 buffer_regs[params::opt / 2];
|
||||
// In 2_2 params we only have one level
|
||||
constexpr uint32_t level_id = 0;
|
||||
// The first product doesn't need using dsm
|
||||
@@ -124,7 +124,7 @@ __device__ void mul_ggsw_glwe_in_fourier_domain_2_2_params(
|
||||
auto bsk_poly = bsk_slice + blockIdx.y * polynomial_size / 2;
|
||||
polynomial_product_accumulate_in_fourier_domain_2_2_params<params, double2,
|
||||
true>(
|
||||
buffer_regs, fft, bsk_poly);
|
||||
buffer_regs, fft_regs, bsk_poly);
|
||||
|
||||
// Synchronize to ensure all blocks have written its fft result
|
||||
group.sync();
|
||||
@@ -142,14 +142,8 @@ __device__ void mul_ggsw_glwe_in_fourier_domain_2_2_params(
|
||||
buffer_regs, fft_slice, bsk_poly);
|
||||
|
||||
// We don't need to synchronize here, cause we are going to use a buffer
|
||||
// different than the input In 2_2 params, level_count=1 so we can just copy
|
||||
// the result from the registers into shared without needing to accumulate
|
||||
int tid = threadIdx.x;
|
||||
for (int i = 0; i < params::opt / 2; i++) {
|
||||
accumulator_out[tid] = buffer_regs[i];
|
||||
tid += params::degree / params::opt;
|
||||
}
|
||||
__syncthreads();
|
||||
// different than the input In 2_2 params, level_count=1 so we can just return
|
||||
// the buffer in registers to avoid synchronizations and shared memory usage
|
||||
}
|
||||
|
||||
template <typename Torus>
|
||||
|
||||
@@ -415,15 +415,14 @@ uint64_t scratch_cuda_multi_bit_programmable_bootstrap_64(
|
||||
input_lwe_ciphertext_count, glwe_dimension, polynomial_size,
|
||||
level_count, cuda_get_max_shared_memory(gpu_index));
|
||||
|
||||
// if (supports_tbc &&
|
||||
// !(input_lwe_ciphertext_count > num_sms / 2 && supports_cg))
|
||||
return scratch_cuda_tbc_multi_bit_programmable_bootstrap<uint64_t>(
|
||||
stream, gpu_index, (pbs_buffer<uint64_t, MULTI_BIT> **)buffer,
|
||||
glwe_dimension, polynomial_size, level_count, input_lwe_ciphertext_count,
|
||||
allocate_gpu_memory);
|
||||
// else
|
||||
if (supports_tbc)
|
||||
return scratch_cuda_tbc_multi_bit_programmable_bootstrap<uint64_t>(
|
||||
stream, gpu_index, (pbs_buffer<uint64_t, MULTI_BIT> **)buffer,
|
||||
glwe_dimension, polynomial_size, level_count,
|
||||
input_lwe_ciphertext_count, allocate_gpu_memory);
|
||||
else
|
||||
#endif
|
||||
if (supports_cg)
|
||||
if (supports_cg)
|
||||
return scratch_cuda_cg_multi_bit_programmable_bootstrap<uint64_t>(
|
||||
stream, gpu_index, (pbs_buffer<uint64_t, MULTI_BIT> **)buffer,
|
||||
glwe_dimension, polynomial_size, level_count,
|
||||
@@ -492,6 +491,17 @@ uint32_t get_lwe_chunk_size(uint32_t gpu_index, uint32_t max_num_pbs,
|
||||
int log2_max_num_pbs = log2_int(max_num_pbs);
|
||||
if (log2_max_num_pbs > 13)
|
||||
ith_divisor = log2_max_num_pbs - 11;
|
||||
#else
|
||||
// When having few samples we are interested in using a larger chunksize so
|
||||
// the keybundle can saturate the GPU. To obtain homogeneous waves we use half
|
||||
// of the sms as the chunksize, by doing so we always get a multiple of the
|
||||
// number of sms, removing the tailing effect. We don't divide by 4 because
|
||||
// some flavors of H100 might not have a number of sms divisible by 4. This is
|
||||
// applied only to few number of samples(8) because it can have a negative
|
||||
// effect of over saturation.
|
||||
if (max_num_pbs <= 8) {
|
||||
return num_sms / 2;
|
||||
}
|
||||
#endif
|
||||
|
||||
for (int i = sqrt(x); i >= 1; i--) {
|
||||
|
||||
@@ -132,6 +132,20 @@ __global__ void device_multi_bit_programmable_bootstrap_keybundle(
|
||||
}
|
||||
}
|
||||
|
||||
// Calculates the keybundles for 2_2 params
|
||||
// Lwe Dimension = 920
|
||||
// Polynomial Size = 2048
|
||||
// Grouping factor = 4
|
||||
// Glwe dimension = 1
|
||||
// PBS level = 1
|
||||
// In this initial version everything is hardcoded as constexpr, we
|
||||
// will wrap it up in a nicer/cleaner version in the future.
|
||||
// Additionally, we initialize an int8_t vector with coefficients used in the
|
||||
// monomial multiplication The size of this vector is 3x2048 and the
|
||||
// coefficients are: [0 .. 2047] = -1 [2048 .. 4095] = 1 [4096 .. 6143] = -11
|
||||
// Then we can just calculate the offset needed to apply this coefficients, and
|
||||
// the operation transforms into a pointwise vector multiplication, avoiding to
|
||||
// perform extra instructions other than MADD
|
||||
template <typename Torus, class params, sharedMemDegree SMD>
|
||||
__global__ void device_multi_bit_programmable_bootstrap_keybundle_2_2_params(
|
||||
const Torus *__restrict__ lwe_array_in,
|
||||
@@ -148,11 +162,7 @@ __global__ void device_multi_bit_programmable_bootstrap_keybundle_2_2_params(
|
||||
extern __shared__ int8_t sharedmem[];
|
||||
int8_t *selected_memory;
|
||||
selected_memory = sharedmem;
|
||||
// Some int8_t coefficients of values {-1,1,1} are precalculated to accelerate
|
||||
// the monomial multiplication There is no need of synchronization since the
|
||||
// sync is done after monimial calculation The precalculated coefficients are
|
||||
// stored after the memory reserved for the monomials
|
||||
// uint32_t[1<<grouping_factor]
|
||||
|
||||
int8_t *precalc_coefs =
|
||||
selected_memory + (sizeof(uint32_t) * (1 << grouping_factor));
|
||||
for (int i = 0; i < params::opt; i++) {
|
||||
@@ -164,6 +174,12 @@ __global__ void device_multi_bit_programmable_bootstrap_keybundle_2_2_params(
|
||||
}
|
||||
|
||||
double2 *shared_fft = (double2 *)(precalc_coefs + polynomial_size * 3);
|
||||
double2 *shared_twiddles = shared_fft + (polynomial_size / 2);
|
||||
for (int k = 0; k < params::opt / 2; k++) {
|
||||
shared_twiddles[threadIdx.x + k * (params::degree / params::opt)] =
|
||||
negtwiddles[threadIdx.x + k * (params::degree / params::opt)];
|
||||
}
|
||||
|
||||
// Ids
|
||||
constexpr uint32_t level_id = 0;
|
||||
uint32_t glwe_id = blockIdx.y / (glwe_dimension + 1);
|
||||
@@ -228,31 +244,29 @@ __global__ void device_multi_bit_programmable_bootstrap_keybundle_2_2_params(
|
||||
}
|
||||
|
||||
// Move from local memory back to shared memory but as complex
|
||||
int tid = threadIdx.x;
|
||||
double2 fft_regs[params::opt / 2];
|
||||
double2 *fft = shared_fft;
|
||||
#pragma unroll
|
||||
for (int i = 0; i < params::opt / 2; i++) {
|
||||
fft[tid] =
|
||||
fft_regs[i] =
|
||||
make_double2(__ll2double_rn((int64_t)reg_acc[i]) /
|
||||
(double)std::numeric_limits<Torus>::max(),
|
||||
__ll2double_rn((int64_t)reg_acc[i + params::opt / 2]) /
|
||||
(double)std::numeric_limits<Torus>::max());
|
||||
tid += params::degree / params::opt;
|
||||
}
|
||||
|
||||
NSMFFT_direct<HalfDegree<params>>(fft);
|
||||
NSMFFT_direct_2_2_params<HalfDegree<params>>(fft, fft_regs,
|
||||
shared_twiddles);
|
||||
|
||||
// lwe iteration
|
||||
auto keybundle_out = get_ith_mask_kth_block(
|
||||
keybundle, blockIdx.x % lwe_chunk_size, glwe_id, level_id,
|
||||
polynomial_size, glwe_dimension, level_count);
|
||||
// auto keybundle_out = get_ith_mask_kth_block_2_2_params<Torus,
|
||||
// polynomial_size,glwe_dimension,level_count,level_id>(keybundle,
|
||||
// blockIdx.x % lwe_chunk_size, glwe_id);
|
||||
auto keybundle_poly = keybundle_out + poly_id * params::degree / 2;
|
||||
|
||||
copy_polynomial<double2, params::opt / 2, params::degree / params::opt>(
|
||||
fft, keybundle_poly);
|
||||
auto keybundle_poly = keybundle_out + poly_id * params::degree / 2;
|
||||
copy_polynomial_from_regs<double2, params::opt / 2,
|
||||
params::degree / params::opt>(fft_regs,
|
||||
keybundle_poly);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -679,8 +693,13 @@ __host__ void execute_compute_keybundle(
|
||||
level_count, lwe_offset, chunk_size, keybundle_size_per_input,
|
||||
d_mem, full_sm_keybundle);
|
||||
} else {
|
||||
if (polynomial_size == 2048 && grouping_factor == 4 && level_count == 1 &&
|
||||
glwe_dimension == 1 && lwe_dimension == 920) {
|
||||
bool supports_tbc =
|
||||
has_support_to_cuda_programmable_bootstrap_tbc_multi_bit<uint64_t>(
|
||||
num_samples, glwe_dimension, polynomial_size, level_count,
|
||||
cuda_get_max_shared_memory(gpu_index));
|
||||
|
||||
if (supports_tbc && polynomial_size == 2048 && grouping_factor == 4 &&
|
||||
level_count == 1 && glwe_dimension == 1 && lwe_dimension == 920) {
|
||||
dim3 thds_new_keybundle(512, 1, 1);
|
||||
check_cuda_error(cudaFuncSetAttribute(
|
||||
device_multi_bit_programmable_bootstrap_keybundle_2_2_params<
|
||||
|
||||
@@ -181,18 +181,34 @@ __global__ void __launch_bounds__(params::degree / params::opt)
|
||||
}
|
||||
}
|
||||
|
||||
// Specialized version for the multi-bit bootstrap using 2_2 params:
|
||||
// Polynomial size = 2048
|
||||
// PBS level = 1
|
||||
// Grouping factor = 4
|
||||
// PBS base = 22
|
||||
// Glwe dimension = 1
|
||||
// At the moment everything is hardcoded as constexpr, but later
|
||||
// we will generate a cleaner/nicer way handle it.
|
||||
// Main optimizations:
|
||||
//- Leverage shared memory to reduce one cluster synchronization. A
|
||||
// ping pong buffer is used for that, so everything is synchronized
|
||||
// automatically after 2 iterations
|
||||
//- Move everything to registers to avoid shared memory synchronizations
|
||||
//- Use a register based fft that uses the minimal synchronizations
|
||||
//- Register based fourier domain multiplication. Transfer fft's between blocks
|
||||
// instead of accumulator.
|
||||
template <typename Torus, class params, sharedMemDegree SMD>
|
||||
__global__ void
|
||||
device_multi_bit_programmable_bootstrap_tbc_accumulate_2_2_params(
|
||||
Torus *lwe_array_out, const Torus *__restrict__ lwe_output_indexes,
|
||||
const Torus *__restrict__ lut_vector,
|
||||
const Torus *__restrict__ lut_vector_indexes,
|
||||
const Torus *__restrict__ lwe_array_in,
|
||||
const Torus *__restrict__ lwe_input_indexes,
|
||||
const double2 *__restrict__ keybundle_array, double2 *join_buffer,
|
||||
Torus *global_accumulator, uint32_t lwe_dimension, uint32_t lwe_offset,
|
||||
uint32_t lwe_chunk_size, uint32_t keybundle_size_per_input,
|
||||
uint32_t num_many_lut, uint32_t lut_stride) {
|
||||
__global__ void __launch_bounds__(params::degree / params::opt)
|
||||
device_multi_bit_programmable_bootstrap_tbc_accumulate_2_2_params(
|
||||
Torus *lwe_array_out, const Torus *__restrict__ lwe_output_indexes,
|
||||
const Torus *__restrict__ lut_vector,
|
||||
const Torus *__restrict__ lut_vector_indexes,
|
||||
const Torus *__restrict__ lwe_array_in,
|
||||
const Torus *__restrict__ lwe_input_indexes,
|
||||
const double2 *__restrict__ keybundle_array, double2 *join_buffer,
|
||||
Torus *global_accumulator, uint32_t lwe_dimension, uint32_t lwe_offset,
|
||||
uint32_t lwe_chunk_size, uint32_t keybundle_size_per_input,
|
||||
uint32_t num_many_lut, uint32_t lut_stride) {
|
||||
|
||||
constexpr uint32_t level_count = 1;
|
||||
constexpr uint32_t grouping_factor = 4;
|
||||
@@ -215,9 +231,10 @@ device_multi_bit_programmable_bootstrap_tbc_accumulate_2_2_params(
|
||||
constexpr uint32_t num_buffers_ping_pong = 2;
|
||||
selected_memory += sizeof(Torus) * polynomial_size * num_buffers_ping_pong;
|
||||
|
||||
double2 *accumulator_aux = (double2 *)sharedmem;
|
||||
double2 *accumulator_fft = accumulator_aux + (polynomial_size / 2);
|
||||
double2 *shared_twiddles = accumulator_fft + (polynomial_size / 2);
|
||||
double2 *accumulator_ping = (double2 *)sharedmem;
|
||||
double2 *accumulator_pong = accumulator_ping + (polynomial_size / 2);
|
||||
double2 *shared_twiddles = accumulator_pong + (polynomial_size / 2);
|
||||
double2 *shared_fft = shared_twiddles + (polynomial_size / 2);
|
||||
// accumulator rotated shares the same memory space than the twiddles.
|
||||
// it is only used during the sample extract so it is safe to use it
|
||||
Torus *accumulator_rotated = (Torus *)selected_memory;
|
||||
@@ -271,30 +288,37 @@ device_multi_bit_programmable_bootstrap_tbc_accumulate_2_2_params(
|
||||
reg_acc_rotated);
|
||||
|
||||
// This is the ping pong buffer logic to avoid a cluster synchronization
|
||||
auto accumulator_in = i % 2 ? accumulator_fft : accumulator_aux;
|
||||
auto accumulator_out = i % 2 ? accumulator_aux : accumulator_fft;
|
||||
auto accumulator_fft = i % 2 ? accumulator_ping : accumulator_pong;
|
||||
|
||||
double2 fft_out_regs[params::opt / 2];
|
||||
// Decompose the accumulator. Each block gets one level of the
|
||||
// decomposition, for the mask and the body (so block 0 will have the
|
||||
// accumulator decomposed at level 0, 1 at 1, etc.)
|
||||
decompose_and_compress_level_2_2_params<Torus, params, base_log>(
|
||||
accumulator_in, reg_acc_rotated);
|
||||
fft_out_regs, reg_acc_rotated);
|
||||
|
||||
NSMFFT_direct_2_2_params<HalfDegree<params>>(accumulator_in,
|
||||
NSMFFT_direct_2_2_params<HalfDegree<params>>(shared_fft, fft_out_regs,
|
||||
shared_twiddles);
|
||||
__syncthreads();
|
||||
// we move registers into shared memory to use dsm
|
||||
int tid = threadIdx.x;
|
||||
for (Index k = 0; k < params::opt / 4; k++) {
|
||||
accumulator_fft[tid] = fft_out_regs[k];
|
||||
accumulator_fft[tid + params::degree / 4] =
|
||||
fft_out_regs[k + params::opt / 4];
|
||||
tid = tid + params::degree / params::opt;
|
||||
}
|
||||
|
||||
double2 buffer_regs[params::opt / 2];
|
||||
// Perform G^-1(ACC) * GGSW -> GLWE
|
||||
mul_ggsw_glwe_in_fourier_domain_2_2_params<
|
||||
cluster_group, params, polynomial_size, glwe_dimension, level_count>(
|
||||
accumulator_in, accumulator_out, keybundle, i, cluster,
|
||||
accumulator_fft, fft_out_regs, buffer_regs, keybundle, i, cluster,
|
||||
this_block_rank);
|
||||
|
||||
NSMFFT_inverse_2_2_params<HalfDegree<params>>(accumulator_out,
|
||||
NSMFFT_inverse_2_2_params<HalfDegree<params>>(shared_fft, buffer_regs,
|
||||
shared_twiddles);
|
||||
__syncthreads();
|
||||
|
||||
add_to_torus_2_2_params<Torus, params>(accumulator_out, reg_acc_rotated);
|
||||
add_to_torus_2_2_params<Torus, params>(buffer_regs, reg_acc_rotated);
|
||||
}
|
||||
|
||||
if (lwe_offset + lwe_chunk_size >= (lwe_dimension / grouping_factor)) {
|
||||
@@ -454,7 +478,7 @@ __host__ uint64_t scratch_tbc_multi_bit_programmable_bootstrap(
|
||||
device_multi_bit_programmable_bootstrap_tbc_accumulate_2_2_params<
|
||||
Torus, params, FULLSM>,
|
||||
cudaFuncAttributeMaxDynamicSharedMemorySize,
|
||||
full_sm_tbc_accumulate + minimum_sm_tbc_accumulate));
|
||||
full_sm_tbc_accumulate + 2 * minimum_sm_tbc_accumulate));
|
||||
check_cuda_error(cudaFuncSetAttribute(
|
||||
device_multi_bit_programmable_bootstrap_tbc_accumulate_2_2_params<
|
||||
Torus, params, FULLSM>,
|
||||
@@ -578,10 +602,13 @@ __host__ void execute_tbc_external_product_loop(
|
||||
config.dynamicSmemBytes = full_dm + minimum_dm;
|
||||
if (polynomial_size == 2048 && grouping_factor == 4 && level_count == 1 &&
|
||||
glwe_dimension == 1 && base_log == 22) {
|
||||
|
||||
config.dynamicSmemBytes = full_dm + 2 * minimum_dm;
|
||||
check_cuda_error(cudaFuncSetAttribute(
|
||||
device_multi_bit_programmable_bootstrap_tbc_accumulate_2_2_params<
|
||||
Torus, params, FULLSM>,
|
||||
cudaFuncAttributeMaxDynamicSharedMemorySize, full_dm + minimum_dm));
|
||||
cudaFuncAttributeMaxDynamicSharedMemorySize,
|
||||
full_dm + 2 * minimum_dm));
|
||||
check_cuda_error(cudaFuncSetAttribute(
|
||||
device_multi_bit_programmable_bootstrap_tbc_accumulate_2_2_params<
|
||||
Torus, params, FULLSM>,
|
||||
|
||||
@@ -237,8 +237,8 @@ __device__ void add_to_torus_2_2_params(double2 *m_values, Torus *result) {
|
||||
int tid = threadIdx.x;
|
||||
#pragma unroll
|
||||
for (int i = 0; i < params::opt / 2; i++) {
|
||||
double double_real = m_values[tid].x;
|
||||
double double_imag = m_values[tid].y;
|
||||
double double_real = m_values[i].x;
|
||||
double double_imag = m_values[i].y;
|
||||
|
||||
Torus torus_real = 0;
|
||||
typecast_double_round_to_torus<Torus>(double_real, torus_real);
|
||||
|
||||
@@ -64,6 +64,8 @@ __device__ void polynomial_product_accumulate_in_fourier_domain(
|
||||
// Computes result += first * second
|
||||
// If init_accumulator is set, assumes that result was not initialized and does
|
||||
// that with the outcome of first * second
|
||||
// The result is always in registers and if init_accumulator true
|
||||
// the first is also in registers this is tuned for 2_2 params
|
||||
template <class params, typename T, bool init_accumulator>
|
||||
__device__ void polynomial_product_accumulate_in_fourier_domain_2_2_params(
|
||||
T *__restrict__ result, T *__restrict__ first,
|
||||
@@ -71,12 +73,12 @@ __device__ void polynomial_product_accumulate_in_fourier_domain_2_2_params(
|
||||
int tid = threadIdx.x;
|
||||
if constexpr (init_accumulator) {
|
||||
for (int i = 0; i < params::opt / 2; i++) {
|
||||
result[i] = first[tid] * second[tid];
|
||||
tid += params::degree / params::opt;
|
||||
result[i] = first[i] * __ldg(&second[tid]);
|
||||
tid += (params::degree / params::opt);
|
||||
}
|
||||
} else {
|
||||
for (int i = 0; i < params::opt / 2; i++) {
|
||||
result[i] += first[tid] * second[tid];
|
||||
result[i] += first[tid] * __ldg(&second[tid]);
|
||||
tid += params::degree / params::opt;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -76,9 +76,9 @@ __device__ inline double2 operator*(double a, double2 b) {
|
||||
|
||||
__device__ inline double2 shfl_xor_double2(double2 val, int laneMask,
|
||||
unsigned mask = 0xFFFFFFFF) {
|
||||
double lo = __shfl_xor_sync(mask, val.x, laneMask);
|
||||
double hi = __shfl_xor_sync(mask, val.y, laneMask);
|
||||
double re = __shfl_xor_sync(mask, val.x, laneMask);
|
||||
double im = __shfl_xor_sync(mask, val.y, laneMask);
|
||||
|
||||
return make_double2(lo, hi);
|
||||
return make_double2(re, im);
|
||||
}
|
||||
#endif
|
||||
|
||||
Reference in New Issue
Block a user