From e7e6c8fb53fee14666ef4a950c3ff457da72ae72 Mon Sep 17 00:00:00 2001 From: Pedro Alves Date: Mon, 27 Feb 2023 14:10:03 -0300 Subject: [PATCH] chore(cuda): Reduces shared memory consumption in the amortized PBS and improves loop unrolling. --- src/bootstrap_amortized.cuh | 28 ++++++++++++---------------- src/vertical_packing.cuh | 29 ++++++++++++++++------------- 2 files changed, 28 insertions(+), 29 deletions(-) diff --git a/src/bootstrap_amortized.cuh b/src/bootstrap_amortized.cuh index 1d6872bf1..359ed3035 100644 --- a/src/bootstrap_amortized.cuh +++ b/src/bootstrap_amortized.cuh @@ -124,11 +124,12 @@ __global__ void device_bootstrap_amortized( // The polynomial multiplications happens at the block level // and each thread handles two or more coefficients int pos = threadIdx.x; - for (int j = 0; j < (glwe_dimension + 1) * params::opt / 2; j++) { - res_fft[pos].x = 0; - res_fft[pos].y = 0; - pos += params::degree / params::opt; - } + for (int i = 0; i < (glwe_dimension + 1); i++) + for (int j = 0; j < params::opt / 2; j++) { + res_fft[pos].x = 0; + res_fft[pos].y = 0; + pos += params::degree / params::opt; + } GadgetMatrix gadget(base_log, level_count, accumulator_rotated, glwe_dimension + 1); @@ -136,20 +137,17 @@ __global__ void device_bootstrap_amortized( // coefficients so as to multiply each decomposed level with the // corresponding part of the bootstrapping key for (int level = level_count - 1; level >= 0; level--) { - gadget.decompose_and_compress_next(accumulator_fft); for (int i = 0; i < (glwe_dimension + 1); i++) { - auto accumulator_fft_slice = accumulator_fft + i * params::degree / 2; + gadget.decompose_and_compress_next_polynomial(accumulator_fft, i); // Switch to the FFT space - NSMFFT_direct>(accumulator_fft_slice); - synchronize_threads_in_block(); + NSMFFT_direct>(accumulator_fft); // Get the bootstrapping key piece necessary for the multiplication // It is already in the Fourier domain auto bsk_slice = get_ith_mask_kth_block(bootstrapping_key, iteration, i, level, polynomial_size, glwe_dimension, level_count); - synchronize_threads_in_block(); // Perform the coefficient-wise product with the two pieces of // bootstrapping key @@ -157,7 +155,7 @@ __global__ void device_bootstrap_amortized( auto bsk_poly = bsk_slice + j * params::degree / 2; auto res_fft_poly = res_fft + j * params::degree / 2; polynomial_product_accumulate_in_fourier_domain( - res_fft_poly, accumulator_fft_slice, bsk_poly); + res_fft_poly, accumulator_fft, bsk_poly); } } synchronize_threads_in_block(); @@ -219,9 +217,8 @@ get_buffer_size_full_sm_bootstrap_amortized(uint32_t polynomial_size, uint32_t glwe_dimension) { return sizeof(Torus) * polynomial_size * (glwe_dimension + 1) + // accumulator sizeof(Torus) * polynomial_size * - (glwe_dimension + 1) + // accumulator rotated - sizeof(double2) * polynomial_size / 2 * - (glwe_dimension + 1) + // accumulator fft + (glwe_dimension + 1) + // accumulator rotated + sizeof(double2) * polynomial_size / 2 + // accumulator fft sizeof(double2) * polynomial_size / 2 * (glwe_dimension + 1); // calculate buffer fft } @@ -230,8 +227,7 @@ template __host__ __device__ int get_buffer_size_partial_sm_bootstrap_amortized(uint32_t polynomial_size, uint32_t glwe_dimension) { - return sizeof(double2) * polynomial_size / 2 * - (glwe_dimension + 1); // calculate buffer fft + return sizeof(double2) * polynomial_size / 2; // calculate buffer fft } template diff --git a/src/vertical_packing.cuh b/src/vertical_packing.cuh index c2d002411..360e26637 100644 --- a/src/vertical_packing.cuh +++ b/src/vertical_packing.cuh @@ -75,11 +75,12 @@ cmux(Torus *glwe_array_out, Torus *glwe_array_in, double2 *ggsw_in, // The polynomial multiplications happens at the block level // and each thread handles two or more coefficients int pos = threadIdx.x; - for (int j = 0; j < (glwe_dim + 1) * params::opt / 2; j++) { - res_fft[pos].x = 0; - res_fft[pos].y = 0; - pos += params::degree / params::opt; - } + for (int i = 0; i < (glwe_dim + 1); i++) + for (int j = 0; j < params::opt / 2; j++) { + res_fft[pos].x = 0; + res_fft[pos].y = 0; + pos += params::degree / params::opt; + } synchronize_threads_in_block(); GadgetMatrix gadget(base_log, level_count, glwe_sub, @@ -124,10 +125,11 @@ cmux(Torus *glwe_array_out, Torus *glwe_array_in, double2 *ggsw_in, Torus *mb = &glwe_array_out[output_idx * (glwe_dim + 1) * polynomial_size]; int tid = threadIdx.x; - for (int i = 0; i < (glwe_dim + 1) * params::opt; i++) { - mb[tid] = m0[tid]; - tid += params::degree / params::opt; - } + for (int i = 0; i < (glwe_dim + 1); i++) + for (int j = 0; j < params::opt; j++) { + mb[tid] = m0[tid]; + tid += params::degree / params::opt; + } for (int i = 0; i < (glwe_dim + 1); i++) { auto res_fft_slice = res_fft + i * params::degree / 2; @@ -425,10 +427,11 @@ __global__ void device_blind_rotation_and_sample_extraction( // Input LUT auto mi = &glwe_in[blockIdx.x * (glwe_dim + 1) * polynomial_size]; int tid = threadIdx.x; - for (int i = 0; i < (glwe_dim + 1) * params::opt; i++) { - accumulator_c0[tid] = mi[tid]; - tid += params::degree / params::opt; - } + for (int i = 0; i < (glwe_dim + 1); i++) + for (int j = 0; j < params::opt; j++) { + accumulator_c0[tid] = mi[tid]; + tid += params::degree / params::opt; + } int monomial_degree = 0; for (int i = mbr_size - 1; i >= 0; i--) {