From 0aedb1a4f427b11e7f156c867f90a44f6fed348b Mon Sep 17 00:00:00 2001 From: Beka Barbakadze Date: Thu, 3 Nov 2022 14:10:57 +0400 Subject: [PATCH] feat(cuda): Add circuit bootstrap in the cuda backend - Add FP-Keyswitch. - Add entry points for cuda fk ksk in the public API. - Add test for fp_ksk in cuda backend. - Add fixture for bit extract Co-authored-by: agnesLeroy --- include/bootstrap.h | 19 +++++ include/keyswitch.h | 18 +++++ src/bootstrap_wop.cu | 177 ++++++++++++++++++++++++++++++++++++++++++ src/bootstrap_wop.cuh | 132 ++++++++++++++++++++++++++++++- src/keyswitch.cu | 33 ++++++++ src/keyswitch.cuh | 89 +++++++++++++++++++++ 6 files changed, 465 insertions(+), 3 deletions(-) diff --git a/include/bootstrap.h b/include/bootstrap.h index cd76c8ca5..bb9dfd66d 100644 --- a/include/bootstrap.h +++ b/include/bootstrap.h @@ -86,6 +86,25 @@ void cuda_extract_bits_64( uint32_t glwe_dimension, uint32_t base_log_bsk, uint32_t level_count_bsk, uint32_t base_log_ksk, uint32_t level_count_ksk, uint32_t number_of_samples, uint32_t max_shared_memory); +void cuda_circuit_bootstrap_32( + void *v_stream, uint32_t gpu_index, void *ggsw_out, void *lwe_array_in, + void *fourier_bsk, void *fp_ksk_array, void *lwe_array_in_shifted_buffer, + void *lut_vector, void *lut_vector_indexes, void *lwe_array_out_pbs_buffer, + void *lwe_array_in_fp_ks_buffer, uint32_t delta_log, + uint32_t polynomial_size, uint32_t glwe_dimension, uint32_t lwe_dimension, + uint32_t level_bsk, uint32_t base_log_bsk, uint32_t level_pksk, + uint32_t base_log_pksk, uint32_t level_cbs, uint32_t base_log_cbs, + uint32_t number_of_samples, uint32_t max_shared_memory); + +void cuda_circuit_bootstrap_64( + void *v_stream, uint32_t gpu_index, void *ggsw_out, void *lwe_array_in, + void *fourier_bsk, void *fp_ksk_array, void *lwe_array_in_shifted_buffer, + void *lut_vector, void *lut_vector_indexes, void *lwe_array_out_pbs_buffer, + void *lwe_array_in_fp_ks_buffer, uint32_t delta_log, + uint32_t polynomial_size, uint32_t glwe_dimension, uint32_t lwe_dimension, + uint32_t level_bsk, uint32_t base_log_bsk, uint32_t level_pksk, + uint32_t base_log_pksk, uint32_t level_cbs, uint32_t base_log_cbs, + uint32_t number_of_samples, uint32_t max_shared_memory); } #ifdef __CUDACC__ diff --git a/include/keyswitch.h b/include/keyswitch.h index badfa837b..42526040d 100644 --- a/include/keyswitch.h +++ b/include/keyswitch.h @@ -14,6 +14,24 @@ void cuda_keyswitch_lwe_ciphertext_vector_64( void *v_stream, uint32_t gpu_index, void *lwe_array_out, void *lwe_array_in, void *ksk, uint32_t lwe_dimension_in, uint32_t lwe_dimension_out, uint32_t base_log, uint32_t level_count, uint32_t num_samples); + +void cuda_fp_keyswitch_lwe_to_glwe_32(void *v_stream, void *glwe_array_out, + void *lwe_array_in, void *fp_ksk_array, + uint32_t input_lwe_dimension, + uint32_t output_glwe_dimension, + uint32_t output_polynomial_size, + uint32_t base_log, uint32_t level_count, + uint32_t number_of_input_lwe, + uint32_t number_of_keys); + +void cuda_fp_keyswitch_lwe_to_glwe_64(void *v_stream, void *glwe_array_out, + void *lwe_array_in, void *fp_ksk_array, + uint32_t input_lwe_dimension, + uint32_t output_glwe_dimension, + uint32_t output_polynomial_size, + uint32_t base_log, uint32_t level_count, + uint32_t number_of_input_lwe, + uint32_t number_of_keys); } #endif // CNCRT_KS_H_ diff --git a/src/bootstrap_wop.cu b/src/bootstrap_wop.cu index 24d9f6901..cd315647b 100644 --- a/src/bootstrap_wop.cu +++ b/src/bootstrap_wop.cu @@ -343,3 +343,180 @@ void cuda_blind_rotate_and_sample_extraction_64( break; } } + +void cuda_circuit_bootstrap_32( + void *v_stream, uint32_t gpu_index, void *ggsw_out, void *lwe_array_in, + void *fourier_bsk, void *fp_ksk_array, void *lwe_array_in_shifted_buffer, + void *lut_vector, void *lut_vector_indexes, void *lwe_array_out_pbs_buffer, + void *lwe_array_in_fp_ks_buffer, uint32_t delta_log, + uint32_t polynomial_size, uint32_t glwe_dimension, uint32_t lwe_dimension, + uint32_t level_bsk, uint32_t base_log_bsk, uint32_t level_pksk, + uint32_t base_log_pksk, uint32_t level_cbs, uint32_t base_log_cbs, + uint32_t number_of_samples, uint32_t max_shared_memory) { + assert(("Error (GPU circuit bootstrap): glwe_dimension should be equal to 1", + glwe_dimension == 1)); + assert(("Error (GPU circuit bootstrap): polynomial_size should be one of " + "512, 1024, 2048, 4096, 8192", + polynomial_size == 512 || polynomial_size == 1024 || + polynomial_size == 2048 || polynomial_size == 4096 || + polynomial_size == 8192)); + // The number of samples should be lower than the number of streaming + // multiprocessors divided by (4 * (k + 1) * l) (the factor 4 being related + // to the occupancy of 50%). The only supported value for k is 1, so + // k + 1 = 2 for now. + int number_of_sm = 0; + cudaDeviceGetAttribute(&number_of_sm, cudaDevAttrMultiProcessorCount, 0); + assert(("Error (GPU extract bits): the number of input LWEs must be lower or " + "equal to the " + "number of streaming multiprocessors on the device divided by 8 * " + "level_count_bsk", + number_of_samples <= number_of_sm / 4. / 2. / level_bsk)); + switch (polynomial_size) { + case 512: + host_circuit_bootstrap>( + v_stream, gpu_index, (uint32_t *)ggsw_out, (uint32_t *)lwe_array_in, + (double2 *)fourier_bsk, (uint32_t *)fp_ksk_array, + (uint32_t *)lwe_array_in_shifted_buffer, (uint32_t *)lut_vector, + (uint32_t *)lut_vector_indexes, (uint32_t *)lwe_array_out_pbs_buffer, + (uint32_t *)lwe_array_in_fp_ks_buffer, delta_log, polynomial_size, + glwe_dimension, lwe_dimension, level_bsk, base_log_bsk, level_pksk, + base_log_pksk, level_cbs, base_log_cbs, number_of_samples, + max_shared_memory); + break; + case 1024: + host_circuit_bootstrap>( + v_stream, gpu_index, (uint32_t *)ggsw_out, (uint32_t *)lwe_array_in, + (double2 *)fourier_bsk, (uint32_t *)fp_ksk_array, + (uint32_t *)lwe_array_in_shifted_buffer, (uint32_t *)lut_vector, + (uint32_t *)lut_vector_indexes, (uint32_t *)lwe_array_out_pbs_buffer, + (uint32_t *)lwe_array_in_fp_ks_buffer, delta_log, polynomial_size, + glwe_dimension, lwe_dimension, level_bsk, base_log_bsk, level_pksk, + base_log_pksk, level_cbs, base_log_cbs, number_of_samples, + max_shared_memory); + break; + case 2048: + host_circuit_bootstrap>( + v_stream, gpu_index, (uint32_t *)ggsw_out, (uint32_t *)lwe_array_in, + (double2 *)fourier_bsk, (uint32_t *)fp_ksk_array, + (uint32_t *)lwe_array_in_shifted_buffer, (uint32_t *)lut_vector, + (uint32_t *)lut_vector_indexes, (uint32_t *)lwe_array_out_pbs_buffer, + (uint32_t *)lwe_array_in_fp_ks_buffer, delta_log, polynomial_size, + glwe_dimension, lwe_dimension, level_bsk, base_log_bsk, level_pksk, + base_log_pksk, level_cbs, base_log_cbs, number_of_samples, + max_shared_memory); + break; + case 4096: + host_circuit_bootstrap>( + v_stream, gpu_index, (uint32_t *)ggsw_out, (uint32_t *)lwe_array_in, + (double2 *)fourier_bsk, (uint32_t *)fp_ksk_array, + (uint32_t *)lwe_array_in_shifted_buffer, (uint32_t *)lut_vector, + (uint32_t *)lut_vector_indexes, (uint32_t *)lwe_array_out_pbs_buffer, + (uint32_t *)lwe_array_in_fp_ks_buffer, delta_log, polynomial_size, + glwe_dimension, lwe_dimension, level_bsk, base_log_bsk, level_pksk, + base_log_pksk, level_cbs, base_log_cbs, number_of_samples, + max_shared_memory); + break; + case 8192: + host_circuit_bootstrap>( + v_stream, gpu_index, (uint32_t *)ggsw_out, (uint32_t *)lwe_array_in, + (double2 *)fourier_bsk, (uint32_t *)fp_ksk_array, + (uint32_t *)lwe_array_in_shifted_buffer, (uint32_t *)lut_vector, + (uint32_t *)lut_vector_indexes, (uint32_t *)lwe_array_out_pbs_buffer, + (uint32_t *)lwe_array_in_fp_ks_buffer, delta_log, polynomial_size, + glwe_dimension, lwe_dimension, level_bsk, base_log_bsk, level_pksk, + base_log_pksk, level_cbs, base_log_cbs, number_of_samples, + max_shared_memory); + break; + default: + break; + } +} + +void cuda_circuit_bootstrap_64( + void *v_stream, uint32_t gpu_index, void *ggsw_out, void *lwe_array_in, + void *fourier_bsk, void *fp_ksk_array, void *lwe_array_in_shifted_buffer, + void *lut_vector, void *lut_vector_indexes, void *lwe_array_out_pbs_buffer, + void *lwe_array_in_fp_ks_buffer, uint32_t delta_log, + uint32_t polynomial_size, uint32_t glwe_dimension, uint32_t lwe_dimension, + uint32_t level_bsk, uint32_t base_log_bsk, uint32_t level_pksk, + uint32_t base_log_pksk, uint32_t level_cbs, uint32_t base_log_cbs, + uint32_t number_of_samples, uint32_t max_shared_memory) { + assert(("Error (GPU circuit bootstrap): glwe_dimension should be equal to 1", + glwe_dimension == 1)); + assert(("Error (GPU circuit bootstrap): polynomial_size should be one of " + "512, 1024, 2048, 4096, 8192", + polynomial_size == 512 || polynomial_size == 1024 || + polynomial_size == 2048 || polynomial_size == 4096 || + polynomial_size == 8192)); + // The number of samples should be lower than the number of streaming + // multiprocessors divided by (4 * (k + 1) * l) (the factor 4 being related + // to the occupancy of 50%). The only supported value for k is 1, so + // k + 1 = 2 for now. + int number_of_sm = 0; + cudaDeviceGetAttribute(&number_of_sm, cudaDevAttrMultiProcessorCount, 0); + assert(("Error (GPU extract bits): the number of input LWEs must be lower or " + "equal to the " + "number of streaming multiprocessors on the device divided by 8 * " + "level_count_bsk", + number_of_samples <= number_of_sm / 4. / 2. / level_bsk)); + // The number of samples should be lower than the number of streaming + switch (polynomial_size) { + case 512: + host_circuit_bootstrap>( + v_stream, gpu_index, (uint64_t *)ggsw_out, (uint64_t *)lwe_array_in, + (double2 *)fourier_bsk, (uint64_t *)fp_ksk_array, + (uint64_t *)lwe_array_in_shifted_buffer, (uint64_t *)lut_vector, + (uint32_t *)lut_vector_indexes, (uint64_t *)lwe_array_out_pbs_buffer, + (uint64_t *)lwe_array_in_fp_ks_buffer, delta_log, polynomial_size, + glwe_dimension, lwe_dimension, level_bsk, base_log_bsk, level_pksk, + base_log_pksk, level_cbs, base_log_cbs, number_of_samples, + max_shared_memory); + break; + case 1024: + host_circuit_bootstrap>( + v_stream, gpu_index, (uint64_t *)ggsw_out, (uint64_t *)lwe_array_in, + (double2 *)fourier_bsk, (uint64_t *)fp_ksk_array, + (uint64_t *)lwe_array_in_shifted_buffer, (uint64_t *)lut_vector, + (uint32_t *)lut_vector_indexes, (uint64_t *)lwe_array_out_pbs_buffer, + (uint64_t *)lwe_array_in_fp_ks_buffer, delta_log, polynomial_size, + glwe_dimension, lwe_dimension, level_bsk, base_log_bsk, level_pksk, + base_log_pksk, level_cbs, base_log_cbs, number_of_samples, + max_shared_memory); + break; + case 2048: + host_circuit_bootstrap>( + v_stream, gpu_index, (uint64_t *)ggsw_out, (uint64_t *)lwe_array_in, + (double2 *)fourier_bsk, (uint64_t *)fp_ksk_array, + (uint64_t *)lwe_array_in_shifted_buffer, (uint64_t *)lut_vector, + (uint32_t *)lut_vector_indexes, (uint64_t *)lwe_array_out_pbs_buffer, + (uint64_t *)lwe_array_in_fp_ks_buffer, delta_log, polynomial_size, + glwe_dimension, lwe_dimension, level_bsk, base_log_bsk, level_pksk, + base_log_pksk, level_cbs, base_log_cbs, number_of_samples, + max_shared_memory); + break; + case 4096: + host_circuit_bootstrap>( + v_stream, gpu_index, (uint64_t *)ggsw_out, (uint64_t *)lwe_array_in, + (double2 *)fourier_bsk, (uint64_t *)fp_ksk_array, + (uint64_t *)lwe_array_in_shifted_buffer, (uint64_t *)lut_vector, + (uint32_t *)lut_vector_indexes, (uint64_t *)lwe_array_out_pbs_buffer, + (uint64_t *)lwe_array_in_fp_ks_buffer, delta_log, polynomial_size, + glwe_dimension, lwe_dimension, level_bsk, base_log_bsk, level_pksk, + base_log_pksk, level_cbs, base_log_cbs, number_of_samples, + max_shared_memory); + break; + case 8192: + host_circuit_bootstrap>( + v_stream, gpu_index, (uint64_t *)ggsw_out, (uint64_t *)lwe_array_in, + (double2 *)fourier_bsk, (uint64_t *)fp_ksk_array, + (uint64_t *)lwe_array_in_shifted_buffer, (uint64_t *)lut_vector, + (uint32_t *)lut_vector_indexes, (uint64_t *)lwe_array_out_pbs_buffer, + (uint64_t *)lwe_array_in_fp_ks_buffer, delta_log, polynomial_size, + glwe_dimension, lwe_dimension, level_bsk, base_log_bsk, level_pksk, + base_log_pksk, level_cbs, base_log_cbs, number_of_samples, + max_shared_memory); + break; + default: + break; + } +} diff --git a/src/bootstrap_wop.cuh b/src/bootstrap_wop.cuh index d1ccb7105..3a6e25ba1 100644 --- a/src/bootstrap_wop.cuh +++ b/src/bootstrap_wop.cuh @@ -5,6 +5,7 @@ #include "../include/helper_cuda.h" #include "bootstrap.h" +#include "bootstrap_amortized.cuh" #include "bootstrap_low_latency.cuh" #include "complex/operations.cuh" #include "crypto/ggsw.cuh" @@ -377,6 +378,31 @@ __global__ void copy_and_shift_lwe(Torus *dst_copy, Torus *dst_shift, } } +// works for lwe with generic sizes +// shifted_lwe_buffer is scalar multiplication of lwe input +// blockIdx.x refers to input ciphertext id +template +__global__ void shift_lwe_cbs(Torus *dst_shift, Torus *src, Torus value, + size_t lwe_size) { + + size_t blockId = blockIdx.y * gridDim.x + blockIdx.x; + size_t threads_per_block = blockDim.x; + size_t opt = lwe_size / threads_per_block; + size_t rem = lwe_size & (threads_per_block - 1); + + auto cur_dst = &dst_shift[blockId * lwe_size]; + auto cur_src = &src[blockIdx.y * lwe_size]; + + size_t tid = threadIdx.x; + for (size_t i = 0; i < opt; i++) { + cur_dst[tid] = cur_src[tid] * value; + tid += threads_per_block; + } + + if (threadIdx.x < rem) + cur_dst[tid] = cur_src[tid] * value; +} + // only works for small lwe in ks+bs case // function copies lwe when length is not a power of two template @@ -413,11 +439,11 @@ __global__ void add_to_body(Torus *lwe, size_t lwe_dimension, Torus value) { // Fill lut(only body) for the current bit (equivalent to trivial encryption as // mask is 0s) -// The LUT is filled with -alpha in each coefficient where alpha = -// delta*2^{bit_idx-1} +// The LUT is filled with value template __global__ void fill_lut_body_for_current_bit(Torus *lut, Torus value) { - Torus *cur_poly = &lut[params::degree]; + + Torus *cur_poly = &lut[blockIdx.x * 2 * params::degree + params::degree]; size_t tid = threadIdx.x; #pragma unroll for (int i = 0; i < params::opt; i++) { @@ -426,6 +452,46 @@ __global__ void fill_lut_body_for_current_bit(Torus *lut, Torus value) { } } +// Fill lut (equivalent to trivial encryption as mask is 0s) +// The LUT is filled with -alpha in each coefficient where +// alpha = 2^{log(q) - 1 - base_log * level} +template +__global__ void fill_lut_body_for_cbs(Torus *lut, uint32_t ciphertext_n_bits, + uint32_t base_log_cbs) { + + Torus *cur_mask = &lut[blockIdx.x * 2 * params::degree]; + Torus *cur_poly = &lut[blockIdx.x * 2 * params::degree + params::degree]; + size_t tid = threadIdx.x; +#pragma unroll + for (int i = 0; i < params::opt; i++) { + cur_mask[tid] = 0; + cur_poly[tid] = + 0ll - 1ll << (ciphertext_n_bits - 1 - base_log_cbs * (blockIdx.x + 1)); + tid += params::degree / params::opt; + } +} + +template +__global__ void copy_add_lwe_cbs(Torus *lwe_dst, Torus *lwe_src, Torus value) { + + size_t tid = threadIdx.x; + size_t src_lwe_id = blockIdx.y; + size_t dst_lwe_id = blockIdx.x; + + auto cur_src = &lwe_src[src_lwe_id * (params::degree + 1)]; + auto cur_dst = &lwe_dst[dst_lwe_id * (params::degree + 1)]; + +#pragma unroll + for (int i = 0; i < params::opt; i++) { + cur_dst[tid] = cur_src[tid]; + tid += params::degree / params::opt; + } + + if (threadIdx.x == 0) { + cur_dst[params::degree] = cur_src[params::degree] + value; + } +} + // Add alpha where alpha = delta*2^{bit_idx-1} to end up with an encryption of 0 // if the extracted bit was 0 and 1 in the other case // @@ -683,4 +749,64 @@ void host_blind_rotate_and_sample_extraction( cuda_drop_async(d_mem, *stream, gpu_index); } +template +__host__ void host_circuit_bootstrap( + void *v_stream, uint32_t gpu_index, Torus *ggsw_out, Torus *lwe_array_in, + double2 *fourier_bsk, Torus *fp_ksk_array, + Torus *lwe_array_in_shifted_buffer, Torus *lut_vector, + uint32_t *lut_vector_indexes, Torus *lwe_array_out_pbs_buffer, + Torus *lwe_array_in_fp_ks_buffer, uint32_t delta_log, + uint32_t polynomial_size, uint32_t glwe_dimension, uint32_t lwe_dimension, + uint32_t level_bsk, uint32_t base_log_bsk, uint32_t level_pksk, + uint32_t base_log_pksk, uint32_t level_cbs, uint32_t base_log_cbs, + uint32_t number_of_samples, uint32_t max_shared_memory) { + auto stream = static_cast(v_stream); + + uint32_t ciphertext_n_bits = sizeof(Torus) * 8; + uint32_t lwe_size = lwe_dimension + 1; + int pbs_count = number_of_samples * level_cbs; + + dim3 blocks(level_cbs, number_of_samples, 1); + int threads = 256; + + // Shift message LSB on padding bit, at this point we expect to have messages + // with only 1 bit of information + shift_lwe_cbs<<>>( + lwe_array_in_shifted_buffer, lwe_array_in, + 1LL << (ciphertext_n_bits - delta_log - 1), lwe_size); + + // Add q/4 to center the error while computing a negacyclic LUT + add_to_body + <<>>(lwe_array_in_shifted_buffer, lwe_dimension, + 1ll << (ciphertext_n_bits - 2)); + // Fill lut (equivalent to trivial encryption as mask is 0s) + // The LUT is filled with -alpha in each coefficient where + // alpha = 2^{log(q) - 1 - base_log * level} + fill_lut_body_for_cbs + <<>>( + lut_vector, ciphertext_n_bits, base_log_cbs); + + // Applying a negacyclic LUT on a ciphertext with one bit of message in the + // MSB and no bit of padding + host_bootstrap_amortized( + v_stream, gpu_index, lwe_array_out_pbs_buffer, lut_vector, + lut_vector_indexes, lwe_array_in_shifted_buffer, fourier_bsk, + lwe_dimension, polynomial_size, base_log_bsk, level_bsk, pbs_count, + level_cbs, 0, max_shared_memory); + + dim3 copy_grid(pbs_count * (glwe_dimension + 1), pbs_count, 1); + dim3 copy_block(params::degree / params::opt, 1, 1); + + // Add q/4 to center the error while computing a negacyclic LUT + // copy pbs result (glwe_dimension + 1) times to be an input of fp-ks + copy_add_lwe_cbs<<>>( + lwe_array_in_fp_ks_buffer, lwe_array_out_pbs_buffer, + 1ll << (ciphertext_n_bits - 1 - base_log_cbs * level_cbs)); + + cuda_fp_keyswitch_lwe_to_glwe( + v_stream, ggsw_out, lwe_array_in_fp_ks_buffer, fp_ksk_array, + polynomial_size, glwe_dimension, polynomial_size, base_log_pksk, + level_pksk, pbs_count * (glwe_dimension + 1), glwe_dimension + 1); +} + #endif // WO_PBS_H diff --git a/src/keyswitch.cu b/src/keyswitch.cu index a4ff04ab4..c9b9377ac 100644 --- a/src/keyswitch.cu +++ b/src/keyswitch.cu @@ -43,3 +43,36 @@ void cuda_keyswitch_lwe_ciphertext_vector_64( static_cast(lwe_array_in), static_cast(ksk), lwe_dimension_in, lwe_dimension_out, base_log, level_count, num_samples); } + +void cuda_fp_keyswitch_lwe_to_glwe_32(void *v_stream, void *glwe_array_out, + void *lwe_array_in, void *fp_ksk_array, + uint32_t input_lwe_dimension, + uint32_t output_glwe_dimension, + uint32_t output_polynomial_size, + uint32_t base_log, uint32_t level_count, + uint32_t number_of_input_lwe, + uint32_t number_of_keys) { + + cuda_fp_keyswitch_lwe_to_glwe( + v_stream, static_cast(glwe_array_out), + static_cast(lwe_array_in), + static_cast(fp_ksk_array), input_lwe_dimension, + output_glwe_dimension, output_polynomial_size, base_log, level_count, + number_of_input_lwe, number_of_keys); +} +void cuda_fp_keyswitch_lwe_to_glwe_64(void *v_stream, void *glwe_array_out, + void *lwe_array_in, void *fp_ksk_array, + uint32_t input_lwe_dimension, + uint32_t output_glwe_dimension, + uint32_t output_polynomial_size, + uint32_t base_log, uint32_t level_count, + uint32_t number_of_input_lwe, + uint32_t number_of_keys) { + + cuda_fp_keyswitch_lwe_to_glwe( + v_stream, static_cast(glwe_array_out), + static_cast(lwe_array_in), + static_cast(fp_ksk_array), input_lwe_dimension, + output_glwe_dimension, output_polynomial_size, base_log, level_count, + number_of_input_lwe, number_of_keys); +} diff --git a/src/keyswitch.cuh b/src/keyswitch.cuh index dca98359c..1fc236432 100644 --- a/src/keyswitch.cuh +++ b/src/keyswitch.cuh @@ -17,6 +17,75 @@ __device__ Torus *get_ith_block(Torus *ksk, int i, int level, return ptr; } +// blockIdx.y represents single lwe ciphertext +// blockIdx.x represents chunk of lwe ciphertext, +// chunk_count = glwe_size * polynomial_size / threads. +// each threads will responsible to process only lwe_size times multiplication +template +__global__ void +fp_keyswitch(Torus *glwe_array_out, Torus *lwe_array_in, Torus *fp_ksk_array, + uint32_t lwe_dimension_in, uint32_t glwe_dimension, + uint32_t polynomial_size, uint32_t base_log, uint32_t level_count, + uint32_t number_of_input_lwe, uint32_t number_of_keys) { + size_t tid = threadIdx.x; + + size_t glwe_size = (glwe_dimension + 1); + size_t lwe_size = (lwe_dimension_in + 1); + + // number of coefficients in a single fp-ksk + size_t ksk_size = lwe_size * level_count * glwe_size * polynomial_size; + + // number of coefficients inside fp-ksk block for each lwe_input coefficient + size_t ksk_block_size = glwe_size * polynomial_size * level_count; + + size_t ciphertext_id = blockIdx.y; + // number of coefficients processed inside single block + size_t coef_per_block = blockDim.x; + size_t chunk_id = blockIdx.x; + size_t ksk_id = ciphertext_id % number_of_keys; + + extern __shared__ char sharedmem[]; + + // result accumulator, shared memory is used because of frequent access + Torus *local_glwe_chunk = (Torus *)sharedmem; + + // current input lwe ciphertext + auto cur_input_lwe = &lwe_array_in[ciphertext_id * lwe_size]; + // current output glwe ciphertext + auto cur_output_glwe = + &glwe_array_out[ciphertext_id * glwe_size * polynomial_size]; + // current out glwe chunk, will be processed inside single block + auto cur_glwe_chunk = &cur_output_glwe[chunk_id * coef_per_block]; + + // fp key used for current ciphertext + auto cur_ksk = &fp_ksk_array[ksk_id * ksk_size]; + + // set shared mem accumulator to 0 + local_glwe_chunk[tid] = 0; + + // iterate through each coefficient of input lwe + for (size_t i = 0; i <= lwe_dimension_in; i++) { + Torus a_i = + round_to_closest_multiple(cur_input_lwe[i], base_log, level_count); + + Torus state = a_i >> (sizeof(Torus) * 8 - base_log * level_count); + Torus mod_b_mask = (1ll << base_log) - 1ll; + + // block of key for current lwe coefficient (cur_input_lwe[i]) + auto ksk_block = &cur_ksk[i * ksk_block_size]; + + // iterate through levels, calculating decomposition in reverse order + for (size_t j = 0; j < level_count; j++) { + auto ksk_glwe = + &ksk_block[(level_count - j - 1) * glwe_size * polynomial_size]; + auto ksk_glwe_chunk = &ksk_glwe[chunk_id * coef_per_block]; + Torus decomposed = decompose_one(state, mod_b_mask, base_log); + local_glwe_chunk[tid] -= decomposed * ksk_glwe_chunk[tid]; + } + } + cur_glwe_chunk[tid] = local_glwe_chunk[tid]; +} + /* * keyswitch kernel * Each thread handles a piece of the following equation: @@ -138,4 +207,24 @@ __host__ void cuda_keyswitch_lwe_ciphertext_vector( cudaStreamSynchronize(*stream); } +template +__host__ void cuda_fp_keyswitch_lwe_to_glwe( + void *v_stream, Torus *glwe_array_out, Torus *lwe_array_in, + Torus *fp_ksk_array, uint32_t lwe_dimension_in, uint32_t glwe_dimension, + uint32_t polynomial_size, uint32_t base_log, uint32_t level_count, + uint32_t number_of_input_lwe, uint32_t number_of_keys) { + int threads = 256; + int glwe_accumulator_size = (glwe_dimension + 1) * polynomial_size; + dim3 blocks(glwe_accumulator_size / threads, number_of_input_lwe, 1); + + int shared_mem = sizeof(Torus) * threads; + auto stream = static_cast(v_stream); + fp_keyswitch<<>>( + glwe_array_out, lwe_array_in, fp_ksk_array, lwe_dimension_in, + glwe_dimension, polynomial_size, base_log, level_count, + number_of_input_lwe, number_of_keys); + + cudaStreamSynchronize(*stream); +} + #endif