From c4f0daa203793127cbb5681c4b10a3b8aba236b4 Mon Sep 17 00:00:00 2001 From: Pedro Alves Date: Wed, 4 Jan 2023 16:19:26 -0300 Subject: [PATCH] fix(cuda): Fix the CUDA test for CBS+VP when tau > 1 and tau * p == log2(N). --- src/vertical_packing.cuh | 4 +-- src/wop_bootstrap.cuh | 63 +++++++++++++++------------------------- 2 files changed, 26 insertions(+), 41 deletions(-) diff --git a/src/vertical_packing.cuh b/src/vertical_packing.cuh index 4607698e9..9daeb58ac 100644 --- a/src/vertical_packing.cuh +++ b/src/vertical_packing.cuh @@ -276,8 +276,8 @@ void host_cmux_tree(void *v_stream, uint32_t gpu_index, Torus *glwe_array_out, int num_lut = (1 << r); if (r == 0) { // Simply copy the LUTs - add_padding_to_lut_async( - glwe_array_out, lut_vector, glwe_dimension, num_lut * tau, stream); + add_padding_to_lut_async(glwe_array_out, lut_vector, + glwe_dimension, tau, stream); checkCudaErrors(cudaStreamSynchronize(*stream)); return; } diff --git a/src/wop_bootstrap.cuh b/src/wop_bootstrap.cuh index fcde83007..361b95c30 100644 --- a/src/wop_bootstrap.cuh +++ b/src/wop_bootstrap.cuh @@ -37,7 +37,7 @@ __host__ void host_circuit_bootstrap_vertical_packing( Torus *cbs_fpksk, uint32_t glwe_dimension, uint32_t lwe_dimension, uint32_t polynomial_size, uint32_t base_log_bsk, uint32_t level_count_bsk, uint32_t base_log_pksk, uint32_t level_count_pksk, uint32_t base_log_cbs, - uint32_t level_count_cbs, uint32_t number_of_inputs, uint32_t lut_number, + uint32_t level_count_cbs, uint32_t number_of_inputs, uint32_t tau, uint32_t max_shared_memory) { auto stream = static_cast(v_stream); @@ -100,46 +100,31 @@ __host__ void host_circuit_bootstrap_vertical_packing( free(h_lut_vector_indexes); // number_of_inputs = tau * p is the total number of GGSWs - if (number_of_inputs > params::log2_degree) { - // split the vec of GGSW in two, the msb GGSW is for the CMux tree and the - // lsb GGSW is for the last blind rotation. - uint32_t r = number_of_inputs - params::log2_degree; - Torus *br_ggsw = (Torus *)ggsw_out + - (ptrdiff_t)(r * level_count_cbs * (glwe_dimension + 1) * - (glwe_dimension + 1) * polynomial_size); - Torus *glwe_array_out = (Torus *)cuda_malloc_async( - lut_number * (glwe_dimension + 1) * polynomial_size * sizeof(Torus), - stream, gpu_index); - // CMUX Tree - // r = tau * p - log2(N) - host_cmux_tree( - v_stream, gpu_index, glwe_array_out, ggsw_out, lut_vector, - glwe_dimension, polynomial_size, base_log_cbs, level_count_cbs, r, - lut_number, max_shared_memory); - checkCudaErrors(cudaGetLastError()); - cuda_drop_async(glwe_array_out, stream, gpu_index); + // split the vec of GGSW in two, the msb GGSW is for the CMux tree and the + // lsb GGSW is for the last blind rotation. + uint32_t r = number_of_inputs - params::log2_degree; + Torus *glwe_array_out = (Torus *)cuda_malloc_async( + tau * (glwe_dimension + 1) * polynomial_size * sizeof(Torus), stream, + gpu_index); + // CMUX Tree + // r = tau * p - log2(N) + host_cmux_tree( + v_stream, gpu_index, glwe_array_out, ggsw_out, lut_vector, glwe_dimension, + polynomial_size, base_log_cbs, level_count_cbs, r, tau, + max_shared_memory); + checkCudaErrors(cudaGetLastError()); - // Blind rotation + sample extraction - // mbr = tau * p - r = log2(N) - host_blind_rotate_and_sample_extraction( - v_stream, gpu_index, lwe_array_out, br_ggsw, glwe_array_out, - number_of_inputs - r, lut_number, glwe_dimension, polynomial_size, - base_log_cbs, level_count_cbs, max_shared_memory); - } else { - // we need to expand the lut to fill the masks with zeros - Torus *lut_vector_glwe = (Torus *)cuda_malloc_async( - lut_number * (glwe_dimension + 1) * polynomial_size * sizeof(Torus), - stream, gpu_index); - add_padding_to_lut_async(lut_vector_glwe, lut_vector, - glwe_dimension, lut_number, stream); - checkCudaErrors(cudaGetLastError()); + // Blind rotation + sample extraction + // mbr = tau * p - r = log2(N) + Torus *br_ggsw = (Torus *)ggsw_out + + (ptrdiff_t)(r * level_count_cbs * (glwe_dimension + 1) * + (glwe_dimension + 1) * polynomial_size); + host_blind_rotate_and_sample_extraction( + v_stream, gpu_index, lwe_array_out, br_ggsw, glwe_array_out, + number_of_inputs - r, tau, glwe_dimension, polynomial_size, base_log_cbs, + level_count_cbs, max_shared_memory); - // Blind rotation + sample extraction - host_blind_rotate_and_sample_extraction( - v_stream, gpu_index, lwe_array_out, ggsw_out, lut_vector_glwe, - number_of_inputs, lut_number, glwe_dimension, polynomial_size, - base_log_cbs, level_count_cbs, max_shared_memory); - } + cuda_drop_async(glwe_array_out, stream, gpu_index); cuda_drop_async(ggsw_out, stream, gpu_index); }