feat(cuda): Adds a parameter in the CUDA host functions passing the gpu index that should be used.

This commit is contained in:
Pedro Alves
2022-11-24 14:29:03 -03:00
committed by Agnès Leroy
parent f04a29aea4
commit 68866766a4
10 changed files with 140 additions and 141 deletions

View File

@@ -20,69 +20,71 @@ void cuda_convert_lwe_bootstrap_key_64(void *dest, void *src, void *v_stream,
uint32_t polynomial_size);
void cuda_bootstrap_amortized_lwe_ciphertext_vector_32(
void *v_stream, void *lwe_array_out, void *test_vector,
void *v_stream, uint32_t gpu_index, void *lwe_array_out, void *test_vector,
void *test_vector_indexes, void *lwe_array_in, void *bootstrapping_key,
uint32_t lwe_dimension, uint32_t glwe_dimension, uint32_t polynomial_size,
uint32_t base_log, uint32_t level_count, uint32_t num_samples,
uint32_t num_test_vectors, uint32_t lwe_idx, uint32_t max_shared_memory);
void cuda_bootstrap_amortized_lwe_ciphertext_vector_64(
void *v_stream, void *lwe_array_out, void *test_vector,
void *v_stream, uint32_t gpu_index, void *lwe_array_out, void *test_vector,
void *test_vector_indexes, void *lwe_array_in, void *bootstrapping_key,
uint32_t lwe_dimension, uint32_t glwe_dimension, uint32_t polynomial_size,
uint32_t base_log, uint32_t level_count, uint32_t num_samples,
uint32_t num_test_vectors, uint32_t lwe_idx, uint32_t max_shared_memory);
void cuda_bootstrap_low_latency_lwe_ciphertext_vector_32(
void *v_stream, void *lwe_array_out, void *test_vector,
void *v_stream, uint32_t gpu_index, void *lwe_array_out, void *test_vector,
void *test_vector_indexes, void *lwe_array_in, void *bootstrapping_key,
uint32_t lwe_dimension, uint32_t glwe_dimension, uint32_t polynomial_size,
uint32_t base_log, uint32_t level_count, uint32_t num_samples,
uint32_t num_test_vectors, uint32_t lwe_idx, uint32_t max_shared_memory);
void cuda_bootstrap_low_latency_lwe_ciphertext_vector_64(
void *v_stream, void *lwe_array_out, void *test_vector,
void *v_stream, uint32_t gpu_index, void *lwe_array_out, void *test_vector,
void *test_vector_indexes, void *lwe_array_in, void *bootstrapping_key,
uint32_t lwe_dimension, uint32_t glwe_dimension, uint32_t polynomial_size,
uint32_t base_log, uint32_t level_count, uint32_t num_samples,
uint32_t num_test_vectors, uint32_t lwe_idx, uint32_t max_shared_memory);
void cuda_cmux_tree_32(void *v_stream, void *glwe_array_out, void *ggsw_in,
void *lut_vector, uint32_t glwe_dimension,
void cuda_cmux_tree_32(void *v_stream, uint32_t gpu_index, void *glwe_array_out,
void *ggsw_in, void *lut_vector, uint32_t glwe_dimension,
uint32_t polynomial_size, uint32_t base_log,
uint32_t level_count, uint32_t r,
uint32_t max_shared_memory);
void cuda_cmux_tree_64(void *v_stream, void *glwe_array_out, void *ggsw_in,
void *lut_vector, uint32_t glwe_dimension,
void cuda_cmux_tree_64(void *v_stream, uint32_t gpu_index, void *glwe_array_out,
void *ggsw_in, void *lut_vector, uint32_t glwe_dimension,
uint32_t polynomial_size, uint32_t base_log,
uint32_t level_count, uint32_t r,
uint32_t max_shared_memory);
void cuda_blind_rotate_and_sample_extraction_64(
void *v_stream, void *lwe_out, void *ggsw_in, void *lut_vector,
uint32_t mbr_size, uint32_t tau, uint32_t glwe_dimension,
void *v_stream, uint32_t gpu_index, void *lwe_out, void *ggsw_in,
void *lut_vector, uint32_t mbr_size, uint32_t tau, uint32_t glwe_dimension,
uint32_t polynomial_size, uint32_t base_log, uint32_t l_gadget,
uint32_t max_shared_memory);
void cuda_extract_bits_32(
void *v_stream, void *list_lwe_array_out, void *lwe_array_in,
void *lwe_array_in_buffer, void *lwe_array_in_shifted_buffer,
void *lwe_array_out_ks_buffer, void *lwe_array_out_pbs_buffer,
void *lut_pbs, void *lut_vector_indexes, void *ksk, void *fourier_bsk,
uint32_t number_of_bits, uint32_t delta_log, uint32_t lwe_dimension_in,
uint32_t lwe_dimension_out, uint32_t glwe_dimension, uint32_t base_log_bsk,
uint32_t level_count_bsk, uint32_t base_log_ksk, uint32_t level_count_ksk,
void *v_stream, uint32_t gpu_index, void *list_lwe_array_out,
void *lwe_array_in, void *lwe_array_in_buffer,
void *lwe_array_in_shifted_buffer, void *lwe_array_out_ks_buffer,
void *lwe_array_out_pbs_buffer, void *lut_pbs, void *lut_vector_indexes,
void *ksk, void *fourier_bsk, uint32_t number_of_bits, uint32_t delta_log,
uint32_t lwe_dimension_in, uint32_t lwe_dimension_out,
uint32_t glwe_dimension, uint32_t base_log_bsk, uint32_t level_count_bsk,
uint32_t base_log_ksk, uint32_t level_count_ksk,
uint32_t number_of_samples);
void cuda_extract_bits_64(
void *v_stream, void *list_lwe_array_out, void *lwe_array_in,
void *lwe_array_in_buffer, void *lwe_array_in_shifted_buffer,
void *lwe_array_out_ks_buffer, void *lwe_array_out_pbs_buffer,
void *lut_pbs, void *lut_vector_indexes, void *ksk, void *fourier_bsk,
uint32_t number_of_bits, uint32_t delta_log, uint32_t lwe_dimension_in,
uint32_t lwe_dimension_out, uint32_t glwe_dimension, uint32_t base_log_bsk,
uint32_t level_count_bsk, uint32_t base_log_ksk, uint32_t level_count_ksk,
void *v_stream, uint32_t gpu_index, void *list_lwe_array_out,
void *lwe_array_in, void *lwe_array_in_buffer,
void *lwe_array_in_shifted_buffer, void *lwe_array_out_ks_buffer,
void *lwe_array_out_pbs_buffer, void *lut_pbs, void *lut_vector_indexes,
void *ksk, void *fourier_bsk, uint32_t number_of_bits, uint32_t delta_log,
uint32_t lwe_dimension_in, uint32_t lwe_dimension_out,
uint32_t glwe_dimension, uint32_t base_log_bsk, uint32_t level_count_bsk,
uint32_t base_log_ksk, uint32_t level_count_ksk,
uint32_t number_of_samples);
}

View File

@@ -6,14 +6,14 @@
extern "C" {
void cuda_keyswitch_lwe_ciphertext_vector_32(
void *v_stream, void *lwe_array_out, void *lwe_array_in, void *ksk,
uint32_t lwe_dimension_in, uint32_t lwe_dimension_out, uint32_t base_log,
uint32_t level_count, uint32_t num_samples);
void *v_stream, uint32_t gpu_index, void *lwe_array_out, void *lwe_array_in,
void *ksk, uint32_t lwe_dimension_in, uint32_t lwe_dimension_out,
uint32_t base_log, uint32_t level_count, uint32_t num_samples);
void cuda_keyswitch_lwe_ciphertext_vector_64(
void *v_stream, void *lwe_array_out, void *lwe_array_in, void *ksk,
uint32_t lwe_dimension_in, uint32_t lwe_dimension_out, uint32_t base_log,
uint32_t level_count, uint32_t num_samples);
void *v_stream, uint32_t gpu_index, void *lwe_array_out, void *lwe_array_in,
void *ksk, uint32_t lwe_dimension_in, uint32_t lwe_dimension_out,
uint32_t base_log, uint32_t level_count, uint32_t num_samples);
}
#endif // CNCRT_KS_H_