chore(cuda): Reduces shared memory consumption in the amortized PBS and improves loop unrolling.

This commit is contained in:
Pedro Alves
2023-02-27 14:10:03 -03:00
committed by Pedro Alves
parent 400786f3f9
commit e7e6c8fb53
2 changed files with 28 additions and 29 deletions

View File

@@ -124,11 +124,12 @@ __global__ void device_bootstrap_amortized(
// The polynomial multiplications happens at the block level
// and each thread handles two or more coefficients
int pos = threadIdx.x;
for (int j = 0; j < (glwe_dimension + 1) * params::opt / 2; j++) {
res_fft[pos].x = 0;
res_fft[pos].y = 0;
pos += params::degree / params::opt;
}
for (int i = 0; i < (glwe_dimension + 1); i++)
for (int j = 0; j < params::opt / 2; j++) {
res_fft[pos].x = 0;
res_fft[pos].y = 0;
pos += params::degree / params::opt;
}
GadgetMatrix<Torus, params> gadget(base_log, level_count,
accumulator_rotated, glwe_dimension + 1);
@@ -136,20 +137,17 @@ __global__ void device_bootstrap_amortized(
// coefficients so as to multiply each decomposed level with the
// corresponding part of the bootstrapping key
for (int level = level_count - 1; level >= 0; level--) {
gadget.decompose_and_compress_next(accumulator_fft);
for (int i = 0; i < (glwe_dimension + 1); i++) {
auto accumulator_fft_slice = accumulator_fft + i * params::degree / 2;
gadget.decompose_and_compress_next_polynomial(accumulator_fft, i);
// Switch to the FFT space
NSMFFT_direct<HalfDegree<params>>(accumulator_fft_slice);
synchronize_threads_in_block();
NSMFFT_direct<HalfDegree<params>>(accumulator_fft);
// Get the bootstrapping key piece necessary for the multiplication
// It is already in the Fourier domain
auto bsk_slice = get_ith_mask_kth_block(bootstrapping_key, iteration, i,
level, polynomial_size,
glwe_dimension, level_count);
synchronize_threads_in_block();
// Perform the coefficient-wise product with the two pieces of
// bootstrapping key
@@ -157,7 +155,7 @@ __global__ void device_bootstrap_amortized(
auto bsk_poly = bsk_slice + j * params::degree / 2;
auto res_fft_poly = res_fft + j * params::degree / 2;
polynomial_product_accumulate_in_fourier_domain<params, double2>(
res_fft_poly, accumulator_fft_slice, bsk_poly);
res_fft_poly, accumulator_fft, bsk_poly);
}
}
synchronize_threads_in_block();
@@ -219,9 +217,8 @@ get_buffer_size_full_sm_bootstrap_amortized(uint32_t polynomial_size,
uint32_t glwe_dimension) {
return sizeof(Torus) * polynomial_size * (glwe_dimension + 1) + // accumulator
sizeof(Torus) * polynomial_size *
(glwe_dimension + 1) + // accumulator rotated
sizeof(double2) * polynomial_size / 2 *
(glwe_dimension + 1) + // accumulator fft
(glwe_dimension + 1) + // accumulator rotated
sizeof(double2) * polynomial_size / 2 + // accumulator fft
sizeof(double2) * polynomial_size / 2 *
(glwe_dimension + 1); // calculate buffer fft
}
@@ -230,8 +227,7 @@ template <typename Torus>
__host__ __device__ int
get_buffer_size_partial_sm_bootstrap_amortized(uint32_t polynomial_size,
uint32_t glwe_dimension) {
return sizeof(double2) * polynomial_size / 2 *
(glwe_dimension + 1); // calculate buffer fft
return sizeof(double2) * polynomial_size / 2; // calculate buffer fft
}
template <typename Torus>

View File

@@ -75,11 +75,12 @@ cmux(Torus *glwe_array_out, Torus *glwe_array_in, double2 *ggsw_in,
// The polynomial multiplications happens at the block level
// and each thread handles two or more coefficients
int pos = threadIdx.x;
for (int j = 0; j < (glwe_dim + 1) * params::opt / 2; j++) {
res_fft[pos].x = 0;
res_fft[pos].y = 0;
pos += params::degree / params::opt;
}
for (int i = 0; i < (glwe_dim + 1); i++)
for (int j = 0; j < params::opt / 2; j++) {
res_fft[pos].x = 0;
res_fft[pos].y = 0;
pos += params::degree / params::opt;
}
synchronize_threads_in_block();
GadgetMatrix<Torus, params> gadget(base_log, level_count, glwe_sub,
@@ -124,10 +125,11 @@ cmux(Torus *glwe_array_out, Torus *glwe_array_in, double2 *ggsw_in,
Torus *mb = &glwe_array_out[output_idx * (glwe_dim + 1) * polynomial_size];
int tid = threadIdx.x;
for (int i = 0; i < (glwe_dim + 1) * params::opt; i++) {
mb[tid] = m0[tid];
tid += params::degree / params::opt;
}
for (int i = 0; i < (glwe_dim + 1); i++)
for (int j = 0; j < params::opt; j++) {
mb[tid] = m0[tid];
tid += params::degree / params::opt;
}
for (int i = 0; i < (glwe_dim + 1); i++) {
auto res_fft_slice = res_fft + i * params::degree / 2;
@@ -425,10 +427,11 @@ __global__ void device_blind_rotation_and_sample_extraction(
// Input LUT
auto mi = &glwe_in[blockIdx.x * (glwe_dim + 1) * polynomial_size];
int tid = threadIdx.x;
for (int i = 0; i < (glwe_dim + 1) * params::opt; i++) {
accumulator_c0[tid] = mi[tid];
tid += params::degree / params::opt;
}
for (int i = 0; i < (glwe_dim + 1); i++)
for (int j = 0; j < params::opt; j++) {
accumulator_c0[tid] = mi[tid];
tid += params::degree / params::opt;
}
int monomial_degree = 0;
for (int i = mbr_size - 1; i >= 0; i--) {