chore(cuda): rename some variables to match concrete-core notations

- rename l_gadget and stop calling low lat PBS with N too large
- rename trlwe and trgsw
- rename lwe_mask_size into lwe_dimension
- rename lwe_in into lwe_array_in
- rename lwe_out into lwe_array_out
- rename decomp_level into level
- rename lwe_dimension_before/after into lwe_dimension_in/out
This commit is contained in:
Agnes Leroy
2022-10-14 11:51:29 +02:00
committed by Agnès Leroy
parent c22aa3e4e9
commit 4445fcc7f1
16 changed files with 509 additions and 518 deletions

View File

@@ -10,91 +10,90 @@ void cuda_initialize_twiddles(uint32_t polynomial_size, uint32_t gpu_index);
void cuda_convert_lwe_bootstrap_key_32(void *dest, void *src, void *v_stream,
uint32_t gpu_index,
uint32_t input_lwe_dim,
uint32_t glwe_dim, uint32_t l_gadget,
uint32_t glwe_dim, uint32_t level_count,
uint32_t polynomial_size);
void cuda_convert_lwe_bootstrap_key_64(void *dest, void *src, void *v_stream,
uint32_t gpu_index,
uint32_t input_lwe_dim,
uint32_t glwe_dim, uint32_t l_gadget,
uint32_t glwe_dim, uint32_t level_count,
uint32_t polynomial_size);
void cuda_bootstrap_amortized_lwe_ciphertext_vector_32(
void *v_stream, void *lwe_out, void *test_vector, void *test_vector_indexes,
void *lwe_in, void *bootstrapping_key, uint32_t lwe_dimension,
uint32_t glwe_dimension, uint32_t polynomial_size, uint32_t base_log,
uint32_t l_gadget, uint32_t num_samples, uint32_t num_test_vectors,
uint32_t lwe_idx, uint32_t max_shared_memory);
void *v_stream, void *lwe_array_out, void *test_vector,
void *test_vector_indexes, void *lwe_array_in, void *bootstrapping_key,
uint32_t lwe_dimension, uint32_t glwe_dimension, uint32_t polynomial_size,
uint32_t base_log, uint32_t level_count, uint32_t num_samples,
uint32_t num_test_vectors, uint32_t lwe_idx, uint32_t max_shared_memory);
void cuda_bootstrap_amortized_lwe_ciphertext_vector_64(
void *v_stream, void *lwe_out, void *test_vector, void *test_vector_indexes,
void *lwe_in, void *bootstrapping_key, uint32_t lwe_dimension,
uint32_t glwe_dimension, uint32_t polynomial_size, uint32_t base_log,
uint32_t l_gadget, uint32_t num_samples, uint32_t num_test_vectors,
uint32_t lwe_idx, uint32_t max_shared_memory);
void *v_stream, void *lwe_array_out, void *test_vector,
void *test_vector_indexes, void *lwe_array_in, void *bootstrapping_key,
uint32_t lwe_dimension, uint32_t glwe_dimension, uint32_t polynomial_size,
uint32_t base_log, uint32_t level_count, uint32_t num_samples,
uint32_t num_test_vectors, uint32_t lwe_idx, uint32_t max_shared_memory);
void cuda_bootstrap_low_latency_lwe_ciphertext_vector_32(
void *v_stream, void *lwe_out, void *test_vector, void *test_vector_indexes,
void *lwe_in, void *bootstrapping_key, uint32_t lwe_dimension,
uint32_t glwe_dimension, uint32_t polynomial_size, uint32_t base_log,
uint32_t l_gadget, uint32_t num_samples, uint32_t num_test_vectors,
uint32_t lwe_idx, uint32_t max_shared_memory);
void *v_stream, void *lwe_array_out, void *test_vector,
void *test_vector_indexes, void *lwe_array_in, void *bootstrapping_key,
uint32_t lwe_dimension, uint32_t glwe_dimension, uint32_t polynomial_size,
uint32_t base_log, uint32_t level_count, uint32_t num_samples,
uint32_t num_test_vectors, uint32_t lwe_idx, uint32_t max_shared_memory);
void cuda_bootstrap_low_latency_lwe_ciphertext_vector_64(
void *v_stream, void *lwe_out, void *test_vector, void *test_vector_indexes,
void *lwe_in, void *bootstrapping_key, uint32_t lwe_dimension,
uint32_t glwe_dimension, uint32_t polynomial_size, uint32_t base_log,
uint32_t l_gadget, uint32_t num_samples, uint32_t num_test_vectors,
uint32_t lwe_idx, uint32_t max_shared_memory);
void *v_stream, void *lwe_array_out, void *test_vector,
void *test_vector_indexes, void *lwe_array_in, void *bootstrapping_key,
uint32_t lwe_dimension, uint32_t glwe_dimension, uint32_t polynomial_size,
uint32_t base_log, uint32_t level_count, uint32_t num_samples,
uint32_t num_test_vectors, uint32_t lwe_idx, uint32_t max_shared_memory);
void cuda_cmux_tree_32(void *v_stream, void *glwe_out, void *ggsw_in,
void cuda_cmux_tree_32(void *v_stream, void *glwe_array_out, void *ggsw_in,
void *lut_vector, uint32_t glwe_dimension,
uint32_t polynomial_size, uint32_t base_log,
uint32_t l_gadget, uint32_t r,
uint32_t level_count, uint32_t r,
uint32_t max_shared_memory);
void cuda_cmux_tree_64(void *v_stream, void *glwe_out, void *ggsw_in,
void cuda_cmux_tree_64(void *v_stream, void *glwe_array_out, void *ggsw_in,
void *lut_vector, uint32_t glwe_dimension,
uint32_t polynomial_size, uint32_t base_log,
uint32_t l_gadget, uint32_t r,
uint32_t level_count, uint32_t r,
uint32_t max_shared_memory);
void cuda_extract_bits_32(void *v_stream, void *list_lwe_out, void *lwe_in,
void *lwe_in_buffer, void *lwe_in_shifted_buffer,
void *lwe_out_ks_buffer, void *lwe_out_pbs_buffer,
void *lut_pbs, void *lut_vector_indexes, void *ksk,
void *fourier_bsk, uint32_t number_of_bits,
uint32_t delta_log, uint32_t lwe_dimension_before,
uint32_t lwe_dimension_after, uint32_t glwe_dimension,
uint32_t base_log_bsk, uint32_t l_gadget_bsk,
uint32_t base_log_ksk, uint32_t l_gadget_ksk,
uint32_t number_of_samples);
void cuda_extract_bits_32(
void *v_stream, void *list_lwe_array_out, void *lwe_array_in,
void *lwe_array_in_buffer, void *lwe_array_in_shifted_buffer,
void *lwe_array_out_ks_buffer, void *lwe_array_out_pbs_buffer,
void *lut_pbs, void *lut_vector_indexes, void *ksk, void *fourier_bsk,
uint32_t number_of_bits, uint32_t delta_log, uint32_t lwe_dimension_in,
uint32_t lwe_dimension_out, uint32_t glwe_dimension, uint32_t base_log_bsk,
uint32_t level_count_bsk, uint32_t base_log_ksk, uint32_t level_count_ksk,
uint32_t number_of_samples);
void cuda_extract_bits_64(void *v_stream, void *list_lwe_out, void *lwe_in,
void *lwe_in_buffer, void *lwe_in_shifted_buffer,
void *lwe_out_ks_buffer, void *lwe_out_pbs_buffer,
void *lut_pbs, void *lut_vector_indexes, void *ksk,
void *fourier_bsk, uint32_t number_of_bits,
uint32_t delta_log, uint32_t lwe_dimension_before,
uint32_t lwe_dimension_after, uint32_t glwe_dimension,
uint32_t base_log_bsk, uint32_t l_gadget_bsk,
uint32_t base_log_ksk, uint32_t l_gadget_ksk,
uint32_t number_of_samples);
void cuda_extract_bits_64(
void *v_stream, void *list_lwe_array_out, void *lwe_array_in,
void *lwe_array_in_buffer, void *lwe_array_in_shifted_buffer,
void *lwe_array_out_ks_buffer, void *lwe_array_out_pbs_buffer,
void *lut_pbs, void *lut_vector_indexes, void *ksk, void *fourier_bsk,
uint32_t number_of_bits, uint32_t delta_log, uint32_t lwe_dimension_in,
uint32_t lwe_dimension_out, uint32_t glwe_dimension, uint32_t base_log_bsk,
uint32_t level_count_bsk, uint32_t base_log_ksk, uint32_t level_count_ksk,
uint32_t number_of_samples);
};
#ifdef __CUDACC__
__device__ inline int get_start_ith_ggsw(int i, uint32_t polynomial_size,
int glwe_dimension, uint32_t l_gadget);
int glwe_dimension,
uint32_t level_count);
template <typename T>
__device__ T *get_ith_mask_kth_block(T *ptr, int i, int k, int level,
uint32_t polynomial_size,
int glwe_dimension, uint32_t l_gadget);
int glwe_dimension, uint32_t level_count);
template <typename T>
__device__ T *get_ith_body_kth_block(T *ptr, int i, int k, int level,
uint32_t polynomial_size,
int glwe_dimension, uint32_t l_gadget);
int glwe_dimension, uint32_t level_count);
#endif
#endif // CUDA_BOOTSTRAP_H

View File

@@ -6,14 +6,14 @@
extern "C" {
void cuda_keyswitch_lwe_ciphertext_vector_32(
void *v_stream, void *lwe_out, void *lwe_in, void *ksk,
uint32_t lwe_dimension_before, uint32_t lwe_dimension_after,
uint32_t base_log, uint32_t l_gadget, uint32_t num_samples);
void *v_stream, void *lwe_array_out, void *lwe_array_in, void *ksk,
uint32_t lwe_dimension_in, uint32_t lwe_dimension_out, uint32_t base_log,
uint32_t level_count, uint32_t num_samples);
void cuda_keyswitch_lwe_ciphertext_vector_64(
void *v_stream, void *lwe_out, void *lwe_in, void *ksk,
uint32_t lwe_dimension_before, uint32_t lwe_dimension_after,
uint32_t base_log, uint32_t l_gadget, uint32_t num_samples);
void *v_stream, void *lwe_array_out, void *lwe_array_in, void *ksk,
uint32_t lwe_dimension_in, uint32_t lwe_dimension_out, uint32_t base_log,
uint32_t level_count, uint32_t num_samples);
}
#endif // CNCRT_KS_H_

View File

@@ -2,7 +2,7 @@
/* Perform bootstrapping on a batch of input LWE ciphertexts
*
* - lwe_out: output batch of num_samples bootstrapped ciphertexts c =
* - lwe_array_out: output batch of num_samples bootstrapped ciphertexts c =
* (a0,..an-1,b) where n is the LWE dimension
* - lut_vector: should hold as many test vectors of size polynomial_size
* as there are input ciphertexts, but actually holds
@@ -10,7 +10,7 @@
* - lut_vector_indexes: stores the index corresponding to
* which test vector to use for each sample in
* lut_vector
* - lwe_in: input batch of num_samples LWE ciphertexts, containing n
* - lwe_array_in: input batch of num_samples LWE ciphertexts, containing n
* mask values + 1 body value
* - bootstrapping_key: RGSW encryption of the LWE secret key sk1
* under secret key sk2
@@ -30,7 +30,7 @@
* - polynomial_size: size of the test polynomial (test vector) and size of the
* GLWE polynomial (~1024)
* - base_log: log base used for the gadget matrix - B = 2^base_log (~8)
* - l_gadget: number of decomposition levels in the gadget matrix (~4)
* - level_count: number of decomposition levels in the gadget matrix (~4)
* - num_samples: number of encrypted input messages
* - num_lut_vectors: parameter to set the actual number of test vectors to be
* used
@@ -44,7 +44,7 @@
* to handle one or more polynomial coefficients at each stage:
* - perform the blind rotation
* - round the result
* - decompose into l_gadget levels, then for each level:
* - decompose into level_count levels, then for each level:
* - switch to the FFT domain
* - multiply with the bootstrapping key
* - come back to the coefficients representation
@@ -58,11 +58,11 @@
*/
void cuda_bootstrap_amortized_lwe_ciphertext_vector_32(
void *v_stream, void *lwe_out, void *lut_vector, void *lut_vector_indexes,
void *lwe_in, void *bootstrapping_key, uint32_t lwe_dimension,
uint32_t glwe_dimension, uint32_t polynomial_size, uint32_t base_log,
uint32_t l_gadget, uint32_t num_samples, uint32_t num_lut_vectors,
uint32_t lwe_idx, uint32_t max_shared_memory) {
void *v_stream, void *lwe_array_out, void *lut_vector,
void *lut_vector_indexes, void *lwe_array_in, void *bootstrapping_key,
uint32_t lwe_dimension, uint32_t glwe_dimension, uint32_t polynomial_size,
uint32_t base_log, uint32_t level_count, uint32_t num_samples,
uint32_t num_lut_vectors, uint32_t lwe_idx, uint32_t max_shared_memory) {
assert(
("Error (GPU amortized PBS): base log should be <= 16", base_log <= 16));
@@ -77,38 +77,38 @@ void cuda_bootstrap_amortized_lwe_ciphertext_vector_32(
switch (polynomial_size) {
case 512:
host_bootstrap_amortized<uint32_t, Degree<512>>(
v_stream, (uint32_t *)lwe_out, (uint32_t *)lut_vector,
(uint32_t *)lut_vector_indexes, (uint32_t *)lwe_in,
v_stream, (uint32_t *)lwe_array_out, (uint32_t *)lut_vector,
(uint32_t *)lut_vector_indexes, (uint32_t *)lwe_array_in,
(double2 *)bootstrapping_key, lwe_dimension, polynomial_size, base_log,
l_gadget, num_samples, num_lut_vectors, lwe_idx, max_shared_memory);
level_count, num_samples, num_lut_vectors, lwe_idx, max_shared_memory);
break;
case 1024:
host_bootstrap_amortized<uint32_t, Degree<1024>>(
v_stream, (uint32_t *)lwe_out, (uint32_t *)lut_vector,
(uint32_t *)lut_vector_indexes, (uint32_t *)lwe_in,
v_stream, (uint32_t *)lwe_array_out, (uint32_t *)lut_vector,
(uint32_t *)lut_vector_indexes, (uint32_t *)lwe_array_in,
(double2 *)bootstrapping_key, lwe_dimension, polynomial_size, base_log,
l_gadget, num_samples, num_lut_vectors, lwe_idx, max_shared_memory);
level_count, num_samples, num_lut_vectors, lwe_idx, max_shared_memory);
break;
case 2048:
host_bootstrap_amortized<uint32_t, Degree<2048>>(
v_stream, (uint32_t *)lwe_out, (uint32_t *)lut_vector,
(uint32_t *)lut_vector_indexes, (uint32_t *)lwe_in,
v_stream, (uint32_t *)lwe_array_out, (uint32_t *)lut_vector,
(uint32_t *)lut_vector_indexes, (uint32_t *)lwe_array_in,
(double2 *)bootstrapping_key, lwe_dimension, polynomial_size, base_log,
l_gadget, num_samples, num_lut_vectors, lwe_idx, max_shared_memory);
level_count, num_samples, num_lut_vectors, lwe_idx, max_shared_memory);
break;
case 4096:
host_bootstrap_amortized<uint32_t, Degree<4096>>(
v_stream, (uint32_t *)lwe_out, (uint32_t *)lut_vector,
(uint32_t *)lut_vector_indexes, (uint32_t *)lwe_in,
v_stream, (uint32_t *)lwe_array_out, (uint32_t *)lut_vector,
(uint32_t *)lut_vector_indexes, (uint32_t *)lwe_array_in,
(double2 *)bootstrapping_key, lwe_dimension, polynomial_size, base_log,
l_gadget, num_samples, num_lut_vectors, lwe_idx, max_shared_memory);
level_count, num_samples, num_lut_vectors, lwe_idx, max_shared_memory);
break;
case 8192:
host_bootstrap_amortized<uint32_t, Degree<8192>>(
v_stream, (uint32_t *)lwe_out, (uint32_t *)lut_vector,
(uint32_t *)lut_vector_indexes, (uint32_t *)lwe_in,
v_stream, (uint32_t *)lwe_array_out, (uint32_t *)lut_vector,
(uint32_t *)lut_vector_indexes, (uint32_t *)lwe_array_in,
(double2 *)bootstrapping_key, lwe_dimension, polynomial_size, base_log,
l_gadget, num_samples, num_lut_vectors, lwe_idx, max_shared_memory);
level_count, num_samples, num_lut_vectors, lwe_idx, max_shared_memory);
break;
default:
break;
@@ -116,11 +116,11 @@ void cuda_bootstrap_amortized_lwe_ciphertext_vector_32(
}
void cuda_bootstrap_amortized_lwe_ciphertext_vector_64(
void *v_stream, void *lwe_out, void *lut_vector, void *lut_vector_indexes,
void *lwe_in, void *bootstrapping_key, uint32_t lwe_dimension,
uint32_t glwe_dimension, uint32_t polynomial_size, uint32_t base_log,
uint32_t l_gadget, uint32_t num_samples, uint32_t num_lut_vectors,
uint32_t lwe_idx, uint32_t max_shared_memory) {
void *v_stream, void *lwe_array_out, void *lut_vector,
void *lut_vector_indexes, void *lwe_array_in, void *bootstrapping_key,
uint32_t lwe_dimension, uint32_t glwe_dimension, uint32_t polynomial_size,
uint32_t base_log, uint32_t level_count, uint32_t num_samples,
uint32_t num_lut_vectors, uint32_t lwe_idx, uint32_t max_shared_memory) {
assert(
("Error (GPU amortized PBS): base log should be <= 16", base_log <= 16));
@@ -135,38 +135,38 @@ void cuda_bootstrap_amortized_lwe_ciphertext_vector_64(
switch (polynomial_size) {
case 512:
host_bootstrap_amortized<uint64_t, Degree<512>>(
v_stream, (uint64_t *)lwe_out, (uint64_t *)lut_vector,
(uint32_t *)lut_vector_indexes, (uint64_t *)lwe_in,
v_stream, (uint64_t *)lwe_array_out, (uint64_t *)lut_vector,
(uint32_t *)lut_vector_indexes, (uint64_t *)lwe_array_in,
(double2 *)bootstrapping_key, lwe_dimension, polynomial_size, base_log,
l_gadget, num_samples, num_lut_vectors, lwe_idx, max_shared_memory);
level_count, num_samples, num_lut_vectors, lwe_idx, max_shared_memory);
break;
case 1024:
host_bootstrap_amortized<uint64_t, Degree<1024>>(
v_stream, (uint64_t *)lwe_out, (uint64_t *)lut_vector,
(uint32_t *)lut_vector_indexes, (uint64_t *)lwe_in,
v_stream, (uint64_t *)lwe_array_out, (uint64_t *)lut_vector,
(uint32_t *)lut_vector_indexes, (uint64_t *)lwe_array_in,
(double2 *)bootstrapping_key, lwe_dimension, polynomial_size, base_log,
l_gadget, num_samples, num_lut_vectors, lwe_idx, max_shared_memory);
level_count, num_samples, num_lut_vectors, lwe_idx, max_shared_memory);
break;
case 2048:
host_bootstrap_amortized<uint64_t, Degree<2048>>(
v_stream, (uint64_t *)lwe_out, (uint64_t *)lut_vector,
(uint32_t *)lut_vector_indexes, (uint64_t *)lwe_in,
v_stream, (uint64_t *)lwe_array_out, (uint64_t *)lut_vector,
(uint32_t *)lut_vector_indexes, (uint64_t *)lwe_array_in,
(double2 *)bootstrapping_key, lwe_dimension, polynomial_size, base_log,
l_gadget, num_samples, num_lut_vectors, lwe_idx, max_shared_memory);
level_count, num_samples, num_lut_vectors, lwe_idx, max_shared_memory);
break;
case 4096:
host_bootstrap_amortized<uint64_t, Degree<4096>>(
v_stream, (uint64_t *)lwe_out, (uint64_t *)lut_vector,
(uint32_t *)lut_vector_indexes, (uint64_t *)lwe_in,
v_stream, (uint64_t *)lwe_array_out, (uint64_t *)lut_vector,
(uint32_t *)lut_vector_indexes, (uint64_t *)lwe_array_in,
(double2 *)bootstrapping_key, lwe_dimension, polynomial_size, base_log,
l_gadget, num_samples, num_lut_vectors, lwe_idx, max_shared_memory);
level_count, num_samples, num_lut_vectors, lwe_idx, max_shared_memory);
break;
case 8192:
host_bootstrap_amortized<uint64_t, Degree<8192>>(
v_stream, (uint64_t *)lwe_out, (uint64_t *)lut_vector,
(uint32_t *)lut_vector_indexes, (uint64_t *)lwe_in,
v_stream, (uint64_t *)lwe_array_out, (uint64_t *)lut_vector,
(uint32_t *)lut_vector_indexes, (uint64_t *)lwe_array_in,
(double2 *)bootstrapping_key, lwe_dimension, polynomial_size, base_log,
l_gadget, num_samples, num_lut_vectors, lwe_idx, max_shared_memory);
level_count, num_samples, num_lut_vectors, lwe_idx, max_shared_memory);
break;
default:
break;

View File

@@ -29,35 +29,36 @@ template <typename Torus, class params, sharedMemDegree SMD>
* Kernel launched by host_bootstrap_amortized
*
* Uses shared memory to increase performance
* - lwe_out: output batch of num_samples bootstrapped ciphertexts c =
* - lwe_array_out: output batch of num_samples bootstrapped ciphertexts c =
* (a0,..an-1,b) where n is the LWE dimension
* - lut_vector: should hold as many test vectors of size polynomial_size
* as there are input ciphertexts, but actually holds
* num_lut_vectors vectors to reduce memory usage
* - lut_vector_indexes: stores the index corresponding to which test vector
* to use for each sample in lut_vector
* - lwe_in: input batch of num_samples LWE ciphertexts, containing n mask
* values + 1 body value
* - lwe_array_in: input batch of num_samples LWE ciphertexts, containing n
* mask values + 1 body value
* - bootstrapping_key: RGSW encryption of the LWE secret key sk1 under secret
* key sk2
* - device_mem: pointer to the device's global memory in case we use it (SMD
* == NOSM or PARTIALSM)
* - lwe_mask_size: size of the Torus vector used to encrypt the input
* - lwe_dimension: size of the Torus vector used to encrypt the input
* LWE ciphertexts - referred to as n above (~ 600)
* - polynomial_size: size of the test polynomial (test vector) and size of the
* GLWE polynomial (~1024)
* - base_log: log base used for the gadget matrix - B = 2^base_log (~8)
* - l_gadget: number of decomposition levels in the gadget matrix (~4)
* - level_count: number of decomposition levels in the gadget matrix (~4)
* - gpu_num: index of the current GPU (useful for multi-GPU computations)
* - lwe_idx: equal to the number of samples per gpu x gpu_num
* - device_memory_size_per_sample: amount of global memory to allocate if SMD
* is not FULLSM
*/
__global__ void device_bootstrap_amortized(
Torus *lwe_out, Torus *lut_vector, uint32_t *lut_vector_indexes,
Torus *lwe_in, double2 *bootstrapping_key, char *device_mem,
uint32_t lwe_mask_size, uint32_t polynomial_size, uint32_t base_log,
uint32_t l_gadget, uint32_t lwe_idx, size_t device_memory_size_per_sample) {
Torus *lwe_array_out, Torus *lut_vector, uint32_t *lut_vector_indexes,
Torus *lwe_array_in, double2 *bootstrapping_key, char *device_mem,
uint32_t lwe_dimension, uint32_t polynomial_size, uint32_t base_log,
uint32_t level_count, uint32_t lwe_idx,
size_t device_memory_size_per_sample) {
// We use shared memory for the polynomials that are used often during the
// bootstrap, since shared memory is kept in L1 cache and accessing it is
// much faster than global memory
@@ -69,7 +70,7 @@ __global__ void device_bootstrap_amortized(
else
selected_memory = &device_mem[blockIdx.x * device_memory_size_per_sample];
// For GPU bootstrapping the RLWE dimension is hard-set to 1: there is only
// For GPU bootstrapping the GLWE dimension is hard-set to 1: there is only
// one mask polynomial and 1 body to handle Also, since the decomposed
// polynomials take coefficients between -B/2 and B/2 they can be represented
// with only 16 bits, assuming the base log does not exceed 2^16
@@ -93,16 +94,16 @@ __global__ void device_bootstrap_amortized(
accumulator_fft =
(double2 *)body_res_fft + (ptrdiff_t)(polynomial_size / 2);
auto block_lwe_in = &lwe_in[blockIdx.x * (lwe_mask_size + 1)];
auto block_lwe_array_in = &lwe_array_in[blockIdx.x * (lwe_dimension + 1)];
Torus *block_lut_vector =
&lut_vector[lut_vector_indexes[lwe_idx + blockIdx.x] * params::degree *
2];
GadgetMatrix<Torus, params> gadget(base_log, l_gadget);
GadgetMatrix<Torus, params> gadget(base_log, level_count);
// Put "b", the body, in [0, 2N[
Torus b_hat = rescale_torus_element(
block_lwe_in[lwe_mask_size],
block_lwe_array_in[lwe_dimension],
2 * params::degree); // 2 * params::log2_degree + 1);
divide_by_monomial_negacyclic_inplace<Torus, params::opt,
@@ -115,14 +116,14 @@ __global__ void device_bootstrap_amortized(
// Loop over all the mask elements of the sample to accumulate
// (X^a_i-1) multiplication, decomposition of the resulting polynomial
// into l_gadget polynomials, and performing polynomial multiplication
// into level_count polynomials, and performing polynomial multiplication
// via an FFT with the RGSW encrypted secret key
for (int iteration = 0; iteration < lwe_mask_size; iteration++) {
for (int iteration = 0; iteration < lwe_dimension; iteration++) {
synchronize_threads_in_block();
// Put "a" in [0, 2N[ instead of Zq
Torus a_hat = rescale_torus_element(
block_lwe_in[iteration],
block_lwe_array_in[iteration],
2 * params::degree); // 2 * params::log2_degree + 1);
// Perform ACC * (X^ä - 1)
@@ -140,11 +141,11 @@ __global__ void device_bootstrap_amortized(
// bootstrapped ciphertext
round_to_closest_multiple_inplace<Torus, params::opt,
params::degree / params::opt>(
accumulator_mask_rotated, base_log, l_gadget);
accumulator_mask_rotated, base_log, level_count);
round_to_closest_multiple_inplace<Torus, params::opt,
params::degree / params::opt>(
accumulator_body_rotated, base_log, l_gadget);
accumulator_body_rotated, base_log, level_count);
// Initialize the polynomial multiplication via FFT arrays
// The polynomial multiplications happens at the block level
// and each thread handles two or more coefficients
@@ -160,13 +161,13 @@ __global__ void device_bootstrap_amortized(
// Now that the rotation is done, decompose the resulting polynomial
// coefficients so as to multiply each decomposed level with the
// corresponding part of the bootstrapping key
for (int decomp_level = 0; decomp_level < l_gadget; decomp_level++) {
for (int level = 0; level < level_count; level++) {
gadget.decompose_one_level(accumulator_mask_decomposed,
accumulator_mask_rotated, decomp_level);
accumulator_mask_rotated, level);
gadget.decompose_one_level(accumulator_body_decomposed,
accumulator_body_rotated, decomp_level);
accumulator_body_rotated, level);
synchronize_threads_in_block();
@@ -187,11 +188,11 @@ __global__ void device_bootstrap_amortized(
// Get the bootstrapping key piece necessary for the multiplication
// It is already in the Fourier domain
auto bsk_mask_slice = PolynomialFourier<double2, params>(
get_ith_mask_kth_block(bootstrapping_key, iteration, 0, decomp_level,
polynomial_size, 1, l_gadget));
get_ith_mask_kth_block(bootstrapping_key, iteration, 0, level,
polynomial_size, 1, level_count));
auto bsk_body_slice = PolynomialFourier<double2, params>(
get_ith_body_kth_block(bootstrapping_key, iteration, 0, decomp_level,
polynomial_size, 1, l_gadget));
get_ith_body_kth_block(bootstrapping_key, iteration, 0, level,
polynomial_size, 1, level_count));
synchronize_threads_in_block();
@@ -216,11 +217,11 @@ __global__ void device_bootstrap_amortized(
correction_direct_fft_inplace<params>(accumulator_fft);
auto bsk_mask_slice_2 = PolynomialFourier<double2, params>(
get_ith_mask_kth_block(bootstrapping_key, iteration, 1, decomp_level,
polynomial_size, 1, l_gadget));
get_ith_mask_kth_block(bootstrapping_key, iteration, 1, level,
polynomial_size, 1, level_count));
auto bsk_body_slice_2 = PolynomialFourier<double2, params>(
get_ith_body_kth_block(bootstrapping_key, iteration, 1, decomp_level,
polynomial_size, 1, l_gadget));
get_ith_body_kth_block(bootstrapping_key, iteration, 1, level,
polynomial_size, 1, level_count));
synchronize_threads_in_block();
@@ -283,23 +284,24 @@ __global__ void device_bootstrap_amortized(
}
}
auto block_lwe_out = &lwe_out[blockIdx.x * (polynomial_size + 1)];
auto block_lwe_array_out = &lwe_array_out[blockIdx.x * (polynomial_size + 1)];
// The blind rotation for this block is over
// Now we can perform the sample extraction: for the body it's just
// the resulting constant coefficient of the accumulator
// For the mask it's more complicated
sample_extract_mask<Torus, params>(block_lwe_out, accumulator_mask);
sample_extract_body<Torus, params>(block_lwe_out, accumulator_body);
sample_extract_mask<Torus, params>(block_lwe_array_out, accumulator_mask);
sample_extract_body<Torus, params>(block_lwe_array_out, accumulator_body);
}
template <typename Torus, class params>
__host__ void host_bootstrap_amortized(
void *v_stream, Torus *lwe_out, Torus *lut_vector,
uint32_t *lut_vector_indexes, Torus *lwe_in, double2 *bootstrapping_key,
uint32_t input_lwe_dimension, uint32_t polynomial_size, uint32_t base_log,
uint32_t l_gadget, uint32_t input_lwe_ciphertext_count,
uint32_t num_lut_vectors, uint32_t lwe_idx, uint32_t max_shared_memory) {
void *v_stream, Torus *lwe_array_out, Torus *lut_vector,
uint32_t *lut_vector_indexes, Torus *lwe_array_in,
double2 *bootstrapping_key, uint32_t input_lwe_dimension,
uint32_t polynomial_size, uint32_t base_log, uint32_t level_count,
uint32_t input_lwe_ciphertext_count, uint32_t num_lut_vectors,
uint32_t lwe_idx, uint32_t max_shared_memory) {
int SM_FULL = sizeof(Torus) * polynomial_size + // accumulator mask
sizeof(Torus) * polynomial_size + // accumulator body
@@ -338,9 +340,9 @@ __host__ void host_bootstrap_amortized(
checkCudaErrors(
cudaMalloc((void **)&d_mem, DM_FULL * input_lwe_ciphertext_count));
device_bootstrap_amortized<Torus, params, NOSM><<<grid, thds, 0, *stream>>>(
lwe_out, lut_vector, lut_vector_indexes, lwe_in, bootstrapping_key,
d_mem, input_lwe_dimension, polynomial_size, base_log, l_gadget,
lwe_idx, DM_FULL);
lwe_array_out, lut_vector, lut_vector_indexes, lwe_array_in,
bootstrapping_key, d_mem, input_lwe_dimension, polynomial_size,
base_log, level_count, lwe_idx, DM_FULL);
} else if (max_shared_memory < SM_FULL) {
cudaFuncSetAttribute(device_bootstrap_amortized<Torus, params, PARTIALSM>,
cudaFuncAttributeMaxDynamicSharedMemorySize, SM_PART);
@@ -350,9 +352,9 @@ __host__ void host_bootstrap_amortized(
cudaMalloc((void **)&d_mem, DM_PART * input_lwe_ciphertext_count));
device_bootstrap_amortized<Torus, params, PARTIALSM>
<<<grid, thds, SM_PART, *stream>>>(
lwe_out, lut_vector, lut_vector_indexes, lwe_in, bootstrapping_key,
d_mem, input_lwe_dimension, polynomial_size, base_log, l_gadget,
lwe_idx, DM_PART);
lwe_array_out, lut_vector, lut_vector_indexes, lwe_array_in,
bootstrapping_key, d_mem, input_lwe_dimension, polynomial_size,
base_log, level_count, lwe_idx, DM_PART);
} else {
// For devices with compute capability 7.x a single thread block can
// address the full capacity of shared memory. Shared memory on the
@@ -369,12 +371,12 @@ __host__ void host_bootstrap_amortized(
device_bootstrap_amortized<Torus, params, FULLSM>
<<<grid, thds, SM_FULL, *stream>>>(
lwe_out, lut_vector, lut_vector_indexes, lwe_in, bootstrapping_key,
d_mem, input_lwe_dimension, polynomial_size, base_log, l_gadget,
lwe_idx, 0);
lwe_array_out, lut_vector, lut_vector_indexes, lwe_array_in,
bootstrapping_key, d_mem, input_lwe_dimension, polynomial_size,
base_log, level_count, lwe_idx, 0);
}
// Synchronize the streams before copying the result to lwe_out at the right
// place
// Synchronize the streams before copying the result to lwe_array_out at the
// right place
cudaStreamSynchronize(*stream);
cudaFree(d_mem);
}

View File

@@ -2,7 +2,7 @@
/* Perform bootstrapping on a batch of input LWE ciphertexts
*
* - lwe_out: output batch of num_samples bootstrapped ciphertexts c =
* - lwe_array_out: output batch of num_samples bootstrapped ciphertexts c =
* (a0,..an-1,b) where n is the LWE dimension
* - lut_vector: should hold as many test vectors of size polynomial_size
* as there are input ciphertexts, but actually holds
@@ -10,7 +10,7 @@
* - lut_vector_indexes: stores the index corresponding to
* which test vector to use for each sample in
* lut_vector
* - lwe_in: input batch of num_samples LWE ciphertexts, containing n
* - lwe_array_in: input batch of num_samples LWE ciphertexts, containing n
* mask values + 1 body value
* - bootstrapping_key: RGSW encryption of the LWE secret key sk1
* under secret key sk2
@@ -30,7 +30,7 @@
* - polynomial_size: size of the test polynomial (test vector) and size of the
* GLWE polynomial (~1024)
* - base_log: log base used for the gadget matrix - B = 2^base_log (~8)
* - l_gadget: number of decomposition levels in the gadget matrix (~4)
* - level_count: number of decomposition levels in the gadget matrix (~4)
* - num_samples: number of encrypted input messages
* - num_lut_vectors: parameter to set the actual number of test vectors to be
* used
@@ -44,7 +44,7 @@
* to handle one or more polynomial coefficients at each stage:
* - perform the blind rotation
* - round the result
* - decompose into l_gadget levels, then for each level:
* - decompose into level_count levels, then for each level:
* - switch to the FFT domain
* - multiply with the bootstrapping key
* - come back to the coefficients representation
@@ -57,11 +57,11 @@
* values for the FFT
*/
void cuda_bootstrap_low_latency_lwe_ciphertext_vector_32(
void *v_stream, void *lwe_out, void *lut_vector, void *lut_vector_indexes,
void *lwe_in, void *bootstrapping_key, uint32_t lwe_dimension,
uint32_t glwe_dimension, uint32_t polynomial_size, uint32_t base_log,
uint32_t l_gadget, uint32_t num_samples, uint32_t num_lut_vectors,
uint32_t lwe_idx, uint32_t max_shared_memory) {
void *v_stream, void *lwe_array_out, void *lut_vector,
void *lut_vector_indexes, void *lwe_array_in, void *bootstrapping_key,
uint32_t lwe_dimension, uint32_t glwe_dimension, uint32_t polynomial_size,
uint32_t base_log, uint32_t level_count, uint32_t num_samples,
uint32_t num_lut_vectors, uint32_t lwe_idx, uint32_t max_shared_memory) {
assert(("Error (GPU low latency PBS): base log should be <= 16",
base_log <= 16));
@@ -79,44 +79,30 @@ void cuda_bootstrap_low_latency_lwe_ciphertext_vector_32(
assert(("Error (GPU low latency PBS): the number of input LWEs must be lower "
"or equal to the "
"number of streaming multiprocessors on the device divided by 8 * "
"l_gadget",
num_samples <= number_of_sm / 4. / 2. / l_gadget));
"level_count",
num_samples <= number_of_sm / 4. / 2. / level_count));
switch (polynomial_size) {
case 512:
host_bootstrap_low_latency<uint32_t, Degree<512>>(
v_stream, (uint32_t *)lwe_out, (uint32_t *)lut_vector,
(uint32_t *)lut_vector_indexes, (uint32_t *)lwe_in,
v_stream, (uint32_t *)lwe_array_out, (uint32_t *)lut_vector,
(uint32_t *)lut_vector_indexes, (uint32_t *)lwe_array_in,
(double2 *)bootstrapping_key, lwe_dimension, polynomial_size, base_log,
l_gadget, num_samples, num_lut_vectors);
level_count, num_samples, num_lut_vectors);
break;
case 1024:
host_bootstrap_low_latency<uint32_t, Degree<1024>>(
v_stream, (uint32_t *)lwe_out, (uint32_t *)lut_vector,
(uint32_t *)lut_vector_indexes, (uint32_t *)lwe_in,
v_stream, (uint32_t *)lwe_array_out, (uint32_t *)lut_vector,
(uint32_t *)lut_vector_indexes, (uint32_t *)lwe_array_in,
(double2 *)bootstrapping_key, lwe_dimension, polynomial_size, base_log,
l_gadget, num_samples, num_lut_vectors);
level_count, num_samples, num_lut_vectors);
break;
case 2048:
host_bootstrap_low_latency<uint32_t, Degree<2048>>(
v_stream, (uint32_t *)lwe_out, (uint32_t *)lut_vector,
(uint32_t *)lut_vector_indexes, (uint32_t *)lwe_in,
v_stream, (uint32_t *)lwe_array_out, (uint32_t *)lut_vector,
(uint32_t *)lut_vector_indexes, (uint32_t *)lwe_array_in,
(double2 *)bootstrapping_key, lwe_dimension, polynomial_size, base_log,
l_gadget, num_samples, num_lut_vectors);
break;
case 4096:
host_bootstrap_low_latency<uint32_t, Degree<4096>>(
v_stream, (uint32_t *)lwe_out, (uint32_t *)lut_vector,
(uint32_t *)lut_vector_indexes, (uint32_t *)lwe_in,
(double2 *)bootstrapping_key, lwe_dimension, polynomial_size, base_log,
l_gadget, num_samples, num_lut_vectors);
break;
case 8192:
host_bootstrap_low_latency<uint32_t, Degree<8192>>(
v_stream, (uint32_t *)lwe_out, (uint32_t *)lut_vector,
(uint32_t *)lut_vector_indexes, (uint32_t *)lwe_in,
(double2 *)bootstrapping_key, lwe_dimension, polynomial_size, base_log,
l_gadget, num_samples, num_lut_vectors);
level_count, num_samples, num_lut_vectors);
break;
default:
break;
@@ -124,11 +110,11 @@ void cuda_bootstrap_low_latency_lwe_ciphertext_vector_32(
}
void cuda_bootstrap_low_latency_lwe_ciphertext_vector_64(
void *v_stream, void *lwe_out, void *lut_vector, void *lut_vector_indexes,
void *lwe_in, void *bootstrapping_key, uint32_t lwe_dimension,
uint32_t glwe_dimension, uint32_t polynomial_size, uint32_t base_log,
uint32_t l_gadget, uint32_t num_samples, uint32_t num_lut_vectors,
uint32_t lwe_idx, uint32_t max_shared_memory) {
void *v_stream, void *lwe_array_out, void *lut_vector,
void *lut_vector_indexes, void *lwe_array_in, void *bootstrapping_key,
uint32_t lwe_dimension, uint32_t glwe_dimension, uint32_t polynomial_size,
uint32_t base_log, uint32_t level_count, uint32_t num_samples,
uint32_t num_lut_vectors, uint32_t lwe_idx, uint32_t max_shared_memory) {
assert(("Error (GPU low latency PBS): base log should be <= 16",
base_log <= 16));
@@ -146,44 +132,30 @@ void cuda_bootstrap_low_latency_lwe_ciphertext_vector_64(
assert(("Error (GPU low latency PBS): the number of input LWEs must be lower "
"or equal to the "
"number of streaming multiprocessors on the device divided by 8 * "
"l_gadget",
num_samples <= number_of_sm / 4. / 2. / l_gadget));
"level_count",
num_samples <= number_of_sm / 4. / 2. / level_count));
switch (polynomial_size) {
case 512:
host_bootstrap_low_latency<uint64_t, Degree<512>>(
v_stream, (uint64_t *)lwe_out, (uint64_t *)lut_vector,
(uint32_t *)lut_vector_indexes, (uint64_t *)lwe_in,
v_stream, (uint64_t *)lwe_array_out, (uint64_t *)lut_vector,
(uint32_t *)lut_vector_indexes, (uint64_t *)lwe_array_in,
(double2 *)bootstrapping_key, lwe_dimension, polynomial_size, base_log,
l_gadget, num_samples, num_lut_vectors);
level_count, num_samples, num_lut_vectors);
break;
case 1024:
host_bootstrap_low_latency<uint64_t, Degree<1024>>(
v_stream, (uint64_t *)lwe_out, (uint64_t *)lut_vector,
(uint32_t *)lut_vector_indexes, (uint64_t *)lwe_in,
v_stream, (uint64_t *)lwe_array_out, (uint64_t *)lut_vector,
(uint32_t *)lut_vector_indexes, (uint64_t *)lwe_array_in,
(double2 *)bootstrapping_key, lwe_dimension, polynomial_size, base_log,
l_gadget, num_samples, num_lut_vectors);
level_count, num_samples, num_lut_vectors);
break;
case 2048:
host_bootstrap_low_latency<uint64_t, Degree<2048>>(
v_stream, (uint64_t *)lwe_out, (uint64_t *)lut_vector,
(uint32_t *)lut_vector_indexes, (uint64_t *)lwe_in,
v_stream, (uint64_t *)lwe_array_out, (uint64_t *)lut_vector,
(uint32_t *)lut_vector_indexes, (uint64_t *)lwe_array_in,
(double2 *)bootstrapping_key, lwe_dimension, polynomial_size, base_log,
l_gadget, num_samples, num_lut_vectors);
break;
case 4096:
host_bootstrap_low_latency<uint64_t, Degree<4096>>(
v_stream, (uint64_t *)lwe_out, (uint64_t *)lut_vector,
(uint32_t *)lut_vector_indexes, (uint64_t *)lwe_in,
(double2 *)bootstrapping_key, lwe_dimension, polynomial_size, base_log,
l_gadget, num_samples, num_lut_vectors);
break;
case 8192:
host_bootstrap_low_latency<uint64_t, Degree<8192>>(
v_stream, (uint64_t *)lwe_out, (uint64_t *)lut_vector,
(uint32_t *)lut_vector_indexes, (uint64_t *)lwe_in,
(double2 *)bootstrapping_key, lwe_dimension, polynomial_size, base_log,
l_gadget, num_samples, num_lut_vectors);
level_count, num_samples, num_lut_vectors);
break;
default:
break;

View File

@@ -29,13 +29,13 @@ namespace cg = cooperative_groups;
template <typename Torus, class params>
__device__ void
mul_trgsw_trlwe(Torus *accumulator, double2 *fft, int16_t *trlwe_decomposed,
double2 *mask_join_buffer, double2 *body_join_buffer,
double2 *bootstrapping_key, int polynomial_size, int l_gadget,
int iteration, grid_group &grid) {
mul_ggsw_glwe(Torus *accumulator, double2 *fft, int16_t *glwe_decomposed,
double2 *mask_join_buffer, double2 *body_join_buffer,
double2 *bootstrapping_key, int polynomial_size, int level_count,
int iteration, grid_group &grid) {
// Put the decomposed TRLWE sample in the Fourier domain
real_to_complex_compressed<int16_t, params>(trlwe_decomposed, fft);
// Put the decomposed GLWE sample in the Fourier domain
real_to_complex_compressed<int16_t, params>(glwe_decomposed, fft);
synchronize_threads_in_block();
// Switch to the FFT space
@@ -53,12 +53,12 @@ mul_trgsw_trlwe(Torus *accumulator, double2 *fft, int16_t *trlwe_decomposed,
auto bsk_mask_slice = PolynomialFourier<double2, params>(
get_ith_mask_kth_block(bootstrapping_key, iteration, blockIdx.y,
blockIdx.x, polynomial_size, 1, l_gadget));
blockIdx.x, polynomial_size, 1, level_count));
auto bsk_body_slice = PolynomialFourier<double2, params>(
get_ith_body_kth_block(bootstrapping_key, iteration, blockIdx.y,
blockIdx.x, polynomial_size, 1, l_gadget));
blockIdx.x, polynomial_size, 1, level_count));
// Perform the matrix multiplication between the RGSW and the TRLWE,
// Perform the matrix multiplication between the GGSW and the GLWE,
// each block operating on a single level for mask and body
auto first_processed_bsk =
@@ -120,7 +120,7 @@ mul_trgsw_trlwe(Torus *accumulator, double2 *fft, int16_t *trlwe_decomposed,
correction_inverse_fft_inplace<params>(fft);
synchronize_threads_in_block();
// Perform the inverse FFT on the result of the RGSWxTRLWE and add to the
// Perform the inverse FFT on the result of the GGSW x GWE and add to the
// accumulator
NSMFFT_inverse<HalfDegree<params>>(fft);
synchronize_threads_in_block();
@@ -134,17 +134,18 @@ template <typename Torus, class params>
/*
* Kernel launched by the low latency version of the
* bootstrapping, that uses cooperative groups
* lwe_out vector of output lwe s, with length (polynomial_size+1)*num_samples
* lut_vector - vector of look up tables with length polynomial_size *
* num_samples lut_vector_indexes - mapping between lwe_in and lut_vector lwe_in
* - vector of lwe inputs with length (lwe_mask_size + 1) * num_samples
* lwe_array_out vector of output lwe s, with length
* (polynomial_size+1)*num_samples lut_vector - vector of look up tables with
* length polynomial_size * num_samples lut_vector_indexes - mapping between
* lwe_array_in and lut_vector lwe_array_in
* - vector of lwe inputs with length (lwe_dimension + 1) * num_samples
*
*/
__global__ void device_bootstrap_low_latency(
Torus *lwe_out, Torus *lut_vector, Torus *lwe_in,
Torus *lwe_array_out, Torus *lut_vector, Torus *lwe_array_in,
double2 *bootstrapping_key, double2 *mask_join_buffer,
double2 *body_join_buffer, uint32_t lwe_mask_size, uint32_t polynomial_size,
uint32_t base_log, uint32_t l_gadget) {
double2 *body_join_buffer, uint32_t lwe_dimension, uint32_t polynomial_size,
uint32_t base_log, uint32_t level_count) {
grid_group grid = this_grid();
@@ -167,23 +168,23 @@ __global__ void device_bootstrap_low_latency(
// The third dimension of the block is used to determine on which ciphertext
// this block is operating, in the case of batch bootstraps
auto block_lwe_in = &lwe_in[blockIdx.z * (lwe_mask_size + 1)];
auto block_lwe_array_in = &lwe_array_in[blockIdx.z * (lwe_dimension + 1)];
auto block_lut_vector = &lut_vector[blockIdx.z * params::degree * 2];
auto block_mask_join_buffer =
&mask_join_buffer[blockIdx.z * l_gadget * params::degree / 2];
&mask_join_buffer[blockIdx.z * level_count * params::degree / 2];
auto block_body_join_buffer =
&body_join_buffer[blockIdx.z * l_gadget * params::degree / 2];
&body_join_buffer[blockIdx.z * level_count * params::degree / 2];
// Since the space is L1 cache is small, we use the same memory location for
// the rotated accumulator and the fft accumulator, since we know that the
// rotated array is not in use anymore by the time we perform the fft
GadgetMatrix<Torus, params> gadget(base_log, l_gadget);
GadgetMatrix<Torus, params> gadget(base_log, level_count);
// Put "b" in [0, 2N[
Torus b_hat =
rescale_torus_element(block_lwe_in[lwe_mask_size], 2 * params::degree);
Torus b_hat = rescale_torus_element(block_lwe_array_in[lwe_dimension],
2 * params::degree);
if (blockIdx.y == 0) {
divide_by_monomial_negacyclic_inplace<Torus, params::opt,
@@ -195,12 +196,12 @@ __global__ void device_bootstrap_low_latency(
accumulator, &block_lut_vector[params::degree], b_hat, false);
}
for (int i = 0; i < lwe_mask_size; i++) {
for (int i = 0; i < lwe_dimension; i++) {
synchronize_threads_in_block();
// Put "a" in [0, 2N[
Torus a_hat = rescale_torus_element(
block_lwe_in[i],
block_lwe_array_in[i],
2 * params::degree); // 2 * params::log2_degree + 1);
// Perform ACC * (X^ä - 1)
@@ -212,7 +213,7 @@ __global__ void device_bootstrap_low_latency(
// bootstrapped ciphertext
round_to_closest_multiple_inplace<Torus, params::opt,
params::degree / params::opt>(
accumulator_rotated, base_log, l_gadget);
accumulator_rotated, base_log, level_count);
// Decompose the accumulator. Each block gets one level of the
// decomposition, for the mask and the body (so block 0 will have the
@@ -224,22 +225,22 @@ __global__ void device_bootstrap_low_latency(
// accumulator_rotated, so we need to synchronize here to make sure they
// don't modify the same memory space at the same time
synchronize_threads_in_block();
// Perform G^-1(ACC) * RGSW -> TRLWE
mul_trgsw_trlwe<Torus, params>(
accumulator, accumulator_fft, accumulator_decomposed,
block_mask_join_buffer, block_body_join_buffer, bootstrapping_key,
polynomial_size, l_gadget, i, grid);
// Perform G^-1(ACC) * GGSW -> GLWE
mul_ggsw_glwe<Torus, params>(accumulator, accumulator_fft,
accumulator_decomposed, block_mask_join_buffer,
block_body_join_buffer, bootstrapping_key,
polynomial_size, level_count, i, grid);
}
auto block_lwe_out = &lwe_out[blockIdx.z * (polynomial_size + 1)];
auto block_lwe_array_out = &lwe_array_out[blockIdx.z * (polynomial_size + 1)];
if (blockIdx.x == 0 && blockIdx.y == 0) {
// Perform a sample extract. At this point, all blocks have the result, but
// we do the computation at block 0 to avoid waiting for extra blocks, in
// case they're not synchronized
sample_extract_mask<Torus, params>(block_lwe_out, accumulator);
sample_extract_mask<Torus, params>(block_lwe_array_out, accumulator);
} else if (blockIdx.x == 0 && blockIdx.y == 1) {
sample_extract_body<Torus, params>(block_lwe_out, accumulator);
sample_extract_body<Torus, params>(block_lwe_array_out, accumulator);
}
}
@@ -248,16 +249,18 @@ __global__ void device_bootstrap_low_latency(
* of bootstrapping
*/
template <typename Torus, class params>
__host__ void host_bootstrap_low_latency(
void *v_stream, Torus *lwe_out, Torus *lut_vector,
uint32_t *lut_vector_indexes, Torus *lwe_in, double2 *bootstrapping_key,
uint32_t lwe_mask_size, uint32_t polynomial_size, uint32_t base_log,
uint32_t l_gadget, uint32_t num_samples, uint32_t num_lut_vectors) {
__host__ void
host_bootstrap_low_latency(void *v_stream, Torus *lwe_array_out,
Torus *lut_vector, uint32_t *lut_vector_indexes,
Torus *lwe_array_in, double2 *bootstrapping_key,
uint32_t lwe_dimension, uint32_t polynomial_size,
uint32_t base_log, uint32_t level_count,
uint32_t num_samples, uint32_t num_lut_vectors) {
auto stream = static_cast<cudaStream_t *>(v_stream);
int buffer_size_per_gpu =
l_gadget * num_samples * polynomial_size / 2 * sizeof(double2);
level_count * num_samples * polynomial_size / 2 * sizeof(double2);
double2 *mask_buffer_fft;
double2 *body_buffer_fft;
checkCudaErrors(cudaMalloc((void **)&mask_buffer_fft, buffer_size_per_gpu));
@@ -268,19 +271,19 @@ __host__ void host_bootstrap_low_latency(
sizeof(double2) * polynomial_size / 2; // accumulator fft
int thds = polynomial_size / params::opt;
dim3 grid(l_gadget, 2, num_samples);
dim3 grid(level_count, 2, num_samples);
void *kernel_args[10];
kernel_args[0] = &lwe_out;
kernel_args[0] = &lwe_array_out;
kernel_args[1] = &lut_vector;
kernel_args[2] = &lwe_in;
kernel_args[2] = &lwe_array_in;
kernel_args[3] = &bootstrapping_key;
kernel_args[4] = &mask_buffer_fft;
kernel_args[5] = &body_buffer_fft;
kernel_args[6] = &lwe_mask_size;
kernel_args[6] = &lwe_dimension;
kernel_args[7] = &polynomial_size;
kernel_args[8] = &base_log;
kernel_args[9] = &l_gadget;
kernel_args[9] = &level_count;
checkCudaErrors(cudaFuncSetAttribute(
device_bootstrap_low_latency<Torus, params>,
@@ -292,8 +295,8 @@ __host__ void host_bootstrap_low_latency(
(void *)device_bootstrap_low_latency<Torus, params>, grid, thds,
(void **)kernel_args, bytes_needed, *stream));
// Synchronize the streams before copying the result to lwe_out at the right
// place
// Synchronize the streams before copying the result to lwe_array_out at the
// right place
cudaStreamSynchronize(*stream);
cudaFree(mask_buffer_fft);
cudaFree(body_buffer_fft);

View File

@@ -1,9 +1,9 @@
#include "bootstrap_wop.cuh"
void cuda_cmux_tree_32(void *v_stream, void *glwe_out, void *ggsw_in,
void cuda_cmux_tree_32(void *v_stream, void *glwe_array_out, void *ggsw_in,
void *lut_vector, uint32_t glwe_dimension,
uint32_t polynomial_size, uint32_t base_log,
uint32_t l_gadget, uint32_t r,
uint32_t level_count, uint32_t r,
uint32_t max_shared_memory) {
assert(("Error (GPU Cmux tree): base log should be <= 16", base_log <= 16));
@@ -22,43 +22,43 @@ void cuda_cmux_tree_32(void *v_stream, void *glwe_out, void *ggsw_in,
switch (polynomial_size) {
case 512:
host_cmux_tree<uint32_t, int32_t, Degree<512>>(
v_stream, (uint32_t *)glwe_out, (uint32_t *)ggsw_in,
v_stream, (uint32_t *)glwe_array_out, (uint32_t *)ggsw_in,
(uint32_t *)lut_vector, glwe_dimension, polynomial_size, base_log,
l_gadget, r, max_shared_memory);
level_count, r, max_shared_memory);
break;
case 1024:
host_cmux_tree<uint32_t, int32_t, Degree<1024>>(
v_stream, (uint32_t *)glwe_out, (uint32_t *)ggsw_in,
v_stream, (uint32_t *)glwe_array_out, (uint32_t *)ggsw_in,
(uint32_t *)lut_vector, glwe_dimension, polynomial_size, base_log,
l_gadget, r, max_shared_memory);
level_count, r, max_shared_memory);
break;
case 2048:
host_cmux_tree<uint32_t, int32_t, Degree<2048>>(
v_stream, (uint32_t *)glwe_out, (uint32_t *)ggsw_in,
v_stream, (uint32_t *)glwe_array_out, (uint32_t *)ggsw_in,
(uint32_t *)lut_vector, glwe_dimension, polynomial_size, base_log,
l_gadget, r, max_shared_memory);
level_count, r, max_shared_memory);
break;
case 4096:
host_cmux_tree<uint32_t, int32_t, Degree<4096>>(
v_stream, (uint32_t *)glwe_out, (uint32_t *)ggsw_in,
v_stream, (uint32_t *)glwe_array_out, (uint32_t *)ggsw_in,
(uint32_t *)lut_vector, glwe_dimension, polynomial_size, base_log,
l_gadget, r, max_shared_memory);
level_count, r, max_shared_memory);
break;
case 8192:
host_cmux_tree<uint32_t, int32_t, Degree<8192>>(
v_stream, (uint32_t *)glwe_out, (uint32_t *)ggsw_in,
v_stream, (uint32_t *)glwe_array_out, (uint32_t *)ggsw_in,
(uint32_t *)lut_vector, glwe_dimension, polynomial_size, base_log,
l_gadget, r, max_shared_memory);
level_count, r, max_shared_memory);
break;
default:
break;
}
}
void cuda_cmux_tree_64(void *v_stream, void *glwe_out, void *ggsw_in,
void cuda_cmux_tree_64(void *v_stream, void *glwe_array_out, void *ggsw_in,
void *lut_vector, uint32_t glwe_dimension,
uint32_t polynomial_size, uint32_t base_log,
uint32_t l_gadget, uint32_t r,
uint32_t level_count, uint32_t r,
uint32_t max_shared_memory) {
assert(("Error (GPU Cmux tree): base log should be <= 16", base_log <= 16));
@@ -77,57 +77,56 @@ void cuda_cmux_tree_64(void *v_stream, void *glwe_out, void *ggsw_in,
switch (polynomial_size) {
case 512:
host_cmux_tree<uint64_t, int64_t, Degree<512>>(
v_stream, (uint64_t *)glwe_out, (uint64_t *)ggsw_in,
v_stream, (uint64_t *)glwe_array_out, (uint64_t *)ggsw_in,
(uint64_t *)lut_vector, glwe_dimension, polynomial_size, base_log,
l_gadget, r, max_shared_memory);
level_count, r, max_shared_memory);
break;
case 1024:
host_cmux_tree<uint64_t, int64_t, Degree<1024>>(
v_stream, (uint64_t *)glwe_out, (uint64_t *)ggsw_in,
v_stream, (uint64_t *)glwe_array_out, (uint64_t *)ggsw_in,
(uint64_t *)lut_vector, glwe_dimension, polynomial_size, base_log,
l_gadget, r, max_shared_memory);
level_count, r, max_shared_memory);
break;
case 2048:
host_cmux_tree<uint64_t, int64_t, Degree<2048>>(
v_stream, (uint64_t *)glwe_out, (uint64_t *)ggsw_in,
v_stream, (uint64_t *)glwe_array_out, (uint64_t *)ggsw_in,
(uint64_t *)lut_vector, glwe_dimension, polynomial_size, base_log,
l_gadget, r, max_shared_memory);
level_count, r, max_shared_memory);
break;
case 4096:
host_cmux_tree<uint64_t, int64_t, Degree<4096>>(
v_stream, (uint64_t *)glwe_out, (uint64_t *)ggsw_in,
v_stream, (uint64_t *)glwe_array_out, (uint64_t *)ggsw_in,
(uint64_t *)lut_vector, glwe_dimension, polynomial_size, base_log,
l_gadget, r, max_shared_memory);
level_count, r, max_shared_memory);
break;
case 8192:
host_cmux_tree<uint64_t, int64_t, Degree<8192>>(
v_stream, (uint64_t *)glwe_out, (uint64_t *)ggsw_in,
v_stream, (uint64_t *)glwe_array_out, (uint64_t *)ggsw_in,
(uint64_t *)lut_vector, glwe_dimension, polynomial_size, base_log,
l_gadget, r, max_shared_memory);
level_count, r, max_shared_memory);
break;
default:
break;
}
}
void cuda_extract_bits_32(void *v_stream, void *list_lwe_out, void *lwe_in,
void *lwe_in_buffer, void *lwe_in_shifted_buffer,
void *lwe_out_ks_buffer, void *lwe_out_pbs_buffer,
void *lut_pbs, void *lut_vector_indexes, void *ksk,
void *fourier_bsk, uint32_t number_of_bits,
uint32_t delta_log, uint32_t lwe_dimension_before,
uint32_t lwe_dimension_after, uint32_t glwe_dimension,
uint32_t base_log_bsk, uint32_t l_gadget_bsk,
uint32_t base_log_ksk, uint32_t l_gadget_ksk,
uint32_t number_of_samples) {
void cuda_extract_bits_32(
void *v_stream, void *list_lwe_array_out, void *lwe_array_in,
void *lwe_array_in_buffer, void *lwe_array_in_shifted_buffer,
void *lwe_array_out_ks_buffer, void *lwe_array_out_pbs_buffer,
void *lut_pbs, void *lut_vector_indexes, void *ksk, void *fourier_bsk,
uint32_t number_of_bits, uint32_t delta_log, uint32_t lwe_dimension_in,
uint32_t lwe_dimension_out, uint32_t glwe_dimension, uint32_t base_log_bsk,
uint32_t level_count_bsk, uint32_t base_log_ksk, uint32_t level_count_ksk,
uint32_t number_of_samples) {
assert(("Error (GPU extract bits): base log should be <= 16",
base_log_bsk <= 16));
assert(("Error (GPU extract bits): glwe_dimension should be equal to 1",
glwe_dimension == 1));
assert(("Error (GPU extract bits): lwe_dimension_before should be one of "
assert(("Error (GPU extract bits): lwe_dimension_in should be one of "
"512, 1024, 2048",
lwe_dimension_before == 512 || lwe_dimension_before == 1024 ||
lwe_dimension_before == 2048));
lwe_dimension_in == 512 || lwe_dimension_in == 1024 ||
lwe_dimension_in == 2048));
// The number of samples should be lower than the number of streaming
// multiprocessors divided by (4 * (k + 1) * l) (the factor 4 being related
// to the occupancy of 50%). The only supported value for k is 1, so
@@ -137,63 +136,68 @@ void cuda_extract_bits_32(void *v_stream, void *list_lwe_out, void *lwe_in,
assert(("Error (GPU extract bits): the number of input LWEs must be lower or "
"equal to the "
"number of streaming multiprocessors on the device divided by 8 * "
"l_gadget_bsk",
number_of_samples <= number_of_sm / 4. / 2. / l_gadget_bsk));
"level_count_bsk",
number_of_samples <= number_of_sm / 4. / 2. / level_count_bsk));
switch (lwe_dimension_before) {
switch (lwe_dimension_in) {
case 512:
host_extract_bits<uint32_t, Degree<512>>(
v_stream, (uint32_t *)list_lwe_out, (uint32_t *)lwe_in,
(uint32_t *)lwe_in_buffer, (uint32_t *)lwe_in_shifted_buffer,
(uint32_t *)lwe_out_ks_buffer, (uint32_t *)lwe_out_pbs_buffer,
(uint32_t *)lut_pbs, (uint32_t *)lut_vector_indexes, (uint32_t *)ksk,
(double2 *)fourier_bsk, number_of_bits, delta_log, lwe_dimension_before,
lwe_dimension_after, base_log_bsk, l_gadget_bsk, base_log_ksk,
l_gadget_ksk, number_of_samples);
v_stream, (uint32_t *)list_lwe_array_out, (uint32_t *)lwe_array_in,
(uint32_t *)lwe_array_in_buffer,
(uint32_t *)lwe_array_in_shifted_buffer,
(uint32_t *)lwe_array_out_ks_buffer,
(uint32_t *)lwe_array_out_pbs_buffer, (uint32_t *)lut_pbs,
(uint32_t *)lut_vector_indexes, (uint32_t *)ksk, (double2 *)fourier_bsk,
number_of_bits, delta_log, lwe_dimension_in, lwe_dimension_out,
base_log_bsk, level_count_bsk, base_log_ksk, level_count_ksk,
number_of_samples);
break;
case 1024:
host_extract_bits<uint32_t, Degree<1024>>(
v_stream, (uint32_t *)list_lwe_out, (uint32_t *)lwe_in,
(uint32_t *)lwe_in_buffer, (uint32_t *)lwe_in_shifted_buffer,
(uint32_t *)lwe_out_ks_buffer, (uint32_t *)lwe_out_pbs_buffer,
(uint32_t *)lut_pbs, (uint32_t *)lut_vector_indexes, (uint32_t *)ksk,
(double2 *)fourier_bsk, number_of_bits, delta_log, lwe_dimension_before,
lwe_dimension_after, base_log_bsk, l_gadget_bsk, base_log_ksk,
l_gadget_ksk, number_of_samples);
v_stream, (uint32_t *)list_lwe_array_out, (uint32_t *)lwe_array_in,
(uint32_t *)lwe_array_in_buffer,
(uint32_t *)lwe_array_in_shifted_buffer,
(uint32_t *)lwe_array_out_ks_buffer,
(uint32_t *)lwe_array_out_pbs_buffer, (uint32_t *)lut_pbs,
(uint32_t *)lut_vector_indexes, (uint32_t *)ksk, (double2 *)fourier_bsk,
number_of_bits, delta_log, lwe_dimension_in, lwe_dimension_out,
base_log_bsk, level_count_bsk, base_log_ksk, level_count_ksk,
number_of_samples);
break;
case 2048:
host_extract_bits<uint32_t, Degree<2048>>(
v_stream, (uint32_t *)list_lwe_out, (uint32_t *)lwe_in,
(uint32_t *)lwe_in_buffer, (uint32_t *)lwe_in_shifted_buffer,
(uint32_t *)lwe_out_ks_buffer, (uint32_t *)lwe_out_pbs_buffer,
(uint32_t *)lut_pbs, (uint32_t *)lut_vector_indexes, (uint32_t *)ksk,
(double2 *)fourier_bsk, number_of_bits, delta_log, lwe_dimension_before,
lwe_dimension_after, base_log_bsk, l_gadget_bsk, base_log_ksk,
l_gadget_ksk, number_of_samples);
v_stream, (uint32_t *)list_lwe_array_out, (uint32_t *)lwe_array_in,
(uint32_t *)lwe_array_in_buffer,
(uint32_t *)lwe_array_in_shifted_buffer,
(uint32_t *)lwe_array_out_ks_buffer,
(uint32_t *)lwe_array_out_pbs_buffer, (uint32_t *)lut_pbs,
(uint32_t *)lut_vector_indexes, (uint32_t *)ksk, (double2 *)fourier_bsk,
number_of_bits, delta_log, lwe_dimension_in, lwe_dimension_out,
base_log_bsk, level_count_bsk, base_log_ksk, level_count_ksk,
number_of_samples);
break;
default:
break;
}
}
void cuda_extract_bits_64(void *v_stream, void *list_lwe_out, void *lwe_in,
void *lwe_in_buffer, void *lwe_in_shifted_buffer,
void *lwe_out_ks_buffer, void *lwe_out_pbs_buffer,
void *lut_pbs, void *lut_vector_indexes, void *ksk,
void *fourier_bsk, uint32_t number_of_bits,
uint32_t delta_log, uint32_t lwe_dimension_before,
uint32_t lwe_dimension_after, uint32_t glwe_dimension,
uint32_t base_log_bsk, uint32_t l_gadget_bsk,
uint32_t base_log_ksk, uint32_t l_gadget_ksk,
uint32_t number_of_samples) {
void cuda_extract_bits_64(
void *v_stream, void *list_lwe_array_out, void *lwe_array_in,
void *lwe_array_in_buffer, void *lwe_array_in_shifted_buffer,
void *lwe_array_out_ks_buffer, void *lwe_array_out_pbs_buffer,
void *lut_pbs, void *lut_vector_indexes, void *ksk, void *fourier_bsk,
uint32_t number_of_bits, uint32_t delta_log, uint32_t lwe_dimension_in,
uint32_t lwe_dimension_out, uint32_t glwe_dimension, uint32_t base_log_bsk,
uint32_t level_count_bsk, uint32_t base_log_ksk, uint32_t level_count_ksk,
uint32_t number_of_samples) {
assert(("Error (GPU extract bits): base log should be <= 16",
base_log_bsk <= 16));
assert(("Error (GPU extract bits): glwe_dimension should be equal to 1",
glwe_dimension == 1));
assert(("Error (GPU extract bits): lwe_dimension_before should be one of "
assert(("Error (GPU extract bits): lwe_dimension_in should be one of "
"512, 1024, 2048",
lwe_dimension_before == 512 || lwe_dimension_before == 1024 ||
lwe_dimension_before == 2048));
lwe_dimension_in == 512 || lwe_dimension_in == 1024 ||
lwe_dimension_in == 2048));
// The number of samples should be lower than the number of streaming
// multiprocessors divided by (4 * (k + 1) * l) (the factor 4 being related
// to the occupancy of 50%). The only supported value for k is 1, so
@@ -203,39 +207,45 @@ void cuda_extract_bits_64(void *v_stream, void *list_lwe_out, void *lwe_in,
assert(("Error (GPU extract bits): the number of input LWEs must be lower or "
"equal to the "
"number of streaming multiprocessors on the device divided by 8 * "
"l_gadget_bsk",
number_of_samples <= number_of_sm / 4. / 2. / l_gadget_bsk));
"level_count_bsk",
number_of_samples <= number_of_sm / 4. / 2. / level_count_bsk));
switch (lwe_dimension_before) {
switch (lwe_dimension_in) {
case 512:
host_extract_bits<uint64_t, Degree<512>>(
v_stream, (uint64_t *)list_lwe_out, (uint64_t *)lwe_in,
(uint64_t *)lwe_in_buffer, (uint64_t *)lwe_in_shifted_buffer,
(uint64_t *)lwe_out_ks_buffer, (uint64_t *)lwe_out_pbs_buffer,
(uint64_t *)lut_pbs, (uint32_t *)lut_vector_indexes, (uint64_t *)ksk,
(double2 *)fourier_bsk, number_of_bits, delta_log, lwe_dimension_before,
lwe_dimension_after, base_log_bsk, l_gadget_bsk, base_log_ksk,
l_gadget_ksk, number_of_samples);
v_stream, (uint64_t *)list_lwe_array_out, (uint64_t *)lwe_array_in,
(uint64_t *)lwe_array_in_buffer,
(uint64_t *)lwe_array_in_shifted_buffer,
(uint64_t *)lwe_array_out_ks_buffer,
(uint64_t *)lwe_array_out_pbs_buffer, (uint64_t *)lut_pbs,
(uint32_t *)lut_vector_indexes, (uint64_t *)ksk, (double2 *)fourier_bsk,
number_of_bits, delta_log, lwe_dimension_in, lwe_dimension_out,
base_log_bsk, level_count_bsk, base_log_ksk, level_count_ksk,
number_of_samples);
break;
case 1024:
host_extract_bits<uint64_t, Degree<1024>>(
v_stream, (uint64_t *)list_lwe_out, (uint64_t *)lwe_in,
(uint64_t *)lwe_in_buffer, (uint64_t *)lwe_in_shifted_buffer,
(uint64_t *)lwe_out_ks_buffer, (uint64_t *)lwe_out_pbs_buffer,
(uint64_t *)lut_pbs, (uint32_t *)lut_vector_indexes, (uint64_t *)ksk,
(double2 *)fourier_bsk, number_of_bits, delta_log, lwe_dimension_before,
lwe_dimension_after, base_log_bsk, l_gadget_bsk, base_log_ksk,
l_gadget_ksk, number_of_samples);
v_stream, (uint64_t *)list_lwe_array_out, (uint64_t *)lwe_array_in,
(uint64_t *)lwe_array_in_buffer,
(uint64_t *)lwe_array_in_shifted_buffer,
(uint64_t *)lwe_array_out_ks_buffer,
(uint64_t *)lwe_array_out_pbs_buffer, (uint64_t *)lut_pbs,
(uint32_t *)lut_vector_indexes, (uint64_t *)ksk, (double2 *)fourier_bsk,
number_of_bits, delta_log, lwe_dimension_in, lwe_dimension_out,
base_log_bsk, level_count_bsk, base_log_ksk, level_count_ksk,
number_of_samples);
break;
case 2048:
host_extract_bits<uint64_t, Degree<2048>>(
v_stream, (uint64_t *)list_lwe_out, (uint64_t *)lwe_in,
(uint64_t *)lwe_in_buffer, (uint64_t *)lwe_in_shifted_buffer,
(uint64_t *)lwe_out_ks_buffer, (uint64_t *)lwe_out_pbs_buffer,
(uint64_t *)lut_pbs, (uint32_t *)lut_vector_indexes, (uint64_t *)ksk,
(double2 *)fourier_bsk, number_of_bits, delta_log, lwe_dimension_before,
lwe_dimension_after, base_log_bsk, l_gadget_bsk, base_log_ksk,
l_gadget_ksk, number_of_samples);
v_stream, (uint64_t *)list_lwe_array_out, (uint64_t *)lwe_array_in,
(uint64_t *)lwe_array_in_buffer,
(uint64_t *)lwe_array_in_shifted_buffer,
(uint64_t *)lwe_array_out_ks_buffer,
(uint64_t *)lwe_array_out_pbs_buffer, (uint64_t *)lut_pbs,
(uint32_t *)lut_vector_indexes, (uint64_t *)ksk, (double2 *)fourier_bsk,
number_of_bits, delta_log, lwe_dimension_in, lwe_dimension_out,
base_log_bsk, level_count_bsk, base_log_ksk, level_count_ksk,
number_of_samples);
break;
default:
break;

View File

@@ -68,12 +68,12 @@ template <class params> __device__ void ifft_inplace(double2 *data) {
* Receives an array of GLWE ciphertexts and two indexes to ciphertexts in this
* array, and an array of GGSW ciphertexts with a index to one ciphertext in it.
* Compute a CMUX with these operands and writes the output to a particular
* index of glwe_out.
* index of glwe_array_out.
*
* This function needs polynomial_size threads per block.
*
* - glwe_out: An array where the result should be written to.
* - glwe_in: An array where the GLWE inputs are stored.
* - glwe_array_out: An array where the result should be written to.
* - glwe_array_in: An array where the GLWE inputs are stored.
* - ggsw_in: An array where the GGSW input is stored. In the fourier domain.
* - selected_memory: An array to be used for the accumulators. Can be in the
* shared memory or global memory.
@@ -84,15 +84,15 @@ template <class params> __device__ void ifft_inplace(double2 *data) {
* - glwe_dim: This is k.
* - polynomial_size: size of the polynomials. This is N.
* - base_log: log base used for the gadget matrix - B = 2^base_log (~8)
* - l_gadget: number of decomposition levels in the gadget matrix (~4)
* - level_count: number of decomposition levels in the gadget matrix (~4)
* - ggsw_idx: The index of the GGSW we will use.
*/
template <typename Torus, typename STorus, class params>
__device__ void cmux(Torus *glwe_out, Torus *glwe_in, double2 *ggsw_in,
char *selected_memory, uint32_t output_idx,
uint32_t input_idx1, uint32_t input_idx2,
uint32_t glwe_dim, uint32_t polynomial_size,
uint32_t base_log, uint32_t l_gadget, uint32_t ggsw_idx) {
__device__ void
cmux(Torus *glwe_array_out, Torus *glwe_array_in, double2 *ggsw_in,
char *selected_memory, uint32_t output_idx, uint32_t input_idx1,
uint32_t input_idx2, uint32_t glwe_dim, uint32_t polynomial_size,
uint32_t base_log, uint32_t level_count, uint32_t ggsw_idx) {
// Define glwe_sub
Torus *glwe_sub_mask = (Torus *)selected_memory;
@@ -109,18 +109,18 @@ __device__ void cmux(Torus *glwe_out, Torus *glwe_in, double2 *ggsw_in,
double2 *glwe_fft =
(double2 *)body_res_fft + (ptrdiff_t)(polynomial_size / 2);
GadgetMatrix<Torus, params> gadget(base_log, l_gadget);
GadgetMatrix<Torus, params> gadget(base_log, level_count);
/////////////////////////////////////
// glwe2-glwe1
// Copy m0 to shared memory to preserve data
auto m0_mask = &glwe_in[input_idx1 * (glwe_dim + 1) * polynomial_size];
auto m0_mask = &glwe_array_in[input_idx1 * (glwe_dim + 1) * polynomial_size];
auto m0_body = m0_mask + polynomial_size;
// Just gets the pointer for m1 on global memory
auto m1_mask = &glwe_in[input_idx2 * (glwe_dim + 1) * polynomial_size];
auto m1_mask = &glwe_array_in[input_idx2 * (glwe_dim + 1) * polynomial_size];
auto m1_body = m1_mask + polynomial_size;
// Mask
@@ -145,13 +145,11 @@ __device__ void cmux(Torus *glwe_out, Torus *glwe_in, double2 *ggsw_in,
// Subtract each glwe operand, decompose the resulting
// polynomial coefficients to multiply each decomposed level
// with the corresponding part of the LUT
for (int decomp_level = 0; decomp_level < l_gadget; decomp_level++) {
for (int level = 0; level < level_count; level++) {
// Decomposition
gadget.decompose_one_level(glwe_mask_decomposed, glwe_sub_mask,
decomp_level);
gadget.decompose_one_level(glwe_body_decomposed, glwe_sub_body,
decomp_level);
gadget.decompose_one_level(glwe_mask_decomposed, glwe_sub_mask, level);
gadget.decompose_one_level(glwe_body_decomposed, glwe_sub_body, level);
// First, perform the polynomial multiplication for the mask
synchronize_threads_in_block();
@@ -159,12 +157,10 @@ __device__ void cmux(Torus *glwe_out, Torus *glwe_in, double2 *ggsw_in,
// External product and accumulate
// Get the piece necessary for the multiplication
auto mask_fourier =
get_ith_mask_kth_block(ggsw_in, ggsw_idx, 0, decomp_level,
polynomial_size, glwe_dim, l_gadget);
auto body_fourier =
get_ith_body_kth_block(ggsw_in, ggsw_idx, 0, decomp_level,
polynomial_size, glwe_dim, l_gadget);
auto mask_fourier = get_ith_mask_kth_block(
ggsw_in, ggsw_idx, 0, level, polynomial_size, glwe_dim, level_count);
auto body_fourier = get_ith_body_kth_block(
ggsw_in, ggsw_idx, 0, level, polynomial_size, glwe_dim, level_count);
synchronize_threads_in_block();
@@ -182,10 +178,10 @@ __device__ void cmux(Torus *glwe_out, Torus *glwe_in, double2 *ggsw_in,
// External product and accumulate
// Get the piece necessary for the multiplication
mask_fourier = get_ith_mask_kth_block(ggsw_in, ggsw_idx, 1, decomp_level,
polynomial_size, glwe_dim, l_gadget);
body_fourier = get_ith_body_kth_block(ggsw_in, ggsw_idx, 1, decomp_level,
polynomial_size, glwe_dim, l_gadget);
mask_fourier = get_ith_mask_kth_block(
ggsw_in, ggsw_idx, 1, level, polynomial_size, glwe_dim, level_count);
body_fourier = get_ith_body_kth_block(
ggsw_in, ggsw_idx, 1, level, polynomial_size, glwe_dim, level_count);
synchronize_threads_in_block();
@@ -202,7 +198,8 @@ __device__ void cmux(Torus *glwe_out, Torus *glwe_in, double2 *ggsw_in,
synchronize_threads_in_block();
// Write the output
Torus *mb_mask = &glwe_out[output_idx * (glwe_dim + 1) * polynomial_size];
Torus *mb_mask =
&glwe_array_out[output_idx * (glwe_dim + 1) * polynomial_size];
Torus *mb_body = mb_mask + polynomial_size;
int tid = threadIdx.x;
@@ -221,8 +218,8 @@ __device__ void cmux(Torus *glwe_out, Torus *glwe_in, double2 *ggsw_in,
* ciphertext. The GLWE ciphertexts are picked two-by-two in sequence. Each
* thread block computes a single CMUX.
*
* - glwe_out: An array where the result should be written to.
* - glwe_in: An array where the GLWE inputs are stored.
* - glwe_array_out: An array where the result should be written to.
* - glwe_array_in: An array where the GLWE inputs are stored.
* - ggsw_in: An array where the GGSW input is stored. In the fourier domain.
* - device_mem: An pointer for the global memory in case the shared memory is
* not big enough to store the accumulators.
@@ -231,15 +228,15 @@ __device__ void cmux(Torus *glwe_out, Torus *glwe_in, double2 *ggsw_in,
* - glwe_dim: This is k.
* - polynomial_size: size of the polynomials. This is N.
* - base_log: log base used for the gadget matrix - B = 2^base_log (~8)
* - l_gadget: number of decomposition levels in the gadget matrix (~4)
* - level_count: number of decomposition levels in the gadget matrix (~4)
* - ggsw_idx: The index of the GGSW we will use.
*/
template <typename Torus, typename STorus, class params, sharedMemDegree SMD>
__global__ void
device_batch_cmux(Torus *glwe_out, Torus *glwe_in, double2 *ggsw_in,
device_batch_cmux(Torus *glwe_array_out, Torus *glwe_array_in, double2 *ggsw_in,
char *device_mem, size_t device_memory_size_per_block,
uint32_t glwe_dim, uint32_t polynomial_size,
uint32_t base_log, uint32_t l_gadget, uint32_t ggsw_idx) {
uint32_t base_log, uint32_t level_count, uint32_t ggsw_idx) {
int cmux_idx = blockIdx.x;
int output_idx = cmux_idx;
@@ -255,9 +252,10 @@ device_batch_cmux(Torus *glwe_out, Torus *glwe_in, double2 *ggsw_in,
else
selected_memory = &device_mem[blockIdx.x * device_memory_size_per_block];
cmux<Torus, STorus, params>(glwe_out, glwe_in, ggsw_in, selected_memory,
output_idx, input_idx1, input_idx2, glwe_dim,
polynomial_size, base_log, l_gadget, ggsw_idx);
cmux<Torus, STorus, params>(glwe_array_out, glwe_array_in, ggsw_in,
selected_memory, output_idx, input_idx1,
input_idx2, glwe_dim, polynomial_size, base_log,
level_count, ggsw_idx);
}
/*
* This kernel executes the CMUX tree used by the hybrid packing of the WoPBS.
@@ -265,20 +263,21 @@ device_batch_cmux(Torus *glwe_out, Torus *glwe_in, double2 *ggsw_in,
* Uses shared memory for intermediate results
*
* - v_stream: The CUDA stream that should be used.
* - glwe_out: A device array for the output GLWE ciphertext.
* - glwe_array_out: A device array for the output GLWE ciphertext.
* - ggsw_in: A device array for the GGSW ciphertexts used in each layer.
* - lut_vector: A device array for the GLWE ciphertexts used in the first
* layer.
* - polynomial_size: size of the polynomials. This is N.
* - base_log: log base used for the gadget matrix - B = 2^base_log (~8)
* - l_gadget: number of decomposition levels in the gadget matrix (~4)
* - level_count: number of decomposition levels in the gadget matrix (~4)
* - r: Number of layers in the tree.
*/
template <typename Torus, typename STorus, class params>
void host_cmux_tree(void *v_stream, Torus *glwe_out, Torus *ggsw_in,
void host_cmux_tree(void *v_stream, Torus *glwe_array_out, Torus *ggsw_in,
Torus *lut_vector, uint32_t glwe_dimension,
uint32_t polynomial_size, uint32_t base_log,
uint32_t l_gadget, uint32_t r, uint32_t max_shared_memory) {
uint32_t level_count, uint32_t r,
uint32_t max_shared_memory) {
auto stream = static_cast<cudaStream_t *>(v_stream);
int num_lut = (1 << r);
@@ -299,7 +298,7 @@ void host_cmux_tree(void *v_stream, Torus *glwe_out, Torus *ggsw_in,
//////////////////////
double2 *d_ggsw_fft_in;
int ggsw_size = r * polynomial_size * (glwe_dimension + 1) *
(glwe_dimension + 1) * l_gadget;
(glwe_dimension + 1) * level_count;
#if (CUDART_VERSION < 11020)
checkCudaErrors(
@@ -311,7 +310,7 @@ void host_cmux_tree(void *v_stream, Torus *glwe_out, Torus *ggsw_in,
batch_fft_ggsw_vector<Torus, STorus, params>(v_stream, d_ggsw_fft_in, ggsw_in,
r, glwe_dimension,
polynomial_size, l_gadget);
polynomial_size, level_count);
//////////////////////
@@ -368,7 +367,7 @@ void host_cmux_tree(void *v_stream, Torus *glwe_out, Torus *ggsw_in,
<<<grid, thds, memory_needed_per_block, *stream>>>(
output, input, d_ggsw_fft_in, d_mem, memory_needed_per_block,
glwe_dimension, // k
polynomial_size, base_log, l_gadget,
polynomial_size, base_log, level_count,
layer_idx // r
);
else
@@ -376,17 +375,18 @@ void host_cmux_tree(void *v_stream, Torus *glwe_out, Torus *ggsw_in,
<<<grid, thds, memory_needed_per_block, *stream>>>(
output, input, d_ggsw_fft_in, d_mem, memory_needed_per_block,
glwe_dimension, // k
polynomial_size, base_log, l_gadget,
polynomial_size, base_log, level_count,
layer_idx // r
);
}
checkCudaErrors(cudaMemcpyAsync(
glwe_out, output, (glwe_dimension + 1) * polynomial_size * sizeof(Torus),
cudaMemcpyDeviceToDevice, *stream));
checkCudaErrors(
cudaMemcpyAsync(glwe_array_out, output,
(glwe_dimension + 1) * polynomial_size * sizeof(Torus),
cudaMemcpyDeviceToDevice, *stream));
// We only need synchronization to assert that data is in glwe_out before
// returning. Memory release can be added to the stream and processed
// We only need synchronization to assert that data is in glwe_array_out
// before returning. Memory release can be added to the stream and processed
// later.
checkCudaErrors(cudaStreamSynchronize(*stream));
@@ -492,36 +492,38 @@ __global__ void fill_lut_body_for_current_bit(Torus *lut, Torus value) {
// instead of alpha= 1ll << (ciphertext_n_bits - delta_log - bit_idx - 1)
template <typename Torus, class params>
__global__ void add_sub_and_mul_lwe(Torus *shifted_lwe, Torus *state_lwe,
Torus *pbs_lwe_out, Torus add_value,
Torus *pbs_lwe_array_out, Torus add_value,
Torus mul_value) {
size_t tid = threadIdx.x;
size_t blockId = blockIdx.x;
auto cur_shifted_lwe = &shifted_lwe[blockId * (params::degree + 1)];
auto cur_state_lwe = &state_lwe[blockId * (params::degree + 1)];
auto cur_pbs_lwe_out = &pbs_lwe_out[blockId * (params::degree + 1)];
auto cur_pbs_lwe_array_out =
&pbs_lwe_array_out[blockId * (params::degree + 1)];
#pragma unroll
for (int i = 0; i < params::opt; i++) {
cur_shifted_lwe[tid] = cur_state_lwe[tid] -= cur_pbs_lwe_out[tid];
cur_shifted_lwe[tid] = cur_state_lwe[tid] -= cur_pbs_lwe_array_out[tid];
cur_shifted_lwe[tid] *= mul_value;
tid += params::degree / params::opt;
}
if (threadIdx.x == params::degree / params::opt - 1) {
cur_shifted_lwe[params::degree] = cur_state_lwe[params::degree] -=
(cur_pbs_lwe_out[params::degree] + add_value);
(cur_pbs_lwe_array_out[params::degree] + add_value);
cur_shifted_lwe[params::degree] *= mul_value;
}
}
template <typename Torus, class params>
__host__ void host_extract_bits(
void *v_stream, Torus *list_lwe_out, Torus *lwe_in, Torus *lwe_in_buffer,
Torus *lwe_in_shifted_buffer, Torus *lwe_out_ks_buffer,
Torus *lwe_out_pbs_buffer, Torus *lut_pbs, uint32_t *lut_vector_indexes,
Torus *ksk, double2 *fourier_bsk, uint32_t number_of_bits,
uint32_t delta_log, uint32_t lwe_dimension_before,
uint32_t lwe_dimension_after, uint32_t base_log_bsk, uint32_t l_gadget_bsk,
uint32_t base_log_ksk, uint32_t l_gadget_ksk, uint32_t number_of_samples) {
void *v_stream, Torus *list_lwe_array_out, Torus *lwe_array_in,
Torus *lwe_array_in_buffer, Torus *lwe_array_in_shifted_buffer,
Torus *lwe_array_out_ks_buffer, Torus *lwe_array_out_pbs_buffer,
Torus *lut_pbs, uint32_t *lut_vector_indexes, Torus *ksk,
double2 *fourier_bsk, uint32_t number_of_bits, uint32_t delta_log,
uint32_t lwe_dimension_in, uint32_t lwe_dimension_out,
uint32_t base_log_bsk, uint32_t level_count_bsk, uint32_t base_log_ksk,
uint32_t level_count_ksk, uint32_t number_of_samples) {
auto stream = static_cast<cudaStream_t *>(v_stream);
uint32_t ciphertext_n_bits = sizeof(Torus) * 8;
@@ -530,38 +532,38 @@ __host__ void host_extract_bits(
int threads = params::degree / params::opt;
copy_and_shift_lwe<Torus, params><<<blocks, threads, 0, *stream>>>(
lwe_in_buffer, lwe_in_shifted_buffer, lwe_in,
lwe_array_in_buffer, lwe_array_in_shifted_buffer, lwe_array_in,
1ll << (ciphertext_n_bits - delta_log - 1));
for (int bit_idx = 0; bit_idx < number_of_bits; bit_idx++) {
cuda_keyswitch_lwe_ciphertext_vector(
v_stream, lwe_out_ks_buffer, lwe_in_shifted_buffer, ksk,
lwe_dimension_before, lwe_dimension_after, base_log_ksk, l_gadget_ksk,
1);
v_stream, lwe_array_out_ks_buffer, lwe_array_in_shifted_buffer, ksk,
lwe_dimension_in, lwe_dimension_out, base_log_ksk, level_count_ksk, 1);
copy_small_lwe<<<1, 256, 0, *stream>>>(
list_lwe_out, lwe_out_ks_buffer, lwe_dimension_after + 1,
list_lwe_array_out, lwe_array_out_ks_buffer, lwe_dimension_out + 1,
number_of_bits, number_of_bits - bit_idx - 1);
if (bit_idx == number_of_bits - 1) {
break;
}
add_to_body<Torus><<<1, 1, 0, *stream>>>(
lwe_out_ks_buffer, lwe_dimension_after, 1ll << (ciphertext_n_bits - 2));
add_to_body<Torus><<<1, 1, 0, *stream>>>(lwe_array_out_ks_buffer,
lwe_dimension_out,
1ll << (ciphertext_n_bits - 2));
fill_lut_body_for_current_bit<Torus, params>
<<<blocks, threads, 0, *stream>>>(
lut_pbs, 0ll - 1ll << (delta_log - 1 + bit_idx));
host_bootstrap_low_latency<Torus, params>(
v_stream, lwe_out_pbs_buffer, lut_pbs, lut_vector_indexes,
lwe_out_ks_buffer, fourier_bsk, lwe_dimension_after,
lwe_dimension_before, base_log_bsk, l_gadget_bsk, number_of_samples, 1);
v_stream, lwe_array_out_pbs_buffer, lut_pbs, lut_vector_indexes,
lwe_array_out_ks_buffer, fourier_bsk, lwe_dimension_out,
lwe_dimension_in, base_log_bsk, level_count_bsk, number_of_samples, 1);
add_sub_and_mul_lwe<Torus, params><<<1, threads, 0, *stream>>>(
lwe_in_shifted_buffer, lwe_in_buffer, lwe_out_pbs_buffer,
1ll << (delta_log - 1 + bit_idx),
lwe_array_in_shifted_buffer, lwe_array_in_buffer,
lwe_array_out_pbs_buffer, 1ll << (delta_log - 1 + bit_idx),
1ll << (ciphertext_n_bits - delta_log - bit_idx - 2));
}
}

View File

@@ -9,16 +9,17 @@
__device__ inline int get_start_ith_ggsw(int i, uint32_t polynomial_size,
int glwe_dimension,
uint32_t l_gadget) {
uint32_t level_count) {
return i * polynomial_size / 2 * (glwe_dimension + 1) * (glwe_dimension + 1) *
l_gadget;
level_count;
}
template <typename T>
__device__ T *get_ith_mask_kth_block(T *ptr, int i, int k, int level,
uint32_t polynomial_size,
int glwe_dimension, uint32_t l_gadget) {
return &ptr[get_start_ith_ggsw(i, polynomial_size, glwe_dimension, l_gadget) +
int glwe_dimension, uint32_t level_count) {
return &ptr[get_start_ith_ggsw(i, polynomial_size, glwe_dimension,
level_count) +
level * polynomial_size / 2 * (glwe_dimension + 1) *
(glwe_dimension + 1) +
k * polynomial_size / 2 * (glwe_dimension + 1)];
@@ -27,8 +28,9 @@ __device__ T *get_ith_mask_kth_block(T *ptr, int i, int k, int level,
template <typename T>
__device__ T *get_ith_body_kth_block(T *ptr, int i, int k, int level,
uint32_t polynomial_size,
int glwe_dimension, uint32_t l_gadget) {
return &ptr[get_start_ith_ggsw(i, polynomial_size, glwe_dimension, l_gadget) +
int glwe_dimension, uint32_t level_count) {
return &ptr[get_start_ith_ggsw(i, polynomial_size, glwe_dimension,
level_count) +
level * polynomial_size / 2 * (glwe_dimension + 1) *
(glwe_dimension + 1) +
k * polynomial_size / 2 * (glwe_dimension + 1) +
@@ -69,14 +71,14 @@ void cuda_initialize_twiddles(uint32_t polynomial_size, uint32_t gpu_index) {
template <typename T, typename ST>
void cuda_convert_lwe_bootstrap_key(double2 *dest, ST *src, void *v_stream,
uint32_t gpu_index, uint32_t input_lwe_dim,
uint32_t glwe_dim, uint32_t l_gadget,
uint32_t glwe_dim, uint32_t level_count,
uint32_t polynomial_size) {
cudaSetDevice(gpu_index);
int shared_memory_size = sizeof(double) * polynomial_size;
int total_polynomials =
input_lwe_dim * (glwe_dim + 1) * (glwe_dim + 1) * l_gadget;
input_lwe_dim * (glwe_dim + 1) * (glwe_dim + 1) * level_count;
// Here the buffer size is the size of double2 times the number of polynomials
// times the polynomial size over 2 because the polynomials are compressed
@@ -142,21 +144,21 @@ void cuda_convert_lwe_bootstrap_key(double2 *dest, ST *src, void *v_stream,
void cuda_convert_lwe_bootstrap_key_32(void *dest, void *src, void *v_stream,
uint32_t gpu_index,
uint32_t input_lwe_dim,
uint32_t glwe_dim, uint32_t l_gadget,
uint32_t glwe_dim, uint32_t level_count,
uint32_t polynomial_size) {
cuda_convert_lwe_bootstrap_key<uint32_t, int32_t>(
(double2 *)dest, (int32_t *)src, v_stream, gpu_index, input_lwe_dim,
glwe_dim, l_gadget, polynomial_size);
glwe_dim, level_count, polynomial_size);
}
void cuda_convert_lwe_bootstrap_key_64(void *dest, void *src, void *v_stream,
uint32_t gpu_index,
uint32_t input_lwe_dim,
uint32_t glwe_dim, uint32_t l_gadget,
uint32_t glwe_dim, uint32_t level_count,
uint32_t polynomial_size) {
cuda_convert_lwe_bootstrap_key<uint64_t, int64_t>(
(double2 *)dest, (int64_t *)src, v_stream, gpu_index, input_lwe_dim,
glwe_dim, l_gadget, polynomial_size);
glwe_dim, level_count, polynomial_size);
}
// We need these lines so the compiler knows how to specialize these functions
@@ -164,31 +166,31 @@ template __device__ uint64_t *get_ith_mask_kth_block(uint64_t *ptr, int i,
int k, int level,
uint32_t polynomial_size,
int glwe_dimension,
uint32_t l_gadget);
uint32_t level_count);
template __device__ uint32_t *get_ith_mask_kth_block(uint32_t *ptr, int i,
int k, int level,
uint32_t polynomial_size,
int glwe_dimension,
uint32_t l_gadget);
uint32_t level_count);
template __device__ double2 *get_ith_mask_kth_block(double2 *ptr, int i, int k,
int level,
uint32_t polynomial_size,
int glwe_dimension,
uint32_t l_gadget);
uint32_t level_count);
template __device__ uint64_t *get_ith_body_kth_block(uint64_t *ptr, int i,
int k, int level,
uint32_t polynomial_size,
int glwe_dimension,
uint32_t l_gadget);
uint32_t level_count);
template __device__ uint32_t *get_ith_body_kth_block(uint32_t *ptr, int i,
int k, int level,
uint32_t polynomial_size,
int glwe_dimension,
uint32_t l_gadget);
uint32_t level_count);
template __device__ double2 *get_ith_body_kth_block(double2 *ptr, int i, int k,
int level,
uint32_t polynomial_size,
int glwe_dimension,
uint32_t l_gadget);
uint32_t level_count);
#endif // CNCRT_BSK_H

View File

@@ -7,20 +7,20 @@
#pragma once
template <typename T, class params> class GadgetMatrix {
private:
uint32_t l_gadget;
uint32_t level_count;
uint32_t base_log;
uint32_t mask;
uint32_t halfbg;
T offset;
public:
__device__ GadgetMatrix(uint32_t base_log, uint32_t l_gadget)
: base_log(base_log), l_gadget(l_gadget) {
__device__ GadgetMatrix(uint32_t base_log, uint32_t level_count)
: base_log(base_log), level_count(level_count) {
uint32_t bg = 1 << base_log;
this->halfbg = bg / 2;
this->mask = bg - 1;
T temp = 0;
for (int i = 0; i < this->l_gadget; i++) {
for (int i = 0; i < this->level_count; i++) {
temp += 1ULL << (sizeof(T) * 8 - (i + 1) * this->base_log);
}
this->offset = temp * this->halfbg;
@@ -62,20 +62,20 @@ public:
template <typename T> class GadgetMatrixSingle {
private:
uint32_t l_gadget;
uint32_t level_count;
uint32_t base_log;
uint32_t mask;
uint32_t halfbg;
T offset;
public:
__device__ GadgetMatrixSingle(uint32_t base_log, uint32_t l_gadget)
: base_log(base_log), l_gadget(l_gadget) {
__device__ GadgetMatrixSingle(uint32_t base_log, uint32_t level_count)
: base_log(base_log), level_count(level_count) {
uint32_t bg = 1 << base_log;
this->halfbg = bg / 2;
this->mask = bg - 1;
T temp = 0;
for (int i = 0; i < this->l_gadget; i++) {
for (int i = 0; i < this->level_count; i++) {
temp += 1ULL << (sizeof(T) * 8 - (i + 1) * this->base_log);
}
this->offset = temp * this->halfbg;

View File

@@ -44,13 +44,13 @@ __global__ void batch_fft_ggsw_vectors(double2 *dest, T *src) {
template <typename T, typename ST, class params>
void batch_fft_ggsw_vector(void *v_stream, double2 *dest, T *src, uint32_t r,
uint32_t glwe_dim, uint32_t polynomial_size,
uint32_t l_gadget) {
uint32_t level_count) {
auto stream = static_cast<cudaStream_t *>(v_stream);
int shared_memory_size = sizeof(double) * polynomial_size;
int total_polynomials = r * (glwe_dim + 1) * (glwe_dim + 1) * l_gadget;
int total_polynomials = r * (glwe_dim + 1) * (glwe_dim + 1) * level_count;
int gridSize = total_polynomials;
int blockSize = polynomial_size / params::opt;

View File

@@ -23,8 +23,8 @@ __device__ inline Torus typecast_double_to_torus(double x) {
template <typename T>
__device__ inline T round_to_closest_multiple(T x, uint32_t base_log,
uint32_t l_gadget) {
T shift = sizeof(T) * 8 - l_gadget * base_log;
uint32_t level_count) {
T shift = sizeof(T) * 8 - level_count * base_log;
T mask = 1ll << (shift - 1);
T b = (x & mask) >> (shift - 1);
T res = x >> shift;

View File

@@ -6,42 +6,40 @@
/* Perform keyswitch on a batch of input LWE ciphertexts for 32 bits
*
* - lwe_out: output batch of num_samples keyswitched ciphertexts c =
* - lwe_array_out: output batch of num_samples keyswitched ciphertexts c =
* (a0,..an-1,b) where n is the LWE dimension
* - lwe_in: input batch of num_samples LWE ciphertexts, containing n
* - lwe_array_in: input batch of num_samples LWE ciphertexts, containing n
* mask values + 1 body value
*
* This function calls a wrapper to a device kernel that performs the keyswitch
* - num_samples blocks of threads are launched
*/
void cuda_keyswitch_lwe_ciphertext_vector_32(
void *v_stream, void *lwe_out, void *lwe_in, void *ksk,
uint32_t lwe_dimension_before, uint32_t lwe_dimension_after,
uint32_t base_log, uint32_t l_gadget, uint32_t num_samples) {
void *v_stream, void *lwe_array_out, void *lwe_array_in, void *ksk,
uint32_t lwe_dimension_in, uint32_t lwe_dimension_out, uint32_t base_log,
uint32_t level_count, uint32_t num_samples) {
cuda_keyswitch_lwe_ciphertext_vector(
v_stream, static_cast<uint32_t *>(lwe_out),
static_cast<uint32_t *>(lwe_in), static_cast<uint32_t *>(ksk),
lwe_dimension_before, lwe_dimension_after, base_log, l_gadget,
num_samples);
v_stream, static_cast<uint32_t *>(lwe_array_out),
static_cast<uint32_t *>(lwe_array_in), static_cast<uint32_t *>(ksk),
lwe_dimension_in, lwe_dimension_out, base_log, level_count, num_samples);
}
/* Perform keyswitch on a batch of input LWE ciphertexts for 64 bits
*
* - lwe_out: output batch of num_samples keyswitched ciphertexts c =
* - lwe_array_out: output batch of num_samples keyswitched ciphertexts c =
* (a0,..an-1,b) where n is the LWE dimension
* - lwe_in: input batch of num_samples LWE ciphertexts, containing n
* - lwe_array_in: input batch of num_samples LWE ciphertexts, containing n
* mask values + 1 body value
*
* This function calls a wrapper to a device kernel that performs the keyswitch
* - num_samples blocks of threads are launched
*/
void cuda_keyswitch_lwe_ciphertext_vector_64(
void *v_stream, void *lwe_out, void *lwe_in, void *ksk,
uint32_t lwe_dimension_before, uint32_t lwe_dimension_after,
uint32_t base_log, uint32_t l_gadget, uint32_t num_samples) {
void *v_stream, void *lwe_array_out, void *lwe_array_in, void *ksk,
uint32_t lwe_dimension_in, uint32_t lwe_dimension_out, uint32_t base_log,
uint32_t level_count, uint32_t num_samples) {
cuda_keyswitch_lwe_ciphertext_vector(
v_stream, static_cast<uint64_t *>(lwe_out),
static_cast<uint64_t *>(lwe_in), static_cast<uint64_t *>(ksk),
lwe_dimension_before, lwe_dimension_after, base_log, l_gadget,
num_samples);
v_stream, static_cast<uint64_t *>(lwe_array_out),
static_cast<uint64_t *>(lwe_array_in), static_cast<uint64_t *>(ksk),
lwe_dimension_in, lwe_dimension_out, base_log, level_count, num_samples);
}

View File

@@ -9,10 +9,10 @@
template <typename Torus>
__device__ Torus *get_ith_block(Torus *ksk, int i, int level,
uint32_t lwe_dimension_after,
uint32_t l_gadget) {
int pos = i * l_gadget * (lwe_dimension_after + 1) +
level * (lwe_dimension_after + 1);
uint32_t lwe_dimension_out,
uint32_t level_count) {
int pos = i * level_count * (lwe_dimension_out + 1) +
level * (lwe_dimension_out + 1);
Torus *ptr = &ksk[pos];
return ptr;
}
@@ -42,21 +42,22 @@ __device__ Torus decompose_one(Torus &state, Torus mod_b_mask, int base_log) {
*
*/
template <typename Torus>
__global__ void keyswitch(Torus *lwe_out, Torus *lwe_in, Torus *ksk,
uint32_t lwe_dimension_before,
uint32_t lwe_dimension_after, uint32_t base_log,
uint32_t l_gadget, int lwe_lower, int lwe_upper,
int cutoff) {
__global__ void keyswitch(Torus *lwe_array_out, Torus *lwe_array_in, Torus *ksk,
uint32_t lwe_dimension_in, uint32_t lwe_dimension_out,
uint32_t base_log, uint32_t level_count,
int lwe_lower, int lwe_upper, int cutoff) {
int tid = threadIdx.x;
extern __shared__ char sharedmem[];
Torus *local_lwe_out = (Torus *)sharedmem;
Torus *local_lwe_array_out = (Torus *)sharedmem;
auto block_lwe_in = get_chunk(lwe_in, blockIdx.x, lwe_dimension_before + 1);
auto block_lwe_out = get_chunk(lwe_out, blockIdx.x, lwe_dimension_after + 1);
auto block_lwe_array_in =
get_chunk(lwe_array_in, blockIdx.x, lwe_dimension_in + 1);
auto block_lwe_array_out =
get_chunk(lwe_array_out, blockIdx.x, lwe_dimension_out + 1);
auto gadget = GadgetMatrixSingle<Torus>(base_log, l_gadget);
auto gadget = GadgetMatrixSingle<Torus>(base_log, level_count);
int lwe_part_per_thd;
if (tid < cutoff) {
@@ -68,49 +69,51 @@ __global__ void keyswitch(Torus *lwe_out, Torus *lwe_in, Torus *ksk,
for (int k = 0; k < lwe_part_per_thd; k++) {
int idx = tid + k * blockDim.x;
local_lwe_out[idx] = 0;
local_lwe_array_out[idx] = 0;
}
if (tid == 0) {
local_lwe_out[lwe_dimension_after] = block_lwe_in[lwe_dimension_before];
local_lwe_array_out[lwe_dimension_out] =
block_lwe_array_in[lwe_dimension_in];
}
for (int i = 0; i < lwe_dimension_before; i++) {
for (int i = 0; i < lwe_dimension_in; i++) {
__syncthreads();
Torus a_i = round_to_closest_multiple(block_lwe_in[i], base_log, l_gadget);
Torus a_i =
round_to_closest_multiple(block_lwe_array_in[i], base_log, level_count);
Torus state = a_i >> (sizeof(Torus) * 8 - base_log * l_gadget);
Torus state = a_i >> (sizeof(Torus) * 8 - base_log * level_count);
Torus mod_b_mask = (1ll << base_log) - 1ll;
for (int j = 0; j < l_gadget; j++) {
auto ksk_block = get_ith_block(ksk, i, l_gadget - j - 1,
lwe_dimension_after, l_gadget);
for (int j = 0; j < level_count; j++) {
auto ksk_block = get_ith_block(ksk, i, level_count - j - 1,
lwe_dimension_out, level_count);
Torus decomposed = decompose_one<Torus>(state, mod_b_mask, base_log);
for (int k = 0; k < lwe_part_per_thd; k++) {
int idx = tid + k * blockDim.x;
local_lwe_out[idx] -= (Torus)ksk_block[idx] * decomposed;
local_lwe_array_out[idx] -= (Torus)ksk_block[idx] * decomposed;
}
}
}
for (int k = 0; k < lwe_part_per_thd; k++) {
int idx = tid + k * blockDim.x;
block_lwe_out[idx] = local_lwe_out[idx];
block_lwe_array_out[idx] = local_lwe_array_out[idx];
}
}
/// assume lwe_in in the gpu
/// assume lwe_array_in in the gpu
template <typename Torus>
__host__ void cuda_keyswitch_lwe_ciphertext_vector(
void *v_stream, Torus *lwe_out, Torus *lwe_in, Torus *ksk,
uint32_t lwe_dimension_before, uint32_t lwe_dimension_after,
uint32_t base_log, uint32_t l_gadget, uint32_t num_samples) {
void *v_stream, Torus *lwe_array_out, Torus *lwe_array_in, Torus *ksk,
uint32_t lwe_dimension_in, uint32_t lwe_dimension_out, uint32_t base_log,
uint32_t level_count, uint32_t num_samples) {
constexpr int ideal_threads = 128;
int lwe_dim = lwe_dimension_after + 1;
int lwe_dim = lwe_dimension_out + 1;
int lwe_lower, lwe_upper, cutoff;
if (lwe_dim % ideal_threads == 0) {
lwe_lower = lwe_dim / ideal_threads;
@@ -124,11 +127,11 @@ __host__ void cuda_keyswitch_lwe_ciphertext_vector(
lwe_upper = (int)ceil((double)lwe_dim / (double)ideal_threads);
}
int lwe_size_after = (lwe_dimension_after + 1) * num_samples;
int lwe_size_after = (lwe_dimension_out + 1) * num_samples;
int shared_mem = sizeof(Torus) * (lwe_dimension_after + 1);
int shared_mem = sizeof(Torus) * (lwe_dimension_out + 1);
cudaMemset(lwe_out, 0, sizeof(Torus) * lwe_size_after);
cudaMemset(lwe_array_out, 0, sizeof(Torus) * lwe_size_after);
dim3 grid(num_samples, 1, 1);
dim3 threads(ideal_threads, 1, 1);
@@ -138,8 +141,8 @@ __host__ void cuda_keyswitch_lwe_ciphertext_vector(
auto stream = static_cast<cudaStream_t *>(v_stream);
keyswitch<<<grid, threads, shared_mem, *stream>>>(
lwe_out, lwe_in, ksk, lwe_dimension_before, lwe_dimension_after, base_log,
l_gadget, lwe_lower, lwe_upper, cutoff);
lwe_array_out, lwe_array_in, ksk, lwe_dimension_in, lwe_dimension_out,
base_log, level_count, lwe_lower, lwe_upper, cutoff);
cudaStreamSynchronize(*stream);
}

View File

@@ -145,12 +145,12 @@ multiply_by_monomial_negacyclic_and_sub_polynomial(T *acc, T *result_acc,
*/
template <typename T, int elems_per_thread, int block_size>
__device__ void round_to_closest_multiple_inplace(T *rotated_acc, int base_log,
int l_gadget) {
int level_count) {
int tid = threadIdx.x;
for (int i = 0; i < elems_per_thread; i++) {
T x_acc = rotated_acc[tid];
T shift = sizeof(T) * 8 - l_gadget * base_log;
T shift = sizeof(T) * 8 - level_count * base_log;
T mask = 1ll << (shift - 1);
T b_acc = (x_acc & mask) >> (shift - 1);
T res_acc = x_acc >> shift;
@@ -191,13 +191,13 @@ __device__ void add_to_torus(double2 *m_values, Torus *result) {
}
template <typename Torus, class params>
__device__ void sample_extract_body(Torus *lwe_out, Torus *accumulator) {
__device__ void sample_extract_body(Torus *lwe_array_out, Torus *accumulator) {
// Set first coefficient of the accumulator as the body of the LWE sample
lwe_out[params::degree] = accumulator[0];
lwe_array_out[params::degree] = accumulator[0];
}
template <typename Torus, class params>
__device__ void sample_extract_mask(Torus *lwe_out, Torus *accumulator) {
__device__ void sample_extract_mask(Torus *lwe_array_out, Torus *accumulator) {
// Set ACC = -ACC
// accumulator.negate_inplace();
@@ -257,12 +257,12 @@ __device__ void sample_extract_mask(Torus *lwe_out, Torus *accumulator) {
synchronize_threads_in_block();
// Copy to the mask of the LWE sample
// accumulator.copy_into(lwe_out);
// accumulator.copy_into(lwe_array_out);
tid = threadIdx.x;
#pragma unroll
for (int i = 0; i < params::opt; i++) {
lwe_out[tid] = accumulator[tid];
lwe_array_out[tid] = accumulator[tid];
tid = tid + params::degree / params::opt;
}
}

View File

@@ -369,13 +369,13 @@ public:
}
__device__ void round_to_closest_multiple_inplace(uint32_t base_log,
uint32_t l_gadget) {
uint32_t level_count) {
int tid = threadIdx.x;
#pragma unroll
for (int i = 0; i < params::opt; i++) {
T x = coefficients[tid];
T shift = sizeof(T) * 8 - l_gadget * base_log;
T shift = sizeof(T) * 8 - level_count * base_log;
T mask = 1ll << (shift - 1);
T b = (x & mask) >> (shift - 1);
T res = x >> shift;