mirror of
https://github.com/zama-ai/concrete.git
synced 2026-02-09 12:15:09 -05:00
feat(cuda): Adds a parameter in the CUDA host functions passing the gpu index that should be used.
This commit is contained in:
@@ -20,69 +20,71 @@ void cuda_convert_lwe_bootstrap_key_64(void *dest, void *src, void *v_stream,
|
||||
uint32_t polynomial_size);
|
||||
|
||||
void cuda_bootstrap_amortized_lwe_ciphertext_vector_32(
|
||||
void *v_stream, void *lwe_array_out, void *test_vector,
|
||||
void *v_stream, uint32_t gpu_index, void *lwe_array_out, void *test_vector,
|
||||
void *test_vector_indexes, void *lwe_array_in, void *bootstrapping_key,
|
||||
uint32_t lwe_dimension, uint32_t glwe_dimension, uint32_t polynomial_size,
|
||||
uint32_t base_log, uint32_t level_count, uint32_t num_samples,
|
||||
uint32_t num_test_vectors, uint32_t lwe_idx, uint32_t max_shared_memory);
|
||||
|
||||
void cuda_bootstrap_amortized_lwe_ciphertext_vector_64(
|
||||
void *v_stream, void *lwe_array_out, void *test_vector,
|
||||
void *v_stream, uint32_t gpu_index, void *lwe_array_out, void *test_vector,
|
||||
void *test_vector_indexes, void *lwe_array_in, void *bootstrapping_key,
|
||||
uint32_t lwe_dimension, uint32_t glwe_dimension, uint32_t polynomial_size,
|
||||
uint32_t base_log, uint32_t level_count, uint32_t num_samples,
|
||||
uint32_t num_test_vectors, uint32_t lwe_idx, uint32_t max_shared_memory);
|
||||
|
||||
void cuda_bootstrap_low_latency_lwe_ciphertext_vector_32(
|
||||
void *v_stream, void *lwe_array_out, void *test_vector,
|
||||
void *v_stream, uint32_t gpu_index, void *lwe_array_out, void *test_vector,
|
||||
void *test_vector_indexes, void *lwe_array_in, void *bootstrapping_key,
|
||||
uint32_t lwe_dimension, uint32_t glwe_dimension, uint32_t polynomial_size,
|
||||
uint32_t base_log, uint32_t level_count, uint32_t num_samples,
|
||||
uint32_t num_test_vectors, uint32_t lwe_idx, uint32_t max_shared_memory);
|
||||
|
||||
void cuda_bootstrap_low_latency_lwe_ciphertext_vector_64(
|
||||
void *v_stream, void *lwe_array_out, void *test_vector,
|
||||
void *v_stream, uint32_t gpu_index, void *lwe_array_out, void *test_vector,
|
||||
void *test_vector_indexes, void *lwe_array_in, void *bootstrapping_key,
|
||||
uint32_t lwe_dimension, uint32_t glwe_dimension, uint32_t polynomial_size,
|
||||
uint32_t base_log, uint32_t level_count, uint32_t num_samples,
|
||||
uint32_t num_test_vectors, uint32_t lwe_idx, uint32_t max_shared_memory);
|
||||
|
||||
void cuda_cmux_tree_32(void *v_stream, void *glwe_array_out, void *ggsw_in,
|
||||
void *lut_vector, uint32_t glwe_dimension,
|
||||
void cuda_cmux_tree_32(void *v_stream, uint32_t gpu_index, void *glwe_array_out,
|
||||
void *ggsw_in, void *lut_vector, uint32_t glwe_dimension,
|
||||
uint32_t polynomial_size, uint32_t base_log,
|
||||
uint32_t level_count, uint32_t r,
|
||||
uint32_t max_shared_memory);
|
||||
|
||||
void cuda_cmux_tree_64(void *v_stream, void *glwe_array_out, void *ggsw_in,
|
||||
void *lut_vector, uint32_t glwe_dimension,
|
||||
void cuda_cmux_tree_64(void *v_stream, uint32_t gpu_index, void *glwe_array_out,
|
||||
void *ggsw_in, void *lut_vector, uint32_t glwe_dimension,
|
||||
uint32_t polynomial_size, uint32_t base_log,
|
||||
uint32_t level_count, uint32_t r,
|
||||
uint32_t max_shared_memory);
|
||||
|
||||
void cuda_blind_rotate_and_sample_extraction_64(
|
||||
void *v_stream, void *lwe_out, void *ggsw_in, void *lut_vector,
|
||||
uint32_t mbr_size, uint32_t tau, uint32_t glwe_dimension,
|
||||
void *v_stream, uint32_t gpu_index, void *lwe_out, void *ggsw_in,
|
||||
void *lut_vector, uint32_t mbr_size, uint32_t tau, uint32_t glwe_dimension,
|
||||
uint32_t polynomial_size, uint32_t base_log, uint32_t l_gadget,
|
||||
uint32_t max_shared_memory);
|
||||
|
||||
void cuda_extract_bits_32(
|
||||
void *v_stream, void *list_lwe_array_out, void *lwe_array_in,
|
||||
void *lwe_array_in_buffer, void *lwe_array_in_shifted_buffer,
|
||||
void *lwe_array_out_ks_buffer, void *lwe_array_out_pbs_buffer,
|
||||
void *lut_pbs, void *lut_vector_indexes, void *ksk, void *fourier_bsk,
|
||||
uint32_t number_of_bits, uint32_t delta_log, uint32_t lwe_dimension_in,
|
||||
uint32_t lwe_dimension_out, uint32_t glwe_dimension, uint32_t base_log_bsk,
|
||||
uint32_t level_count_bsk, uint32_t base_log_ksk, uint32_t level_count_ksk,
|
||||
void *v_stream, uint32_t gpu_index, void *list_lwe_array_out,
|
||||
void *lwe_array_in, void *lwe_array_in_buffer,
|
||||
void *lwe_array_in_shifted_buffer, void *lwe_array_out_ks_buffer,
|
||||
void *lwe_array_out_pbs_buffer, void *lut_pbs, void *lut_vector_indexes,
|
||||
void *ksk, void *fourier_bsk, uint32_t number_of_bits, uint32_t delta_log,
|
||||
uint32_t lwe_dimension_in, uint32_t lwe_dimension_out,
|
||||
uint32_t glwe_dimension, uint32_t base_log_bsk, uint32_t level_count_bsk,
|
||||
uint32_t base_log_ksk, uint32_t level_count_ksk,
|
||||
uint32_t number_of_samples);
|
||||
|
||||
void cuda_extract_bits_64(
|
||||
void *v_stream, void *list_lwe_array_out, void *lwe_array_in,
|
||||
void *lwe_array_in_buffer, void *lwe_array_in_shifted_buffer,
|
||||
void *lwe_array_out_ks_buffer, void *lwe_array_out_pbs_buffer,
|
||||
void *lut_pbs, void *lut_vector_indexes, void *ksk, void *fourier_bsk,
|
||||
uint32_t number_of_bits, uint32_t delta_log, uint32_t lwe_dimension_in,
|
||||
uint32_t lwe_dimension_out, uint32_t glwe_dimension, uint32_t base_log_bsk,
|
||||
uint32_t level_count_bsk, uint32_t base_log_ksk, uint32_t level_count_ksk,
|
||||
void *v_stream, uint32_t gpu_index, void *list_lwe_array_out,
|
||||
void *lwe_array_in, void *lwe_array_in_buffer,
|
||||
void *lwe_array_in_shifted_buffer, void *lwe_array_out_ks_buffer,
|
||||
void *lwe_array_out_pbs_buffer, void *lut_pbs, void *lut_vector_indexes,
|
||||
void *ksk, void *fourier_bsk, uint32_t number_of_bits, uint32_t delta_log,
|
||||
uint32_t lwe_dimension_in, uint32_t lwe_dimension_out,
|
||||
uint32_t glwe_dimension, uint32_t base_log_bsk, uint32_t level_count_bsk,
|
||||
uint32_t base_log_ksk, uint32_t level_count_ksk,
|
||||
uint32_t number_of_samples);
|
||||
}
|
||||
|
||||
|
||||
@@ -6,14 +6,14 @@
|
||||
extern "C" {
|
||||
|
||||
void cuda_keyswitch_lwe_ciphertext_vector_32(
|
||||
void *v_stream, void *lwe_array_out, void *lwe_array_in, void *ksk,
|
||||
uint32_t lwe_dimension_in, uint32_t lwe_dimension_out, uint32_t base_log,
|
||||
uint32_t level_count, uint32_t num_samples);
|
||||
void *v_stream, uint32_t gpu_index, void *lwe_array_out, void *lwe_array_in,
|
||||
void *ksk, uint32_t lwe_dimension_in, uint32_t lwe_dimension_out,
|
||||
uint32_t base_log, uint32_t level_count, uint32_t num_samples);
|
||||
|
||||
void cuda_keyswitch_lwe_ciphertext_vector_64(
|
||||
void *v_stream, void *lwe_array_out, void *lwe_array_in, void *ksk,
|
||||
uint32_t lwe_dimension_in, uint32_t lwe_dimension_out, uint32_t base_log,
|
||||
uint32_t level_count, uint32_t num_samples);
|
||||
void *v_stream, uint32_t gpu_index, void *lwe_array_out, void *lwe_array_in,
|
||||
void *ksk, uint32_t lwe_dimension_in, uint32_t lwe_dimension_out,
|
||||
uint32_t base_log, uint32_t level_count, uint32_t num_samples);
|
||||
}
|
||||
|
||||
#endif // CNCRT_KS_H_
|
||||
|
||||
@@ -58,7 +58,7 @@
|
||||
*/
|
||||
|
||||
void cuda_bootstrap_amortized_lwe_ciphertext_vector_32(
|
||||
void *v_stream, void *lwe_array_out, void *lut_vector,
|
||||
void *v_stream, uint32_t gpu_index, void *lwe_array_out, void *lut_vector,
|
||||
void *lut_vector_indexes, void *lwe_array_in, void *bootstrapping_key,
|
||||
uint32_t lwe_dimension, uint32_t glwe_dimension, uint32_t polynomial_size,
|
||||
uint32_t base_log, uint32_t level_count, uint32_t num_samples,
|
||||
@@ -77,35 +77,35 @@ void cuda_bootstrap_amortized_lwe_ciphertext_vector_32(
|
||||
switch (polynomial_size) {
|
||||
case 512:
|
||||
host_bootstrap_amortized<uint32_t, Degree<512>>(
|
||||
v_stream, (uint32_t *)lwe_array_out, (uint32_t *)lut_vector,
|
||||
v_stream, gpu_index, (uint32_t *)lwe_array_out, (uint32_t *)lut_vector,
|
||||
(uint32_t *)lut_vector_indexes, (uint32_t *)lwe_array_in,
|
||||
(double2 *)bootstrapping_key, lwe_dimension, polynomial_size, base_log,
|
||||
level_count, num_samples, num_lut_vectors, lwe_idx, max_shared_memory);
|
||||
break;
|
||||
case 1024:
|
||||
host_bootstrap_amortized<uint32_t, Degree<1024>>(
|
||||
v_stream, (uint32_t *)lwe_array_out, (uint32_t *)lut_vector,
|
||||
v_stream, gpu_index, (uint32_t *)lwe_array_out, (uint32_t *)lut_vector,
|
||||
(uint32_t *)lut_vector_indexes, (uint32_t *)lwe_array_in,
|
||||
(double2 *)bootstrapping_key, lwe_dimension, polynomial_size, base_log,
|
||||
level_count, num_samples, num_lut_vectors, lwe_idx, max_shared_memory);
|
||||
break;
|
||||
case 2048:
|
||||
host_bootstrap_amortized<uint32_t, Degree<2048>>(
|
||||
v_stream, (uint32_t *)lwe_array_out, (uint32_t *)lut_vector,
|
||||
v_stream, gpu_index, (uint32_t *)lwe_array_out, (uint32_t *)lut_vector,
|
||||
(uint32_t *)lut_vector_indexes, (uint32_t *)lwe_array_in,
|
||||
(double2 *)bootstrapping_key, lwe_dimension, polynomial_size, base_log,
|
||||
level_count, num_samples, num_lut_vectors, lwe_idx, max_shared_memory);
|
||||
break;
|
||||
case 4096:
|
||||
host_bootstrap_amortized<uint32_t, Degree<4096>>(
|
||||
v_stream, (uint32_t *)lwe_array_out, (uint32_t *)lut_vector,
|
||||
v_stream, gpu_index, (uint32_t *)lwe_array_out, (uint32_t *)lut_vector,
|
||||
(uint32_t *)lut_vector_indexes, (uint32_t *)lwe_array_in,
|
||||
(double2 *)bootstrapping_key, lwe_dimension, polynomial_size, base_log,
|
||||
level_count, num_samples, num_lut_vectors, lwe_idx, max_shared_memory);
|
||||
break;
|
||||
case 8192:
|
||||
host_bootstrap_amortized<uint32_t, Degree<8192>>(
|
||||
v_stream, (uint32_t *)lwe_array_out, (uint32_t *)lut_vector,
|
||||
v_stream, gpu_index, (uint32_t *)lwe_array_out, (uint32_t *)lut_vector,
|
||||
(uint32_t *)lut_vector_indexes, (uint32_t *)lwe_array_in,
|
||||
(double2 *)bootstrapping_key, lwe_dimension, polynomial_size, base_log,
|
||||
level_count, num_samples, num_lut_vectors, lwe_idx, max_shared_memory);
|
||||
@@ -116,7 +116,7 @@ void cuda_bootstrap_amortized_lwe_ciphertext_vector_32(
|
||||
}
|
||||
|
||||
void cuda_bootstrap_amortized_lwe_ciphertext_vector_64(
|
||||
void *v_stream, void *lwe_array_out, void *lut_vector,
|
||||
void *v_stream, uint32_t gpu_index, void *lwe_array_out, void *lut_vector,
|
||||
void *lut_vector_indexes, void *lwe_array_in, void *bootstrapping_key,
|
||||
uint32_t lwe_dimension, uint32_t glwe_dimension, uint32_t polynomial_size,
|
||||
uint32_t base_log, uint32_t level_count, uint32_t num_samples,
|
||||
@@ -135,35 +135,35 @@ void cuda_bootstrap_amortized_lwe_ciphertext_vector_64(
|
||||
switch (polynomial_size) {
|
||||
case 512:
|
||||
host_bootstrap_amortized<uint64_t, Degree<512>>(
|
||||
v_stream, (uint64_t *)lwe_array_out, (uint64_t *)lut_vector,
|
||||
v_stream, gpu_index, (uint64_t *)lwe_array_out, (uint64_t *)lut_vector,
|
||||
(uint32_t *)lut_vector_indexes, (uint64_t *)lwe_array_in,
|
||||
(double2 *)bootstrapping_key, lwe_dimension, polynomial_size, base_log,
|
||||
level_count, num_samples, num_lut_vectors, lwe_idx, max_shared_memory);
|
||||
break;
|
||||
case 1024:
|
||||
host_bootstrap_amortized<uint64_t, Degree<1024>>(
|
||||
v_stream, (uint64_t *)lwe_array_out, (uint64_t *)lut_vector,
|
||||
v_stream, gpu_index, (uint64_t *)lwe_array_out, (uint64_t *)lut_vector,
|
||||
(uint32_t *)lut_vector_indexes, (uint64_t *)lwe_array_in,
|
||||
(double2 *)bootstrapping_key, lwe_dimension, polynomial_size, base_log,
|
||||
level_count, num_samples, num_lut_vectors, lwe_idx, max_shared_memory);
|
||||
break;
|
||||
case 2048:
|
||||
host_bootstrap_amortized<uint64_t, Degree<2048>>(
|
||||
v_stream, (uint64_t *)lwe_array_out, (uint64_t *)lut_vector,
|
||||
v_stream, gpu_index, (uint64_t *)lwe_array_out, (uint64_t *)lut_vector,
|
||||
(uint32_t *)lut_vector_indexes, (uint64_t *)lwe_array_in,
|
||||
(double2 *)bootstrapping_key, lwe_dimension, polynomial_size, base_log,
|
||||
level_count, num_samples, num_lut_vectors, lwe_idx, max_shared_memory);
|
||||
break;
|
||||
case 4096:
|
||||
host_bootstrap_amortized<uint64_t, Degree<4096>>(
|
||||
v_stream, (uint64_t *)lwe_array_out, (uint64_t *)lut_vector,
|
||||
v_stream, gpu_index, (uint64_t *)lwe_array_out, (uint64_t *)lut_vector,
|
||||
(uint32_t *)lut_vector_indexes, (uint64_t *)lwe_array_in,
|
||||
(double2 *)bootstrapping_key, lwe_dimension, polynomial_size, base_log,
|
||||
level_count, num_samples, num_lut_vectors, lwe_idx, max_shared_memory);
|
||||
break;
|
||||
case 8192:
|
||||
host_bootstrap_amortized<uint64_t, Degree<8192>>(
|
||||
v_stream, (uint64_t *)lwe_array_out, (uint64_t *)lut_vector,
|
||||
v_stream, gpu_index, (uint64_t *)lwe_array_out, (uint64_t *)lut_vector,
|
||||
(uint32_t *)lut_vector_indexes, (uint64_t *)lwe_array_in,
|
||||
(double2 *)bootstrapping_key, lwe_dimension, polynomial_size, base_log,
|
||||
level_count, num_samples, num_lut_vectors, lwe_idx, max_shared_memory);
|
||||
|
||||
@@ -277,15 +277,13 @@ __global__ void device_bootstrap_amortized(
|
||||
|
||||
template <typename Torus, class params>
|
||||
__host__ void host_bootstrap_amortized(
|
||||
void *v_stream, Torus *lwe_array_out, Torus *lut_vector,
|
||||
void *v_stream, uint32_t gpu_index, Torus *lwe_array_out, Torus *lut_vector,
|
||||
uint32_t *lut_vector_indexes, Torus *lwe_array_in,
|
||||
double2 *bootstrapping_key, uint32_t input_lwe_dimension,
|
||||
uint32_t polynomial_size, uint32_t base_log, uint32_t level_count,
|
||||
uint32_t input_lwe_ciphertext_count, uint32_t num_lut_vectors,
|
||||
uint32_t lwe_idx, uint32_t max_shared_memory) {
|
||||
|
||||
uint32_t gpu_index = 0;
|
||||
|
||||
int SM_FULL = sizeof(Torus) * polynomial_size + // accumulator mask
|
||||
sizeof(Torus) * polynomial_size + // accumulator body
|
||||
sizeof(Torus) * polynomial_size + // accumulator mask rotated
|
||||
|
||||
@@ -57,7 +57,7 @@
|
||||
* values for the FFT
|
||||
*/
|
||||
void cuda_bootstrap_low_latency_lwe_ciphertext_vector_32(
|
||||
void *v_stream, void *lwe_array_out, void *lut_vector,
|
||||
void *v_stream, uint32_t gpu_index, void *lwe_array_out, void *lut_vector,
|
||||
void *lut_vector_indexes, void *lwe_array_in, void *bootstrapping_key,
|
||||
uint32_t lwe_dimension, uint32_t glwe_dimension, uint32_t polynomial_size,
|
||||
uint32_t base_log, uint32_t level_count, uint32_t num_samples,
|
||||
@@ -85,21 +85,21 @@ void cuda_bootstrap_low_latency_lwe_ciphertext_vector_32(
|
||||
switch (polynomial_size) {
|
||||
case 512:
|
||||
host_bootstrap_low_latency<uint32_t, Degree<512>>(
|
||||
v_stream, (uint32_t *)lwe_array_out, (uint32_t *)lut_vector,
|
||||
v_stream, gpu_index, (uint32_t *)lwe_array_out, (uint32_t *)lut_vector,
|
||||
(uint32_t *)lut_vector_indexes, (uint32_t *)lwe_array_in,
|
||||
(double2 *)bootstrapping_key, lwe_dimension, polynomial_size, base_log,
|
||||
level_count, num_samples, num_lut_vectors);
|
||||
break;
|
||||
case 1024:
|
||||
host_bootstrap_low_latency<uint32_t, Degree<1024>>(
|
||||
v_stream, (uint32_t *)lwe_array_out, (uint32_t *)lut_vector,
|
||||
v_stream, gpu_index, (uint32_t *)lwe_array_out, (uint32_t *)lut_vector,
|
||||
(uint32_t *)lut_vector_indexes, (uint32_t *)lwe_array_in,
|
||||
(double2 *)bootstrapping_key, lwe_dimension, polynomial_size, base_log,
|
||||
level_count, num_samples, num_lut_vectors);
|
||||
break;
|
||||
case 2048:
|
||||
host_bootstrap_low_latency<uint32_t, Degree<2048>>(
|
||||
v_stream, (uint32_t *)lwe_array_out, (uint32_t *)lut_vector,
|
||||
v_stream, gpu_index, (uint32_t *)lwe_array_out, (uint32_t *)lut_vector,
|
||||
(uint32_t *)lut_vector_indexes, (uint32_t *)lwe_array_in,
|
||||
(double2 *)bootstrapping_key, lwe_dimension, polynomial_size, base_log,
|
||||
level_count, num_samples, num_lut_vectors);
|
||||
@@ -110,7 +110,7 @@ void cuda_bootstrap_low_latency_lwe_ciphertext_vector_32(
|
||||
}
|
||||
|
||||
void cuda_bootstrap_low_latency_lwe_ciphertext_vector_64(
|
||||
void *v_stream, void *lwe_array_out, void *lut_vector,
|
||||
void *v_stream, uint32_t gpu_index, void *lwe_array_out, void *lut_vector,
|
||||
void *lut_vector_indexes, void *lwe_array_in, void *bootstrapping_key,
|
||||
uint32_t lwe_dimension, uint32_t glwe_dimension, uint32_t polynomial_size,
|
||||
uint32_t base_log, uint32_t level_count, uint32_t num_samples,
|
||||
@@ -138,21 +138,21 @@ void cuda_bootstrap_low_latency_lwe_ciphertext_vector_64(
|
||||
switch (polynomial_size) {
|
||||
case 512:
|
||||
host_bootstrap_low_latency<uint64_t, Degree<512>>(
|
||||
v_stream, (uint64_t *)lwe_array_out, (uint64_t *)lut_vector,
|
||||
v_stream, gpu_index, (uint64_t *)lwe_array_out, (uint64_t *)lut_vector,
|
||||
(uint32_t *)lut_vector_indexes, (uint64_t *)lwe_array_in,
|
||||
(double2 *)bootstrapping_key, lwe_dimension, polynomial_size, base_log,
|
||||
level_count, num_samples, num_lut_vectors);
|
||||
break;
|
||||
case 1024:
|
||||
host_bootstrap_low_latency<uint64_t, Degree<1024>>(
|
||||
v_stream, (uint64_t *)lwe_array_out, (uint64_t *)lut_vector,
|
||||
v_stream, gpu_index, (uint64_t *)lwe_array_out, (uint64_t *)lut_vector,
|
||||
(uint32_t *)lut_vector_indexes, (uint64_t *)lwe_array_in,
|
||||
(double2 *)bootstrapping_key, lwe_dimension, polynomial_size, base_log,
|
||||
level_count, num_samples, num_lut_vectors);
|
||||
break;
|
||||
case 2048:
|
||||
host_bootstrap_low_latency<uint64_t, Degree<2048>>(
|
||||
v_stream, (uint64_t *)lwe_array_out, (uint64_t *)lut_vector,
|
||||
v_stream, gpu_index, (uint64_t *)lwe_array_out, (uint64_t *)lut_vector,
|
||||
(uint32_t *)lut_vector_indexes, (uint64_t *)lwe_array_in,
|
||||
(double2 *)bootstrapping_key, lwe_dimension, polynomial_size, base_log,
|
||||
level_count, num_samples, num_lut_vectors);
|
||||
|
||||
@@ -246,15 +246,12 @@ __global__ void device_bootstrap_low_latency(
|
||||
* of bootstrapping
|
||||
*/
|
||||
template <typename Torus, class params>
|
||||
__host__ void
|
||||
host_bootstrap_low_latency(void *v_stream, Torus *lwe_array_out,
|
||||
Torus *lut_vector, uint32_t *lut_vector_indexes,
|
||||
Torus *lwe_array_in, double2 *bootstrapping_key,
|
||||
uint32_t lwe_dimension, uint32_t polynomial_size,
|
||||
uint32_t base_log, uint32_t level_count,
|
||||
uint32_t num_samples, uint32_t num_lut_vectors) {
|
||||
|
||||
uint32_t gpu_index = 0;
|
||||
__host__ void host_bootstrap_low_latency(
|
||||
void *v_stream, uint32_t gpu_index, Torus *lwe_array_out, Torus *lut_vector,
|
||||
uint32_t *lut_vector_indexes, Torus *lwe_array_in,
|
||||
double2 *bootstrapping_key, uint32_t lwe_dimension,
|
||||
uint32_t polynomial_size, uint32_t base_log, uint32_t level_count,
|
||||
uint32_t num_samples, uint32_t num_lut_vectors) {
|
||||
|
||||
auto stream = static_cast<cudaStream_t *>(v_stream);
|
||||
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
#include "bootstrap_wop.cuh"
|
||||
|
||||
void cuda_cmux_tree_32(void *v_stream, void *glwe_array_out, void *ggsw_in,
|
||||
void *lut_vector, uint32_t glwe_dimension,
|
||||
void cuda_cmux_tree_32(void *v_stream, uint32_t gpu_index, void *glwe_array_out,
|
||||
void *ggsw_in, void *lut_vector, uint32_t glwe_dimension,
|
||||
uint32_t polynomial_size, uint32_t base_log,
|
||||
uint32_t level_count, uint32_t r,
|
||||
uint32_t max_shared_memory) {
|
||||
@@ -22,31 +22,31 @@ void cuda_cmux_tree_32(void *v_stream, void *glwe_array_out, void *ggsw_in,
|
||||
switch (polynomial_size) {
|
||||
case 512:
|
||||
host_cmux_tree<uint32_t, int32_t, Degree<512>>(
|
||||
v_stream, (uint32_t *)glwe_array_out, (uint32_t *)ggsw_in,
|
||||
v_stream, gpu_index, (uint32_t *)glwe_array_out, (uint32_t *)ggsw_in,
|
||||
(uint32_t *)lut_vector, glwe_dimension, polynomial_size, base_log,
|
||||
level_count, r, max_shared_memory);
|
||||
break;
|
||||
case 1024:
|
||||
host_cmux_tree<uint32_t, int32_t, Degree<1024>>(
|
||||
v_stream, (uint32_t *)glwe_array_out, (uint32_t *)ggsw_in,
|
||||
v_stream, gpu_index, (uint32_t *)glwe_array_out, (uint32_t *)ggsw_in,
|
||||
(uint32_t *)lut_vector, glwe_dimension, polynomial_size, base_log,
|
||||
level_count, r, max_shared_memory);
|
||||
break;
|
||||
case 2048:
|
||||
host_cmux_tree<uint32_t, int32_t, Degree<2048>>(
|
||||
v_stream, (uint32_t *)glwe_array_out, (uint32_t *)ggsw_in,
|
||||
v_stream, gpu_index, (uint32_t *)glwe_array_out, (uint32_t *)ggsw_in,
|
||||
(uint32_t *)lut_vector, glwe_dimension, polynomial_size, base_log,
|
||||
level_count, r, max_shared_memory);
|
||||
break;
|
||||
case 4096:
|
||||
host_cmux_tree<uint32_t, int32_t, Degree<4096>>(
|
||||
v_stream, (uint32_t *)glwe_array_out, (uint32_t *)ggsw_in,
|
||||
v_stream, gpu_index, (uint32_t *)glwe_array_out, (uint32_t *)ggsw_in,
|
||||
(uint32_t *)lut_vector, glwe_dimension, polynomial_size, base_log,
|
||||
level_count, r, max_shared_memory);
|
||||
break;
|
||||
case 8192:
|
||||
host_cmux_tree<uint32_t, int32_t, Degree<8192>>(
|
||||
v_stream, (uint32_t *)glwe_array_out, (uint32_t *)ggsw_in,
|
||||
v_stream, gpu_index, (uint32_t *)glwe_array_out, (uint32_t *)ggsw_in,
|
||||
(uint32_t *)lut_vector, glwe_dimension, polynomial_size, base_log,
|
||||
level_count, r, max_shared_memory);
|
||||
break;
|
||||
@@ -55,8 +55,8 @@ void cuda_cmux_tree_32(void *v_stream, void *glwe_array_out, void *ggsw_in,
|
||||
}
|
||||
}
|
||||
|
||||
void cuda_cmux_tree_64(void *v_stream, void *glwe_array_out, void *ggsw_in,
|
||||
void *lut_vector, uint32_t glwe_dimension,
|
||||
void cuda_cmux_tree_64(void *v_stream, uint32_t gpu_index, void *glwe_array_out,
|
||||
void *ggsw_in, void *lut_vector, uint32_t glwe_dimension,
|
||||
uint32_t polynomial_size, uint32_t base_log,
|
||||
uint32_t level_count, uint32_t r,
|
||||
uint32_t max_shared_memory) {
|
||||
@@ -77,31 +77,31 @@ void cuda_cmux_tree_64(void *v_stream, void *glwe_array_out, void *ggsw_in,
|
||||
switch (polynomial_size) {
|
||||
case 512:
|
||||
host_cmux_tree<uint64_t, int64_t, Degree<512>>(
|
||||
v_stream, (uint64_t *)glwe_array_out, (uint64_t *)ggsw_in,
|
||||
v_stream, gpu_index, (uint64_t *)glwe_array_out, (uint64_t *)ggsw_in,
|
||||
(uint64_t *)lut_vector, glwe_dimension, polynomial_size, base_log,
|
||||
level_count, r, max_shared_memory);
|
||||
break;
|
||||
case 1024:
|
||||
host_cmux_tree<uint64_t, int64_t, Degree<1024>>(
|
||||
v_stream, (uint64_t *)glwe_array_out, (uint64_t *)ggsw_in,
|
||||
v_stream, gpu_index, (uint64_t *)glwe_array_out, (uint64_t *)ggsw_in,
|
||||
(uint64_t *)lut_vector, glwe_dimension, polynomial_size, base_log,
|
||||
level_count, r, max_shared_memory);
|
||||
break;
|
||||
case 2048:
|
||||
host_cmux_tree<uint64_t, int64_t, Degree<2048>>(
|
||||
v_stream, (uint64_t *)glwe_array_out, (uint64_t *)ggsw_in,
|
||||
v_stream, gpu_index, (uint64_t *)glwe_array_out, (uint64_t *)ggsw_in,
|
||||
(uint64_t *)lut_vector, glwe_dimension, polynomial_size, base_log,
|
||||
level_count, r, max_shared_memory);
|
||||
break;
|
||||
case 4096:
|
||||
host_cmux_tree<uint64_t, int64_t, Degree<4096>>(
|
||||
v_stream, (uint64_t *)glwe_array_out, (uint64_t *)ggsw_in,
|
||||
v_stream, gpu_index, (uint64_t *)glwe_array_out, (uint64_t *)ggsw_in,
|
||||
(uint64_t *)lut_vector, glwe_dimension, polynomial_size, base_log,
|
||||
level_count, r, max_shared_memory);
|
||||
break;
|
||||
case 8192:
|
||||
host_cmux_tree<uint64_t, int64_t, Degree<8192>>(
|
||||
v_stream, (uint64_t *)glwe_array_out, (uint64_t *)ggsw_in,
|
||||
v_stream, gpu_index, (uint64_t *)glwe_array_out, (uint64_t *)ggsw_in,
|
||||
(uint64_t *)lut_vector, glwe_dimension, polynomial_size, base_log,
|
||||
level_count, r, max_shared_memory);
|
||||
break;
|
||||
@@ -111,13 +111,14 @@ void cuda_cmux_tree_64(void *v_stream, void *glwe_array_out, void *ggsw_in,
|
||||
}
|
||||
|
||||
void cuda_extract_bits_32(
|
||||
void *v_stream, void *list_lwe_array_out, void *lwe_array_in,
|
||||
void *lwe_array_in_buffer, void *lwe_array_in_shifted_buffer,
|
||||
void *lwe_array_out_ks_buffer, void *lwe_array_out_pbs_buffer,
|
||||
void *lut_pbs, void *lut_vector_indexes, void *ksk, void *fourier_bsk,
|
||||
uint32_t number_of_bits, uint32_t delta_log, uint32_t lwe_dimension_in,
|
||||
uint32_t lwe_dimension_out, uint32_t glwe_dimension, uint32_t base_log_bsk,
|
||||
uint32_t level_count_bsk, uint32_t base_log_ksk, uint32_t level_count_ksk,
|
||||
void *v_stream, uint32_t gpu_index, void *list_lwe_array_out,
|
||||
void *lwe_array_in, void *lwe_array_in_buffer,
|
||||
void *lwe_array_in_shifted_buffer, void *lwe_array_out_ks_buffer,
|
||||
void *lwe_array_out_pbs_buffer, void *lut_pbs, void *lut_vector_indexes,
|
||||
void *ksk, void *fourier_bsk, uint32_t number_of_bits, uint32_t delta_log,
|
||||
uint32_t lwe_dimension_in, uint32_t lwe_dimension_out,
|
||||
uint32_t glwe_dimension, uint32_t base_log_bsk, uint32_t level_count_bsk,
|
||||
uint32_t base_log_ksk, uint32_t level_count_ksk,
|
||||
uint32_t number_of_samples) {
|
||||
assert(("Error (GPU extract bits): base log should be <= 32",
|
||||
base_log_bsk <= 32));
|
||||
@@ -142,8 +143,8 @@ void cuda_extract_bits_32(
|
||||
switch (lwe_dimension_in) {
|
||||
case 512:
|
||||
host_extract_bits<uint32_t, Degree<512>>(
|
||||
v_stream, (uint32_t *)list_lwe_array_out, (uint32_t *)lwe_array_in,
|
||||
(uint32_t *)lwe_array_in_buffer,
|
||||
v_stream, gpu_index, (uint32_t *)list_lwe_array_out,
|
||||
(uint32_t *)lwe_array_in, (uint32_t *)lwe_array_in_buffer,
|
||||
(uint32_t *)lwe_array_in_shifted_buffer,
|
||||
(uint32_t *)lwe_array_out_ks_buffer,
|
||||
(uint32_t *)lwe_array_out_pbs_buffer, (uint32_t *)lut_pbs,
|
||||
@@ -154,8 +155,8 @@ void cuda_extract_bits_32(
|
||||
break;
|
||||
case 1024:
|
||||
host_extract_bits<uint32_t, Degree<1024>>(
|
||||
v_stream, (uint32_t *)list_lwe_array_out, (uint32_t *)lwe_array_in,
|
||||
(uint32_t *)lwe_array_in_buffer,
|
||||
v_stream, gpu_index, (uint32_t *)list_lwe_array_out,
|
||||
(uint32_t *)lwe_array_in, (uint32_t *)lwe_array_in_buffer,
|
||||
(uint32_t *)lwe_array_in_shifted_buffer,
|
||||
(uint32_t *)lwe_array_out_ks_buffer,
|
||||
(uint32_t *)lwe_array_out_pbs_buffer, (uint32_t *)lut_pbs,
|
||||
@@ -166,8 +167,8 @@ void cuda_extract_bits_32(
|
||||
break;
|
||||
case 2048:
|
||||
host_extract_bits<uint32_t, Degree<2048>>(
|
||||
v_stream, (uint32_t *)list_lwe_array_out, (uint32_t *)lwe_array_in,
|
||||
(uint32_t *)lwe_array_in_buffer,
|
||||
v_stream, gpu_index, (uint32_t *)list_lwe_array_out,
|
||||
(uint32_t *)lwe_array_in, (uint32_t *)lwe_array_in_buffer,
|
||||
(uint32_t *)lwe_array_in_shifted_buffer,
|
||||
(uint32_t *)lwe_array_out_ks_buffer,
|
||||
(uint32_t *)lwe_array_out_pbs_buffer, (uint32_t *)lut_pbs,
|
||||
@@ -182,13 +183,14 @@ void cuda_extract_bits_32(
|
||||
}
|
||||
|
||||
void cuda_extract_bits_64(
|
||||
void *v_stream, void *list_lwe_array_out, void *lwe_array_in,
|
||||
void *lwe_array_in_buffer, void *lwe_array_in_shifted_buffer,
|
||||
void *lwe_array_out_ks_buffer, void *lwe_array_out_pbs_buffer,
|
||||
void *lut_pbs, void *lut_vector_indexes, void *ksk, void *fourier_bsk,
|
||||
uint32_t number_of_bits, uint32_t delta_log, uint32_t lwe_dimension_in,
|
||||
uint32_t lwe_dimension_out, uint32_t glwe_dimension, uint32_t base_log_bsk,
|
||||
uint32_t level_count_bsk, uint32_t base_log_ksk, uint32_t level_count_ksk,
|
||||
void *v_stream, uint32_t gpu_index, void *list_lwe_array_out,
|
||||
void *lwe_array_in, void *lwe_array_in_buffer,
|
||||
void *lwe_array_in_shifted_buffer, void *lwe_array_out_ks_buffer,
|
||||
void *lwe_array_out_pbs_buffer, void *lut_pbs, void *lut_vector_indexes,
|
||||
void *ksk, void *fourier_bsk, uint32_t number_of_bits, uint32_t delta_log,
|
||||
uint32_t lwe_dimension_in, uint32_t lwe_dimension_out,
|
||||
uint32_t glwe_dimension, uint32_t base_log_bsk, uint32_t level_count_bsk,
|
||||
uint32_t base_log_ksk, uint32_t level_count_ksk,
|
||||
uint32_t number_of_samples) {
|
||||
assert(("Error (GPU extract bits): base log should be <= 64",
|
||||
base_log_bsk <= 64));
|
||||
@@ -213,8 +215,8 @@ void cuda_extract_bits_64(
|
||||
switch (lwe_dimension_in) {
|
||||
case 512:
|
||||
host_extract_bits<uint64_t, Degree<512>>(
|
||||
v_stream, (uint64_t *)list_lwe_array_out, (uint64_t *)lwe_array_in,
|
||||
(uint64_t *)lwe_array_in_buffer,
|
||||
v_stream, gpu_index, (uint64_t *)list_lwe_array_out,
|
||||
(uint64_t *)lwe_array_in, (uint64_t *)lwe_array_in_buffer,
|
||||
(uint64_t *)lwe_array_in_shifted_buffer,
|
||||
(uint64_t *)lwe_array_out_ks_buffer,
|
||||
(uint64_t *)lwe_array_out_pbs_buffer, (uint64_t *)lut_pbs,
|
||||
@@ -225,8 +227,8 @@ void cuda_extract_bits_64(
|
||||
break;
|
||||
case 1024:
|
||||
host_extract_bits<uint64_t, Degree<1024>>(
|
||||
v_stream, (uint64_t *)list_lwe_array_out, (uint64_t *)lwe_array_in,
|
||||
(uint64_t *)lwe_array_in_buffer,
|
||||
v_stream, gpu_index, (uint64_t *)list_lwe_array_out,
|
||||
(uint64_t *)lwe_array_in, (uint64_t *)lwe_array_in_buffer,
|
||||
(uint64_t *)lwe_array_in_shifted_buffer,
|
||||
(uint64_t *)lwe_array_out_ks_buffer,
|
||||
(uint64_t *)lwe_array_out_pbs_buffer, (uint64_t *)lut_pbs,
|
||||
@@ -237,8 +239,8 @@ void cuda_extract_bits_64(
|
||||
break;
|
||||
case 2048:
|
||||
host_extract_bits<uint64_t, Degree<2048>>(
|
||||
v_stream, (uint64_t *)list_lwe_array_out, (uint64_t *)lwe_array_in,
|
||||
(uint64_t *)lwe_array_in_buffer,
|
||||
v_stream, gpu_index, (uint64_t *)list_lwe_array_out,
|
||||
(uint64_t *)lwe_array_in, (uint64_t *)lwe_array_in_buffer,
|
||||
(uint64_t *)lwe_array_in_shifted_buffer,
|
||||
(uint64_t *)lwe_array_out_ks_buffer,
|
||||
(uint64_t *)lwe_array_out_pbs_buffer, (uint64_t *)lut_pbs,
|
||||
@@ -253,39 +255,39 @@ void cuda_extract_bits_64(
|
||||
}
|
||||
|
||||
void cuda_blind_rotate_and_sample_extraction_64(
|
||||
void *v_stream, void *lwe_out, void *ggsw_in, void *lut_vector,
|
||||
uint32_t mbr_size, uint32_t tau, uint32_t glwe_dimension,
|
||||
void *v_stream, uint32_t gpu_index, void *lwe_out, void *ggsw_in,
|
||||
void *lut_vector, uint32_t mbr_size, uint32_t tau, uint32_t glwe_dimension,
|
||||
uint32_t polynomial_size, uint32_t base_log, uint32_t l_gadget,
|
||||
uint32_t max_shared_memory) {
|
||||
|
||||
switch (polynomial_size) {
|
||||
case 512:
|
||||
host_blind_rotate_and_sample_extraction<uint64_t, int64_t, Degree<512>>(
|
||||
v_stream, (uint64_t *)lwe_out, (uint64_t *)ggsw_in,
|
||||
v_stream, gpu_index, (uint64_t *)lwe_out, (uint64_t *)ggsw_in,
|
||||
(uint64_t *)lut_vector, mbr_size, tau, glwe_dimension, polynomial_size,
|
||||
base_log, l_gadget, max_shared_memory);
|
||||
break;
|
||||
case 1024:
|
||||
host_blind_rotate_and_sample_extraction<uint64_t, int64_t, Degree<1024>>(
|
||||
v_stream, (uint64_t *)lwe_out, (uint64_t *)ggsw_in,
|
||||
v_stream, gpu_index, (uint64_t *)lwe_out, (uint64_t *)ggsw_in,
|
||||
(uint64_t *)lut_vector, mbr_size, tau, glwe_dimension, polynomial_size,
|
||||
base_log, l_gadget, max_shared_memory);
|
||||
break;
|
||||
case 2048:
|
||||
host_blind_rotate_and_sample_extraction<uint64_t, int64_t, Degree<2048>>(
|
||||
v_stream, (uint64_t *)lwe_out, (uint64_t *)ggsw_in,
|
||||
v_stream, gpu_index, (uint64_t *)lwe_out, (uint64_t *)ggsw_in,
|
||||
(uint64_t *)lut_vector, mbr_size, tau, glwe_dimension, polynomial_size,
|
||||
base_log, l_gadget, max_shared_memory);
|
||||
break;
|
||||
case 4096:
|
||||
host_blind_rotate_and_sample_extraction<uint64_t, int64_t, Degree<4096>>(
|
||||
v_stream, (uint64_t *)lwe_out, (uint64_t *)ggsw_in,
|
||||
v_stream, gpu_index, (uint64_t *)lwe_out, (uint64_t *)ggsw_in,
|
||||
(uint64_t *)lut_vector, mbr_size, tau, glwe_dimension, polynomial_size,
|
||||
base_log, l_gadget, max_shared_memory);
|
||||
break;
|
||||
case 8192:
|
||||
host_blind_rotate_and_sample_extraction<uint64_t, int64_t, Degree<8192>>(
|
||||
v_stream, (uint64_t *)lwe_out, (uint64_t *)ggsw_in,
|
||||
v_stream, gpu_index, (uint64_t *)lwe_out, (uint64_t *)ggsw_in,
|
||||
(uint64_t *)lut_vector, mbr_size, tau, glwe_dimension, polynomial_size,
|
||||
base_log, l_gadget, max_shared_memory);
|
||||
break;
|
||||
|
||||
@@ -246,13 +246,11 @@ device_batch_cmux(Torus *glwe_array_out, Torus *glwe_array_in, double2 *ggsw_in,
|
||||
* - r: Number of layers in the tree.
|
||||
*/
|
||||
template <typename Torus, typename STorus, class params>
|
||||
void host_cmux_tree(void *v_stream, Torus *glwe_array_out, Torus *ggsw_in,
|
||||
Torus *lut_vector, uint32_t glwe_dimension,
|
||||
void host_cmux_tree(void *v_stream, uint32_t gpu_index, Torus *glwe_array_out,
|
||||
Torus *ggsw_in, Torus *lut_vector, uint32_t glwe_dimension,
|
||||
uint32_t polynomial_size, uint32_t base_log,
|
||||
uint32_t level_count, uint32_t r,
|
||||
uint32_t max_shared_memory) {
|
||||
// This should be refactored to pass the gpu index as a parameter
|
||||
uint32_t gpu_index = 0;
|
||||
|
||||
auto stream = static_cast<cudaStream_t *>(v_stream);
|
||||
int num_lut = (1 << r);
|
||||
@@ -463,14 +461,15 @@ __global__ void add_sub_and_mul_lwe(Torus *shifted_lwe, Torus *state_lwe,
|
||||
|
||||
template <typename Torus, class params>
|
||||
__host__ void host_extract_bits(
|
||||
void *v_stream, Torus *list_lwe_array_out, Torus *lwe_array_in,
|
||||
Torus *lwe_array_in_buffer, Torus *lwe_array_in_shifted_buffer,
|
||||
Torus *lwe_array_out_ks_buffer, Torus *lwe_array_out_pbs_buffer,
|
||||
Torus *lut_pbs, uint32_t *lut_vector_indexes, Torus *ksk,
|
||||
double2 *fourier_bsk, uint32_t number_of_bits, uint32_t delta_log,
|
||||
uint32_t lwe_dimension_in, uint32_t lwe_dimension_out,
|
||||
uint32_t base_log_bsk, uint32_t level_count_bsk, uint32_t base_log_ksk,
|
||||
uint32_t level_count_ksk, uint32_t number_of_samples) {
|
||||
void *v_stream, uint32_t gpu_index, Torus *list_lwe_array_out,
|
||||
Torus *lwe_array_in, Torus *lwe_array_in_buffer,
|
||||
Torus *lwe_array_in_shifted_buffer, Torus *lwe_array_out_ks_buffer,
|
||||
Torus *lwe_array_out_pbs_buffer, Torus *lut_pbs,
|
||||
uint32_t *lut_vector_indexes, Torus *ksk, double2 *fourier_bsk,
|
||||
uint32_t number_of_bits, uint32_t delta_log, uint32_t lwe_dimension_in,
|
||||
uint32_t lwe_dimension_out, uint32_t base_log_bsk, uint32_t level_count_bsk,
|
||||
uint32_t base_log_ksk, uint32_t level_count_ksk,
|
||||
uint32_t number_of_samples) {
|
||||
|
||||
auto stream = static_cast<cudaStream_t *>(v_stream);
|
||||
uint32_t ciphertext_n_bits = sizeof(Torus) * 8;
|
||||
@@ -485,8 +484,9 @@ __host__ void host_extract_bits(
|
||||
|
||||
for (int bit_idx = 0; bit_idx < number_of_bits; bit_idx++) {
|
||||
cuda_keyswitch_lwe_ciphertext_vector(
|
||||
v_stream, lwe_array_out_ks_buffer, lwe_array_in_shifted_buffer, ksk,
|
||||
lwe_dimension_in, lwe_dimension_out, base_log_ksk, level_count_ksk, 1);
|
||||
v_stream, gpu_index, lwe_array_out_ks_buffer,
|
||||
lwe_array_in_shifted_buffer, ksk, lwe_dimension_in, lwe_dimension_out,
|
||||
base_log_ksk, level_count_ksk, 1);
|
||||
|
||||
copy_small_lwe<<<1, 256, 0, *stream>>>(
|
||||
list_lwe_array_out, lwe_array_out_ks_buffer, lwe_dimension_out + 1,
|
||||
@@ -508,9 +508,10 @@ __host__ void host_extract_bits(
|
||||
checkCudaErrors(cudaGetLastError());
|
||||
|
||||
host_bootstrap_low_latency<Torus, params>(
|
||||
v_stream, lwe_array_out_pbs_buffer, lut_pbs, lut_vector_indexes,
|
||||
lwe_array_out_ks_buffer, fourier_bsk, lwe_dimension_out,
|
||||
lwe_dimension_in, base_log_bsk, level_count_bsk, number_of_samples, 1);
|
||||
v_stream, gpu_index, lwe_array_out_pbs_buffer, lut_pbs,
|
||||
lut_vector_indexes, lwe_array_out_ks_buffer, fourier_bsk,
|
||||
lwe_dimension_out, lwe_dimension_in, base_log_bsk, level_count_bsk,
|
||||
number_of_samples, 1);
|
||||
|
||||
add_sub_and_mul_lwe<Torus, params><<<1, threads, 0, *stream>>>(
|
||||
lwe_array_in_shifted_buffer, lwe_array_in_buffer,
|
||||
@@ -609,12 +610,10 @@ __global__ void device_blind_rotation_and_sample_extraction(
|
||||
|
||||
template <typename Torus, typename STorus, class params>
|
||||
void host_blind_rotate_and_sample_extraction(
|
||||
void *v_stream, Torus *lwe_out, Torus *ggsw_in, Torus *lut_vector,
|
||||
uint32_t mbr_size, uint32_t tau, uint32_t glwe_dimension,
|
||||
void *v_stream, uint32_t gpu_index, Torus *lwe_out, Torus *ggsw_in,
|
||||
Torus *lut_vector, uint32_t mbr_size, uint32_t tau, uint32_t glwe_dimension,
|
||||
uint32_t polynomial_size, uint32_t base_log, uint32_t l_gadget,
|
||||
uint32_t max_shared_memory) {
|
||||
// This should be refactored to pass the gpu index as a parameter
|
||||
uint32_t gpu_index = 0;
|
||||
|
||||
assert(glwe_dimension ==
|
||||
1); // For larger k we will need to adjust the mask size
|
||||
|
||||
@@ -15,11 +15,11 @@
|
||||
* - num_samples blocks of threads are launched
|
||||
*/
|
||||
void cuda_keyswitch_lwe_ciphertext_vector_32(
|
||||
void *v_stream, void *lwe_array_out, void *lwe_array_in, void *ksk,
|
||||
uint32_t lwe_dimension_in, uint32_t lwe_dimension_out, uint32_t base_log,
|
||||
uint32_t level_count, uint32_t num_samples) {
|
||||
void *v_stream, uint32_t gpu_index, void *lwe_array_out, void *lwe_array_in,
|
||||
void *ksk, uint32_t lwe_dimension_in, uint32_t lwe_dimension_out,
|
||||
uint32_t base_log, uint32_t level_count, uint32_t num_samples) {
|
||||
cuda_keyswitch_lwe_ciphertext_vector(
|
||||
v_stream, static_cast<uint32_t *>(lwe_array_out),
|
||||
v_stream, gpu_index, static_cast<uint32_t *>(lwe_array_out),
|
||||
static_cast<uint32_t *>(lwe_array_in), static_cast<uint32_t *>(ksk),
|
||||
lwe_dimension_in, lwe_dimension_out, base_log, level_count, num_samples);
|
||||
}
|
||||
@@ -35,11 +35,11 @@ void cuda_keyswitch_lwe_ciphertext_vector_32(
|
||||
* - num_samples blocks of threads are launched
|
||||
*/
|
||||
void cuda_keyswitch_lwe_ciphertext_vector_64(
|
||||
void *v_stream, void *lwe_array_out, void *lwe_array_in, void *ksk,
|
||||
uint32_t lwe_dimension_in, uint32_t lwe_dimension_out, uint32_t base_log,
|
||||
uint32_t level_count, uint32_t num_samples) {
|
||||
void *v_stream, uint32_t gpu_index, void *lwe_array_out, void *lwe_array_in,
|
||||
void *ksk, uint32_t lwe_dimension_in, uint32_t lwe_dimension_out,
|
||||
uint32_t base_log, uint32_t level_count, uint32_t num_samples) {
|
||||
cuda_keyswitch_lwe_ciphertext_vector(
|
||||
v_stream, static_cast<uint64_t *>(lwe_array_out),
|
||||
v_stream, gpu_index, static_cast<uint64_t *>(lwe_array_out),
|
||||
static_cast<uint64_t *>(lwe_array_in), static_cast<uint64_t *>(ksk),
|
||||
lwe_dimension_in, lwe_dimension_out, base_log, level_count, num_samples);
|
||||
}
|
||||
|
||||
@@ -96,9 +96,10 @@ __global__ void keyswitch(Torus *lwe_array_out, Torus *lwe_array_in, Torus *ksk,
|
||||
/// assume lwe_array_in in the gpu
|
||||
template <typename Torus>
|
||||
__host__ void cuda_keyswitch_lwe_ciphertext_vector(
|
||||
void *v_stream, Torus *lwe_array_out, Torus *lwe_array_in, Torus *ksk,
|
||||
uint32_t lwe_dimension_in, uint32_t lwe_dimension_out, uint32_t base_log,
|
||||
uint32_t level_count, uint32_t num_samples) {
|
||||
void *v_stream, uint32_t gpu_index, Torus *lwe_array_out,
|
||||
Torus *lwe_array_in, Torus *ksk, uint32_t lwe_dimension_in,
|
||||
uint32_t lwe_dimension_out, uint32_t base_log, uint32_t level_count,
|
||||
uint32_t num_samples) {
|
||||
|
||||
constexpr int ideal_threads = 128;
|
||||
|
||||
|
||||
Reference in New Issue
Block a user