From b0a362af6d8b31e68d06e9e38e40762fbc31df34 Mon Sep 17 00:00:00 2001 From: Pedro Alves Date: Wed, 22 Feb 2023 16:55:14 -0300 Subject: [PATCH] refactor(cuda): Implements support to k>1 on cbs+vp. --- src/circuit_bootstrap.cuh | 52 ++++++++++++++++++++++++--------------- src/vertical_packing.cuh | 40 +++++++++++++++--------------- 2 files changed, 52 insertions(+), 40 deletions(-) diff --git a/src/circuit_bootstrap.cuh b/src/circuit_bootstrap.cuh index 5152df058..84599747e 100644 --- a/src/circuit_bootstrap.cuh +++ b/src/circuit_bootstrap.cuh @@ -45,15 +45,15 @@ __global__ void shift_lwe_cbs(Torus *dst_shift, Torus *src, Torus value, */ template __global__ void fill_lut_body_for_cbs(Torus *lut, uint32_t ciphertext_n_bits, - uint32_t base_log_cbs) { + uint32_t base_log_cbs, + uint32_t glwe_dimension) { - Torus *cur_mask = &lut[blockIdx.x * 2 * params::degree]; - Torus *cur_poly = &lut[blockIdx.x * 2 * params::degree + params::degree]; + Torus *cur_body = &lut[(blockIdx.x * (glwe_dimension + 1) + glwe_dimension) * + params::degree]; size_t tid = threadIdx.x; #pragma unroll for (int i = 0; i < params::opt; i++) { - cur_mask[tid] = 0; - cur_poly[tid] = + cur_body[tid] = 0ll - (1ll << (ciphertext_n_bits - 1 - base_log_cbs * (blockIdx.x + 1))); tid += params::degree / params::opt; @@ -76,22 +76,27 @@ __global__ void fill_lut_body_for_cbs(Torus *lut, uint32_t ciphertext_n_bits, template __global__ void copy_add_lwe_cbs(Torus *lwe_dst, Torus *lwe_src, uint32_t ciphertext_n_bits, - uint32_t base_log_cbs, uint32_t level_cbs) { + uint32_t base_log_cbs, uint32_t level_cbs, + uint32_t glwe_dimension) { size_t tid = threadIdx.x; + size_t src_lwe_id = blockIdx.x / (glwe_dimension + 1); size_t dst_lwe_id = blockIdx.x; - size_t src_lwe_id = dst_lwe_id / 2; size_t cur_cbs_level = src_lwe_id % level_cbs + 1; - auto cur_src = &lwe_src[src_lwe_id * (params::degree + 1)]; - auto cur_dst = &lwe_dst[dst_lwe_id * (params::degree + 1)]; + auto cur_src = &lwe_src[src_lwe_id * (glwe_dimension * params::degree + 1)]; + auto cur_dst = &lwe_dst[dst_lwe_id * (glwe_dimension * params::degree + 1)]; + + auto cur_src_slice = cur_src + blockIdx.y * params::degree; + auto cur_dst_slice = cur_dst + blockIdx.y * params::degree; #pragma unroll for (int i = 0; i < params::opt; i++) { - cur_dst[tid] = cur_src[tid]; + cur_dst_slice[tid] = cur_src_slice[tid]; tid += params::degree / params::opt; } Torus val = 1ll << (ciphertext_n_bits - 1 - base_log_cbs * cur_cbs_level); - if (threadIdx.x == 0) { - cur_dst[params::degree] = cur_src[params::degree] + val; + if (threadIdx.x == 0 && blockIdx.y == 0) { + cur_dst[glwe_dimension * params::degree] = + cur_src[glwe_dimension * params::degree] + val; } } @@ -102,9 +107,10 @@ get_buffer_size_cbs(uint32_t glwe_dimension, uint32_t lwe_dimension, uint32_t number_of_inputs) { return number_of_inputs * level_count_cbs * (glwe_dimension + 1) * - (polynomial_size + 1) * + (glwe_dimension * polynomial_size + 1) * sizeof(Torus) + // lwe_array_in_fp_ks_buffer - number_of_inputs * level_count_cbs * (polynomial_size + 1) * + number_of_inputs * level_count_cbs * + (glwe_dimension * polynomial_size + 1) * sizeof(Torus) + // lwe_array_out_pbs_buffer number_of_inputs * level_count_cbs * (lwe_dimension + 1) * sizeof(Torus) + // lwe_array_in_shifted_buffer @@ -174,7 +180,8 @@ __host__ void host_circuit_bootstrap( sizeof(Torus)); Torus *lwe_array_in_shifted_buffer = lwe_array_out_pbs_buffer + - (ptrdiff_t)(number_of_inputs * level_cbs * (polynomial_size + 1)); + (ptrdiff_t)(number_of_inputs * level_cbs * + (glwe_dimension * polynomial_size + 1)); Torus *lut_vector = lwe_array_in_shifted_buffer + (ptrdiff_t)(number_of_inputs * level_cbs * (lwe_dimension + 1)); @@ -195,9 +202,13 @@ __host__ void host_circuit_bootstrap( // Fill lut (equivalent to trivial encryption as mask is 0s) // The LUT is filled with -alpha in each coefficient where // alpha = 2^{log(q) - 1 - base_log * level} + check_cuda_error(cudaMemsetAsync(lut_vector, 0, + level_cbs * (glwe_dimension + 1) * + polynomial_size * sizeof(Torus), + *stream)); fill_lut_body_for_cbs <<>>( - lut_vector, ciphertext_n_bits, base_log_cbs); + lut_vector, ciphertext_n_bits, base_log_cbs, glwe_dimension); // Applying a negacyclic LUT on a ciphertext with one bit of message in the // MSB and no bit of padding @@ -207,18 +218,19 @@ __host__ void host_circuit_bootstrap( glwe_dimension, lwe_dimension, polynomial_size, base_log_bsk, level_bsk, pbs_count, level_cbs, 0, max_shared_memory); - dim3 copy_grid(pbs_count * (glwe_dimension + 1), 1, 1); + dim3 copy_grid(pbs_count * (glwe_dimension + 1), glwe_dimension, 1); dim3 copy_block(params::degree / params::opt, 1, 1); // Add q/4 to center the error while computing a negacyclic LUT // copy pbs result (glwe_dimension + 1) times to be an input of fp-ks copy_add_lwe_cbs<<>>( lwe_array_in_fp_ks_buffer, lwe_array_out_pbs_buffer, ciphertext_n_bits, - base_log_cbs, level_cbs); + base_log_cbs, level_cbs, glwe_dimension); cuda_fp_keyswitch_lwe_to_glwe( v_stream, gpu_index, ggsw_out, lwe_array_in_fp_ks_buffer, fp_ksk_array, - polynomial_size, glwe_dimension, polynomial_size, base_log_pksk, - level_pksk, pbs_count * (glwe_dimension + 1), glwe_dimension + 1); + glwe_dimension * polynomial_size, glwe_dimension, polynomial_size, + base_log_pksk, level_pksk, pbs_count * (glwe_dimension + 1), + glwe_dimension + 1); } #endif // CBS_CUH diff --git a/src/vertical_packing.cuh b/src/vertical_packing.cuh index a1ce2fc2c..9f55d1b47 100644 --- a/src/vertical_packing.cuh +++ b/src/vertical_packing.cuh @@ -420,14 +420,16 @@ __global__ void device_blind_rotation_and_sample_extraction( selected_memory = &device_mem[blockIdx.x * device_memory_size_per_sample]; Torus *accumulator_c0 = (Torus *)selected_memory; - Torus *accumulator_c1 = (Torus *)accumulator_c0 + 2 * polynomial_size; + Torus *accumulator_c1 = + (Torus *)accumulator_c0 + (glwe_dim + 1) * polynomial_size; + int8_t *cmux_memory = + (int8_t *)(accumulator_c1 + (glwe_dim + 1) * polynomial_size); // Input LUT auto mi = &glwe_in[blockIdx.x * (glwe_dim + 1) * polynomial_size]; int tid = threadIdx.x; - for (int i = 0; i < params::opt; i++) { + for (int i = 0; i < (glwe_dim + 1) * params::opt; i++) { accumulator_c0[tid] = mi[tid]; - accumulator_c0[tid + params::degree] = mi[tid + params::degree]; tid += params::degree / params::opt; } @@ -436,45 +438,43 @@ __global__ void device_blind_rotation_and_sample_extraction( synchronize_threads_in_block(); // Compute x^ai * ACC - // Body + // Mask and Body divide_by_monomial_negacyclic_inplace( - accumulator_c1, accumulator_c0, (1 << monomial_degree), false, 1); - // Mask - divide_by_monomial_negacyclic_inplace( - accumulator_c1 + polynomial_size, accumulator_c0 + polynomial_size, - (1 << monomial_degree), false, 1); + accumulator_c1, accumulator_c0, (1 << monomial_degree), false, + (glwe_dim + 1)); monomial_degree += 1; // ACC = CMUX ( Ci, x^ai * ACC, ACC ) synchronize_threads_in_block(); - cmux( - accumulator_c0, accumulator_c0, ggsw_in, - (int8_t *)(accumulator_c0 + 4 * polynomial_size), 0, 0, 1, glwe_dim, - polynomial_size, base_log, level_count, i); + cmux(accumulator_c0, accumulator_c0, ggsw_in, + cmux_memory, 0, 0, 1, glwe_dim, polynomial_size, + base_log, level_count, i); } synchronize_threads_in_block(); // Write the output - auto block_lwe_out = &lwe_out[blockIdx.x * (polynomial_size + 1)]; + auto block_lwe_out = &lwe_out[blockIdx.x * (glwe_dim * polynomial_size + 1)]; // The blind rotation for this block is over // Now we can perform the sample extraction: for the body it's just // the resulting constant coefficient of the accumulator // For the mask it's more complicated - sample_extract_mask(block_lwe_out, accumulator_c0, 1); - sample_extract_body(block_lwe_out, accumulator_c0, 1); + sample_extract_mask(block_lwe_out, accumulator_c0, glwe_dim); + sample_extract_body(block_lwe_out, accumulator_c0, glwe_dim); } template __host__ __device__ int get_memory_needed_per_block_blind_rotation_sample_extraction( uint32_t glwe_dimension, uint32_t polynomial_size) { - return sizeof(Torus) * polynomial_size * (glwe_dimension+1) + // accumulator_c0 - sizeof(Torus) * polynomial_size * (glwe_dimension+1) + // accumulator_c1 - + get_memory_needed_per_block_cmux_tree(glwe_dimension, polynomial_size); + return sizeof(Torus) * polynomial_size * + (glwe_dimension + 1) + // accumulator_c0 + sizeof(Torus) * polynomial_size * + (glwe_dimension + 1) + // accumulator_c1 + +get_memory_needed_per_block_cmux_tree(glwe_dimension, + polynomial_size); } template