mirror of
https://github.com/zama-ai/concrete.git
synced 2026-02-08 19:44:57 -05:00
chore(cuda): rename some variables to match concrete-core notations
- rename l_gadget and stop calling low lat PBS with N too large - rename trlwe and trgsw - rename lwe_mask_size into lwe_dimension - rename lwe_in into lwe_array_in - rename lwe_out into lwe_array_out - rename decomp_level into level - rename lwe_dimension_before/after into lwe_dimension_in/out
This commit is contained in:
@@ -10,91 +10,90 @@ void cuda_initialize_twiddles(uint32_t polynomial_size, uint32_t gpu_index);
|
||||
void cuda_convert_lwe_bootstrap_key_32(void *dest, void *src, void *v_stream,
|
||||
uint32_t gpu_index,
|
||||
uint32_t input_lwe_dim,
|
||||
uint32_t glwe_dim, uint32_t l_gadget,
|
||||
uint32_t glwe_dim, uint32_t level_count,
|
||||
uint32_t polynomial_size);
|
||||
|
||||
void cuda_convert_lwe_bootstrap_key_64(void *dest, void *src, void *v_stream,
|
||||
uint32_t gpu_index,
|
||||
uint32_t input_lwe_dim,
|
||||
uint32_t glwe_dim, uint32_t l_gadget,
|
||||
uint32_t glwe_dim, uint32_t level_count,
|
||||
uint32_t polynomial_size);
|
||||
|
||||
void cuda_bootstrap_amortized_lwe_ciphertext_vector_32(
|
||||
void *v_stream, void *lwe_out, void *test_vector, void *test_vector_indexes,
|
||||
void *lwe_in, void *bootstrapping_key, uint32_t lwe_dimension,
|
||||
uint32_t glwe_dimension, uint32_t polynomial_size, uint32_t base_log,
|
||||
uint32_t l_gadget, uint32_t num_samples, uint32_t num_test_vectors,
|
||||
uint32_t lwe_idx, uint32_t max_shared_memory);
|
||||
void *v_stream, void *lwe_array_out, void *test_vector,
|
||||
void *test_vector_indexes, void *lwe_array_in, void *bootstrapping_key,
|
||||
uint32_t lwe_dimension, uint32_t glwe_dimension, uint32_t polynomial_size,
|
||||
uint32_t base_log, uint32_t level_count, uint32_t num_samples,
|
||||
uint32_t num_test_vectors, uint32_t lwe_idx, uint32_t max_shared_memory);
|
||||
|
||||
void cuda_bootstrap_amortized_lwe_ciphertext_vector_64(
|
||||
void *v_stream, void *lwe_out, void *test_vector, void *test_vector_indexes,
|
||||
void *lwe_in, void *bootstrapping_key, uint32_t lwe_dimension,
|
||||
uint32_t glwe_dimension, uint32_t polynomial_size, uint32_t base_log,
|
||||
uint32_t l_gadget, uint32_t num_samples, uint32_t num_test_vectors,
|
||||
uint32_t lwe_idx, uint32_t max_shared_memory);
|
||||
void *v_stream, void *lwe_array_out, void *test_vector,
|
||||
void *test_vector_indexes, void *lwe_array_in, void *bootstrapping_key,
|
||||
uint32_t lwe_dimension, uint32_t glwe_dimension, uint32_t polynomial_size,
|
||||
uint32_t base_log, uint32_t level_count, uint32_t num_samples,
|
||||
uint32_t num_test_vectors, uint32_t lwe_idx, uint32_t max_shared_memory);
|
||||
|
||||
void cuda_bootstrap_low_latency_lwe_ciphertext_vector_32(
|
||||
void *v_stream, void *lwe_out, void *test_vector, void *test_vector_indexes,
|
||||
void *lwe_in, void *bootstrapping_key, uint32_t lwe_dimension,
|
||||
uint32_t glwe_dimension, uint32_t polynomial_size, uint32_t base_log,
|
||||
uint32_t l_gadget, uint32_t num_samples, uint32_t num_test_vectors,
|
||||
uint32_t lwe_idx, uint32_t max_shared_memory);
|
||||
void *v_stream, void *lwe_array_out, void *test_vector,
|
||||
void *test_vector_indexes, void *lwe_array_in, void *bootstrapping_key,
|
||||
uint32_t lwe_dimension, uint32_t glwe_dimension, uint32_t polynomial_size,
|
||||
uint32_t base_log, uint32_t level_count, uint32_t num_samples,
|
||||
uint32_t num_test_vectors, uint32_t lwe_idx, uint32_t max_shared_memory);
|
||||
|
||||
void cuda_bootstrap_low_latency_lwe_ciphertext_vector_64(
|
||||
void *v_stream, void *lwe_out, void *test_vector, void *test_vector_indexes,
|
||||
void *lwe_in, void *bootstrapping_key, uint32_t lwe_dimension,
|
||||
uint32_t glwe_dimension, uint32_t polynomial_size, uint32_t base_log,
|
||||
uint32_t l_gadget, uint32_t num_samples, uint32_t num_test_vectors,
|
||||
uint32_t lwe_idx, uint32_t max_shared_memory);
|
||||
void *v_stream, void *lwe_array_out, void *test_vector,
|
||||
void *test_vector_indexes, void *lwe_array_in, void *bootstrapping_key,
|
||||
uint32_t lwe_dimension, uint32_t glwe_dimension, uint32_t polynomial_size,
|
||||
uint32_t base_log, uint32_t level_count, uint32_t num_samples,
|
||||
uint32_t num_test_vectors, uint32_t lwe_idx, uint32_t max_shared_memory);
|
||||
|
||||
void cuda_cmux_tree_32(void *v_stream, void *glwe_out, void *ggsw_in,
|
||||
void cuda_cmux_tree_32(void *v_stream, void *glwe_array_out, void *ggsw_in,
|
||||
void *lut_vector, uint32_t glwe_dimension,
|
||||
uint32_t polynomial_size, uint32_t base_log,
|
||||
uint32_t l_gadget, uint32_t r,
|
||||
uint32_t level_count, uint32_t r,
|
||||
uint32_t max_shared_memory);
|
||||
|
||||
void cuda_cmux_tree_64(void *v_stream, void *glwe_out, void *ggsw_in,
|
||||
void cuda_cmux_tree_64(void *v_stream, void *glwe_array_out, void *ggsw_in,
|
||||
void *lut_vector, uint32_t glwe_dimension,
|
||||
uint32_t polynomial_size, uint32_t base_log,
|
||||
uint32_t l_gadget, uint32_t r,
|
||||
uint32_t level_count, uint32_t r,
|
||||
uint32_t max_shared_memory);
|
||||
|
||||
void cuda_extract_bits_32(void *v_stream, void *list_lwe_out, void *lwe_in,
|
||||
void *lwe_in_buffer, void *lwe_in_shifted_buffer,
|
||||
void *lwe_out_ks_buffer, void *lwe_out_pbs_buffer,
|
||||
void *lut_pbs, void *lut_vector_indexes, void *ksk,
|
||||
void *fourier_bsk, uint32_t number_of_bits,
|
||||
uint32_t delta_log, uint32_t lwe_dimension_before,
|
||||
uint32_t lwe_dimension_after, uint32_t glwe_dimension,
|
||||
uint32_t base_log_bsk, uint32_t l_gadget_bsk,
|
||||
uint32_t base_log_ksk, uint32_t l_gadget_ksk,
|
||||
uint32_t number_of_samples);
|
||||
void cuda_extract_bits_32(
|
||||
void *v_stream, void *list_lwe_array_out, void *lwe_array_in,
|
||||
void *lwe_array_in_buffer, void *lwe_array_in_shifted_buffer,
|
||||
void *lwe_array_out_ks_buffer, void *lwe_array_out_pbs_buffer,
|
||||
void *lut_pbs, void *lut_vector_indexes, void *ksk, void *fourier_bsk,
|
||||
uint32_t number_of_bits, uint32_t delta_log, uint32_t lwe_dimension_in,
|
||||
uint32_t lwe_dimension_out, uint32_t glwe_dimension, uint32_t base_log_bsk,
|
||||
uint32_t level_count_bsk, uint32_t base_log_ksk, uint32_t level_count_ksk,
|
||||
uint32_t number_of_samples);
|
||||
|
||||
void cuda_extract_bits_64(void *v_stream, void *list_lwe_out, void *lwe_in,
|
||||
void *lwe_in_buffer, void *lwe_in_shifted_buffer,
|
||||
void *lwe_out_ks_buffer, void *lwe_out_pbs_buffer,
|
||||
void *lut_pbs, void *lut_vector_indexes, void *ksk,
|
||||
void *fourier_bsk, uint32_t number_of_bits,
|
||||
uint32_t delta_log, uint32_t lwe_dimension_before,
|
||||
uint32_t lwe_dimension_after, uint32_t glwe_dimension,
|
||||
uint32_t base_log_bsk, uint32_t l_gadget_bsk,
|
||||
uint32_t base_log_ksk, uint32_t l_gadget_ksk,
|
||||
uint32_t number_of_samples);
|
||||
void cuda_extract_bits_64(
|
||||
void *v_stream, void *list_lwe_array_out, void *lwe_array_in,
|
||||
void *lwe_array_in_buffer, void *lwe_array_in_shifted_buffer,
|
||||
void *lwe_array_out_ks_buffer, void *lwe_array_out_pbs_buffer,
|
||||
void *lut_pbs, void *lut_vector_indexes, void *ksk, void *fourier_bsk,
|
||||
uint32_t number_of_bits, uint32_t delta_log, uint32_t lwe_dimension_in,
|
||||
uint32_t lwe_dimension_out, uint32_t glwe_dimension, uint32_t base_log_bsk,
|
||||
uint32_t level_count_bsk, uint32_t base_log_ksk, uint32_t level_count_ksk,
|
||||
uint32_t number_of_samples);
|
||||
};
|
||||
|
||||
#ifdef __CUDACC__
|
||||
__device__ inline int get_start_ith_ggsw(int i, uint32_t polynomial_size,
|
||||
int glwe_dimension, uint32_t l_gadget);
|
||||
int glwe_dimension,
|
||||
uint32_t level_count);
|
||||
|
||||
template <typename T>
|
||||
__device__ T *get_ith_mask_kth_block(T *ptr, int i, int k, int level,
|
||||
uint32_t polynomial_size,
|
||||
int glwe_dimension, uint32_t l_gadget);
|
||||
int glwe_dimension, uint32_t level_count);
|
||||
|
||||
template <typename T>
|
||||
__device__ T *get_ith_body_kth_block(T *ptr, int i, int k, int level,
|
||||
uint32_t polynomial_size,
|
||||
int glwe_dimension, uint32_t l_gadget);
|
||||
int glwe_dimension, uint32_t level_count);
|
||||
#endif
|
||||
|
||||
#endif // CUDA_BOOTSTRAP_H
|
||||
|
||||
@@ -6,14 +6,14 @@
|
||||
extern "C" {
|
||||
|
||||
void cuda_keyswitch_lwe_ciphertext_vector_32(
|
||||
void *v_stream, void *lwe_out, void *lwe_in, void *ksk,
|
||||
uint32_t lwe_dimension_before, uint32_t lwe_dimension_after,
|
||||
uint32_t base_log, uint32_t l_gadget, uint32_t num_samples);
|
||||
void *v_stream, void *lwe_array_out, void *lwe_array_in, void *ksk,
|
||||
uint32_t lwe_dimension_in, uint32_t lwe_dimension_out, uint32_t base_log,
|
||||
uint32_t level_count, uint32_t num_samples);
|
||||
|
||||
void cuda_keyswitch_lwe_ciphertext_vector_64(
|
||||
void *v_stream, void *lwe_out, void *lwe_in, void *ksk,
|
||||
uint32_t lwe_dimension_before, uint32_t lwe_dimension_after,
|
||||
uint32_t base_log, uint32_t l_gadget, uint32_t num_samples);
|
||||
void *v_stream, void *lwe_array_out, void *lwe_array_in, void *ksk,
|
||||
uint32_t lwe_dimension_in, uint32_t lwe_dimension_out, uint32_t base_log,
|
||||
uint32_t level_count, uint32_t num_samples);
|
||||
}
|
||||
|
||||
#endif // CNCRT_KS_H_
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
|
||||
/* Perform bootstrapping on a batch of input LWE ciphertexts
|
||||
*
|
||||
* - lwe_out: output batch of num_samples bootstrapped ciphertexts c =
|
||||
* - lwe_array_out: output batch of num_samples bootstrapped ciphertexts c =
|
||||
* (a0,..an-1,b) where n is the LWE dimension
|
||||
* - lut_vector: should hold as many test vectors of size polynomial_size
|
||||
* as there are input ciphertexts, but actually holds
|
||||
@@ -10,7 +10,7 @@
|
||||
* - lut_vector_indexes: stores the index corresponding to
|
||||
* which test vector to use for each sample in
|
||||
* lut_vector
|
||||
* - lwe_in: input batch of num_samples LWE ciphertexts, containing n
|
||||
* - lwe_array_in: input batch of num_samples LWE ciphertexts, containing n
|
||||
* mask values + 1 body value
|
||||
* - bootstrapping_key: RGSW encryption of the LWE secret key sk1
|
||||
* under secret key sk2
|
||||
@@ -30,7 +30,7 @@
|
||||
* - polynomial_size: size of the test polynomial (test vector) and size of the
|
||||
* GLWE polynomial (~1024)
|
||||
* - base_log: log base used for the gadget matrix - B = 2^base_log (~8)
|
||||
* - l_gadget: number of decomposition levels in the gadget matrix (~4)
|
||||
* - level_count: number of decomposition levels in the gadget matrix (~4)
|
||||
* - num_samples: number of encrypted input messages
|
||||
* - num_lut_vectors: parameter to set the actual number of test vectors to be
|
||||
* used
|
||||
@@ -44,7 +44,7 @@
|
||||
* to handle one or more polynomial coefficients at each stage:
|
||||
* - perform the blind rotation
|
||||
* - round the result
|
||||
* - decompose into l_gadget levels, then for each level:
|
||||
* - decompose into level_count levels, then for each level:
|
||||
* - switch to the FFT domain
|
||||
* - multiply with the bootstrapping key
|
||||
* - come back to the coefficients representation
|
||||
@@ -58,11 +58,11 @@
|
||||
*/
|
||||
|
||||
void cuda_bootstrap_amortized_lwe_ciphertext_vector_32(
|
||||
void *v_stream, void *lwe_out, void *lut_vector, void *lut_vector_indexes,
|
||||
void *lwe_in, void *bootstrapping_key, uint32_t lwe_dimension,
|
||||
uint32_t glwe_dimension, uint32_t polynomial_size, uint32_t base_log,
|
||||
uint32_t l_gadget, uint32_t num_samples, uint32_t num_lut_vectors,
|
||||
uint32_t lwe_idx, uint32_t max_shared_memory) {
|
||||
void *v_stream, void *lwe_array_out, void *lut_vector,
|
||||
void *lut_vector_indexes, void *lwe_array_in, void *bootstrapping_key,
|
||||
uint32_t lwe_dimension, uint32_t glwe_dimension, uint32_t polynomial_size,
|
||||
uint32_t base_log, uint32_t level_count, uint32_t num_samples,
|
||||
uint32_t num_lut_vectors, uint32_t lwe_idx, uint32_t max_shared_memory) {
|
||||
|
||||
assert(
|
||||
("Error (GPU amortized PBS): base log should be <= 16", base_log <= 16));
|
||||
@@ -77,38 +77,38 @@ void cuda_bootstrap_amortized_lwe_ciphertext_vector_32(
|
||||
switch (polynomial_size) {
|
||||
case 512:
|
||||
host_bootstrap_amortized<uint32_t, Degree<512>>(
|
||||
v_stream, (uint32_t *)lwe_out, (uint32_t *)lut_vector,
|
||||
(uint32_t *)lut_vector_indexes, (uint32_t *)lwe_in,
|
||||
v_stream, (uint32_t *)lwe_array_out, (uint32_t *)lut_vector,
|
||||
(uint32_t *)lut_vector_indexes, (uint32_t *)lwe_array_in,
|
||||
(double2 *)bootstrapping_key, lwe_dimension, polynomial_size, base_log,
|
||||
l_gadget, num_samples, num_lut_vectors, lwe_idx, max_shared_memory);
|
||||
level_count, num_samples, num_lut_vectors, lwe_idx, max_shared_memory);
|
||||
break;
|
||||
case 1024:
|
||||
host_bootstrap_amortized<uint32_t, Degree<1024>>(
|
||||
v_stream, (uint32_t *)lwe_out, (uint32_t *)lut_vector,
|
||||
(uint32_t *)lut_vector_indexes, (uint32_t *)lwe_in,
|
||||
v_stream, (uint32_t *)lwe_array_out, (uint32_t *)lut_vector,
|
||||
(uint32_t *)lut_vector_indexes, (uint32_t *)lwe_array_in,
|
||||
(double2 *)bootstrapping_key, lwe_dimension, polynomial_size, base_log,
|
||||
l_gadget, num_samples, num_lut_vectors, lwe_idx, max_shared_memory);
|
||||
level_count, num_samples, num_lut_vectors, lwe_idx, max_shared_memory);
|
||||
break;
|
||||
case 2048:
|
||||
host_bootstrap_amortized<uint32_t, Degree<2048>>(
|
||||
v_stream, (uint32_t *)lwe_out, (uint32_t *)lut_vector,
|
||||
(uint32_t *)lut_vector_indexes, (uint32_t *)lwe_in,
|
||||
v_stream, (uint32_t *)lwe_array_out, (uint32_t *)lut_vector,
|
||||
(uint32_t *)lut_vector_indexes, (uint32_t *)lwe_array_in,
|
||||
(double2 *)bootstrapping_key, lwe_dimension, polynomial_size, base_log,
|
||||
l_gadget, num_samples, num_lut_vectors, lwe_idx, max_shared_memory);
|
||||
level_count, num_samples, num_lut_vectors, lwe_idx, max_shared_memory);
|
||||
break;
|
||||
case 4096:
|
||||
host_bootstrap_amortized<uint32_t, Degree<4096>>(
|
||||
v_stream, (uint32_t *)lwe_out, (uint32_t *)lut_vector,
|
||||
(uint32_t *)lut_vector_indexes, (uint32_t *)lwe_in,
|
||||
v_stream, (uint32_t *)lwe_array_out, (uint32_t *)lut_vector,
|
||||
(uint32_t *)lut_vector_indexes, (uint32_t *)lwe_array_in,
|
||||
(double2 *)bootstrapping_key, lwe_dimension, polynomial_size, base_log,
|
||||
l_gadget, num_samples, num_lut_vectors, lwe_idx, max_shared_memory);
|
||||
level_count, num_samples, num_lut_vectors, lwe_idx, max_shared_memory);
|
||||
break;
|
||||
case 8192:
|
||||
host_bootstrap_amortized<uint32_t, Degree<8192>>(
|
||||
v_stream, (uint32_t *)lwe_out, (uint32_t *)lut_vector,
|
||||
(uint32_t *)lut_vector_indexes, (uint32_t *)lwe_in,
|
||||
v_stream, (uint32_t *)lwe_array_out, (uint32_t *)lut_vector,
|
||||
(uint32_t *)lut_vector_indexes, (uint32_t *)lwe_array_in,
|
||||
(double2 *)bootstrapping_key, lwe_dimension, polynomial_size, base_log,
|
||||
l_gadget, num_samples, num_lut_vectors, lwe_idx, max_shared_memory);
|
||||
level_count, num_samples, num_lut_vectors, lwe_idx, max_shared_memory);
|
||||
break;
|
||||
default:
|
||||
break;
|
||||
@@ -116,11 +116,11 @@ void cuda_bootstrap_amortized_lwe_ciphertext_vector_32(
|
||||
}
|
||||
|
||||
void cuda_bootstrap_amortized_lwe_ciphertext_vector_64(
|
||||
void *v_stream, void *lwe_out, void *lut_vector, void *lut_vector_indexes,
|
||||
void *lwe_in, void *bootstrapping_key, uint32_t lwe_dimension,
|
||||
uint32_t glwe_dimension, uint32_t polynomial_size, uint32_t base_log,
|
||||
uint32_t l_gadget, uint32_t num_samples, uint32_t num_lut_vectors,
|
||||
uint32_t lwe_idx, uint32_t max_shared_memory) {
|
||||
void *v_stream, void *lwe_array_out, void *lut_vector,
|
||||
void *lut_vector_indexes, void *lwe_array_in, void *bootstrapping_key,
|
||||
uint32_t lwe_dimension, uint32_t glwe_dimension, uint32_t polynomial_size,
|
||||
uint32_t base_log, uint32_t level_count, uint32_t num_samples,
|
||||
uint32_t num_lut_vectors, uint32_t lwe_idx, uint32_t max_shared_memory) {
|
||||
|
||||
assert(
|
||||
("Error (GPU amortized PBS): base log should be <= 16", base_log <= 16));
|
||||
@@ -135,38 +135,38 @@ void cuda_bootstrap_amortized_lwe_ciphertext_vector_64(
|
||||
switch (polynomial_size) {
|
||||
case 512:
|
||||
host_bootstrap_amortized<uint64_t, Degree<512>>(
|
||||
v_stream, (uint64_t *)lwe_out, (uint64_t *)lut_vector,
|
||||
(uint32_t *)lut_vector_indexes, (uint64_t *)lwe_in,
|
||||
v_stream, (uint64_t *)lwe_array_out, (uint64_t *)lut_vector,
|
||||
(uint32_t *)lut_vector_indexes, (uint64_t *)lwe_array_in,
|
||||
(double2 *)bootstrapping_key, lwe_dimension, polynomial_size, base_log,
|
||||
l_gadget, num_samples, num_lut_vectors, lwe_idx, max_shared_memory);
|
||||
level_count, num_samples, num_lut_vectors, lwe_idx, max_shared_memory);
|
||||
break;
|
||||
case 1024:
|
||||
host_bootstrap_amortized<uint64_t, Degree<1024>>(
|
||||
v_stream, (uint64_t *)lwe_out, (uint64_t *)lut_vector,
|
||||
(uint32_t *)lut_vector_indexes, (uint64_t *)lwe_in,
|
||||
v_stream, (uint64_t *)lwe_array_out, (uint64_t *)lut_vector,
|
||||
(uint32_t *)lut_vector_indexes, (uint64_t *)lwe_array_in,
|
||||
(double2 *)bootstrapping_key, lwe_dimension, polynomial_size, base_log,
|
||||
l_gadget, num_samples, num_lut_vectors, lwe_idx, max_shared_memory);
|
||||
level_count, num_samples, num_lut_vectors, lwe_idx, max_shared_memory);
|
||||
break;
|
||||
case 2048:
|
||||
host_bootstrap_amortized<uint64_t, Degree<2048>>(
|
||||
v_stream, (uint64_t *)lwe_out, (uint64_t *)lut_vector,
|
||||
(uint32_t *)lut_vector_indexes, (uint64_t *)lwe_in,
|
||||
v_stream, (uint64_t *)lwe_array_out, (uint64_t *)lut_vector,
|
||||
(uint32_t *)lut_vector_indexes, (uint64_t *)lwe_array_in,
|
||||
(double2 *)bootstrapping_key, lwe_dimension, polynomial_size, base_log,
|
||||
l_gadget, num_samples, num_lut_vectors, lwe_idx, max_shared_memory);
|
||||
level_count, num_samples, num_lut_vectors, lwe_idx, max_shared_memory);
|
||||
break;
|
||||
case 4096:
|
||||
host_bootstrap_amortized<uint64_t, Degree<4096>>(
|
||||
v_stream, (uint64_t *)lwe_out, (uint64_t *)lut_vector,
|
||||
(uint32_t *)lut_vector_indexes, (uint64_t *)lwe_in,
|
||||
v_stream, (uint64_t *)lwe_array_out, (uint64_t *)lut_vector,
|
||||
(uint32_t *)lut_vector_indexes, (uint64_t *)lwe_array_in,
|
||||
(double2 *)bootstrapping_key, lwe_dimension, polynomial_size, base_log,
|
||||
l_gadget, num_samples, num_lut_vectors, lwe_idx, max_shared_memory);
|
||||
level_count, num_samples, num_lut_vectors, lwe_idx, max_shared_memory);
|
||||
break;
|
||||
case 8192:
|
||||
host_bootstrap_amortized<uint64_t, Degree<8192>>(
|
||||
v_stream, (uint64_t *)lwe_out, (uint64_t *)lut_vector,
|
||||
(uint32_t *)lut_vector_indexes, (uint64_t *)lwe_in,
|
||||
v_stream, (uint64_t *)lwe_array_out, (uint64_t *)lut_vector,
|
||||
(uint32_t *)lut_vector_indexes, (uint64_t *)lwe_array_in,
|
||||
(double2 *)bootstrapping_key, lwe_dimension, polynomial_size, base_log,
|
||||
l_gadget, num_samples, num_lut_vectors, lwe_idx, max_shared_memory);
|
||||
level_count, num_samples, num_lut_vectors, lwe_idx, max_shared_memory);
|
||||
break;
|
||||
default:
|
||||
break;
|
||||
|
||||
@@ -29,35 +29,36 @@ template <typename Torus, class params, sharedMemDegree SMD>
|
||||
* Kernel launched by host_bootstrap_amortized
|
||||
*
|
||||
* Uses shared memory to increase performance
|
||||
* - lwe_out: output batch of num_samples bootstrapped ciphertexts c =
|
||||
* - lwe_array_out: output batch of num_samples bootstrapped ciphertexts c =
|
||||
* (a0,..an-1,b) where n is the LWE dimension
|
||||
* - lut_vector: should hold as many test vectors of size polynomial_size
|
||||
* as there are input ciphertexts, but actually holds
|
||||
* num_lut_vectors vectors to reduce memory usage
|
||||
* - lut_vector_indexes: stores the index corresponding to which test vector
|
||||
* to use for each sample in lut_vector
|
||||
* - lwe_in: input batch of num_samples LWE ciphertexts, containing n mask
|
||||
* values + 1 body value
|
||||
* - lwe_array_in: input batch of num_samples LWE ciphertexts, containing n
|
||||
* mask values + 1 body value
|
||||
* - bootstrapping_key: RGSW encryption of the LWE secret key sk1 under secret
|
||||
* key sk2
|
||||
* - device_mem: pointer to the device's global memory in case we use it (SMD
|
||||
* == NOSM or PARTIALSM)
|
||||
* - lwe_mask_size: size of the Torus vector used to encrypt the input
|
||||
* - lwe_dimension: size of the Torus vector used to encrypt the input
|
||||
* LWE ciphertexts - referred to as n above (~ 600)
|
||||
* - polynomial_size: size of the test polynomial (test vector) and size of the
|
||||
* GLWE polynomial (~1024)
|
||||
* - base_log: log base used for the gadget matrix - B = 2^base_log (~8)
|
||||
* - l_gadget: number of decomposition levels in the gadget matrix (~4)
|
||||
* - level_count: number of decomposition levels in the gadget matrix (~4)
|
||||
* - gpu_num: index of the current GPU (useful for multi-GPU computations)
|
||||
* - lwe_idx: equal to the number of samples per gpu x gpu_num
|
||||
* - device_memory_size_per_sample: amount of global memory to allocate if SMD
|
||||
* is not FULLSM
|
||||
*/
|
||||
__global__ void device_bootstrap_amortized(
|
||||
Torus *lwe_out, Torus *lut_vector, uint32_t *lut_vector_indexes,
|
||||
Torus *lwe_in, double2 *bootstrapping_key, char *device_mem,
|
||||
uint32_t lwe_mask_size, uint32_t polynomial_size, uint32_t base_log,
|
||||
uint32_t l_gadget, uint32_t lwe_idx, size_t device_memory_size_per_sample) {
|
||||
Torus *lwe_array_out, Torus *lut_vector, uint32_t *lut_vector_indexes,
|
||||
Torus *lwe_array_in, double2 *bootstrapping_key, char *device_mem,
|
||||
uint32_t lwe_dimension, uint32_t polynomial_size, uint32_t base_log,
|
||||
uint32_t level_count, uint32_t lwe_idx,
|
||||
size_t device_memory_size_per_sample) {
|
||||
// We use shared memory for the polynomials that are used often during the
|
||||
// bootstrap, since shared memory is kept in L1 cache and accessing it is
|
||||
// much faster than global memory
|
||||
@@ -69,7 +70,7 @@ __global__ void device_bootstrap_amortized(
|
||||
else
|
||||
selected_memory = &device_mem[blockIdx.x * device_memory_size_per_sample];
|
||||
|
||||
// For GPU bootstrapping the RLWE dimension is hard-set to 1: there is only
|
||||
// For GPU bootstrapping the GLWE dimension is hard-set to 1: there is only
|
||||
// one mask polynomial and 1 body to handle Also, since the decomposed
|
||||
// polynomials take coefficients between -B/2 and B/2 they can be represented
|
||||
// with only 16 bits, assuming the base log does not exceed 2^16
|
||||
@@ -93,16 +94,16 @@ __global__ void device_bootstrap_amortized(
|
||||
accumulator_fft =
|
||||
(double2 *)body_res_fft + (ptrdiff_t)(polynomial_size / 2);
|
||||
|
||||
auto block_lwe_in = &lwe_in[blockIdx.x * (lwe_mask_size + 1)];
|
||||
auto block_lwe_array_in = &lwe_array_in[blockIdx.x * (lwe_dimension + 1)];
|
||||
Torus *block_lut_vector =
|
||||
&lut_vector[lut_vector_indexes[lwe_idx + blockIdx.x] * params::degree *
|
||||
2];
|
||||
|
||||
GadgetMatrix<Torus, params> gadget(base_log, l_gadget);
|
||||
GadgetMatrix<Torus, params> gadget(base_log, level_count);
|
||||
|
||||
// Put "b", the body, in [0, 2N[
|
||||
Torus b_hat = rescale_torus_element(
|
||||
block_lwe_in[lwe_mask_size],
|
||||
block_lwe_array_in[lwe_dimension],
|
||||
2 * params::degree); // 2 * params::log2_degree + 1);
|
||||
|
||||
divide_by_monomial_negacyclic_inplace<Torus, params::opt,
|
||||
@@ -115,14 +116,14 @@ __global__ void device_bootstrap_amortized(
|
||||
|
||||
// Loop over all the mask elements of the sample to accumulate
|
||||
// (X^a_i-1) multiplication, decomposition of the resulting polynomial
|
||||
// into l_gadget polynomials, and performing polynomial multiplication
|
||||
// into level_count polynomials, and performing polynomial multiplication
|
||||
// via an FFT with the RGSW encrypted secret key
|
||||
for (int iteration = 0; iteration < lwe_mask_size; iteration++) {
|
||||
for (int iteration = 0; iteration < lwe_dimension; iteration++) {
|
||||
synchronize_threads_in_block();
|
||||
|
||||
// Put "a" in [0, 2N[ instead of Zq
|
||||
Torus a_hat = rescale_torus_element(
|
||||
block_lwe_in[iteration],
|
||||
block_lwe_array_in[iteration],
|
||||
2 * params::degree); // 2 * params::log2_degree + 1);
|
||||
|
||||
// Perform ACC * (X^ä - 1)
|
||||
@@ -140,11 +141,11 @@ __global__ void device_bootstrap_amortized(
|
||||
// bootstrapped ciphertext
|
||||
round_to_closest_multiple_inplace<Torus, params::opt,
|
||||
params::degree / params::opt>(
|
||||
accumulator_mask_rotated, base_log, l_gadget);
|
||||
accumulator_mask_rotated, base_log, level_count);
|
||||
|
||||
round_to_closest_multiple_inplace<Torus, params::opt,
|
||||
params::degree / params::opt>(
|
||||
accumulator_body_rotated, base_log, l_gadget);
|
||||
accumulator_body_rotated, base_log, level_count);
|
||||
// Initialize the polynomial multiplication via FFT arrays
|
||||
// The polynomial multiplications happens at the block level
|
||||
// and each thread handles two or more coefficients
|
||||
@@ -160,13 +161,13 @@ __global__ void device_bootstrap_amortized(
|
||||
// Now that the rotation is done, decompose the resulting polynomial
|
||||
// coefficients so as to multiply each decomposed level with the
|
||||
// corresponding part of the bootstrapping key
|
||||
for (int decomp_level = 0; decomp_level < l_gadget; decomp_level++) {
|
||||
for (int level = 0; level < level_count; level++) {
|
||||
|
||||
gadget.decompose_one_level(accumulator_mask_decomposed,
|
||||
accumulator_mask_rotated, decomp_level);
|
||||
accumulator_mask_rotated, level);
|
||||
|
||||
gadget.decompose_one_level(accumulator_body_decomposed,
|
||||
accumulator_body_rotated, decomp_level);
|
||||
accumulator_body_rotated, level);
|
||||
|
||||
synchronize_threads_in_block();
|
||||
|
||||
@@ -187,11 +188,11 @@ __global__ void device_bootstrap_amortized(
|
||||
// Get the bootstrapping key piece necessary for the multiplication
|
||||
// It is already in the Fourier domain
|
||||
auto bsk_mask_slice = PolynomialFourier<double2, params>(
|
||||
get_ith_mask_kth_block(bootstrapping_key, iteration, 0, decomp_level,
|
||||
polynomial_size, 1, l_gadget));
|
||||
get_ith_mask_kth_block(bootstrapping_key, iteration, 0, level,
|
||||
polynomial_size, 1, level_count));
|
||||
auto bsk_body_slice = PolynomialFourier<double2, params>(
|
||||
get_ith_body_kth_block(bootstrapping_key, iteration, 0, decomp_level,
|
||||
polynomial_size, 1, l_gadget));
|
||||
get_ith_body_kth_block(bootstrapping_key, iteration, 0, level,
|
||||
polynomial_size, 1, level_count));
|
||||
|
||||
synchronize_threads_in_block();
|
||||
|
||||
@@ -216,11 +217,11 @@ __global__ void device_bootstrap_amortized(
|
||||
correction_direct_fft_inplace<params>(accumulator_fft);
|
||||
|
||||
auto bsk_mask_slice_2 = PolynomialFourier<double2, params>(
|
||||
get_ith_mask_kth_block(bootstrapping_key, iteration, 1, decomp_level,
|
||||
polynomial_size, 1, l_gadget));
|
||||
get_ith_mask_kth_block(bootstrapping_key, iteration, 1, level,
|
||||
polynomial_size, 1, level_count));
|
||||
auto bsk_body_slice_2 = PolynomialFourier<double2, params>(
|
||||
get_ith_body_kth_block(bootstrapping_key, iteration, 1, decomp_level,
|
||||
polynomial_size, 1, l_gadget));
|
||||
get_ith_body_kth_block(bootstrapping_key, iteration, 1, level,
|
||||
polynomial_size, 1, level_count));
|
||||
|
||||
synchronize_threads_in_block();
|
||||
|
||||
@@ -283,23 +284,24 @@ __global__ void device_bootstrap_amortized(
|
||||
}
|
||||
}
|
||||
|
||||
auto block_lwe_out = &lwe_out[blockIdx.x * (polynomial_size + 1)];
|
||||
auto block_lwe_array_out = &lwe_array_out[blockIdx.x * (polynomial_size + 1)];
|
||||
|
||||
// The blind rotation for this block is over
|
||||
// Now we can perform the sample extraction: for the body it's just
|
||||
// the resulting constant coefficient of the accumulator
|
||||
// For the mask it's more complicated
|
||||
sample_extract_mask<Torus, params>(block_lwe_out, accumulator_mask);
|
||||
sample_extract_body<Torus, params>(block_lwe_out, accumulator_body);
|
||||
sample_extract_mask<Torus, params>(block_lwe_array_out, accumulator_mask);
|
||||
sample_extract_body<Torus, params>(block_lwe_array_out, accumulator_body);
|
||||
}
|
||||
|
||||
template <typename Torus, class params>
|
||||
__host__ void host_bootstrap_amortized(
|
||||
void *v_stream, Torus *lwe_out, Torus *lut_vector,
|
||||
uint32_t *lut_vector_indexes, Torus *lwe_in, double2 *bootstrapping_key,
|
||||
uint32_t input_lwe_dimension, uint32_t polynomial_size, uint32_t base_log,
|
||||
uint32_t l_gadget, uint32_t input_lwe_ciphertext_count,
|
||||
uint32_t num_lut_vectors, uint32_t lwe_idx, uint32_t max_shared_memory) {
|
||||
void *v_stream, Torus *lwe_array_out, Torus *lut_vector,
|
||||
uint32_t *lut_vector_indexes, Torus *lwe_array_in,
|
||||
double2 *bootstrapping_key, uint32_t input_lwe_dimension,
|
||||
uint32_t polynomial_size, uint32_t base_log, uint32_t level_count,
|
||||
uint32_t input_lwe_ciphertext_count, uint32_t num_lut_vectors,
|
||||
uint32_t lwe_idx, uint32_t max_shared_memory) {
|
||||
|
||||
int SM_FULL = sizeof(Torus) * polynomial_size + // accumulator mask
|
||||
sizeof(Torus) * polynomial_size + // accumulator body
|
||||
@@ -338,9 +340,9 @@ __host__ void host_bootstrap_amortized(
|
||||
checkCudaErrors(
|
||||
cudaMalloc((void **)&d_mem, DM_FULL * input_lwe_ciphertext_count));
|
||||
device_bootstrap_amortized<Torus, params, NOSM><<<grid, thds, 0, *stream>>>(
|
||||
lwe_out, lut_vector, lut_vector_indexes, lwe_in, bootstrapping_key,
|
||||
d_mem, input_lwe_dimension, polynomial_size, base_log, l_gadget,
|
||||
lwe_idx, DM_FULL);
|
||||
lwe_array_out, lut_vector, lut_vector_indexes, lwe_array_in,
|
||||
bootstrapping_key, d_mem, input_lwe_dimension, polynomial_size,
|
||||
base_log, level_count, lwe_idx, DM_FULL);
|
||||
} else if (max_shared_memory < SM_FULL) {
|
||||
cudaFuncSetAttribute(device_bootstrap_amortized<Torus, params, PARTIALSM>,
|
||||
cudaFuncAttributeMaxDynamicSharedMemorySize, SM_PART);
|
||||
@@ -350,9 +352,9 @@ __host__ void host_bootstrap_amortized(
|
||||
cudaMalloc((void **)&d_mem, DM_PART * input_lwe_ciphertext_count));
|
||||
device_bootstrap_amortized<Torus, params, PARTIALSM>
|
||||
<<<grid, thds, SM_PART, *stream>>>(
|
||||
lwe_out, lut_vector, lut_vector_indexes, lwe_in, bootstrapping_key,
|
||||
d_mem, input_lwe_dimension, polynomial_size, base_log, l_gadget,
|
||||
lwe_idx, DM_PART);
|
||||
lwe_array_out, lut_vector, lut_vector_indexes, lwe_array_in,
|
||||
bootstrapping_key, d_mem, input_lwe_dimension, polynomial_size,
|
||||
base_log, level_count, lwe_idx, DM_PART);
|
||||
} else {
|
||||
// For devices with compute capability 7.x a single thread block can
|
||||
// address the full capacity of shared memory. Shared memory on the
|
||||
@@ -369,12 +371,12 @@ __host__ void host_bootstrap_amortized(
|
||||
|
||||
device_bootstrap_amortized<Torus, params, FULLSM>
|
||||
<<<grid, thds, SM_FULL, *stream>>>(
|
||||
lwe_out, lut_vector, lut_vector_indexes, lwe_in, bootstrapping_key,
|
||||
d_mem, input_lwe_dimension, polynomial_size, base_log, l_gadget,
|
||||
lwe_idx, 0);
|
||||
lwe_array_out, lut_vector, lut_vector_indexes, lwe_array_in,
|
||||
bootstrapping_key, d_mem, input_lwe_dimension, polynomial_size,
|
||||
base_log, level_count, lwe_idx, 0);
|
||||
}
|
||||
// Synchronize the streams before copying the result to lwe_out at the right
|
||||
// place
|
||||
// Synchronize the streams before copying the result to lwe_array_out at the
|
||||
// right place
|
||||
cudaStreamSynchronize(*stream);
|
||||
cudaFree(d_mem);
|
||||
}
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
|
||||
/* Perform bootstrapping on a batch of input LWE ciphertexts
|
||||
*
|
||||
* - lwe_out: output batch of num_samples bootstrapped ciphertexts c =
|
||||
* - lwe_array_out: output batch of num_samples bootstrapped ciphertexts c =
|
||||
* (a0,..an-1,b) where n is the LWE dimension
|
||||
* - lut_vector: should hold as many test vectors of size polynomial_size
|
||||
* as there are input ciphertexts, but actually holds
|
||||
@@ -10,7 +10,7 @@
|
||||
* - lut_vector_indexes: stores the index corresponding to
|
||||
* which test vector to use for each sample in
|
||||
* lut_vector
|
||||
* - lwe_in: input batch of num_samples LWE ciphertexts, containing n
|
||||
* - lwe_array_in: input batch of num_samples LWE ciphertexts, containing n
|
||||
* mask values + 1 body value
|
||||
* - bootstrapping_key: RGSW encryption of the LWE secret key sk1
|
||||
* under secret key sk2
|
||||
@@ -30,7 +30,7 @@
|
||||
* - polynomial_size: size of the test polynomial (test vector) and size of the
|
||||
* GLWE polynomial (~1024)
|
||||
* - base_log: log base used for the gadget matrix - B = 2^base_log (~8)
|
||||
* - l_gadget: number of decomposition levels in the gadget matrix (~4)
|
||||
* - level_count: number of decomposition levels in the gadget matrix (~4)
|
||||
* - num_samples: number of encrypted input messages
|
||||
* - num_lut_vectors: parameter to set the actual number of test vectors to be
|
||||
* used
|
||||
@@ -44,7 +44,7 @@
|
||||
* to handle one or more polynomial coefficients at each stage:
|
||||
* - perform the blind rotation
|
||||
* - round the result
|
||||
* - decompose into l_gadget levels, then for each level:
|
||||
* - decompose into level_count levels, then for each level:
|
||||
* - switch to the FFT domain
|
||||
* - multiply with the bootstrapping key
|
||||
* - come back to the coefficients representation
|
||||
@@ -57,11 +57,11 @@
|
||||
* values for the FFT
|
||||
*/
|
||||
void cuda_bootstrap_low_latency_lwe_ciphertext_vector_32(
|
||||
void *v_stream, void *lwe_out, void *lut_vector, void *lut_vector_indexes,
|
||||
void *lwe_in, void *bootstrapping_key, uint32_t lwe_dimension,
|
||||
uint32_t glwe_dimension, uint32_t polynomial_size, uint32_t base_log,
|
||||
uint32_t l_gadget, uint32_t num_samples, uint32_t num_lut_vectors,
|
||||
uint32_t lwe_idx, uint32_t max_shared_memory) {
|
||||
void *v_stream, void *lwe_array_out, void *lut_vector,
|
||||
void *lut_vector_indexes, void *lwe_array_in, void *bootstrapping_key,
|
||||
uint32_t lwe_dimension, uint32_t glwe_dimension, uint32_t polynomial_size,
|
||||
uint32_t base_log, uint32_t level_count, uint32_t num_samples,
|
||||
uint32_t num_lut_vectors, uint32_t lwe_idx, uint32_t max_shared_memory) {
|
||||
|
||||
assert(("Error (GPU low latency PBS): base log should be <= 16",
|
||||
base_log <= 16));
|
||||
@@ -79,44 +79,30 @@ void cuda_bootstrap_low_latency_lwe_ciphertext_vector_32(
|
||||
assert(("Error (GPU low latency PBS): the number of input LWEs must be lower "
|
||||
"or equal to the "
|
||||
"number of streaming multiprocessors on the device divided by 8 * "
|
||||
"l_gadget",
|
||||
num_samples <= number_of_sm / 4. / 2. / l_gadget));
|
||||
"level_count",
|
||||
num_samples <= number_of_sm / 4. / 2. / level_count));
|
||||
|
||||
switch (polynomial_size) {
|
||||
case 512:
|
||||
host_bootstrap_low_latency<uint32_t, Degree<512>>(
|
||||
v_stream, (uint32_t *)lwe_out, (uint32_t *)lut_vector,
|
||||
(uint32_t *)lut_vector_indexes, (uint32_t *)lwe_in,
|
||||
v_stream, (uint32_t *)lwe_array_out, (uint32_t *)lut_vector,
|
||||
(uint32_t *)lut_vector_indexes, (uint32_t *)lwe_array_in,
|
||||
(double2 *)bootstrapping_key, lwe_dimension, polynomial_size, base_log,
|
||||
l_gadget, num_samples, num_lut_vectors);
|
||||
level_count, num_samples, num_lut_vectors);
|
||||
break;
|
||||
case 1024:
|
||||
host_bootstrap_low_latency<uint32_t, Degree<1024>>(
|
||||
v_stream, (uint32_t *)lwe_out, (uint32_t *)lut_vector,
|
||||
(uint32_t *)lut_vector_indexes, (uint32_t *)lwe_in,
|
||||
v_stream, (uint32_t *)lwe_array_out, (uint32_t *)lut_vector,
|
||||
(uint32_t *)lut_vector_indexes, (uint32_t *)lwe_array_in,
|
||||
(double2 *)bootstrapping_key, lwe_dimension, polynomial_size, base_log,
|
||||
l_gadget, num_samples, num_lut_vectors);
|
||||
level_count, num_samples, num_lut_vectors);
|
||||
break;
|
||||
case 2048:
|
||||
host_bootstrap_low_latency<uint32_t, Degree<2048>>(
|
||||
v_stream, (uint32_t *)lwe_out, (uint32_t *)lut_vector,
|
||||
(uint32_t *)lut_vector_indexes, (uint32_t *)lwe_in,
|
||||
v_stream, (uint32_t *)lwe_array_out, (uint32_t *)lut_vector,
|
||||
(uint32_t *)lut_vector_indexes, (uint32_t *)lwe_array_in,
|
||||
(double2 *)bootstrapping_key, lwe_dimension, polynomial_size, base_log,
|
||||
l_gadget, num_samples, num_lut_vectors);
|
||||
break;
|
||||
case 4096:
|
||||
host_bootstrap_low_latency<uint32_t, Degree<4096>>(
|
||||
v_stream, (uint32_t *)lwe_out, (uint32_t *)lut_vector,
|
||||
(uint32_t *)lut_vector_indexes, (uint32_t *)lwe_in,
|
||||
(double2 *)bootstrapping_key, lwe_dimension, polynomial_size, base_log,
|
||||
l_gadget, num_samples, num_lut_vectors);
|
||||
break;
|
||||
case 8192:
|
||||
host_bootstrap_low_latency<uint32_t, Degree<8192>>(
|
||||
v_stream, (uint32_t *)lwe_out, (uint32_t *)lut_vector,
|
||||
(uint32_t *)lut_vector_indexes, (uint32_t *)lwe_in,
|
||||
(double2 *)bootstrapping_key, lwe_dimension, polynomial_size, base_log,
|
||||
l_gadget, num_samples, num_lut_vectors);
|
||||
level_count, num_samples, num_lut_vectors);
|
||||
break;
|
||||
default:
|
||||
break;
|
||||
@@ -124,11 +110,11 @@ void cuda_bootstrap_low_latency_lwe_ciphertext_vector_32(
|
||||
}
|
||||
|
||||
void cuda_bootstrap_low_latency_lwe_ciphertext_vector_64(
|
||||
void *v_stream, void *lwe_out, void *lut_vector, void *lut_vector_indexes,
|
||||
void *lwe_in, void *bootstrapping_key, uint32_t lwe_dimension,
|
||||
uint32_t glwe_dimension, uint32_t polynomial_size, uint32_t base_log,
|
||||
uint32_t l_gadget, uint32_t num_samples, uint32_t num_lut_vectors,
|
||||
uint32_t lwe_idx, uint32_t max_shared_memory) {
|
||||
void *v_stream, void *lwe_array_out, void *lut_vector,
|
||||
void *lut_vector_indexes, void *lwe_array_in, void *bootstrapping_key,
|
||||
uint32_t lwe_dimension, uint32_t glwe_dimension, uint32_t polynomial_size,
|
||||
uint32_t base_log, uint32_t level_count, uint32_t num_samples,
|
||||
uint32_t num_lut_vectors, uint32_t lwe_idx, uint32_t max_shared_memory) {
|
||||
|
||||
assert(("Error (GPU low latency PBS): base log should be <= 16",
|
||||
base_log <= 16));
|
||||
@@ -146,44 +132,30 @@ void cuda_bootstrap_low_latency_lwe_ciphertext_vector_64(
|
||||
assert(("Error (GPU low latency PBS): the number of input LWEs must be lower "
|
||||
"or equal to the "
|
||||
"number of streaming multiprocessors on the device divided by 8 * "
|
||||
"l_gadget",
|
||||
num_samples <= number_of_sm / 4. / 2. / l_gadget));
|
||||
"level_count",
|
||||
num_samples <= number_of_sm / 4. / 2. / level_count));
|
||||
|
||||
switch (polynomial_size) {
|
||||
case 512:
|
||||
host_bootstrap_low_latency<uint64_t, Degree<512>>(
|
||||
v_stream, (uint64_t *)lwe_out, (uint64_t *)lut_vector,
|
||||
(uint32_t *)lut_vector_indexes, (uint64_t *)lwe_in,
|
||||
v_stream, (uint64_t *)lwe_array_out, (uint64_t *)lut_vector,
|
||||
(uint32_t *)lut_vector_indexes, (uint64_t *)lwe_array_in,
|
||||
(double2 *)bootstrapping_key, lwe_dimension, polynomial_size, base_log,
|
||||
l_gadget, num_samples, num_lut_vectors);
|
||||
level_count, num_samples, num_lut_vectors);
|
||||
break;
|
||||
case 1024:
|
||||
host_bootstrap_low_latency<uint64_t, Degree<1024>>(
|
||||
v_stream, (uint64_t *)lwe_out, (uint64_t *)lut_vector,
|
||||
(uint32_t *)lut_vector_indexes, (uint64_t *)lwe_in,
|
||||
v_stream, (uint64_t *)lwe_array_out, (uint64_t *)lut_vector,
|
||||
(uint32_t *)lut_vector_indexes, (uint64_t *)lwe_array_in,
|
||||
(double2 *)bootstrapping_key, lwe_dimension, polynomial_size, base_log,
|
||||
l_gadget, num_samples, num_lut_vectors);
|
||||
level_count, num_samples, num_lut_vectors);
|
||||
break;
|
||||
case 2048:
|
||||
host_bootstrap_low_latency<uint64_t, Degree<2048>>(
|
||||
v_stream, (uint64_t *)lwe_out, (uint64_t *)lut_vector,
|
||||
(uint32_t *)lut_vector_indexes, (uint64_t *)lwe_in,
|
||||
v_stream, (uint64_t *)lwe_array_out, (uint64_t *)lut_vector,
|
||||
(uint32_t *)lut_vector_indexes, (uint64_t *)lwe_array_in,
|
||||
(double2 *)bootstrapping_key, lwe_dimension, polynomial_size, base_log,
|
||||
l_gadget, num_samples, num_lut_vectors);
|
||||
break;
|
||||
case 4096:
|
||||
host_bootstrap_low_latency<uint64_t, Degree<4096>>(
|
||||
v_stream, (uint64_t *)lwe_out, (uint64_t *)lut_vector,
|
||||
(uint32_t *)lut_vector_indexes, (uint64_t *)lwe_in,
|
||||
(double2 *)bootstrapping_key, lwe_dimension, polynomial_size, base_log,
|
||||
l_gadget, num_samples, num_lut_vectors);
|
||||
break;
|
||||
case 8192:
|
||||
host_bootstrap_low_latency<uint64_t, Degree<8192>>(
|
||||
v_stream, (uint64_t *)lwe_out, (uint64_t *)lut_vector,
|
||||
(uint32_t *)lut_vector_indexes, (uint64_t *)lwe_in,
|
||||
(double2 *)bootstrapping_key, lwe_dimension, polynomial_size, base_log,
|
||||
l_gadget, num_samples, num_lut_vectors);
|
||||
level_count, num_samples, num_lut_vectors);
|
||||
break;
|
||||
default:
|
||||
break;
|
||||
|
||||
@@ -29,13 +29,13 @@ namespace cg = cooperative_groups;
|
||||
|
||||
template <typename Torus, class params>
|
||||
__device__ void
|
||||
mul_trgsw_trlwe(Torus *accumulator, double2 *fft, int16_t *trlwe_decomposed,
|
||||
double2 *mask_join_buffer, double2 *body_join_buffer,
|
||||
double2 *bootstrapping_key, int polynomial_size, int l_gadget,
|
||||
int iteration, grid_group &grid) {
|
||||
mul_ggsw_glwe(Torus *accumulator, double2 *fft, int16_t *glwe_decomposed,
|
||||
double2 *mask_join_buffer, double2 *body_join_buffer,
|
||||
double2 *bootstrapping_key, int polynomial_size, int level_count,
|
||||
int iteration, grid_group &grid) {
|
||||
|
||||
// Put the decomposed TRLWE sample in the Fourier domain
|
||||
real_to_complex_compressed<int16_t, params>(trlwe_decomposed, fft);
|
||||
// Put the decomposed GLWE sample in the Fourier domain
|
||||
real_to_complex_compressed<int16_t, params>(glwe_decomposed, fft);
|
||||
synchronize_threads_in_block();
|
||||
|
||||
// Switch to the FFT space
|
||||
@@ -53,12 +53,12 @@ mul_trgsw_trlwe(Torus *accumulator, double2 *fft, int16_t *trlwe_decomposed,
|
||||
|
||||
auto bsk_mask_slice = PolynomialFourier<double2, params>(
|
||||
get_ith_mask_kth_block(bootstrapping_key, iteration, blockIdx.y,
|
||||
blockIdx.x, polynomial_size, 1, l_gadget));
|
||||
blockIdx.x, polynomial_size, 1, level_count));
|
||||
auto bsk_body_slice = PolynomialFourier<double2, params>(
|
||||
get_ith_body_kth_block(bootstrapping_key, iteration, blockIdx.y,
|
||||
blockIdx.x, polynomial_size, 1, l_gadget));
|
||||
blockIdx.x, polynomial_size, 1, level_count));
|
||||
|
||||
// Perform the matrix multiplication between the RGSW and the TRLWE,
|
||||
// Perform the matrix multiplication between the GGSW and the GLWE,
|
||||
// each block operating on a single level for mask and body
|
||||
|
||||
auto first_processed_bsk =
|
||||
@@ -120,7 +120,7 @@ mul_trgsw_trlwe(Torus *accumulator, double2 *fft, int16_t *trlwe_decomposed,
|
||||
correction_inverse_fft_inplace<params>(fft);
|
||||
synchronize_threads_in_block();
|
||||
|
||||
// Perform the inverse FFT on the result of the RGSWxTRLWE and add to the
|
||||
// Perform the inverse FFT on the result of the GGSW x GWE and add to the
|
||||
// accumulator
|
||||
NSMFFT_inverse<HalfDegree<params>>(fft);
|
||||
synchronize_threads_in_block();
|
||||
@@ -134,17 +134,18 @@ template <typename Torus, class params>
|
||||
/*
|
||||
* Kernel launched by the low latency version of the
|
||||
* bootstrapping, that uses cooperative groups
|
||||
* lwe_out vector of output lwe s, with length (polynomial_size+1)*num_samples
|
||||
* lut_vector - vector of look up tables with length polynomial_size *
|
||||
* num_samples lut_vector_indexes - mapping between lwe_in and lut_vector lwe_in
|
||||
* - vector of lwe inputs with length (lwe_mask_size + 1) * num_samples
|
||||
* lwe_array_out vector of output lwe s, with length
|
||||
* (polynomial_size+1)*num_samples lut_vector - vector of look up tables with
|
||||
* length polynomial_size * num_samples lut_vector_indexes - mapping between
|
||||
* lwe_array_in and lut_vector lwe_array_in
|
||||
* - vector of lwe inputs with length (lwe_dimension + 1) * num_samples
|
||||
*
|
||||
*/
|
||||
__global__ void device_bootstrap_low_latency(
|
||||
Torus *lwe_out, Torus *lut_vector, Torus *lwe_in,
|
||||
Torus *lwe_array_out, Torus *lut_vector, Torus *lwe_array_in,
|
||||
double2 *bootstrapping_key, double2 *mask_join_buffer,
|
||||
double2 *body_join_buffer, uint32_t lwe_mask_size, uint32_t polynomial_size,
|
||||
uint32_t base_log, uint32_t l_gadget) {
|
||||
double2 *body_join_buffer, uint32_t lwe_dimension, uint32_t polynomial_size,
|
||||
uint32_t base_log, uint32_t level_count) {
|
||||
|
||||
grid_group grid = this_grid();
|
||||
|
||||
@@ -167,23 +168,23 @@ __global__ void device_bootstrap_low_latency(
|
||||
|
||||
// The third dimension of the block is used to determine on which ciphertext
|
||||
// this block is operating, in the case of batch bootstraps
|
||||
auto block_lwe_in = &lwe_in[blockIdx.z * (lwe_mask_size + 1)];
|
||||
auto block_lwe_array_in = &lwe_array_in[blockIdx.z * (lwe_dimension + 1)];
|
||||
|
||||
auto block_lut_vector = &lut_vector[blockIdx.z * params::degree * 2];
|
||||
|
||||
auto block_mask_join_buffer =
|
||||
&mask_join_buffer[blockIdx.z * l_gadget * params::degree / 2];
|
||||
&mask_join_buffer[blockIdx.z * level_count * params::degree / 2];
|
||||
auto block_body_join_buffer =
|
||||
&body_join_buffer[blockIdx.z * l_gadget * params::degree / 2];
|
||||
&body_join_buffer[blockIdx.z * level_count * params::degree / 2];
|
||||
|
||||
// Since the space is L1 cache is small, we use the same memory location for
|
||||
// the rotated accumulator and the fft accumulator, since we know that the
|
||||
// rotated array is not in use anymore by the time we perform the fft
|
||||
GadgetMatrix<Torus, params> gadget(base_log, l_gadget);
|
||||
GadgetMatrix<Torus, params> gadget(base_log, level_count);
|
||||
|
||||
// Put "b" in [0, 2N[
|
||||
Torus b_hat =
|
||||
rescale_torus_element(block_lwe_in[lwe_mask_size], 2 * params::degree);
|
||||
Torus b_hat = rescale_torus_element(block_lwe_array_in[lwe_dimension],
|
||||
2 * params::degree);
|
||||
|
||||
if (blockIdx.y == 0) {
|
||||
divide_by_monomial_negacyclic_inplace<Torus, params::opt,
|
||||
@@ -195,12 +196,12 @@ __global__ void device_bootstrap_low_latency(
|
||||
accumulator, &block_lut_vector[params::degree], b_hat, false);
|
||||
}
|
||||
|
||||
for (int i = 0; i < lwe_mask_size; i++) {
|
||||
for (int i = 0; i < lwe_dimension; i++) {
|
||||
synchronize_threads_in_block();
|
||||
|
||||
// Put "a" in [0, 2N[
|
||||
Torus a_hat = rescale_torus_element(
|
||||
block_lwe_in[i],
|
||||
block_lwe_array_in[i],
|
||||
2 * params::degree); // 2 * params::log2_degree + 1);
|
||||
|
||||
// Perform ACC * (X^ä - 1)
|
||||
@@ -212,7 +213,7 @@ __global__ void device_bootstrap_low_latency(
|
||||
// bootstrapped ciphertext
|
||||
round_to_closest_multiple_inplace<Torus, params::opt,
|
||||
params::degree / params::opt>(
|
||||
accumulator_rotated, base_log, l_gadget);
|
||||
accumulator_rotated, base_log, level_count);
|
||||
|
||||
// Decompose the accumulator. Each block gets one level of the
|
||||
// decomposition, for the mask and the body (so block 0 will have the
|
||||
@@ -224,22 +225,22 @@ __global__ void device_bootstrap_low_latency(
|
||||
// accumulator_rotated, so we need to synchronize here to make sure they
|
||||
// don't modify the same memory space at the same time
|
||||
synchronize_threads_in_block();
|
||||
// Perform G^-1(ACC) * RGSW -> TRLWE
|
||||
mul_trgsw_trlwe<Torus, params>(
|
||||
accumulator, accumulator_fft, accumulator_decomposed,
|
||||
block_mask_join_buffer, block_body_join_buffer, bootstrapping_key,
|
||||
polynomial_size, l_gadget, i, grid);
|
||||
// Perform G^-1(ACC) * GGSW -> GLWE
|
||||
mul_ggsw_glwe<Torus, params>(accumulator, accumulator_fft,
|
||||
accumulator_decomposed, block_mask_join_buffer,
|
||||
block_body_join_buffer, bootstrapping_key,
|
||||
polynomial_size, level_count, i, grid);
|
||||
}
|
||||
|
||||
auto block_lwe_out = &lwe_out[blockIdx.z * (polynomial_size + 1)];
|
||||
auto block_lwe_array_out = &lwe_array_out[blockIdx.z * (polynomial_size + 1)];
|
||||
|
||||
if (blockIdx.x == 0 && blockIdx.y == 0) {
|
||||
// Perform a sample extract. At this point, all blocks have the result, but
|
||||
// we do the computation at block 0 to avoid waiting for extra blocks, in
|
||||
// case they're not synchronized
|
||||
sample_extract_mask<Torus, params>(block_lwe_out, accumulator);
|
||||
sample_extract_mask<Torus, params>(block_lwe_array_out, accumulator);
|
||||
} else if (blockIdx.x == 0 && blockIdx.y == 1) {
|
||||
sample_extract_body<Torus, params>(block_lwe_out, accumulator);
|
||||
sample_extract_body<Torus, params>(block_lwe_array_out, accumulator);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -248,16 +249,18 @@ __global__ void device_bootstrap_low_latency(
|
||||
* of bootstrapping
|
||||
*/
|
||||
template <typename Torus, class params>
|
||||
__host__ void host_bootstrap_low_latency(
|
||||
void *v_stream, Torus *lwe_out, Torus *lut_vector,
|
||||
uint32_t *lut_vector_indexes, Torus *lwe_in, double2 *bootstrapping_key,
|
||||
uint32_t lwe_mask_size, uint32_t polynomial_size, uint32_t base_log,
|
||||
uint32_t l_gadget, uint32_t num_samples, uint32_t num_lut_vectors) {
|
||||
__host__ void
|
||||
host_bootstrap_low_latency(void *v_stream, Torus *lwe_array_out,
|
||||
Torus *lut_vector, uint32_t *lut_vector_indexes,
|
||||
Torus *lwe_array_in, double2 *bootstrapping_key,
|
||||
uint32_t lwe_dimension, uint32_t polynomial_size,
|
||||
uint32_t base_log, uint32_t level_count,
|
||||
uint32_t num_samples, uint32_t num_lut_vectors) {
|
||||
|
||||
auto stream = static_cast<cudaStream_t *>(v_stream);
|
||||
|
||||
int buffer_size_per_gpu =
|
||||
l_gadget * num_samples * polynomial_size / 2 * sizeof(double2);
|
||||
level_count * num_samples * polynomial_size / 2 * sizeof(double2);
|
||||
double2 *mask_buffer_fft;
|
||||
double2 *body_buffer_fft;
|
||||
checkCudaErrors(cudaMalloc((void **)&mask_buffer_fft, buffer_size_per_gpu));
|
||||
@@ -268,19 +271,19 @@ __host__ void host_bootstrap_low_latency(
|
||||
sizeof(double2) * polynomial_size / 2; // accumulator fft
|
||||
|
||||
int thds = polynomial_size / params::opt;
|
||||
dim3 grid(l_gadget, 2, num_samples);
|
||||
dim3 grid(level_count, 2, num_samples);
|
||||
|
||||
void *kernel_args[10];
|
||||
kernel_args[0] = &lwe_out;
|
||||
kernel_args[0] = &lwe_array_out;
|
||||
kernel_args[1] = &lut_vector;
|
||||
kernel_args[2] = &lwe_in;
|
||||
kernel_args[2] = &lwe_array_in;
|
||||
kernel_args[3] = &bootstrapping_key;
|
||||
kernel_args[4] = &mask_buffer_fft;
|
||||
kernel_args[5] = &body_buffer_fft;
|
||||
kernel_args[6] = &lwe_mask_size;
|
||||
kernel_args[6] = &lwe_dimension;
|
||||
kernel_args[7] = &polynomial_size;
|
||||
kernel_args[8] = &base_log;
|
||||
kernel_args[9] = &l_gadget;
|
||||
kernel_args[9] = &level_count;
|
||||
|
||||
checkCudaErrors(cudaFuncSetAttribute(
|
||||
device_bootstrap_low_latency<Torus, params>,
|
||||
@@ -292,8 +295,8 @@ __host__ void host_bootstrap_low_latency(
|
||||
(void *)device_bootstrap_low_latency<Torus, params>, grid, thds,
|
||||
(void **)kernel_args, bytes_needed, *stream));
|
||||
|
||||
// Synchronize the streams before copying the result to lwe_out at the right
|
||||
// place
|
||||
// Synchronize the streams before copying the result to lwe_array_out at the
|
||||
// right place
|
||||
cudaStreamSynchronize(*stream);
|
||||
cudaFree(mask_buffer_fft);
|
||||
cudaFree(body_buffer_fft);
|
||||
|
||||
@@ -1,9 +1,9 @@
|
||||
#include "bootstrap_wop.cuh"
|
||||
|
||||
void cuda_cmux_tree_32(void *v_stream, void *glwe_out, void *ggsw_in,
|
||||
void cuda_cmux_tree_32(void *v_stream, void *glwe_array_out, void *ggsw_in,
|
||||
void *lut_vector, uint32_t glwe_dimension,
|
||||
uint32_t polynomial_size, uint32_t base_log,
|
||||
uint32_t l_gadget, uint32_t r,
|
||||
uint32_t level_count, uint32_t r,
|
||||
uint32_t max_shared_memory) {
|
||||
|
||||
assert(("Error (GPU Cmux tree): base log should be <= 16", base_log <= 16));
|
||||
@@ -22,43 +22,43 @@ void cuda_cmux_tree_32(void *v_stream, void *glwe_out, void *ggsw_in,
|
||||
switch (polynomial_size) {
|
||||
case 512:
|
||||
host_cmux_tree<uint32_t, int32_t, Degree<512>>(
|
||||
v_stream, (uint32_t *)glwe_out, (uint32_t *)ggsw_in,
|
||||
v_stream, (uint32_t *)glwe_array_out, (uint32_t *)ggsw_in,
|
||||
(uint32_t *)lut_vector, glwe_dimension, polynomial_size, base_log,
|
||||
l_gadget, r, max_shared_memory);
|
||||
level_count, r, max_shared_memory);
|
||||
break;
|
||||
case 1024:
|
||||
host_cmux_tree<uint32_t, int32_t, Degree<1024>>(
|
||||
v_stream, (uint32_t *)glwe_out, (uint32_t *)ggsw_in,
|
||||
v_stream, (uint32_t *)glwe_array_out, (uint32_t *)ggsw_in,
|
||||
(uint32_t *)lut_vector, glwe_dimension, polynomial_size, base_log,
|
||||
l_gadget, r, max_shared_memory);
|
||||
level_count, r, max_shared_memory);
|
||||
break;
|
||||
case 2048:
|
||||
host_cmux_tree<uint32_t, int32_t, Degree<2048>>(
|
||||
v_stream, (uint32_t *)glwe_out, (uint32_t *)ggsw_in,
|
||||
v_stream, (uint32_t *)glwe_array_out, (uint32_t *)ggsw_in,
|
||||
(uint32_t *)lut_vector, glwe_dimension, polynomial_size, base_log,
|
||||
l_gadget, r, max_shared_memory);
|
||||
level_count, r, max_shared_memory);
|
||||
break;
|
||||
case 4096:
|
||||
host_cmux_tree<uint32_t, int32_t, Degree<4096>>(
|
||||
v_stream, (uint32_t *)glwe_out, (uint32_t *)ggsw_in,
|
||||
v_stream, (uint32_t *)glwe_array_out, (uint32_t *)ggsw_in,
|
||||
(uint32_t *)lut_vector, glwe_dimension, polynomial_size, base_log,
|
||||
l_gadget, r, max_shared_memory);
|
||||
level_count, r, max_shared_memory);
|
||||
break;
|
||||
case 8192:
|
||||
host_cmux_tree<uint32_t, int32_t, Degree<8192>>(
|
||||
v_stream, (uint32_t *)glwe_out, (uint32_t *)ggsw_in,
|
||||
v_stream, (uint32_t *)glwe_array_out, (uint32_t *)ggsw_in,
|
||||
(uint32_t *)lut_vector, glwe_dimension, polynomial_size, base_log,
|
||||
l_gadget, r, max_shared_memory);
|
||||
level_count, r, max_shared_memory);
|
||||
break;
|
||||
default:
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
void cuda_cmux_tree_64(void *v_stream, void *glwe_out, void *ggsw_in,
|
||||
void cuda_cmux_tree_64(void *v_stream, void *glwe_array_out, void *ggsw_in,
|
||||
void *lut_vector, uint32_t glwe_dimension,
|
||||
uint32_t polynomial_size, uint32_t base_log,
|
||||
uint32_t l_gadget, uint32_t r,
|
||||
uint32_t level_count, uint32_t r,
|
||||
uint32_t max_shared_memory) {
|
||||
|
||||
assert(("Error (GPU Cmux tree): base log should be <= 16", base_log <= 16));
|
||||
@@ -77,57 +77,56 @@ void cuda_cmux_tree_64(void *v_stream, void *glwe_out, void *ggsw_in,
|
||||
switch (polynomial_size) {
|
||||
case 512:
|
||||
host_cmux_tree<uint64_t, int64_t, Degree<512>>(
|
||||
v_stream, (uint64_t *)glwe_out, (uint64_t *)ggsw_in,
|
||||
v_stream, (uint64_t *)glwe_array_out, (uint64_t *)ggsw_in,
|
||||
(uint64_t *)lut_vector, glwe_dimension, polynomial_size, base_log,
|
||||
l_gadget, r, max_shared_memory);
|
||||
level_count, r, max_shared_memory);
|
||||
break;
|
||||
case 1024:
|
||||
host_cmux_tree<uint64_t, int64_t, Degree<1024>>(
|
||||
v_stream, (uint64_t *)glwe_out, (uint64_t *)ggsw_in,
|
||||
v_stream, (uint64_t *)glwe_array_out, (uint64_t *)ggsw_in,
|
||||
(uint64_t *)lut_vector, glwe_dimension, polynomial_size, base_log,
|
||||
l_gadget, r, max_shared_memory);
|
||||
level_count, r, max_shared_memory);
|
||||
break;
|
||||
case 2048:
|
||||
host_cmux_tree<uint64_t, int64_t, Degree<2048>>(
|
||||
v_stream, (uint64_t *)glwe_out, (uint64_t *)ggsw_in,
|
||||
v_stream, (uint64_t *)glwe_array_out, (uint64_t *)ggsw_in,
|
||||
(uint64_t *)lut_vector, glwe_dimension, polynomial_size, base_log,
|
||||
l_gadget, r, max_shared_memory);
|
||||
level_count, r, max_shared_memory);
|
||||
break;
|
||||
case 4096:
|
||||
host_cmux_tree<uint64_t, int64_t, Degree<4096>>(
|
||||
v_stream, (uint64_t *)glwe_out, (uint64_t *)ggsw_in,
|
||||
v_stream, (uint64_t *)glwe_array_out, (uint64_t *)ggsw_in,
|
||||
(uint64_t *)lut_vector, glwe_dimension, polynomial_size, base_log,
|
||||
l_gadget, r, max_shared_memory);
|
||||
level_count, r, max_shared_memory);
|
||||
break;
|
||||
case 8192:
|
||||
host_cmux_tree<uint64_t, int64_t, Degree<8192>>(
|
||||
v_stream, (uint64_t *)glwe_out, (uint64_t *)ggsw_in,
|
||||
v_stream, (uint64_t *)glwe_array_out, (uint64_t *)ggsw_in,
|
||||
(uint64_t *)lut_vector, glwe_dimension, polynomial_size, base_log,
|
||||
l_gadget, r, max_shared_memory);
|
||||
level_count, r, max_shared_memory);
|
||||
break;
|
||||
default:
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
void cuda_extract_bits_32(void *v_stream, void *list_lwe_out, void *lwe_in,
|
||||
void *lwe_in_buffer, void *lwe_in_shifted_buffer,
|
||||
void *lwe_out_ks_buffer, void *lwe_out_pbs_buffer,
|
||||
void *lut_pbs, void *lut_vector_indexes, void *ksk,
|
||||
void *fourier_bsk, uint32_t number_of_bits,
|
||||
uint32_t delta_log, uint32_t lwe_dimension_before,
|
||||
uint32_t lwe_dimension_after, uint32_t glwe_dimension,
|
||||
uint32_t base_log_bsk, uint32_t l_gadget_bsk,
|
||||
uint32_t base_log_ksk, uint32_t l_gadget_ksk,
|
||||
uint32_t number_of_samples) {
|
||||
void cuda_extract_bits_32(
|
||||
void *v_stream, void *list_lwe_array_out, void *lwe_array_in,
|
||||
void *lwe_array_in_buffer, void *lwe_array_in_shifted_buffer,
|
||||
void *lwe_array_out_ks_buffer, void *lwe_array_out_pbs_buffer,
|
||||
void *lut_pbs, void *lut_vector_indexes, void *ksk, void *fourier_bsk,
|
||||
uint32_t number_of_bits, uint32_t delta_log, uint32_t lwe_dimension_in,
|
||||
uint32_t lwe_dimension_out, uint32_t glwe_dimension, uint32_t base_log_bsk,
|
||||
uint32_t level_count_bsk, uint32_t base_log_ksk, uint32_t level_count_ksk,
|
||||
uint32_t number_of_samples) {
|
||||
assert(("Error (GPU extract bits): base log should be <= 16",
|
||||
base_log_bsk <= 16));
|
||||
assert(("Error (GPU extract bits): glwe_dimension should be equal to 1",
|
||||
glwe_dimension == 1));
|
||||
assert(("Error (GPU extract bits): lwe_dimension_before should be one of "
|
||||
assert(("Error (GPU extract bits): lwe_dimension_in should be one of "
|
||||
"512, 1024, 2048",
|
||||
lwe_dimension_before == 512 || lwe_dimension_before == 1024 ||
|
||||
lwe_dimension_before == 2048));
|
||||
lwe_dimension_in == 512 || lwe_dimension_in == 1024 ||
|
||||
lwe_dimension_in == 2048));
|
||||
// The number of samples should be lower than the number of streaming
|
||||
// multiprocessors divided by (4 * (k + 1) * l) (the factor 4 being related
|
||||
// to the occupancy of 50%). The only supported value for k is 1, so
|
||||
@@ -137,63 +136,68 @@ void cuda_extract_bits_32(void *v_stream, void *list_lwe_out, void *lwe_in,
|
||||
assert(("Error (GPU extract bits): the number of input LWEs must be lower or "
|
||||
"equal to the "
|
||||
"number of streaming multiprocessors on the device divided by 8 * "
|
||||
"l_gadget_bsk",
|
||||
number_of_samples <= number_of_sm / 4. / 2. / l_gadget_bsk));
|
||||
"level_count_bsk",
|
||||
number_of_samples <= number_of_sm / 4. / 2. / level_count_bsk));
|
||||
|
||||
switch (lwe_dimension_before) {
|
||||
switch (lwe_dimension_in) {
|
||||
case 512:
|
||||
host_extract_bits<uint32_t, Degree<512>>(
|
||||
v_stream, (uint32_t *)list_lwe_out, (uint32_t *)lwe_in,
|
||||
(uint32_t *)lwe_in_buffer, (uint32_t *)lwe_in_shifted_buffer,
|
||||
(uint32_t *)lwe_out_ks_buffer, (uint32_t *)lwe_out_pbs_buffer,
|
||||
(uint32_t *)lut_pbs, (uint32_t *)lut_vector_indexes, (uint32_t *)ksk,
|
||||
(double2 *)fourier_bsk, number_of_bits, delta_log, lwe_dimension_before,
|
||||
lwe_dimension_after, base_log_bsk, l_gadget_bsk, base_log_ksk,
|
||||
l_gadget_ksk, number_of_samples);
|
||||
v_stream, (uint32_t *)list_lwe_array_out, (uint32_t *)lwe_array_in,
|
||||
(uint32_t *)lwe_array_in_buffer,
|
||||
(uint32_t *)lwe_array_in_shifted_buffer,
|
||||
(uint32_t *)lwe_array_out_ks_buffer,
|
||||
(uint32_t *)lwe_array_out_pbs_buffer, (uint32_t *)lut_pbs,
|
||||
(uint32_t *)lut_vector_indexes, (uint32_t *)ksk, (double2 *)fourier_bsk,
|
||||
number_of_bits, delta_log, lwe_dimension_in, lwe_dimension_out,
|
||||
base_log_bsk, level_count_bsk, base_log_ksk, level_count_ksk,
|
||||
number_of_samples);
|
||||
break;
|
||||
case 1024:
|
||||
host_extract_bits<uint32_t, Degree<1024>>(
|
||||
v_stream, (uint32_t *)list_lwe_out, (uint32_t *)lwe_in,
|
||||
(uint32_t *)lwe_in_buffer, (uint32_t *)lwe_in_shifted_buffer,
|
||||
(uint32_t *)lwe_out_ks_buffer, (uint32_t *)lwe_out_pbs_buffer,
|
||||
(uint32_t *)lut_pbs, (uint32_t *)lut_vector_indexes, (uint32_t *)ksk,
|
||||
(double2 *)fourier_bsk, number_of_bits, delta_log, lwe_dimension_before,
|
||||
lwe_dimension_after, base_log_bsk, l_gadget_bsk, base_log_ksk,
|
||||
l_gadget_ksk, number_of_samples);
|
||||
v_stream, (uint32_t *)list_lwe_array_out, (uint32_t *)lwe_array_in,
|
||||
(uint32_t *)lwe_array_in_buffer,
|
||||
(uint32_t *)lwe_array_in_shifted_buffer,
|
||||
(uint32_t *)lwe_array_out_ks_buffer,
|
||||
(uint32_t *)lwe_array_out_pbs_buffer, (uint32_t *)lut_pbs,
|
||||
(uint32_t *)lut_vector_indexes, (uint32_t *)ksk, (double2 *)fourier_bsk,
|
||||
number_of_bits, delta_log, lwe_dimension_in, lwe_dimension_out,
|
||||
base_log_bsk, level_count_bsk, base_log_ksk, level_count_ksk,
|
||||
number_of_samples);
|
||||
break;
|
||||
case 2048:
|
||||
host_extract_bits<uint32_t, Degree<2048>>(
|
||||
v_stream, (uint32_t *)list_lwe_out, (uint32_t *)lwe_in,
|
||||
(uint32_t *)lwe_in_buffer, (uint32_t *)lwe_in_shifted_buffer,
|
||||
(uint32_t *)lwe_out_ks_buffer, (uint32_t *)lwe_out_pbs_buffer,
|
||||
(uint32_t *)lut_pbs, (uint32_t *)lut_vector_indexes, (uint32_t *)ksk,
|
||||
(double2 *)fourier_bsk, number_of_bits, delta_log, lwe_dimension_before,
|
||||
lwe_dimension_after, base_log_bsk, l_gadget_bsk, base_log_ksk,
|
||||
l_gadget_ksk, number_of_samples);
|
||||
v_stream, (uint32_t *)list_lwe_array_out, (uint32_t *)lwe_array_in,
|
||||
(uint32_t *)lwe_array_in_buffer,
|
||||
(uint32_t *)lwe_array_in_shifted_buffer,
|
||||
(uint32_t *)lwe_array_out_ks_buffer,
|
||||
(uint32_t *)lwe_array_out_pbs_buffer, (uint32_t *)lut_pbs,
|
||||
(uint32_t *)lut_vector_indexes, (uint32_t *)ksk, (double2 *)fourier_bsk,
|
||||
number_of_bits, delta_log, lwe_dimension_in, lwe_dimension_out,
|
||||
base_log_bsk, level_count_bsk, base_log_ksk, level_count_ksk,
|
||||
number_of_samples);
|
||||
break;
|
||||
default:
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
void cuda_extract_bits_64(void *v_stream, void *list_lwe_out, void *lwe_in,
|
||||
void *lwe_in_buffer, void *lwe_in_shifted_buffer,
|
||||
void *lwe_out_ks_buffer, void *lwe_out_pbs_buffer,
|
||||
void *lut_pbs, void *lut_vector_indexes, void *ksk,
|
||||
void *fourier_bsk, uint32_t number_of_bits,
|
||||
uint32_t delta_log, uint32_t lwe_dimension_before,
|
||||
uint32_t lwe_dimension_after, uint32_t glwe_dimension,
|
||||
uint32_t base_log_bsk, uint32_t l_gadget_bsk,
|
||||
uint32_t base_log_ksk, uint32_t l_gadget_ksk,
|
||||
uint32_t number_of_samples) {
|
||||
void cuda_extract_bits_64(
|
||||
void *v_stream, void *list_lwe_array_out, void *lwe_array_in,
|
||||
void *lwe_array_in_buffer, void *lwe_array_in_shifted_buffer,
|
||||
void *lwe_array_out_ks_buffer, void *lwe_array_out_pbs_buffer,
|
||||
void *lut_pbs, void *lut_vector_indexes, void *ksk, void *fourier_bsk,
|
||||
uint32_t number_of_bits, uint32_t delta_log, uint32_t lwe_dimension_in,
|
||||
uint32_t lwe_dimension_out, uint32_t glwe_dimension, uint32_t base_log_bsk,
|
||||
uint32_t level_count_bsk, uint32_t base_log_ksk, uint32_t level_count_ksk,
|
||||
uint32_t number_of_samples) {
|
||||
assert(("Error (GPU extract bits): base log should be <= 16",
|
||||
base_log_bsk <= 16));
|
||||
assert(("Error (GPU extract bits): glwe_dimension should be equal to 1",
|
||||
glwe_dimension == 1));
|
||||
assert(("Error (GPU extract bits): lwe_dimension_before should be one of "
|
||||
assert(("Error (GPU extract bits): lwe_dimension_in should be one of "
|
||||
"512, 1024, 2048",
|
||||
lwe_dimension_before == 512 || lwe_dimension_before == 1024 ||
|
||||
lwe_dimension_before == 2048));
|
||||
lwe_dimension_in == 512 || lwe_dimension_in == 1024 ||
|
||||
lwe_dimension_in == 2048));
|
||||
// The number of samples should be lower than the number of streaming
|
||||
// multiprocessors divided by (4 * (k + 1) * l) (the factor 4 being related
|
||||
// to the occupancy of 50%). The only supported value for k is 1, so
|
||||
@@ -203,39 +207,45 @@ void cuda_extract_bits_64(void *v_stream, void *list_lwe_out, void *lwe_in,
|
||||
assert(("Error (GPU extract bits): the number of input LWEs must be lower or "
|
||||
"equal to the "
|
||||
"number of streaming multiprocessors on the device divided by 8 * "
|
||||
"l_gadget_bsk",
|
||||
number_of_samples <= number_of_sm / 4. / 2. / l_gadget_bsk));
|
||||
"level_count_bsk",
|
||||
number_of_samples <= number_of_sm / 4. / 2. / level_count_bsk));
|
||||
|
||||
switch (lwe_dimension_before) {
|
||||
switch (lwe_dimension_in) {
|
||||
case 512:
|
||||
host_extract_bits<uint64_t, Degree<512>>(
|
||||
v_stream, (uint64_t *)list_lwe_out, (uint64_t *)lwe_in,
|
||||
(uint64_t *)lwe_in_buffer, (uint64_t *)lwe_in_shifted_buffer,
|
||||
(uint64_t *)lwe_out_ks_buffer, (uint64_t *)lwe_out_pbs_buffer,
|
||||
(uint64_t *)lut_pbs, (uint32_t *)lut_vector_indexes, (uint64_t *)ksk,
|
||||
(double2 *)fourier_bsk, number_of_bits, delta_log, lwe_dimension_before,
|
||||
lwe_dimension_after, base_log_bsk, l_gadget_bsk, base_log_ksk,
|
||||
l_gadget_ksk, number_of_samples);
|
||||
v_stream, (uint64_t *)list_lwe_array_out, (uint64_t *)lwe_array_in,
|
||||
(uint64_t *)lwe_array_in_buffer,
|
||||
(uint64_t *)lwe_array_in_shifted_buffer,
|
||||
(uint64_t *)lwe_array_out_ks_buffer,
|
||||
(uint64_t *)lwe_array_out_pbs_buffer, (uint64_t *)lut_pbs,
|
||||
(uint32_t *)lut_vector_indexes, (uint64_t *)ksk, (double2 *)fourier_bsk,
|
||||
number_of_bits, delta_log, lwe_dimension_in, lwe_dimension_out,
|
||||
base_log_bsk, level_count_bsk, base_log_ksk, level_count_ksk,
|
||||
number_of_samples);
|
||||
break;
|
||||
case 1024:
|
||||
host_extract_bits<uint64_t, Degree<1024>>(
|
||||
v_stream, (uint64_t *)list_lwe_out, (uint64_t *)lwe_in,
|
||||
(uint64_t *)lwe_in_buffer, (uint64_t *)lwe_in_shifted_buffer,
|
||||
(uint64_t *)lwe_out_ks_buffer, (uint64_t *)lwe_out_pbs_buffer,
|
||||
(uint64_t *)lut_pbs, (uint32_t *)lut_vector_indexes, (uint64_t *)ksk,
|
||||
(double2 *)fourier_bsk, number_of_bits, delta_log, lwe_dimension_before,
|
||||
lwe_dimension_after, base_log_bsk, l_gadget_bsk, base_log_ksk,
|
||||
l_gadget_ksk, number_of_samples);
|
||||
v_stream, (uint64_t *)list_lwe_array_out, (uint64_t *)lwe_array_in,
|
||||
(uint64_t *)lwe_array_in_buffer,
|
||||
(uint64_t *)lwe_array_in_shifted_buffer,
|
||||
(uint64_t *)lwe_array_out_ks_buffer,
|
||||
(uint64_t *)lwe_array_out_pbs_buffer, (uint64_t *)lut_pbs,
|
||||
(uint32_t *)lut_vector_indexes, (uint64_t *)ksk, (double2 *)fourier_bsk,
|
||||
number_of_bits, delta_log, lwe_dimension_in, lwe_dimension_out,
|
||||
base_log_bsk, level_count_bsk, base_log_ksk, level_count_ksk,
|
||||
number_of_samples);
|
||||
break;
|
||||
case 2048:
|
||||
host_extract_bits<uint64_t, Degree<2048>>(
|
||||
v_stream, (uint64_t *)list_lwe_out, (uint64_t *)lwe_in,
|
||||
(uint64_t *)lwe_in_buffer, (uint64_t *)lwe_in_shifted_buffer,
|
||||
(uint64_t *)lwe_out_ks_buffer, (uint64_t *)lwe_out_pbs_buffer,
|
||||
(uint64_t *)lut_pbs, (uint32_t *)lut_vector_indexes, (uint64_t *)ksk,
|
||||
(double2 *)fourier_bsk, number_of_bits, delta_log, lwe_dimension_before,
|
||||
lwe_dimension_after, base_log_bsk, l_gadget_bsk, base_log_ksk,
|
||||
l_gadget_ksk, number_of_samples);
|
||||
v_stream, (uint64_t *)list_lwe_array_out, (uint64_t *)lwe_array_in,
|
||||
(uint64_t *)lwe_array_in_buffer,
|
||||
(uint64_t *)lwe_array_in_shifted_buffer,
|
||||
(uint64_t *)lwe_array_out_ks_buffer,
|
||||
(uint64_t *)lwe_array_out_pbs_buffer, (uint64_t *)lut_pbs,
|
||||
(uint32_t *)lut_vector_indexes, (uint64_t *)ksk, (double2 *)fourier_bsk,
|
||||
number_of_bits, delta_log, lwe_dimension_in, lwe_dimension_out,
|
||||
base_log_bsk, level_count_bsk, base_log_ksk, level_count_ksk,
|
||||
number_of_samples);
|
||||
break;
|
||||
default:
|
||||
break;
|
||||
|
||||
@@ -68,12 +68,12 @@ template <class params> __device__ void ifft_inplace(double2 *data) {
|
||||
* Receives an array of GLWE ciphertexts and two indexes to ciphertexts in this
|
||||
* array, and an array of GGSW ciphertexts with a index to one ciphertext in it.
|
||||
* Compute a CMUX with these operands and writes the output to a particular
|
||||
* index of glwe_out.
|
||||
* index of glwe_array_out.
|
||||
*
|
||||
* This function needs polynomial_size threads per block.
|
||||
*
|
||||
* - glwe_out: An array where the result should be written to.
|
||||
* - glwe_in: An array where the GLWE inputs are stored.
|
||||
* - glwe_array_out: An array where the result should be written to.
|
||||
* - glwe_array_in: An array where the GLWE inputs are stored.
|
||||
* - ggsw_in: An array where the GGSW input is stored. In the fourier domain.
|
||||
* - selected_memory: An array to be used for the accumulators. Can be in the
|
||||
* shared memory or global memory.
|
||||
@@ -84,15 +84,15 @@ template <class params> __device__ void ifft_inplace(double2 *data) {
|
||||
* - glwe_dim: This is k.
|
||||
* - polynomial_size: size of the polynomials. This is N.
|
||||
* - base_log: log base used for the gadget matrix - B = 2^base_log (~8)
|
||||
* - l_gadget: number of decomposition levels in the gadget matrix (~4)
|
||||
* - level_count: number of decomposition levels in the gadget matrix (~4)
|
||||
* - ggsw_idx: The index of the GGSW we will use.
|
||||
*/
|
||||
template <typename Torus, typename STorus, class params>
|
||||
__device__ void cmux(Torus *glwe_out, Torus *glwe_in, double2 *ggsw_in,
|
||||
char *selected_memory, uint32_t output_idx,
|
||||
uint32_t input_idx1, uint32_t input_idx2,
|
||||
uint32_t glwe_dim, uint32_t polynomial_size,
|
||||
uint32_t base_log, uint32_t l_gadget, uint32_t ggsw_idx) {
|
||||
__device__ void
|
||||
cmux(Torus *glwe_array_out, Torus *glwe_array_in, double2 *ggsw_in,
|
||||
char *selected_memory, uint32_t output_idx, uint32_t input_idx1,
|
||||
uint32_t input_idx2, uint32_t glwe_dim, uint32_t polynomial_size,
|
||||
uint32_t base_log, uint32_t level_count, uint32_t ggsw_idx) {
|
||||
|
||||
// Define glwe_sub
|
||||
Torus *glwe_sub_mask = (Torus *)selected_memory;
|
||||
@@ -109,18 +109,18 @@ __device__ void cmux(Torus *glwe_out, Torus *glwe_in, double2 *ggsw_in,
|
||||
double2 *glwe_fft =
|
||||
(double2 *)body_res_fft + (ptrdiff_t)(polynomial_size / 2);
|
||||
|
||||
GadgetMatrix<Torus, params> gadget(base_log, l_gadget);
|
||||
GadgetMatrix<Torus, params> gadget(base_log, level_count);
|
||||
|
||||
/////////////////////////////////////
|
||||
|
||||
// glwe2-glwe1
|
||||
|
||||
// Copy m0 to shared memory to preserve data
|
||||
auto m0_mask = &glwe_in[input_idx1 * (glwe_dim + 1) * polynomial_size];
|
||||
auto m0_mask = &glwe_array_in[input_idx1 * (glwe_dim + 1) * polynomial_size];
|
||||
auto m0_body = m0_mask + polynomial_size;
|
||||
|
||||
// Just gets the pointer for m1 on global memory
|
||||
auto m1_mask = &glwe_in[input_idx2 * (glwe_dim + 1) * polynomial_size];
|
||||
auto m1_mask = &glwe_array_in[input_idx2 * (glwe_dim + 1) * polynomial_size];
|
||||
auto m1_body = m1_mask + polynomial_size;
|
||||
|
||||
// Mask
|
||||
@@ -145,13 +145,11 @@ __device__ void cmux(Torus *glwe_out, Torus *glwe_in, double2 *ggsw_in,
|
||||
// Subtract each glwe operand, decompose the resulting
|
||||
// polynomial coefficients to multiply each decomposed level
|
||||
// with the corresponding part of the LUT
|
||||
for (int decomp_level = 0; decomp_level < l_gadget; decomp_level++) {
|
||||
for (int level = 0; level < level_count; level++) {
|
||||
|
||||
// Decomposition
|
||||
gadget.decompose_one_level(glwe_mask_decomposed, glwe_sub_mask,
|
||||
decomp_level);
|
||||
gadget.decompose_one_level(glwe_body_decomposed, glwe_sub_body,
|
||||
decomp_level);
|
||||
gadget.decompose_one_level(glwe_mask_decomposed, glwe_sub_mask, level);
|
||||
gadget.decompose_one_level(glwe_body_decomposed, glwe_sub_body, level);
|
||||
|
||||
// First, perform the polynomial multiplication for the mask
|
||||
synchronize_threads_in_block();
|
||||
@@ -159,12 +157,10 @@ __device__ void cmux(Torus *glwe_out, Torus *glwe_in, double2 *ggsw_in,
|
||||
|
||||
// External product and accumulate
|
||||
// Get the piece necessary for the multiplication
|
||||
auto mask_fourier =
|
||||
get_ith_mask_kth_block(ggsw_in, ggsw_idx, 0, decomp_level,
|
||||
polynomial_size, glwe_dim, l_gadget);
|
||||
auto body_fourier =
|
||||
get_ith_body_kth_block(ggsw_in, ggsw_idx, 0, decomp_level,
|
||||
polynomial_size, glwe_dim, l_gadget);
|
||||
auto mask_fourier = get_ith_mask_kth_block(
|
||||
ggsw_in, ggsw_idx, 0, level, polynomial_size, glwe_dim, level_count);
|
||||
auto body_fourier = get_ith_body_kth_block(
|
||||
ggsw_in, ggsw_idx, 0, level, polynomial_size, glwe_dim, level_count);
|
||||
|
||||
synchronize_threads_in_block();
|
||||
|
||||
@@ -182,10 +178,10 @@ __device__ void cmux(Torus *glwe_out, Torus *glwe_in, double2 *ggsw_in,
|
||||
|
||||
// External product and accumulate
|
||||
// Get the piece necessary for the multiplication
|
||||
mask_fourier = get_ith_mask_kth_block(ggsw_in, ggsw_idx, 1, decomp_level,
|
||||
polynomial_size, glwe_dim, l_gadget);
|
||||
body_fourier = get_ith_body_kth_block(ggsw_in, ggsw_idx, 1, decomp_level,
|
||||
polynomial_size, glwe_dim, l_gadget);
|
||||
mask_fourier = get_ith_mask_kth_block(
|
||||
ggsw_in, ggsw_idx, 1, level, polynomial_size, glwe_dim, level_count);
|
||||
body_fourier = get_ith_body_kth_block(
|
||||
ggsw_in, ggsw_idx, 1, level, polynomial_size, glwe_dim, level_count);
|
||||
|
||||
synchronize_threads_in_block();
|
||||
|
||||
@@ -202,7 +198,8 @@ __device__ void cmux(Torus *glwe_out, Torus *glwe_in, double2 *ggsw_in,
|
||||
synchronize_threads_in_block();
|
||||
|
||||
// Write the output
|
||||
Torus *mb_mask = &glwe_out[output_idx * (glwe_dim + 1) * polynomial_size];
|
||||
Torus *mb_mask =
|
||||
&glwe_array_out[output_idx * (glwe_dim + 1) * polynomial_size];
|
||||
Torus *mb_body = mb_mask + polynomial_size;
|
||||
|
||||
int tid = threadIdx.x;
|
||||
@@ -221,8 +218,8 @@ __device__ void cmux(Torus *glwe_out, Torus *glwe_in, double2 *ggsw_in,
|
||||
* ciphertext. The GLWE ciphertexts are picked two-by-two in sequence. Each
|
||||
* thread block computes a single CMUX.
|
||||
*
|
||||
* - glwe_out: An array where the result should be written to.
|
||||
* - glwe_in: An array where the GLWE inputs are stored.
|
||||
* - glwe_array_out: An array where the result should be written to.
|
||||
* - glwe_array_in: An array where the GLWE inputs are stored.
|
||||
* - ggsw_in: An array where the GGSW input is stored. In the fourier domain.
|
||||
* - device_mem: An pointer for the global memory in case the shared memory is
|
||||
* not big enough to store the accumulators.
|
||||
@@ -231,15 +228,15 @@ __device__ void cmux(Torus *glwe_out, Torus *glwe_in, double2 *ggsw_in,
|
||||
* - glwe_dim: This is k.
|
||||
* - polynomial_size: size of the polynomials. This is N.
|
||||
* - base_log: log base used for the gadget matrix - B = 2^base_log (~8)
|
||||
* - l_gadget: number of decomposition levels in the gadget matrix (~4)
|
||||
* - level_count: number of decomposition levels in the gadget matrix (~4)
|
||||
* - ggsw_idx: The index of the GGSW we will use.
|
||||
*/
|
||||
template <typename Torus, typename STorus, class params, sharedMemDegree SMD>
|
||||
__global__ void
|
||||
device_batch_cmux(Torus *glwe_out, Torus *glwe_in, double2 *ggsw_in,
|
||||
device_batch_cmux(Torus *glwe_array_out, Torus *glwe_array_in, double2 *ggsw_in,
|
||||
char *device_mem, size_t device_memory_size_per_block,
|
||||
uint32_t glwe_dim, uint32_t polynomial_size,
|
||||
uint32_t base_log, uint32_t l_gadget, uint32_t ggsw_idx) {
|
||||
uint32_t base_log, uint32_t level_count, uint32_t ggsw_idx) {
|
||||
|
||||
int cmux_idx = blockIdx.x;
|
||||
int output_idx = cmux_idx;
|
||||
@@ -255,9 +252,10 @@ device_batch_cmux(Torus *glwe_out, Torus *glwe_in, double2 *ggsw_in,
|
||||
else
|
||||
selected_memory = &device_mem[blockIdx.x * device_memory_size_per_block];
|
||||
|
||||
cmux<Torus, STorus, params>(glwe_out, glwe_in, ggsw_in, selected_memory,
|
||||
output_idx, input_idx1, input_idx2, glwe_dim,
|
||||
polynomial_size, base_log, l_gadget, ggsw_idx);
|
||||
cmux<Torus, STorus, params>(glwe_array_out, glwe_array_in, ggsw_in,
|
||||
selected_memory, output_idx, input_idx1,
|
||||
input_idx2, glwe_dim, polynomial_size, base_log,
|
||||
level_count, ggsw_idx);
|
||||
}
|
||||
/*
|
||||
* This kernel executes the CMUX tree used by the hybrid packing of the WoPBS.
|
||||
@@ -265,20 +263,21 @@ device_batch_cmux(Torus *glwe_out, Torus *glwe_in, double2 *ggsw_in,
|
||||
* Uses shared memory for intermediate results
|
||||
*
|
||||
* - v_stream: The CUDA stream that should be used.
|
||||
* - glwe_out: A device array for the output GLWE ciphertext.
|
||||
* - glwe_array_out: A device array for the output GLWE ciphertext.
|
||||
* - ggsw_in: A device array for the GGSW ciphertexts used in each layer.
|
||||
* - lut_vector: A device array for the GLWE ciphertexts used in the first
|
||||
* layer.
|
||||
* - polynomial_size: size of the polynomials. This is N.
|
||||
* - base_log: log base used for the gadget matrix - B = 2^base_log (~8)
|
||||
* - l_gadget: number of decomposition levels in the gadget matrix (~4)
|
||||
* - level_count: number of decomposition levels in the gadget matrix (~4)
|
||||
* - r: Number of layers in the tree.
|
||||
*/
|
||||
template <typename Torus, typename STorus, class params>
|
||||
void host_cmux_tree(void *v_stream, Torus *glwe_out, Torus *ggsw_in,
|
||||
void host_cmux_tree(void *v_stream, Torus *glwe_array_out, Torus *ggsw_in,
|
||||
Torus *lut_vector, uint32_t glwe_dimension,
|
||||
uint32_t polynomial_size, uint32_t base_log,
|
||||
uint32_t l_gadget, uint32_t r, uint32_t max_shared_memory) {
|
||||
uint32_t level_count, uint32_t r,
|
||||
uint32_t max_shared_memory) {
|
||||
|
||||
auto stream = static_cast<cudaStream_t *>(v_stream);
|
||||
int num_lut = (1 << r);
|
||||
@@ -299,7 +298,7 @@ void host_cmux_tree(void *v_stream, Torus *glwe_out, Torus *ggsw_in,
|
||||
//////////////////////
|
||||
double2 *d_ggsw_fft_in;
|
||||
int ggsw_size = r * polynomial_size * (glwe_dimension + 1) *
|
||||
(glwe_dimension + 1) * l_gadget;
|
||||
(glwe_dimension + 1) * level_count;
|
||||
|
||||
#if (CUDART_VERSION < 11020)
|
||||
checkCudaErrors(
|
||||
@@ -311,7 +310,7 @@ void host_cmux_tree(void *v_stream, Torus *glwe_out, Torus *ggsw_in,
|
||||
|
||||
batch_fft_ggsw_vector<Torus, STorus, params>(v_stream, d_ggsw_fft_in, ggsw_in,
|
||||
r, glwe_dimension,
|
||||
polynomial_size, l_gadget);
|
||||
polynomial_size, level_count);
|
||||
|
||||
//////////////////////
|
||||
|
||||
@@ -368,7 +367,7 @@ void host_cmux_tree(void *v_stream, Torus *glwe_out, Torus *ggsw_in,
|
||||
<<<grid, thds, memory_needed_per_block, *stream>>>(
|
||||
output, input, d_ggsw_fft_in, d_mem, memory_needed_per_block,
|
||||
glwe_dimension, // k
|
||||
polynomial_size, base_log, l_gadget,
|
||||
polynomial_size, base_log, level_count,
|
||||
layer_idx // r
|
||||
);
|
||||
else
|
||||
@@ -376,17 +375,18 @@ void host_cmux_tree(void *v_stream, Torus *glwe_out, Torus *ggsw_in,
|
||||
<<<grid, thds, memory_needed_per_block, *stream>>>(
|
||||
output, input, d_ggsw_fft_in, d_mem, memory_needed_per_block,
|
||||
glwe_dimension, // k
|
||||
polynomial_size, base_log, l_gadget,
|
||||
polynomial_size, base_log, level_count,
|
||||
layer_idx // r
|
||||
);
|
||||
}
|
||||
|
||||
checkCudaErrors(cudaMemcpyAsync(
|
||||
glwe_out, output, (glwe_dimension + 1) * polynomial_size * sizeof(Torus),
|
||||
cudaMemcpyDeviceToDevice, *stream));
|
||||
checkCudaErrors(
|
||||
cudaMemcpyAsync(glwe_array_out, output,
|
||||
(glwe_dimension + 1) * polynomial_size * sizeof(Torus),
|
||||
cudaMemcpyDeviceToDevice, *stream));
|
||||
|
||||
// We only need synchronization to assert that data is in glwe_out before
|
||||
// returning. Memory release can be added to the stream and processed
|
||||
// We only need synchronization to assert that data is in glwe_array_out
|
||||
// before returning. Memory release can be added to the stream and processed
|
||||
// later.
|
||||
checkCudaErrors(cudaStreamSynchronize(*stream));
|
||||
|
||||
@@ -492,36 +492,38 @@ __global__ void fill_lut_body_for_current_bit(Torus *lut, Torus value) {
|
||||
// instead of alpha= 1ll << (ciphertext_n_bits - delta_log - bit_idx - 1)
|
||||
template <typename Torus, class params>
|
||||
__global__ void add_sub_and_mul_lwe(Torus *shifted_lwe, Torus *state_lwe,
|
||||
Torus *pbs_lwe_out, Torus add_value,
|
||||
Torus *pbs_lwe_array_out, Torus add_value,
|
||||
Torus mul_value) {
|
||||
size_t tid = threadIdx.x;
|
||||
size_t blockId = blockIdx.x;
|
||||
auto cur_shifted_lwe = &shifted_lwe[blockId * (params::degree + 1)];
|
||||
auto cur_state_lwe = &state_lwe[blockId * (params::degree + 1)];
|
||||
auto cur_pbs_lwe_out = &pbs_lwe_out[blockId * (params::degree + 1)];
|
||||
auto cur_pbs_lwe_array_out =
|
||||
&pbs_lwe_array_out[blockId * (params::degree + 1)];
|
||||
#pragma unroll
|
||||
for (int i = 0; i < params::opt; i++) {
|
||||
cur_shifted_lwe[tid] = cur_state_lwe[tid] -= cur_pbs_lwe_out[tid];
|
||||
cur_shifted_lwe[tid] = cur_state_lwe[tid] -= cur_pbs_lwe_array_out[tid];
|
||||
cur_shifted_lwe[tid] *= mul_value;
|
||||
tid += params::degree / params::opt;
|
||||
}
|
||||
|
||||
if (threadIdx.x == params::degree / params::opt - 1) {
|
||||
cur_shifted_lwe[params::degree] = cur_state_lwe[params::degree] -=
|
||||
(cur_pbs_lwe_out[params::degree] + add_value);
|
||||
(cur_pbs_lwe_array_out[params::degree] + add_value);
|
||||
cur_shifted_lwe[params::degree] *= mul_value;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename Torus, class params>
|
||||
__host__ void host_extract_bits(
|
||||
void *v_stream, Torus *list_lwe_out, Torus *lwe_in, Torus *lwe_in_buffer,
|
||||
Torus *lwe_in_shifted_buffer, Torus *lwe_out_ks_buffer,
|
||||
Torus *lwe_out_pbs_buffer, Torus *lut_pbs, uint32_t *lut_vector_indexes,
|
||||
Torus *ksk, double2 *fourier_bsk, uint32_t number_of_bits,
|
||||
uint32_t delta_log, uint32_t lwe_dimension_before,
|
||||
uint32_t lwe_dimension_after, uint32_t base_log_bsk, uint32_t l_gadget_bsk,
|
||||
uint32_t base_log_ksk, uint32_t l_gadget_ksk, uint32_t number_of_samples) {
|
||||
void *v_stream, Torus *list_lwe_array_out, Torus *lwe_array_in,
|
||||
Torus *lwe_array_in_buffer, Torus *lwe_array_in_shifted_buffer,
|
||||
Torus *lwe_array_out_ks_buffer, Torus *lwe_array_out_pbs_buffer,
|
||||
Torus *lut_pbs, uint32_t *lut_vector_indexes, Torus *ksk,
|
||||
double2 *fourier_bsk, uint32_t number_of_bits, uint32_t delta_log,
|
||||
uint32_t lwe_dimension_in, uint32_t lwe_dimension_out,
|
||||
uint32_t base_log_bsk, uint32_t level_count_bsk, uint32_t base_log_ksk,
|
||||
uint32_t level_count_ksk, uint32_t number_of_samples) {
|
||||
|
||||
auto stream = static_cast<cudaStream_t *>(v_stream);
|
||||
uint32_t ciphertext_n_bits = sizeof(Torus) * 8;
|
||||
@@ -530,38 +532,38 @@ __host__ void host_extract_bits(
|
||||
int threads = params::degree / params::opt;
|
||||
|
||||
copy_and_shift_lwe<Torus, params><<<blocks, threads, 0, *stream>>>(
|
||||
lwe_in_buffer, lwe_in_shifted_buffer, lwe_in,
|
||||
lwe_array_in_buffer, lwe_array_in_shifted_buffer, lwe_array_in,
|
||||
1ll << (ciphertext_n_bits - delta_log - 1));
|
||||
|
||||
for (int bit_idx = 0; bit_idx < number_of_bits; bit_idx++) {
|
||||
cuda_keyswitch_lwe_ciphertext_vector(
|
||||
v_stream, lwe_out_ks_buffer, lwe_in_shifted_buffer, ksk,
|
||||
lwe_dimension_before, lwe_dimension_after, base_log_ksk, l_gadget_ksk,
|
||||
1);
|
||||
v_stream, lwe_array_out_ks_buffer, lwe_array_in_shifted_buffer, ksk,
|
||||
lwe_dimension_in, lwe_dimension_out, base_log_ksk, level_count_ksk, 1);
|
||||
|
||||
copy_small_lwe<<<1, 256, 0, *stream>>>(
|
||||
list_lwe_out, lwe_out_ks_buffer, lwe_dimension_after + 1,
|
||||
list_lwe_array_out, lwe_array_out_ks_buffer, lwe_dimension_out + 1,
|
||||
number_of_bits, number_of_bits - bit_idx - 1);
|
||||
|
||||
if (bit_idx == number_of_bits - 1) {
|
||||
break;
|
||||
}
|
||||
|
||||
add_to_body<Torus><<<1, 1, 0, *stream>>>(
|
||||
lwe_out_ks_buffer, lwe_dimension_after, 1ll << (ciphertext_n_bits - 2));
|
||||
add_to_body<Torus><<<1, 1, 0, *stream>>>(lwe_array_out_ks_buffer,
|
||||
lwe_dimension_out,
|
||||
1ll << (ciphertext_n_bits - 2));
|
||||
|
||||
fill_lut_body_for_current_bit<Torus, params>
|
||||
<<<blocks, threads, 0, *stream>>>(
|
||||
lut_pbs, 0ll - 1ll << (delta_log - 1 + bit_idx));
|
||||
|
||||
host_bootstrap_low_latency<Torus, params>(
|
||||
v_stream, lwe_out_pbs_buffer, lut_pbs, lut_vector_indexes,
|
||||
lwe_out_ks_buffer, fourier_bsk, lwe_dimension_after,
|
||||
lwe_dimension_before, base_log_bsk, l_gadget_bsk, number_of_samples, 1);
|
||||
v_stream, lwe_array_out_pbs_buffer, lut_pbs, lut_vector_indexes,
|
||||
lwe_array_out_ks_buffer, fourier_bsk, lwe_dimension_out,
|
||||
lwe_dimension_in, base_log_bsk, level_count_bsk, number_of_samples, 1);
|
||||
|
||||
add_sub_and_mul_lwe<Torus, params><<<1, threads, 0, *stream>>>(
|
||||
lwe_in_shifted_buffer, lwe_in_buffer, lwe_out_pbs_buffer,
|
||||
1ll << (delta_log - 1 + bit_idx),
|
||||
lwe_array_in_shifted_buffer, lwe_array_in_buffer,
|
||||
lwe_array_out_pbs_buffer, 1ll << (delta_log - 1 + bit_idx),
|
||||
1ll << (ciphertext_n_bits - delta_log - bit_idx - 2));
|
||||
}
|
||||
}
|
||||
|
||||
@@ -9,16 +9,17 @@
|
||||
|
||||
__device__ inline int get_start_ith_ggsw(int i, uint32_t polynomial_size,
|
||||
int glwe_dimension,
|
||||
uint32_t l_gadget) {
|
||||
uint32_t level_count) {
|
||||
return i * polynomial_size / 2 * (glwe_dimension + 1) * (glwe_dimension + 1) *
|
||||
l_gadget;
|
||||
level_count;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__device__ T *get_ith_mask_kth_block(T *ptr, int i, int k, int level,
|
||||
uint32_t polynomial_size,
|
||||
int glwe_dimension, uint32_t l_gadget) {
|
||||
return &ptr[get_start_ith_ggsw(i, polynomial_size, glwe_dimension, l_gadget) +
|
||||
int glwe_dimension, uint32_t level_count) {
|
||||
return &ptr[get_start_ith_ggsw(i, polynomial_size, glwe_dimension,
|
||||
level_count) +
|
||||
level * polynomial_size / 2 * (glwe_dimension + 1) *
|
||||
(glwe_dimension + 1) +
|
||||
k * polynomial_size / 2 * (glwe_dimension + 1)];
|
||||
@@ -27,8 +28,9 @@ __device__ T *get_ith_mask_kth_block(T *ptr, int i, int k, int level,
|
||||
template <typename T>
|
||||
__device__ T *get_ith_body_kth_block(T *ptr, int i, int k, int level,
|
||||
uint32_t polynomial_size,
|
||||
int glwe_dimension, uint32_t l_gadget) {
|
||||
return &ptr[get_start_ith_ggsw(i, polynomial_size, glwe_dimension, l_gadget) +
|
||||
int glwe_dimension, uint32_t level_count) {
|
||||
return &ptr[get_start_ith_ggsw(i, polynomial_size, glwe_dimension,
|
||||
level_count) +
|
||||
level * polynomial_size / 2 * (glwe_dimension + 1) *
|
||||
(glwe_dimension + 1) +
|
||||
k * polynomial_size / 2 * (glwe_dimension + 1) +
|
||||
@@ -69,14 +71,14 @@ void cuda_initialize_twiddles(uint32_t polynomial_size, uint32_t gpu_index) {
|
||||
template <typename T, typename ST>
|
||||
void cuda_convert_lwe_bootstrap_key(double2 *dest, ST *src, void *v_stream,
|
||||
uint32_t gpu_index, uint32_t input_lwe_dim,
|
||||
uint32_t glwe_dim, uint32_t l_gadget,
|
||||
uint32_t glwe_dim, uint32_t level_count,
|
||||
uint32_t polynomial_size) {
|
||||
|
||||
cudaSetDevice(gpu_index);
|
||||
int shared_memory_size = sizeof(double) * polynomial_size;
|
||||
|
||||
int total_polynomials =
|
||||
input_lwe_dim * (glwe_dim + 1) * (glwe_dim + 1) * l_gadget;
|
||||
input_lwe_dim * (glwe_dim + 1) * (glwe_dim + 1) * level_count;
|
||||
|
||||
// Here the buffer size is the size of double2 times the number of polynomials
|
||||
// times the polynomial size over 2 because the polynomials are compressed
|
||||
@@ -142,21 +144,21 @@ void cuda_convert_lwe_bootstrap_key(double2 *dest, ST *src, void *v_stream,
|
||||
void cuda_convert_lwe_bootstrap_key_32(void *dest, void *src, void *v_stream,
|
||||
uint32_t gpu_index,
|
||||
uint32_t input_lwe_dim,
|
||||
uint32_t glwe_dim, uint32_t l_gadget,
|
||||
uint32_t glwe_dim, uint32_t level_count,
|
||||
uint32_t polynomial_size) {
|
||||
cuda_convert_lwe_bootstrap_key<uint32_t, int32_t>(
|
||||
(double2 *)dest, (int32_t *)src, v_stream, gpu_index, input_lwe_dim,
|
||||
glwe_dim, l_gadget, polynomial_size);
|
||||
glwe_dim, level_count, polynomial_size);
|
||||
}
|
||||
|
||||
void cuda_convert_lwe_bootstrap_key_64(void *dest, void *src, void *v_stream,
|
||||
uint32_t gpu_index,
|
||||
uint32_t input_lwe_dim,
|
||||
uint32_t glwe_dim, uint32_t l_gadget,
|
||||
uint32_t glwe_dim, uint32_t level_count,
|
||||
uint32_t polynomial_size) {
|
||||
cuda_convert_lwe_bootstrap_key<uint64_t, int64_t>(
|
||||
(double2 *)dest, (int64_t *)src, v_stream, gpu_index, input_lwe_dim,
|
||||
glwe_dim, l_gadget, polynomial_size);
|
||||
glwe_dim, level_count, polynomial_size);
|
||||
}
|
||||
|
||||
// We need these lines so the compiler knows how to specialize these functions
|
||||
@@ -164,31 +166,31 @@ template __device__ uint64_t *get_ith_mask_kth_block(uint64_t *ptr, int i,
|
||||
int k, int level,
|
||||
uint32_t polynomial_size,
|
||||
int glwe_dimension,
|
||||
uint32_t l_gadget);
|
||||
uint32_t level_count);
|
||||
template __device__ uint32_t *get_ith_mask_kth_block(uint32_t *ptr, int i,
|
||||
int k, int level,
|
||||
uint32_t polynomial_size,
|
||||
int glwe_dimension,
|
||||
uint32_t l_gadget);
|
||||
uint32_t level_count);
|
||||
template __device__ double2 *get_ith_mask_kth_block(double2 *ptr, int i, int k,
|
||||
int level,
|
||||
uint32_t polynomial_size,
|
||||
int glwe_dimension,
|
||||
uint32_t l_gadget);
|
||||
uint32_t level_count);
|
||||
template __device__ uint64_t *get_ith_body_kth_block(uint64_t *ptr, int i,
|
||||
int k, int level,
|
||||
uint32_t polynomial_size,
|
||||
int glwe_dimension,
|
||||
uint32_t l_gadget);
|
||||
uint32_t level_count);
|
||||
template __device__ uint32_t *get_ith_body_kth_block(uint32_t *ptr, int i,
|
||||
int k, int level,
|
||||
uint32_t polynomial_size,
|
||||
int glwe_dimension,
|
||||
uint32_t l_gadget);
|
||||
uint32_t level_count);
|
||||
template __device__ double2 *get_ith_body_kth_block(double2 *ptr, int i, int k,
|
||||
int level,
|
||||
uint32_t polynomial_size,
|
||||
int glwe_dimension,
|
||||
uint32_t l_gadget);
|
||||
uint32_t level_count);
|
||||
|
||||
#endif // CNCRT_BSK_H
|
||||
|
||||
@@ -7,20 +7,20 @@
|
||||
#pragma once
|
||||
template <typename T, class params> class GadgetMatrix {
|
||||
private:
|
||||
uint32_t l_gadget;
|
||||
uint32_t level_count;
|
||||
uint32_t base_log;
|
||||
uint32_t mask;
|
||||
uint32_t halfbg;
|
||||
T offset;
|
||||
|
||||
public:
|
||||
__device__ GadgetMatrix(uint32_t base_log, uint32_t l_gadget)
|
||||
: base_log(base_log), l_gadget(l_gadget) {
|
||||
__device__ GadgetMatrix(uint32_t base_log, uint32_t level_count)
|
||||
: base_log(base_log), level_count(level_count) {
|
||||
uint32_t bg = 1 << base_log;
|
||||
this->halfbg = bg / 2;
|
||||
this->mask = bg - 1;
|
||||
T temp = 0;
|
||||
for (int i = 0; i < this->l_gadget; i++) {
|
||||
for (int i = 0; i < this->level_count; i++) {
|
||||
temp += 1ULL << (sizeof(T) * 8 - (i + 1) * this->base_log);
|
||||
}
|
||||
this->offset = temp * this->halfbg;
|
||||
@@ -62,20 +62,20 @@ public:
|
||||
|
||||
template <typename T> class GadgetMatrixSingle {
|
||||
private:
|
||||
uint32_t l_gadget;
|
||||
uint32_t level_count;
|
||||
uint32_t base_log;
|
||||
uint32_t mask;
|
||||
uint32_t halfbg;
|
||||
T offset;
|
||||
|
||||
public:
|
||||
__device__ GadgetMatrixSingle(uint32_t base_log, uint32_t l_gadget)
|
||||
: base_log(base_log), l_gadget(l_gadget) {
|
||||
__device__ GadgetMatrixSingle(uint32_t base_log, uint32_t level_count)
|
||||
: base_log(base_log), level_count(level_count) {
|
||||
uint32_t bg = 1 << base_log;
|
||||
this->halfbg = bg / 2;
|
||||
this->mask = bg - 1;
|
||||
T temp = 0;
|
||||
for (int i = 0; i < this->l_gadget; i++) {
|
||||
for (int i = 0; i < this->level_count; i++) {
|
||||
temp += 1ULL << (sizeof(T) * 8 - (i + 1) * this->base_log);
|
||||
}
|
||||
this->offset = temp * this->halfbg;
|
||||
|
||||
@@ -44,13 +44,13 @@ __global__ void batch_fft_ggsw_vectors(double2 *dest, T *src) {
|
||||
template <typename T, typename ST, class params>
|
||||
void batch_fft_ggsw_vector(void *v_stream, double2 *dest, T *src, uint32_t r,
|
||||
uint32_t glwe_dim, uint32_t polynomial_size,
|
||||
uint32_t l_gadget) {
|
||||
uint32_t level_count) {
|
||||
|
||||
auto stream = static_cast<cudaStream_t *>(v_stream);
|
||||
|
||||
int shared_memory_size = sizeof(double) * polynomial_size;
|
||||
|
||||
int total_polynomials = r * (glwe_dim + 1) * (glwe_dim + 1) * l_gadget;
|
||||
int total_polynomials = r * (glwe_dim + 1) * (glwe_dim + 1) * level_count;
|
||||
int gridSize = total_polynomials;
|
||||
int blockSize = polynomial_size / params::opt;
|
||||
|
||||
|
||||
@@ -23,8 +23,8 @@ __device__ inline Torus typecast_double_to_torus(double x) {
|
||||
|
||||
template <typename T>
|
||||
__device__ inline T round_to_closest_multiple(T x, uint32_t base_log,
|
||||
uint32_t l_gadget) {
|
||||
T shift = sizeof(T) * 8 - l_gadget * base_log;
|
||||
uint32_t level_count) {
|
||||
T shift = sizeof(T) * 8 - level_count * base_log;
|
||||
T mask = 1ll << (shift - 1);
|
||||
T b = (x & mask) >> (shift - 1);
|
||||
T res = x >> shift;
|
||||
|
||||
@@ -6,42 +6,40 @@
|
||||
|
||||
/* Perform keyswitch on a batch of input LWE ciphertexts for 32 bits
|
||||
*
|
||||
* - lwe_out: output batch of num_samples keyswitched ciphertexts c =
|
||||
* - lwe_array_out: output batch of num_samples keyswitched ciphertexts c =
|
||||
* (a0,..an-1,b) where n is the LWE dimension
|
||||
* - lwe_in: input batch of num_samples LWE ciphertexts, containing n
|
||||
* - lwe_array_in: input batch of num_samples LWE ciphertexts, containing n
|
||||
* mask values + 1 body value
|
||||
*
|
||||
* This function calls a wrapper to a device kernel that performs the keyswitch
|
||||
* - num_samples blocks of threads are launched
|
||||
*/
|
||||
void cuda_keyswitch_lwe_ciphertext_vector_32(
|
||||
void *v_stream, void *lwe_out, void *lwe_in, void *ksk,
|
||||
uint32_t lwe_dimension_before, uint32_t lwe_dimension_after,
|
||||
uint32_t base_log, uint32_t l_gadget, uint32_t num_samples) {
|
||||
void *v_stream, void *lwe_array_out, void *lwe_array_in, void *ksk,
|
||||
uint32_t lwe_dimension_in, uint32_t lwe_dimension_out, uint32_t base_log,
|
||||
uint32_t level_count, uint32_t num_samples) {
|
||||
cuda_keyswitch_lwe_ciphertext_vector(
|
||||
v_stream, static_cast<uint32_t *>(lwe_out),
|
||||
static_cast<uint32_t *>(lwe_in), static_cast<uint32_t *>(ksk),
|
||||
lwe_dimension_before, lwe_dimension_after, base_log, l_gadget,
|
||||
num_samples);
|
||||
v_stream, static_cast<uint32_t *>(lwe_array_out),
|
||||
static_cast<uint32_t *>(lwe_array_in), static_cast<uint32_t *>(ksk),
|
||||
lwe_dimension_in, lwe_dimension_out, base_log, level_count, num_samples);
|
||||
}
|
||||
|
||||
/* Perform keyswitch on a batch of input LWE ciphertexts for 64 bits
|
||||
*
|
||||
* - lwe_out: output batch of num_samples keyswitched ciphertexts c =
|
||||
* - lwe_array_out: output batch of num_samples keyswitched ciphertexts c =
|
||||
* (a0,..an-1,b) where n is the LWE dimension
|
||||
* - lwe_in: input batch of num_samples LWE ciphertexts, containing n
|
||||
* - lwe_array_in: input batch of num_samples LWE ciphertexts, containing n
|
||||
* mask values + 1 body value
|
||||
*
|
||||
* This function calls a wrapper to a device kernel that performs the keyswitch
|
||||
* - num_samples blocks of threads are launched
|
||||
*/
|
||||
void cuda_keyswitch_lwe_ciphertext_vector_64(
|
||||
void *v_stream, void *lwe_out, void *lwe_in, void *ksk,
|
||||
uint32_t lwe_dimension_before, uint32_t lwe_dimension_after,
|
||||
uint32_t base_log, uint32_t l_gadget, uint32_t num_samples) {
|
||||
void *v_stream, void *lwe_array_out, void *lwe_array_in, void *ksk,
|
||||
uint32_t lwe_dimension_in, uint32_t lwe_dimension_out, uint32_t base_log,
|
||||
uint32_t level_count, uint32_t num_samples) {
|
||||
cuda_keyswitch_lwe_ciphertext_vector(
|
||||
v_stream, static_cast<uint64_t *>(lwe_out),
|
||||
static_cast<uint64_t *>(lwe_in), static_cast<uint64_t *>(ksk),
|
||||
lwe_dimension_before, lwe_dimension_after, base_log, l_gadget,
|
||||
num_samples);
|
||||
v_stream, static_cast<uint64_t *>(lwe_array_out),
|
||||
static_cast<uint64_t *>(lwe_array_in), static_cast<uint64_t *>(ksk),
|
||||
lwe_dimension_in, lwe_dimension_out, base_log, level_count, num_samples);
|
||||
}
|
||||
|
||||
@@ -9,10 +9,10 @@
|
||||
|
||||
template <typename Torus>
|
||||
__device__ Torus *get_ith_block(Torus *ksk, int i, int level,
|
||||
uint32_t lwe_dimension_after,
|
||||
uint32_t l_gadget) {
|
||||
int pos = i * l_gadget * (lwe_dimension_after + 1) +
|
||||
level * (lwe_dimension_after + 1);
|
||||
uint32_t lwe_dimension_out,
|
||||
uint32_t level_count) {
|
||||
int pos = i * level_count * (lwe_dimension_out + 1) +
|
||||
level * (lwe_dimension_out + 1);
|
||||
Torus *ptr = &ksk[pos];
|
||||
return ptr;
|
||||
}
|
||||
@@ -42,21 +42,22 @@ __device__ Torus decompose_one(Torus &state, Torus mod_b_mask, int base_log) {
|
||||
*
|
||||
*/
|
||||
template <typename Torus>
|
||||
__global__ void keyswitch(Torus *lwe_out, Torus *lwe_in, Torus *ksk,
|
||||
uint32_t lwe_dimension_before,
|
||||
uint32_t lwe_dimension_after, uint32_t base_log,
|
||||
uint32_t l_gadget, int lwe_lower, int lwe_upper,
|
||||
int cutoff) {
|
||||
__global__ void keyswitch(Torus *lwe_array_out, Torus *lwe_array_in, Torus *ksk,
|
||||
uint32_t lwe_dimension_in, uint32_t lwe_dimension_out,
|
||||
uint32_t base_log, uint32_t level_count,
|
||||
int lwe_lower, int lwe_upper, int cutoff) {
|
||||
int tid = threadIdx.x;
|
||||
|
||||
extern __shared__ char sharedmem[];
|
||||
|
||||
Torus *local_lwe_out = (Torus *)sharedmem;
|
||||
Torus *local_lwe_array_out = (Torus *)sharedmem;
|
||||
|
||||
auto block_lwe_in = get_chunk(lwe_in, blockIdx.x, lwe_dimension_before + 1);
|
||||
auto block_lwe_out = get_chunk(lwe_out, blockIdx.x, lwe_dimension_after + 1);
|
||||
auto block_lwe_array_in =
|
||||
get_chunk(lwe_array_in, blockIdx.x, lwe_dimension_in + 1);
|
||||
auto block_lwe_array_out =
|
||||
get_chunk(lwe_array_out, blockIdx.x, lwe_dimension_out + 1);
|
||||
|
||||
auto gadget = GadgetMatrixSingle<Torus>(base_log, l_gadget);
|
||||
auto gadget = GadgetMatrixSingle<Torus>(base_log, level_count);
|
||||
|
||||
int lwe_part_per_thd;
|
||||
if (tid < cutoff) {
|
||||
@@ -68,49 +69,51 @@ __global__ void keyswitch(Torus *lwe_out, Torus *lwe_in, Torus *ksk,
|
||||
|
||||
for (int k = 0; k < lwe_part_per_thd; k++) {
|
||||
int idx = tid + k * blockDim.x;
|
||||
local_lwe_out[idx] = 0;
|
||||
local_lwe_array_out[idx] = 0;
|
||||
}
|
||||
|
||||
if (tid == 0) {
|
||||
local_lwe_out[lwe_dimension_after] = block_lwe_in[lwe_dimension_before];
|
||||
local_lwe_array_out[lwe_dimension_out] =
|
||||
block_lwe_array_in[lwe_dimension_in];
|
||||
}
|
||||
|
||||
for (int i = 0; i < lwe_dimension_before; i++) {
|
||||
for (int i = 0; i < lwe_dimension_in; i++) {
|
||||
|
||||
__syncthreads();
|
||||
|
||||
Torus a_i = round_to_closest_multiple(block_lwe_in[i], base_log, l_gadget);
|
||||
Torus a_i =
|
||||
round_to_closest_multiple(block_lwe_array_in[i], base_log, level_count);
|
||||
|
||||
Torus state = a_i >> (sizeof(Torus) * 8 - base_log * l_gadget);
|
||||
Torus state = a_i >> (sizeof(Torus) * 8 - base_log * level_count);
|
||||
Torus mod_b_mask = (1ll << base_log) - 1ll;
|
||||
|
||||
for (int j = 0; j < l_gadget; j++) {
|
||||
auto ksk_block = get_ith_block(ksk, i, l_gadget - j - 1,
|
||||
lwe_dimension_after, l_gadget);
|
||||
for (int j = 0; j < level_count; j++) {
|
||||
auto ksk_block = get_ith_block(ksk, i, level_count - j - 1,
|
||||
lwe_dimension_out, level_count);
|
||||
Torus decomposed = decompose_one<Torus>(state, mod_b_mask, base_log);
|
||||
for (int k = 0; k < lwe_part_per_thd; k++) {
|
||||
int idx = tid + k * blockDim.x;
|
||||
local_lwe_out[idx] -= (Torus)ksk_block[idx] * decomposed;
|
||||
local_lwe_array_out[idx] -= (Torus)ksk_block[idx] * decomposed;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for (int k = 0; k < lwe_part_per_thd; k++) {
|
||||
int idx = tid + k * blockDim.x;
|
||||
block_lwe_out[idx] = local_lwe_out[idx];
|
||||
block_lwe_array_out[idx] = local_lwe_array_out[idx];
|
||||
}
|
||||
}
|
||||
|
||||
/// assume lwe_in in the gpu
|
||||
/// assume lwe_array_in in the gpu
|
||||
template <typename Torus>
|
||||
__host__ void cuda_keyswitch_lwe_ciphertext_vector(
|
||||
void *v_stream, Torus *lwe_out, Torus *lwe_in, Torus *ksk,
|
||||
uint32_t lwe_dimension_before, uint32_t lwe_dimension_after,
|
||||
uint32_t base_log, uint32_t l_gadget, uint32_t num_samples) {
|
||||
void *v_stream, Torus *lwe_array_out, Torus *lwe_array_in, Torus *ksk,
|
||||
uint32_t lwe_dimension_in, uint32_t lwe_dimension_out, uint32_t base_log,
|
||||
uint32_t level_count, uint32_t num_samples) {
|
||||
|
||||
constexpr int ideal_threads = 128;
|
||||
|
||||
int lwe_dim = lwe_dimension_after + 1;
|
||||
int lwe_dim = lwe_dimension_out + 1;
|
||||
int lwe_lower, lwe_upper, cutoff;
|
||||
if (lwe_dim % ideal_threads == 0) {
|
||||
lwe_lower = lwe_dim / ideal_threads;
|
||||
@@ -124,11 +127,11 @@ __host__ void cuda_keyswitch_lwe_ciphertext_vector(
|
||||
lwe_upper = (int)ceil((double)lwe_dim / (double)ideal_threads);
|
||||
}
|
||||
|
||||
int lwe_size_after = (lwe_dimension_after + 1) * num_samples;
|
||||
int lwe_size_after = (lwe_dimension_out + 1) * num_samples;
|
||||
|
||||
int shared_mem = sizeof(Torus) * (lwe_dimension_after + 1);
|
||||
int shared_mem = sizeof(Torus) * (lwe_dimension_out + 1);
|
||||
|
||||
cudaMemset(lwe_out, 0, sizeof(Torus) * lwe_size_after);
|
||||
cudaMemset(lwe_array_out, 0, sizeof(Torus) * lwe_size_after);
|
||||
|
||||
dim3 grid(num_samples, 1, 1);
|
||||
dim3 threads(ideal_threads, 1, 1);
|
||||
@@ -138,8 +141,8 @@ __host__ void cuda_keyswitch_lwe_ciphertext_vector(
|
||||
|
||||
auto stream = static_cast<cudaStream_t *>(v_stream);
|
||||
keyswitch<<<grid, threads, shared_mem, *stream>>>(
|
||||
lwe_out, lwe_in, ksk, lwe_dimension_before, lwe_dimension_after, base_log,
|
||||
l_gadget, lwe_lower, lwe_upper, cutoff);
|
||||
lwe_array_out, lwe_array_in, ksk, lwe_dimension_in, lwe_dimension_out,
|
||||
base_log, level_count, lwe_lower, lwe_upper, cutoff);
|
||||
|
||||
cudaStreamSynchronize(*stream);
|
||||
}
|
||||
|
||||
@@ -145,12 +145,12 @@ multiply_by_monomial_negacyclic_and_sub_polynomial(T *acc, T *result_acc,
|
||||
*/
|
||||
template <typename T, int elems_per_thread, int block_size>
|
||||
__device__ void round_to_closest_multiple_inplace(T *rotated_acc, int base_log,
|
||||
int l_gadget) {
|
||||
int level_count) {
|
||||
int tid = threadIdx.x;
|
||||
for (int i = 0; i < elems_per_thread; i++) {
|
||||
|
||||
T x_acc = rotated_acc[tid];
|
||||
T shift = sizeof(T) * 8 - l_gadget * base_log;
|
||||
T shift = sizeof(T) * 8 - level_count * base_log;
|
||||
T mask = 1ll << (shift - 1);
|
||||
T b_acc = (x_acc & mask) >> (shift - 1);
|
||||
T res_acc = x_acc >> shift;
|
||||
@@ -191,13 +191,13 @@ __device__ void add_to_torus(double2 *m_values, Torus *result) {
|
||||
}
|
||||
|
||||
template <typename Torus, class params>
|
||||
__device__ void sample_extract_body(Torus *lwe_out, Torus *accumulator) {
|
||||
__device__ void sample_extract_body(Torus *lwe_array_out, Torus *accumulator) {
|
||||
// Set first coefficient of the accumulator as the body of the LWE sample
|
||||
lwe_out[params::degree] = accumulator[0];
|
||||
lwe_array_out[params::degree] = accumulator[0];
|
||||
}
|
||||
|
||||
template <typename Torus, class params>
|
||||
__device__ void sample_extract_mask(Torus *lwe_out, Torus *accumulator) {
|
||||
__device__ void sample_extract_mask(Torus *lwe_array_out, Torus *accumulator) {
|
||||
// Set ACC = -ACC
|
||||
// accumulator.negate_inplace();
|
||||
|
||||
@@ -257,12 +257,12 @@ __device__ void sample_extract_mask(Torus *lwe_out, Torus *accumulator) {
|
||||
synchronize_threads_in_block();
|
||||
|
||||
// Copy to the mask of the LWE sample
|
||||
// accumulator.copy_into(lwe_out);
|
||||
// accumulator.copy_into(lwe_array_out);
|
||||
|
||||
tid = threadIdx.x;
|
||||
#pragma unroll
|
||||
for (int i = 0; i < params::opt; i++) {
|
||||
lwe_out[tid] = accumulator[tid];
|
||||
lwe_array_out[tid] = accumulator[tid];
|
||||
tid = tid + params::degree / params::opt;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -369,13 +369,13 @@ public:
|
||||
}
|
||||
|
||||
__device__ void round_to_closest_multiple_inplace(uint32_t base_log,
|
||||
uint32_t l_gadget) {
|
||||
uint32_t level_count) {
|
||||
int tid = threadIdx.x;
|
||||
#pragma unroll
|
||||
for (int i = 0; i < params::opt; i++) {
|
||||
|
||||
T x = coefficients[tid];
|
||||
T shift = sizeof(T) * 8 - l_gadget * base_log;
|
||||
T shift = sizeof(T) * 8 - level_count * base_log;
|
||||
T mask = 1ll << (shift - 1);
|
||||
T b = (x & mask) >> (shift - 1);
|
||||
T res = x >> shift;
|
||||
|
||||
Reference in New Issue
Block a user