fix(cuda): Fix the CUDA test for CBS+VP when tau > 1 and tau * p == log2(N).

This commit is contained in:
Pedro Alves
2023-01-04 16:19:26 -03:00
committed by Agnès Leroy
parent e82a8d4e81
commit c4f0daa203
2 changed files with 26 additions and 41 deletions

View File

@@ -276,8 +276,8 @@ void host_cmux_tree(void *v_stream, uint32_t gpu_index, Torus *glwe_array_out,
int num_lut = (1 << r);
if (r == 0) {
// Simply copy the LUTs
add_padding_to_lut_async<Torus, params>(
glwe_array_out, lut_vector, glwe_dimension, num_lut * tau, stream);
add_padding_to_lut_async<Torus, params>(glwe_array_out, lut_vector,
glwe_dimension, tau, stream);
checkCudaErrors(cudaStreamSynchronize(*stream));
return;
}

View File

@@ -37,7 +37,7 @@ __host__ void host_circuit_bootstrap_vertical_packing(
Torus *cbs_fpksk, uint32_t glwe_dimension, uint32_t lwe_dimension,
uint32_t polynomial_size, uint32_t base_log_bsk, uint32_t level_count_bsk,
uint32_t base_log_pksk, uint32_t level_count_pksk, uint32_t base_log_cbs,
uint32_t level_count_cbs, uint32_t number_of_inputs, uint32_t lut_number,
uint32_t level_count_cbs, uint32_t number_of_inputs, uint32_t tau,
uint32_t max_shared_memory) {
auto stream = static_cast<cudaStream_t *>(v_stream);
@@ -100,46 +100,31 @@ __host__ void host_circuit_bootstrap_vertical_packing(
free(h_lut_vector_indexes);
// number_of_inputs = tau * p is the total number of GGSWs
if (number_of_inputs > params::log2_degree) {
// split the vec of GGSW in two, the msb GGSW is for the CMux tree and the
// lsb GGSW is for the last blind rotation.
uint32_t r = number_of_inputs - params::log2_degree;
Torus *br_ggsw = (Torus *)ggsw_out +
(ptrdiff_t)(r * level_count_cbs * (glwe_dimension + 1) *
(glwe_dimension + 1) * polynomial_size);
Torus *glwe_array_out = (Torus *)cuda_malloc_async(
lut_number * (glwe_dimension + 1) * polynomial_size * sizeof(Torus),
stream, gpu_index);
// CMUX Tree
// r = tau * p - log2(N)
host_cmux_tree<Torus, STorus, params>(
v_stream, gpu_index, glwe_array_out, ggsw_out, lut_vector,
glwe_dimension, polynomial_size, base_log_cbs, level_count_cbs, r,
lut_number, max_shared_memory);
checkCudaErrors(cudaGetLastError());
cuda_drop_async(glwe_array_out, stream, gpu_index);
// split the vec of GGSW in two, the msb GGSW is for the CMux tree and the
// lsb GGSW is for the last blind rotation.
uint32_t r = number_of_inputs - params::log2_degree;
Torus *glwe_array_out = (Torus *)cuda_malloc_async(
tau * (glwe_dimension + 1) * polynomial_size * sizeof(Torus), stream,
gpu_index);
// CMUX Tree
// r = tau * p - log2(N)
host_cmux_tree<Torus, STorus, params>(
v_stream, gpu_index, glwe_array_out, ggsw_out, lut_vector, glwe_dimension,
polynomial_size, base_log_cbs, level_count_cbs, r, tau,
max_shared_memory);
checkCudaErrors(cudaGetLastError());
// Blind rotation + sample extraction
// mbr = tau * p - r = log2(N)
host_blind_rotate_and_sample_extraction<Torus, STorus, params>(
v_stream, gpu_index, lwe_array_out, br_ggsw, glwe_array_out,
number_of_inputs - r, lut_number, glwe_dimension, polynomial_size,
base_log_cbs, level_count_cbs, max_shared_memory);
} else {
// we need to expand the lut to fill the masks with zeros
Torus *lut_vector_glwe = (Torus *)cuda_malloc_async(
lut_number * (glwe_dimension + 1) * polynomial_size * sizeof(Torus),
stream, gpu_index);
add_padding_to_lut_async<Torus, params>(lut_vector_glwe, lut_vector,
glwe_dimension, lut_number, stream);
checkCudaErrors(cudaGetLastError());
// Blind rotation + sample extraction
// mbr = tau * p - r = log2(N)
Torus *br_ggsw = (Torus *)ggsw_out +
(ptrdiff_t)(r * level_count_cbs * (glwe_dimension + 1) *
(glwe_dimension + 1) * polynomial_size);
host_blind_rotate_and_sample_extraction<Torus, STorus, params>(
v_stream, gpu_index, lwe_array_out, br_ggsw, glwe_array_out,
number_of_inputs - r, tau, glwe_dimension, polynomial_size, base_log_cbs,
level_count_cbs, max_shared_memory);
// Blind rotation + sample extraction
host_blind_rotate_and_sample_extraction<Torus, STorus, params>(
v_stream, gpu_index, lwe_array_out, ggsw_out, lut_vector_glwe,
number_of_inputs, lut_number, glwe_dimension, polynomial_size,
base_log_cbs, level_count_cbs, max_shared_memory);
}
cuda_drop_async(glwe_array_out, stream, gpu_index);
cuda_drop_async(ggsw_out, stream, gpu_index);
}