diff --git a/src/bit_extraction.cuh b/src/bit_extraction.cuh index 10f1f86f5..fa0de9659 100644 --- a/src/bit_extraction.cuh +++ b/src/bit_extraction.cuh @@ -135,15 +135,17 @@ get_buffer_size_extract_bits(uint32_t glwe_dimension, uint32_t lwe_dimension, uint32_t polynomial_size, uint32_t number_of_inputs) { - return sizeof(Torus) * number_of_inputs // lut_vector_indexes - + ((glwe_dimension + 1) * polynomial_size) * sizeof(Torus) // lut_pbs - + (glwe_dimension * polynomial_size + 1) * - sizeof(Torus) // lwe_array_in_buffer - + (glwe_dimension * polynomial_size + 1) * - sizeof(Torus) // lwe_array_in_shifted_buffer - + (lwe_dimension + 1) * sizeof(Torus) // lwe_array_out_ks_buffer - + (glwe_dimension * polynomial_size + 1) * - sizeof(Torus); // lwe_array_out_pbs_buffer + int buffer_size = + sizeof(Torus) * number_of_inputs // lut_vector_indexes + + ((glwe_dimension + 1) * polynomial_size) * sizeof(Torus) // lut_pbs + + (glwe_dimension * polynomial_size + 1) * + sizeof(Torus) // lwe_array_in_buffer + + (glwe_dimension * polynomial_size + 1) * + sizeof(Torus) // lwe_array_in_shifted_buffer + + (lwe_dimension + 1) * sizeof(Torus) // lwe_array_out_ks_buffer + + (glwe_dimension * polynomial_size + 1) * + sizeof(Torus); // lwe_array_out_pbs_buffer + return buffer_size + buffer_size % sizeof(double2); } template diff --git a/src/crypto/gadget.cuh b/src/crypto/gadget.cuh index 51fe9d2eb..d22ed6bd8 100644 --- a/src/crypto/gadget.cuh +++ b/src/crypto/gadget.cuh @@ -43,35 +43,45 @@ public: synchronize_threads_in_block(); } + // Decomposes all polynomials at once __device__ void decompose_and_compress_next(double2 *result) { - current_level -= 1; for (int j = 0; j < num_poly; j++) { - int tid = threadIdx.x; - auto state_slice = state + j * params::degree; auto result_slice = result + j * params::degree / 2; - for (int i = 0; i < params::opt / 2; i++) { - T res_re = state_slice[tid] & mask_mod_b; - T res_im = state_slice[tid + params::degree / 2] & mask_mod_b; - state_slice[tid] >>= base_log; - state_slice[tid + params::degree / 2] >>= base_log; - T carry_re = ((res_re - 1ll) | state_slice[tid]) & res_re; - T carry_im = - ((res_im - 1ll) | state_slice[tid + params::degree / 2]) & res_im; - carry_re >>= (base_log - 1); - carry_im >>= (base_log - 1); - state_slice[tid] += carry_re; - state_slice[tid + params::degree / 2] += carry_im; - res_re -= carry_re << base_log; - res_im -= carry_im << base_log; + decompose_and_compress_next_polynomial(result_slice, j); + } + } - result_slice[tid].x = (int32_t)res_re; - result_slice[tid].y = (int32_t)res_im; + // Decomposes a single polynomial + __device__ void decompose_and_compress_next_polynomial(double2 *result, + int j) { + if (j == 0) + current_level -= 1; - tid += params::degree / params::opt; - } + int tid = threadIdx.x; + auto state_slice = state + j * params::degree; + for (int i = 0; i < params::opt / 2; i++) { + T res_re = state_slice[tid] & mask_mod_b; + T res_im = state_slice[tid + params::degree / 2] & mask_mod_b; + state_slice[tid] >>= base_log; + state_slice[tid + params::degree / 2] >>= base_log; + T carry_re = ((res_re - 1ll) | state_slice[tid]) & res_re; + T carry_im = + ((res_im - 1ll) | state_slice[tid + params::degree / 2]) & res_im; + carry_re >>= (base_log - 1); + carry_im >>= (base_log - 1); + state_slice[tid] += carry_re; + state_slice[tid + params::degree / 2] += carry_im; + res_re -= carry_re << base_log; + res_im -= carry_im << base_log; + + result[tid].x = (int32_t)res_re; + result[tid].y = (int32_t)res_im; + + tid += params::degree / params::opt; } synchronize_threads_in_block(); } + __device__ void decompose_and_compress_level(double2 *result, int level) { for (int i = 0; i < level_count - level; i++) decompose_and_compress_next(result); diff --git a/src/vertical_packing.cuh b/src/vertical_packing.cuh index 9f55d1b47..c2d002411 100644 --- a/src/vertical_packing.cuh +++ b/src/vertical_packing.cuh @@ -84,31 +84,29 @@ cmux(Torus *glwe_array_out, Torus *glwe_array_in, double2 *ggsw_in, synchronize_threads_in_block(); GadgetMatrix gadget(base_log, level_count, glwe_sub, glwe_dim + 1); + // Subtract each glwe operand, decompose the resulting // polynomial coefficients to multiply each decomposed level // with the corresponding part of the LUT for (int level = level_count - 1; level >= 0; level--) { // Decomposition - gadget.decompose_and_compress_next(glwe_fft); - synchronize_threads_in_block(); for (int i = 0; i < (glwe_dim + 1); i++) { - auto glwe_fft_slice = glwe_fft + i * params::degree / 2; + gadget.decompose_and_compress_next_polynomial(glwe_fft, i); // First, perform the polynomial multiplication - NSMFFT_direct>(glwe_fft_slice); + NSMFFT_direct>(glwe_fft); // External product and accumulate // Get the piece necessary for the multiplication auto bsk_slice = get_ith_mask_kth_block( ggsw_in, ggsw_idx, i, level, polynomial_size, glwe_dim, level_count); - synchronize_threads_in_block(); // Perform the coefficient-wise product for (int j = 0; j < (glwe_dim + 1); j++) { auto bsk_poly = bsk_slice + j * params::degree / 2; auto res_fft_poly = res_fft + j * params::degree / 2; polynomial_product_accumulate_in_fourier_domain( - res_fft_poly, glwe_fft_slice, bsk_poly); + res_fft_poly, glwe_fft, bsk_poly); } } synchronize_threads_in_block(); @@ -215,9 +213,8 @@ get_memory_needed_per_block_cmux_tree(uint32_t glwe_dimension, uint32_t polynomial_size) { return sizeof(Torus) * polynomial_size * (glwe_dimension + 1) + // glwe_sub sizeof(double2) * polynomial_size / 2 * - (glwe_dimension + 1) + // res_fft - sizeof(double2) * polynomial_size / 2 * - (glwe_dimension + 1); // glwe_fft + (glwe_dimension + 1) + // res_fft + sizeof(double2) * polynomial_size / 2; // glwe_fft } template @@ -538,8 +535,6 @@ __host__ void host_blind_rotate_and_sample_extraction( uint32_t level_count, uint32_t max_shared_memory) { cudaSetDevice(gpu_index); - assert(glwe_dimension == - 1); // For larger k we will need to adjust the mask size auto stream = static_cast(v_stream); int memory_needed_per_block = diff --git a/src/wop_bootstrap.cuh b/src/wop_bootstrap.cuh index 4a3e17cba..e073327ce 100644 --- a/src/wop_bootstrap.cuh +++ b/src/wop_bootstrap.cuh @@ -305,9 +305,9 @@ __host__ void host_wop_pbs( host_extract_bits( v_stream, gpu_index, (Torus *)lwe_array_out_bit_extract, lwe_array_in, bit_extract_buffer, ksk, fourier_bsk, number_of_bits_to_extract, - delta_log, polynomial_size, lwe_dimension, glwe_dimension, - polynomial_size, base_log_bsk, level_count_bsk, base_log_ksk, - level_count_ksk, number_of_inputs, max_shared_memory); + delta_log, glwe_dimension * polynomial_size, lwe_dimension, + glwe_dimension, polynomial_size, base_log_bsk, level_count_bsk, + base_log_ksk, level_count_ksk, number_of_inputs, max_shared_memory); check_cuda_error(cudaGetLastError()); int8_t *cbs_vp_buffer =