Files
concrete/src/bootstrap_wop.cu
Agnes Leroy 4445fcc7f1 chore(cuda): rename some variables to match concrete-core notations
- rename l_gadget and stop calling low lat PBS with N too large
- rename trlwe and trgsw
- rename lwe_mask_size into lwe_dimension
- rename lwe_in into lwe_array_in
- rename lwe_out into lwe_array_out
- rename decomp_level into level
- rename lwe_dimension_before/after into lwe_dimension_in/out
2022-10-19 10:26:08 +02:00

254 lines
11 KiB
Plaintext

#include "bootstrap_wop.cuh"
void cuda_cmux_tree_32(void *v_stream, void *glwe_array_out, void *ggsw_in,
void *lut_vector, uint32_t glwe_dimension,
uint32_t polynomial_size, uint32_t base_log,
uint32_t level_count, uint32_t r,
uint32_t max_shared_memory) {
assert(("Error (GPU Cmux tree): base log should be <= 16", base_log <= 16));
assert(("Error (GPU Cmux tree): polynomial size should be one of 512, 1024, "
"2048, 4096, 8192",
polynomial_size == 512 || polynomial_size == 1024 ||
polynomial_size == 2048 || polynomial_size == 4096 ||
polynomial_size == 8192));
// For larger k we will need to adjust the mask size
assert(("Error (GPU Cmux tree): glwe_dimension should be equal to 1",
glwe_dimension == 1));
assert(("Error (GPU Cmux tree): r, the number of layers in the tree, should "
"be >= 1 ",
r >= 1));
switch (polynomial_size) {
case 512:
host_cmux_tree<uint32_t, int32_t, Degree<512>>(
v_stream, (uint32_t *)glwe_array_out, (uint32_t *)ggsw_in,
(uint32_t *)lut_vector, glwe_dimension, polynomial_size, base_log,
level_count, r, max_shared_memory);
break;
case 1024:
host_cmux_tree<uint32_t, int32_t, Degree<1024>>(
v_stream, (uint32_t *)glwe_array_out, (uint32_t *)ggsw_in,
(uint32_t *)lut_vector, glwe_dimension, polynomial_size, base_log,
level_count, r, max_shared_memory);
break;
case 2048:
host_cmux_tree<uint32_t, int32_t, Degree<2048>>(
v_stream, (uint32_t *)glwe_array_out, (uint32_t *)ggsw_in,
(uint32_t *)lut_vector, glwe_dimension, polynomial_size, base_log,
level_count, r, max_shared_memory);
break;
case 4096:
host_cmux_tree<uint32_t, int32_t, Degree<4096>>(
v_stream, (uint32_t *)glwe_array_out, (uint32_t *)ggsw_in,
(uint32_t *)lut_vector, glwe_dimension, polynomial_size, base_log,
level_count, r, max_shared_memory);
break;
case 8192:
host_cmux_tree<uint32_t, int32_t, Degree<8192>>(
v_stream, (uint32_t *)glwe_array_out, (uint32_t *)ggsw_in,
(uint32_t *)lut_vector, glwe_dimension, polynomial_size, base_log,
level_count, r, max_shared_memory);
break;
default:
break;
}
}
void cuda_cmux_tree_64(void *v_stream, void *glwe_array_out, void *ggsw_in,
void *lut_vector, uint32_t glwe_dimension,
uint32_t polynomial_size, uint32_t base_log,
uint32_t level_count, uint32_t r,
uint32_t max_shared_memory) {
assert(("Error (GPU Cmux tree): base log should be <= 16", base_log <= 16));
assert(("Error (GPU Cmux tree): polynomial size should be one of 512, 1024, "
"2048, 4096, 8192",
polynomial_size == 512 || polynomial_size == 1024 ||
polynomial_size == 2048 || polynomial_size == 4096 ||
polynomial_size == 8192));
// For larger k we will need to adjust the mask size
assert(("Error (GPU Cmux tree): glwe_dimension should be equal to 1",
glwe_dimension == 1));
assert(("Error (GPU Cmux tree): r, the number of layers in the tree, should "
"be >= 1 ",
r >= 1));
switch (polynomial_size) {
case 512:
host_cmux_tree<uint64_t, int64_t, Degree<512>>(
v_stream, (uint64_t *)glwe_array_out, (uint64_t *)ggsw_in,
(uint64_t *)lut_vector, glwe_dimension, polynomial_size, base_log,
level_count, r, max_shared_memory);
break;
case 1024:
host_cmux_tree<uint64_t, int64_t, Degree<1024>>(
v_stream, (uint64_t *)glwe_array_out, (uint64_t *)ggsw_in,
(uint64_t *)lut_vector, glwe_dimension, polynomial_size, base_log,
level_count, r, max_shared_memory);
break;
case 2048:
host_cmux_tree<uint64_t, int64_t, Degree<2048>>(
v_stream, (uint64_t *)glwe_array_out, (uint64_t *)ggsw_in,
(uint64_t *)lut_vector, glwe_dimension, polynomial_size, base_log,
level_count, r, max_shared_memory);
break;
case 4096:
host_cmux_tree<uint64_t, int64_t, Degree<4096>>(
v_stream, (uint64_t *)glwe_array_out, (uint64_t *)ggsw_in,
(uint64_t *)lut_vector, glwe_dimension, polynomial_size, base_log,
level_count, r, max_shared_memory);
break;
case 8192:
host_cmux_tree<uint64_t, int64_t, Degree<8192>>(
v_stream, (uint64_t *)glwe_array_out, (uint64_t *)ggsw_in,
(uint64_t *)lut_vector, glwe_dimension, polynomial_size, base_log,
level_count, r, max_shared_memory);
break;
default:
break;
}
}
void cuda_extract_bits_32(
void *v_stream, void *list_lwe_array_out, void *lwe_array_in,
void *lwe_array_in_buffer, void *lwe_array_in_shifted_buffer,
void *lwe_array_out_ks_buffer, void *lwe_array_out_pbs_buffer,
void *lut_pbs, void *lut_vector_indexes, void *ksk, void *fourier_bsk,
uint32_t number_of_bits, uint32_t delta_log, uint32_t lwe_dimension_in,
uint32_t lwe_dimension_out, uint32_t glwe_dimension, uint32_t base_log_bsk,
uint32_t level_count_bsk, uint32_t base_log_ksk, uint32_t level_count_ksk,
uint32_t number_of_samples) {
assert(("Error (GPU extract bits): base log should be <= 16",
base_log_bsk <= 16));
assert(("Error (GPU extract bits): glwe_dimension should be equal to 1",
glwe_dimension == 1));
assert(("Error (GPU extract bits): lwe_dimension_in should be one of "
"512, 1024, 2048",
lwe_dimension_in == 512 || lwe_dimension_in == 1024 ||
lwe_dimension_in == 2048));
// The number of samples should be lower than the number of streaming
// multiprocessors divided by (4 * (k + 1) * l) (the factor 4 being related
// to the occupancy of 50%). The only supported value for k is 1, so
// k + 1 = 2 for now.
int number_of_sm = 0;
cudaDeviceGetAttribute(&number_of_sm, cudaDevAttrMultiProcessorCount, 0);
assert(("Error (GPU extract bits): the number of input LWEs must be lower or "
"equal to the "
"number of streaming multiprocessors on the device divided by 8 * "
"level_count_bsk",
number_of_samples <= number_of_sm / 4. / 2. / level_count_bsk));
switch (lwe_dimension_in) {
case 512:
host_extract_bits<uint32_t, Degree<512>>(
v_stream, (uint32_t *)list_lwe_array_out, (uint32_t *)lwe_array_in,
(uint32_t *)lwe_array_in_buffer,
(uint32_t *)lwe_array_in_shifted_buffer,
(uint32_t *)lwe_array_out_ks_buffer,
(uint32_t *)lwe_array_out_pbs_buffer, (uint32_t *)lut_pbs,
(uint32_t *)lut_vector_indexes, (uint32_t *)ksk, (double2 *)fourier_bsk,
number_of_bits, delta_log, lwe_dimension_in, lwe_dimension_out,
base_log_bsk, level_count_bsk, base_log_ksk, level_count_ksk,
number_of_samples);
break;
case 1024:
host_extract_bits<uint32_t, Degree<1024>>(
v_stream, (uint32_t *)list_lwe_array_out, (uint32_t *)lwe_array_in,
(uint32_t *)lwe_array_in_buffer,
(uint32_t *)lwe_array_in_shifted_buffer,
(uint32_t *)lwe_array_out_ks_buffer,
(uint32_t *)lwe_array_out_pbs_buffer, (uint32_t *)lut_pbs,
(uint32_t *)lut_vector_indexes, (uint32_t *)ksk, (double2 *)fourier_bsk,
number_of_bits, delta_log, lwe_dimension_in, lwe_dimension_out,
base_log_bsk, level_count_bsk, base_log_ksk, level_count_ksk,
number_of_samples);
break;
case 2048:
host_extract_bits<uint32_t, Degree<2048>>(
v_stream, (uint32_t *)list_lwe_array_out, (uint32_t *)lwe_array_in,
(uint32_t *)lwe_array_in_buffer,
(uint32_t *)lwe_array_in_shifted_buffer,
(uint32_t *)lwe_array_out_ks_buffer,
(uint32_t *)lwe_array_out_pbs_buffer, (uint32_t *)lut_pbs,
(uint32_t *)lut_vector_indexes, (uint32_t *)ksk, (double2 *)fourier_bsk,
number_of_bits, delta_log, lwe_dimension_in, lwe_dimension_out,
base_log_bsk, level_count_bsk, base_log_ksk, level_count_ksk,
number_of_samples);
break;
default:
break;
}
}
void cuda_extract_bits_64(
void *v_stream, void *list_lwe_array_out, void *lwe_array_in,
void *lwe_array_in_buffer, void *lwe_array_in_shifted_buffer,
void *lwe_array_out_ks_buffer, void *lwe_array_out_pbs_buffer,
void *lut_pbs, void *lut_vector_indexes, void *ksk, void *fourier_bsk,
uint32_t number_of_bits, uint32_t delta_log, uint32_t lwe_dimension_in,
uint32_t lwe_dimension_out, uint32_t glwe_dimension, uint32_t base_log_bsk,
uint32_t level_count_bsk, uint32_t base_log_ksk, uint32_t level_count_ksk,
uint32_t number_of_samples) {
assert(("Error (GPU extract bits): base log should be <= 16",
base_log_bsk <= 16));
assert(("Error (GPU extract bits): glwe_dimension should be equal to 1",
glwe_dimension == 1));
assert(("Error (GPU extract bits): lwe_dimension_in should be one of "
"512, 1024, 2048",
lwe_dimension_in == 512 || lwe_dimension_in == 1024 ||
lwe_dimension_in == 2048));
// The number of samples should be lower than the number of streaming
// multiprocessors divided by (4 * (k + 1) * l) (the factor 4 being related
// to the occupancy of 50%). The only supported value for k is 1, so
// k + 1 = 2 for now.
int number_of_sm = 0;
cudaDeviceGetAttribute(&number_of_sm, cudaDevAttrMultiProcessorCount, 0);
assert(("Error (GPU extract bits): the number of input LWEs must be lower or "
"equal to the "
"number of streaming multiprocessors on the device divided by 8 * "
"level_count_bsk",
number_of_samples <= number_of_sm / 4. / 2. / level_count_bsk));
switch (lwe_dimension_in) {
case 512:
host_extract_bits<uint64_t, Degree<512>>(
v_stream, (uint64_t *)list_lwe_array_out, (uint64_t *)lwe_array_in,
(uint64_t *)lwe_array_in_buffer,
(uint64_t *)lwe_array_in_shifted_buffer,
(uint64_t *)lwe_array_out_ks_buffer,
(uint64_t *)lwe_array_out_pbs_buffer, (uint64_t *)lut_pbs,
(uint32_t *)lut_vector_indexes, (uint64_t *)ksk, (double2 *)fourier_bsk,
number_of_bits, delta_log, lwe_dimension_in, lwe_dimension_out,
base_log_bsk, level_count_bsk, base_log_ksk, level_count_ksk,
number_of_samples);
break;
case 1024:
host_extract_bits<uint64_t, Degree<1024>>(
v_stream, (uint64_t *)list_lwe_array_out, (uint64_t *)lwe_array_in,
(uint64_t *)lwe_array_in_buffer,
(uint64_t *)lwe_array_in_shifted_buffer,
(uint64_t *)lwe_array_out_ks_buffer,
(uint64_t *)lwe_array_out_pbs_buffer, (uint64_t *)lut_pbs,
(uint32_t *)lut_vector_indexes, (uint64_t *)ksk, (double2 *)fourier_bsk,
number_of_bits, delta_log, lwe_dimension_in, lwe_dimension_out,
base_log_bsk, level_count_bsk, base_log_ksk, level_count_ksk,
number_of_samples);
break;
case 2048:
host_extract_bits<uint64_t, Degree<2048>>(
v_stream, (uint64_t *)list_lwe_array_out, (uint64_t *)lwe_array_in,
(uint64_t *)lwe_array_in_buffer,
(uint64_t *)lwe_array_in_shifted_buffer,
(uint64_t *)lwe_array_out_ks_buffer,
(uint64_t *)lwe_array_out_pbs_buffer, (uint64_t *)lut_pbs,
(uint32_t *)lut_vector_indexes, (uint64_t *)ksk, (double2 *)fourier_bsk,
number_of_bits, delta_log, lwe_dimension_in, lwe_dimension_out,
base_log_bsk, level_count_bsk, base_log_ksk, level_count_ksk,
number_of_samples);
break;
default:
break;
}
}