From 5cd0cb5d19aae9d65717f282f62607094412b57f Mon Sep 17 00:00:00 2001 From: Agnes Leroy Date: Thu, 16 Feb 2023 10:04:50 +0100 Subject: [PATCH] refactor(cuda): introduce scratch for amortized PBS --- include/bootstrap.h | 61 +++++----- include/circuit_bootstrap.h | 42 +++++++ src/CMakeLists.txt | 3 +- src/boolean_gates.cu | 54 +++++++-- src/bootstrap_amortized.cu | 148 +++++++++++++++++++++---- src/bootstrap_amortized.cuh | 122 ++++++++++++++------ src/circuit_bootstrap.cu | 215 ++++++++++++++++++++++++------------ src/circuit_bootstrap.cuh | 86 ++++++++++++--- src/vertical_packing.cuh | 1 - src/wop_bootstrap.cuh | 127 +++++++++++---------- 10 files changed, 612 insertions(+), 247 deletions(-) create mode 100644 include/circuit_bootstrap.h diff --git a/include/bootstrap.h b/include/bootstrap.h index 732fdccd4..6ac6d8088 100644 --- a/include/bootstrap.h +++ b/include/bootstrap.h @@ -17,19 +17,40 @@ void cuda_convert_lwe_bootstrap_key_64(void *dest, void *src, void *v_stream, uint32_t glwe_dim, uint32_t level_count, uint32_t polynomial_size); +void scratch_cuda_bootstrap_amortized_32(void *v_stream, uint32_t gpu_index, + int8_t **pbs_buffer, + uint32_t glwe_dimension, + uint32_t polynomial_size, + uint32_t input_lwe_ciphertext_count, + uint32_t max_shared_memory, + bool allocate_gpu_memory); + +void scratch_cuda_bootstrap_amortized_64(void *v_stream, uint32_t gpu_index, + int8_t **pbs_buffer, + uint32_t glwe_dimension, + uint32_t polynomial_size, + uint32_t input_lwe_ciphertext_count, + uint32_t max_shared_memory, + bool allocate_gpu_memory); + void cuda_bootstrap_amortized_lwe_ciphertext_vector_32( - void *v_stream, uint32_t gpu_index, void *lwe_array_out, void *test_vector, - void *test_vector_indexes, void *lwe_array_in, void *bootstrapping_key, - uint32_t lwe_dimension, uint32_t glwe_dimension, uint32_t polynomial_size, - uint32_t base_log, uint32_t level_count, uint32_t num_samples, - uint32_t num_test_vectors, uint32_t lwe_idx, uint32_t max_shared_memory); + void *v_stream, uint32_t gpu_index, void *lwe_array_out, void *lut_vector, + void *lut_vector_indexes, void *lwe_array_in, void *bootstrapping_key, + int8_t *pbs_buffer, uint32_t lwe_dimension, uint32_t glwe_dimension, + uint32_t polynomial_size, uint32_t base_log, uint32_t level_count, + uint32_t num_samples, uint32_t num_lut_vectors, uint32_t lwe_idx, + uint32_t max_shared_memory); void cuda_bootstrap_amortized_lwe_ciphertext_vector_64( - void *v_stream, uint32_t gpu_index, void *lwe_array_out, void *test_vector, - void *test_vector_indexes, void *lwe_array_in, void *bootstrapping_key, - uint32_t lwe_dimension, uint32_t glwe_dimension, uint32_t polynomial_size, - uint32_t base_log, uint32_t level_count, uint32_t num_samples, - uint32_t num_test_vectors, uint32_t lwe_idx, uint32_t max_shared_memory); + void *v_stream, uint32_t gpu_index, void *lwe_array_out, void *lut_vector, + void *lut_vector_indexes, void *lwe_array_in, void *bootstrapping_key, + int8_t *pbs_buffer, uint32_t lwe_dimension, uint32_t glwe_dimension, + uint32_t polynomial_size, uint32_t base_log, uint32_t level_count, + uint32_t num_samples, uint32_t num_lut_vectors, uint32_t lwe_idx, + uint32_t max_shared_memory); + +void cleanup_cuda_bootstrap_amortized(void *v_stream, uint32_t gpu_index, + int8_t **pbs_buffer); void cuda_bootstrap_low_latency_lwe_ciphertext_vector_32( void *v_stream, uint32_t gpu_index, void *lwe_array_out, void *test_vector, @@ -67,26 +88,6 @@ void cuda_extract_bits_64( uint32_t base_log_ksk, uint32_t level_count_ksk, uint32_t number_of_samples, uint32_t max_shared_memory); -void cuda_circuit_bootstrap_32( - void *v_stream, uint32_t gpu_index, void *ggsw_out, void *lwe_array_in, - void *fourier_bsk, void *fp_ksk_array, void *lwe_array_in_shifted_buffer, - void *lut_vector, void *lut_vector_indexes, void *lwe_array_out_pbs_buffer, - void *lwe_array_in_fp_ks_buffer, uint32_t delta_log, - uint32_t polynomial_size, uint32_t glwe_dimension, uint32_t lwe_dimension, - uint32_t level_bsk, uint32_t base_log_bsk, uint32_t level_pksk, - uint32_t base_log_pksk, uint32_t level_cbs, uint32_t base_log_cbs, - uint32_t number_of_samples, uint32_t max_shared_memory); - -void cuda_circuit_bootstrap_64( - void *v_stream, uint32_t gpu_index, void *ggsw_out, void *lwe_array_in, - void *fourier_bsk, void *fp_ksk_array, void *lwe_array_in_shifted_buffer, - void *lut_vector, void *lut_vector_indexes, void *lwe_array_out_pbs_buffer, - void *lwe_array_in_fp_ks_buffer, uint32_t delta_log, - uint32_t polynomial_size, uint32_t glwe_dimension, uint32_t lwe_dimension, - uint32_t level_bsk, uint32_t base_log_bsk, uint32_t level_pksk, - uint32_t base_log_pksk, uint32_t level_cbs, uint32_t base_log_cbs, - uint32_t number_of_samples, uint32_t max_shared_memory); - void scratch_cuda_circuit_bootstrap_vertical_packing_32( void *v_stream, uint32_t gpu_index, int8_t **cbs_vp_buffer, uint32_t *cbs_delta_log, uint32_t glwe_dimension, uint32_t lwe_dimension, diff --git a/include/circuit_bootstrap.h b/include/circuit_bootstrap.h new file mode 100644 index 000000000..728810344 --- /dev/null +++ b/include/circuit_bootstrap.h @@ -0,0 +1,42 @@ +#ifndef CUDA_CIRCUIT_BOOTSTRAP_H +#define CUDA_CIRCUIT_BOOTSTRAP_H + +#include + +extern "C" { + +void scratch_cuda_circuit_bootstrap_32( + void *v_stream, uint32_t gpu_index, int8_t **cbs_buffer, + uint32_t glwe_dimension, uint32_t lwe_dimension, uint32_t polynomial_size, + uint32_t level_count_cbs, uint32_t number_of_inputs, + uint32_t max_shared_memory, bool allocate_gpu_memory); + +void scratch_cuda_circuit_bootstrap_64( + void *v_stream, uint32_t gpu_index, int8_t **cbs_buffer, + uint32_t glwe_dimension, uint32_t lwe_dimension, uint32_t polynomial_size, + uint32_t level_count_cbs, uint32_t number_of_inputs, + uint32_t max_shared_memory, bool allocate_gpu_memory); + +void cuda_circuit_bootstrap_32( + void *v_stream, uint32_t gpu_index, void *ggsw_out, void *lwe_array_in, + void *fourier_bsk, void *fp_ksk_array, void *lut_vector_indexes, + int8_t *cbs_buffer, uint32_t delta_log, uint32_t polynomial_size, + uint32_t glwe_dimension, uint32_t lwe_dimension, uint32_t level_bsk, + uint32_t base_log_bsk, uint32_t level_pksk, uint32_t base_log_pksk, + uint32_t level_cbs, uint32_t base_log_cbs, uint32_t number_of_inputs, + uint32_t max_shared_memory); + +void cuda_circuit_bootstrap_64( + void *v_stream, uint32_t gpu_index, void *ggsw_out, void *lwe_array_in, + void *fourier_bsk, void *fp_ksk_array, void *lut_vector_indexes, + int8_t *cbs_buffer, uint32_t delta_log, uint32_t polynomial_size, + uint32_t glwe_dimension, uint32_t lwe_dimension, uint32_t level_bsk, + uint32_t base_log_bsk, uint32_t level_pksk, uint32_t base_log_pksk, + uint32_t level_cbs, uint32_t base_log_cbs, uint32_t number_of_inputs, + uint32_t max_shared_memory); + +void cleanup_cuda_circuit_bootstrap(void *v_stream, uint32_t gpu_index, + int8_t **cbs_buffer); +} + +#endif // CUDA_CIRCUIT_BOOTSTRAP_H diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index 95be8d92f..282249b23 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -3,7 +3,8 @@ set(SOURCES ${CMAKE_SOURCE_DIR}/${INCLUDE_DIR}/bootstrap.h ${CMAKE_SOURCE_DIR}/${INCLUDE_DIR}/keyswitch.h ${CMAKE_SOURCE_DIR}/${INCLUDE_DIR}/linear_algebra.h - ${CMAKE_SOURCE_DIR}/${INCLUDE_DIR}/vertical_packing.h) + ${CMAKE_SOURCE_DIR}/${INCLUDE_DIR}/vertical_packing.h + ${CMAKE_SOURCE_DIR}/${INCLUDE_DIR}/circuit_bootstrap.h) file(GLOB SOURCES "*.cu" "*.h" diff --git a/src/boolean_gates.cu b/src/boolean_gates.cu index d906065f3..55f2d0e7c 100644 --- a/src/boolean_gates.cu +++ b/src/boolean_gates.cu @@ -93,11 +93,16 @@ extern "C" void cuda_boolean_and_32( stream, gpu_index); check_cuda_error(cudaGetLastError()); + int8_t *pbs_buffer = nullptr; + scratch_cuda_bootstrap_amortized_32( + v_stream, gpu_index, &pbs_buffer, glwe_dimension, polynomial_size, + input_lwe_ciphertext_count, max_shared_memory, true); cuda_bootstrap_amortized_lwe_ciphertext_vector_32( v_stream, gpu_index, lwe_pbs_buffer, pbs_lut, pbs_lut_indexes, - lwe_buffer_2, bootstrapping_key, input_lwe_dimension, glwe_dimension, - polynomial_size, pbs_base_log, pbs_level_count, + lwe_buffer_2, bootstrapping_key, pbs_buffer, input_lwe_dimension, + glwe_dimension, polynomial_size, pbs_base_log, pbs_level_count, input_lwe_ciphertext_count, 1, 0, max_shared_memory); + cleanup_cuda_bootstrap_amortized(v_stream, gpu_index, &pbs_buffer); check_cuda_error(cudaGetLastError()); cuda_drop_async(lwe_buffer_2, stream, gpu_index); @@ -196,11 +201,16 @@ extern "C" void cuda_boolean_nand_32( stream, gpu_index); check_cuda_error(cudaGetLastError()); + int8_t *pbs_buffer = nullptr; + scratch_cuda_bootstrap_amortized_32( + v_stream, gpu_index, &pbs_buffer, glwe_dimension, polynomial_size, + input_lwe_ciphertext_count, max_shared_memory, true); cuda_bootstrap_amortized_lwe_ciphertext_vector_32( v_stream, gpu_index, lwe_pbs_buffer, pbs_lut, pbs_lut_indexes, - lwe_buffer_3, bootstrapping_key, input_lwe_dimension, glwe_dimension, - polynomial_size, pbs_base_log, pbs_level_count, + lwe_buffer_3, bootstrapping_key, pbs_buffer, input_lwe_dimension, + glwe_dimension, polynomial_size, pbs_base_log, pbs_level_count, input_lwe_ciphertext_count, 1, 0, max_shared_memory); + cleanup_cuda_bootstrap_amortized(v_stream, gpu_index, &pbs_buffer); check_cuda_error(cudaGetLastError()); cuda_drop_async(lwe_buffer_3, stream, gpu_index); @@ -299,11 +309,16 @@ extern "C" void cuda_boolean_nor_32( stream, gpu_index); check_cuda_error(cudaGetLastError()); + int8_t *pbs_buffer = nullptr; + scratch_cuda_bootstrap_amortized_32( + v_stream, gpu_index, &pbs_buffer, glwe_dimension, polynomial_size, + input_lwe_ciphertext_count, max_shared_memory, true); cuda_bootstrap_amortized_lwe_ciphertext_vector_32( v_stream, gpu_index, lwe_pbs_buffer, pbs_lut, pbs_lut_indexes, - lwe_buffer_3, bootstrapping_key, input_lwe_dimension, glwe_dimension, - polynomial_size, pbs_base_log, pbs_level_count, + lwe_buffer_3, bootstrapping_key, pbs_buffer, input_lwe_dimension, + glwe_dimension, polynomial_size, pbs_base_log, pbs_level_count, input_lwe_ciphertext_count, 1, 0, max_shared_memory); + cleanup_cuda_bootstrap_amortized(v_stream, gpu_index, &pbs_buffer); check_cuda_error(cudaGetLastError()); cuda_drop_async(lwe_buffer_3, stream, gpu_index); @@ -394,11 +409,16 @@ extern "C" void cuda_boolean_or_32( stream, gpu_index); check_cuda_error(cudaGetLastError()); + int8_t *pbs_buffer = nullptr; + scratch_cuda_bootstrap_amortized_32( + v_stream, gpu_index, &pbs_buffer, glwe_dimension, polynomial_size, + input_lwe_ciphertext_count, max_shared_memory, true); cuda_bootstrap_amortized_lwe_ciphertext_vector_32( v_stream, gpu_index, lwe_pbs_buffer, pbs_lut, pbs_lut_indexes, - lwe_buffer_2, bootstrapping_key, input_lwe_dimension, glwe_dimension, - polynomial_size, pbs_base_log, pbs_level_count, + lwe_buffer_2, bootstrapping_key, pbs_buffer, input_lwe_dimension, + glwe_dimension, polynomial_size, pbs_base_log, pbs_level_count, input_lwe_ciphertext_count, 1, 0, max_shared_memory); + cleanup_cuda_bootstrap_amortized(v_stream, gpu_index, &pbs_buffer); check_cuda_error(cudaGetLastError()); cuda_drop_async(lwe_buffer_2, stream, gpu_index); @@ -510,11 +530,16 @@ extern "C" void cuda_boolean_xor_32( stream, gpu_index); check_cuda_error(cudaGetLastError()); + int8_t *pbs_buffer = nullptr; + scratch_cuda_bootstrap_amortized_32( + v_stream, gpu_index, &pbs_buffer, glwe_dimension, polynomial_size, + input_lwe_ciphertext_count, max_shared_memory, true); cuda_bootstrap_amortized_lwe_ciphertext_vector_32( v_stream, gpu_index, lwe_pbs_buffer, pbs_lut, pbs_lut_indexes, - lwe_buffer_3, bootstrapping_key, input_lwe_dimension, glwe_dimension, - polynomial_size, pbs_base_log, pbs_level_count, + lwe_buffer_3, bootstrapping_key, pbs_buffer, input_lwe_dimension, + glwe_dimension, polynomial_size, pbs_base_log, pbs_level_count, input_lwe_ciphertext_count, 1, 0, max_shared_memory); + cleanup_cuda_bootstrap_amortized(v_stream, gpu_index, &pbs_buffer); check_cuda_error(cudaGetLastError()); cuda_drop_async(lwe_buffer_3, stream, gpu_index); @@ -633,11 +658,16 @@ extern "C" void cuda_boolean_xnor_32( stream, gpu_index); check_cuda_error(cudaGetLastError()); + int8_t *pbs_buffer = nullptr; + scratch_cuda_bootstrap_amortized_32( + v_stream, gpu_index, &pbs_buffer, glwe_dimension, polynomial_size, + input_lwe_ciphertext_count, max_shared_memory, true); cuda_bootstrap_amortized_lwe_ciphertext_vector_32( v_stream, gpu_index, lwe_pbs_buffer, pbs_lut, pbs_lut_indexes, - lwe_buffer_4, bootstrapping_key, input_lwe_dimension, glwe_dimension, - polynomial_size, pbs_base_log, pbs_level_count, + lwe_buffer_4, bootstrapping_key, pbs_buffer, input_lwe_dimension, + glwe_dimension, polynomial_size, pbs_base_log, pbs_level_count, input_lwe_ciphertext_count, 1, 0, max_shared_memory); + cleanup_cuda_bootstrap_amortized(v_stream, gpu_index, &pbs_buffer); check_cuda_error(cudaGetLastError()); cuda_drop_async(lwe_buffer_4, stream, gpu_index); diff --git a/src/bootstrap_amortized.cu b/src/bootstrap_amortized.cu index 79da3e8f9..069fcfddb 100644 --- a/src/bootstrap_amortized.cu +++ b/src/bootstrap_amortized.cu @@ -1,15 +1,113 @@ #include "bootstrap_amortized.cuh" +/* + * This scratch function allocates the necessary amount of data on the GPU for + * the amortized PBS on 32 bits inputs, into `cmux_tree_buffer`. It also + * configures SM options on the GPU in case FULLSM mode is going to be used. + */ +void scratch_cuda_bootstrap_amortized_32(void *v_stream, uint32_t gpu_index, + int8_t **pbs_buffer, + uint32_t glwe_dimension, + uint32_t polynomial_size, + uint32_t input_lwe_ciphertext_count, + uint32_t max_shared_memory, + bool allocate_gpu_memory) { + + switch (polynomial_size) { + case 256: + scratch_bootstrap_amortized>( + v_stream, gpu_index, pbs_buffer, glwe_dimension, polynomial_size, + input_lwe_ciphertext_count, max_shared_memory, allocate_gpu_memory); + break; + case 512: + scratch_bootstrap_amortized>( + v_stream, gpu_index, pbs_buffer, glwe_dimension, polynomial_size, + input_lwe_ciphertext_count, max_shared_memory, allocate_gpu_memory); + break; + case 1024: + scratch_bootstrap_amortized>( + v_stream, gpu_index, pbs_buffer, glwe_dimension, polynomial_size, + input_lwe_ciphertext_count, max_shared_memory, allocate_gpu_memory); + break; + case 2048: + scratch_bootstrap_amortized>( + v_stream, gpu_index, pbs_buffer, glwe_dimension, polynomial_size, + input_lwe_ciphertext_count, max_shared_memory, allocate_gpu_memory); + break; + case 4096: + scratch_bootstrap_amortized>( + v_stream, gpu_index, pbs_buffer, glwe_dimension, polynomial_size, + input_lwe_ciphertext_count, max_shared_memory, allocate_gpu_memory); + break; + case 8192: + scratch_bootstrap_amortized>( + v_stream, gpu_index, pbs_buffer, glwe_dimension, polynomial_size, + input_lwe_ciphertext_count, max_shared_memory, allocate_gpu_memory); + break; + default: + break; + } +} + +/* + * This scratch function allocates the necessary amount of data on the GPU for + * the amortized PBS on 64 bits inputs, into `cmux_tree_buffer`. It also + * configures SM options on the GPU in case FULLSM mode is going to be used. + */ +void scratch_cuda_bootstrap_amortized_64(void *v_stream, uint32_t gpu_index, + int8_t **pbs_buffer, + uint32_t glwe_dimension, + uint32_t polynomial_size, + uint32_t input_lwe_ciphertext_count, + uint32_t max_shared_memory, + bool allocate_gpu_memory) { + + switch (polynomial_size) { + case 256: + scratch_bootstrap_amortized>( + v_stream, gpu_index, pbs_buffer, glwe_dimension, polynomial_size, + input_lwe_ciphertext_count, max_shared_memory, allocate_gpu_memory); + break; + case 512: + scratch_bootstrap_amortized>( + v_stream, gpu_index, pbs_buffer, glwe_dimension, polynomial_size, + input_lwe_ciphertext_count, max_shared_memory, allocate_gpu_memory); + break; + case 1024: + scratch_bootstrap_amortized>( + v_stream, gpu_index, pbs_buffer, glwe_dimension, polynomial_size, + input_lwe_ciphertext_count, max_shared_memory, allocate_gpu_memory); + break; + case 2048: + scratch_bootstrap_amortized>( + v_stream, gpu_index, pbs_buffer, glwe_dimension, polynomial_size, + input_lwe_ciphertext_count, max_shared_memory, allocate_gpu_memory); + break; + case 4096: + scratch_bootstrap_amortized>( + v_stream, gpu_index, pbs_buffer, glwe_dimension, polynomial_size, + input_lwe_ciphertext_count, max_shared_memory, allocate_gpu_memory); + break; + case 8192: + scratch_bootstrap_amortized>( + v_stream, gpu_index, pbs_buffer, glwe_dimension, polynomial_size, + input_lwe_ciphertext_count, max_shared_memory, allocate_gpu_memory); + break; + default: + break; + } +} + /* Perform the programmable bootstrapping on a batch of input u32 LWE * ciphertexts. See the corresponding operation on 64 bits for more details. */ - void cuda_bootstrap_amortized_lwe_ciphertext_vector_32( void *v_stream, uint32_t gpu_index, void *lwe_array_out, void *lut_vector, void *lut_vector_indexes, void *lwe_array_in, void *bootstrapping_key, - uint32_t lwe_dimension, uint32_t glwe_dimension, uint32_t polynomial_size, - uint32_t base_log, uint32_t level_count, uint32_t num_samples, - uint32_t num_lut_vectors, uint32_t lwe_idx, uint32_t max_shared_memory) { + int8_t *pbs_buffer, uint32_t lwe_dimension, uint32_t glwe_dimension, + uint32_t polynomial_size, uint32_t base_log, uint32_t level_count, + uint32_t num_samples, uint32_t num_lut_vectors, uint32_t lwe_idx, + uint32_t max_shared_memory) { assert( ("Error (GPU amortized PBS): base log should be <= 32", base_log <= 32)); @@ -25,7 +123,7 @@ void cuda_bootstrap_amortized_lwe_ciphertext_vector_32( host_bootstrap_amortized>( v_stream, gpu_index, (uint32_t *)lwe_array_out, (uint32_t *)lut_vector, (uint32_t *)lut_vector_indexes, (uint32_t *)lwe_array_in, - (double2 *)bootstrapping_key, glwe_dimension, lwe_dimension, + (double2 *)bootstrapping_key, pbs_buffer, glwe_dimension, lwe_dimension, polynomial_size, base_log, level_count, num_samples, num_lut_vectors, lwe_idx, max_shared_memory); break; @@ -33,7 +131,7 @@ void cuda_bootstrap_amortized_lwe_ciphertext_vector_32( host_bootstrap_amortized>( v_stream, gpu_index, (uint32_t *)lwe_array_out, (uint32_t *)lut_vector, (uint32_t *)lut_vector_indexes, (uint32_t *)lwe_array_in, - (double2 *)bootstrapping_key, glwe_dimension, lwe_dimension, + (double2 *)bootstrapping_key, pbs_buffer, glwe_dimension, lwe_dimension, polynomial_size, base_log, level_count, num_samples, num_lut_vectors, lwe_idx, max_shared_memory); break; @@ -41,7 +139,7 @@ void cuda_bootstrap_amortized_lwe_ciphertext_vector_32( host_bootstrap_amortized>( v_stream, gpu_index, (uint32_t *)lwe_array_out, (uint32_t *)lut_vector, (uint32_t *)lut_vector_indexes, (uint32_t *)lwe_array_in, - (double2 *)bootstrapping_key, glwe_dimension, lwe_dimension, + (double2 *)bootstrapping_key, pbs_buffer, glwe_dimension, lwe_dimension, polynomial_size, base_log, level_count, num_samples, num_lut_vectors, lwe_idx, max_shared_memory); break; @@ -49,7 +147,7 @@ void cuda_bootstrap_amortized_lwe_ciphertext_vector_32( host_bootstrap_amortized>( v_stream, gpu_index, (uint32_t *)lwe_array_out, (uint32_t *)lut_vector, (uint32_t *)lut_vector_indexes, (uint32_t *)lwe_array_in, - (double2 *)bootstrapping_key, glwe_dimension, lwe_dimension, + (double2 *)bootstrapping_key, pbs_buffer, glwe_dimension, lwe_dimension, polynomial_size, base_log, level_count, num_samples, num_lut_vectors, lwe_idx, max_shared_memory); break; @@ -57,7 +155,7 @@ void cuda_bootstrap_amortized_lwe_ciphertext_vector_32( host_bootstrap_amortized>( v_stream, gpu_index, (uint32_t *)lwe_array_out, (uint32_t *)lut_vector, (uint32_t *)lut_vector_indexes, (uint32_t *)lwe_array_in, - (double2 *)bootstrapping_key, glwe_dimension, lwe_dimension, + (double2 *)bootstrapping_key, pbs_buffer, glwe_dimension, lwe_dimension, polynomial_size, base_log, level_count, num_samples, num_lut_vectors, lwe_idx, max_shared_memory); break; @@ -65,7 +163,7 @@ void cuda_bootstrap_amortized_lwe_ciphertext_vector_32( host_bootstrap_amortized>( v_stream, gpu_index, (uint32_t *)lwe_array_out, (uint32_t *)lut_vector, (uint32_t *)lut_vector_indexes, (uint32_t *)lwe_array_in, - (double2 *)bootstrapping_key, glwe_dimension, lwe_dimension, + (double2 *)bootstrapping_key, pbs_buffer, glwe_dimension, lwe_dimension, polynomial_size, base_log, level_count, num_samples, num_lut_vectors, lwe_idx, max_shared_memory); break; @@ -142,9 +240,10 @@ void cuda_bootstrap_amortized_lwe_ciphertext_vector_32( void cuda_bootstrap_amortized_lwe_ciphertext_vector_64( void *v_stream, uint32_t gpu_index, void *lwe_array_out, void *lut_vector, void *lut_vector_indexes, void *lwe_array_in, void *bootstrapping_key, - uint32_t lwe_dimension, uint32_t glwe_dimension, uint32_t polynomial_size, - uint32_t base_log, uint32_t level_count, uint32_t num_samples, - uint32_t num_lut_vectors, uint32_t lwe_idx, uint32_t max_shared_memory) { + int8_t *pbs_buffer, uint32_t lwe_dimension, uint32_t glwe_dimension, + uint32_t polynomial_size, uint32_t base_log, uint32_t level_count, + uint32_t num_samples, uint32_t num_lut_vectors, uint32_t lwe_idx, + uint32_t max_shared_memory) { assert( ("Error (GPU amortized PBS): base log should be <= 64", base_log <= 64)); @@ -160,7 +259,7 @@ void cuda_bootstrap_amortized_lwe_ciphertext_vector_64( host_bootstrap_amortized>( v_stream, gpu_index, (uint64_t *)lwe_array_out, (uint64_t *)lut_vector, (uint64_t *)lut_vector_indexes, (uint64_t *)lwe_array_in, - (double2 *)bootstrapping_key, glwe_dimension, lwe_dimension, + (double2 *)bootstrapping_key, pbs_buffer, glwe_dimension, lwe_dimension, polynomial_size, base_log, level_count, num_samples, num_lut_vectors, lwe_idx, max_shared_memory); break; @@ -168,7 +267,7 @@ void cuda_bootstrap_amortized_lwe_ciphertext_vector_64( host_bootstrap_amortized>( v_stream, gpu_index, (uint64_t *)lwe_array_out, (uint64_t *)lut_vector, (uint64_t *)lut_vector_indexes, (uint64_t *)lwe_array_in, - (double2 *)bootstrapping_key, glwe_dimension, lwe_dimension, + (double2 *)bootstrapping_key, pbs_buffer, glwe_dimension, lwe_dimension, polynomial_size, base_log, level_count, num_samples, num_lut_vectors, lwe_idx, max_shared_memory); break; @@ -176,7 +275,7 @@ void cuda_bootstrap_amortized_lwe_ciphertext_vector_64( host_bootstrap_amortized>( v_stream, gpu_index, (uint64_t *)lwe_array_out, (uint64_t *)lut_vector, (uint64_t *)lut_vector_indexes, (uint64_t *)lwe_array_in, - (double2 *)bootstrapping_key, glwe_dimension, lwe_dimension, + (double2 *)bootstrapping_key, pbs_buffer, glwe_dimension, lwe_dimension, polynomial_size, base_log, level_count, num_samples, num_lut_vectors, lwe_idx, max_shared_memory); break; @@ -184,7 +283,7 @@ void cuda_bootstrap_amortized_lwe_ciphertext_vector_64( host_bootstrap_amortized>( v_stream, gpu_index, (uint64_t *)lwe_array_out, (uint64_t *)lut_vector, (uint64_t *)lut_vector_indexes, (uint64_t *)lwe_array_in, - (double2 *)bootstrapping_key, glwe_dimension, lwe_dimension, + (double2 *)bootstrapping_key, pbs_buffer, glwe_dimension, lwe_dimension, polynomial_size, base_log, level_count, num_samples, num_lut_vectors, lwe_idx, max_shared_memory); break; @@ -192,7 +291,7 @@ void cuda_bootstrap_amortized_lwe_ciphertext_vector_64( host_bootstrap_amortized>( v_stream, gpu_index, (uint64_t *)lwe_array_out, (uint64_t *)lut_vector, (uint64_t *)lut_vector_indexes, (uint64_t *)lwe_array_in, - (double2 *)bootstrapping_key, glwe_dimension, lwe_dimension, + (double2 *)bootstrapping_key, pbs_buffer, glwe_dimension, lwe_dimension, polynomial_size, base_log, level_count, num_samples, num_lut_vectors, lwe_idx, max_shared_memory); break; @@ -200,7 +299,7 @@ void cuda_bootstrap_amortized_lwe_ciphertext_vector_64( host_bootstrap_amortized>( v_stream, gpu_index, (uint64_t *)lwe_array_out, (uint64_t *)lut_vector, (uint64_t *)lut_vector_indexes, (uint64_t *)lwe_array_in, - (double2 *)bootstrapping_key, glwe_dimension, lwe_dimension, + (double2 *)bootstrapping_key, pbs_buffer, glwe_dimension, lwe_dimension, polynomial_size, base_log, level_count, num_samples, num_lut_vectors, lwe_idx, max_shared_memory); break; @@ -208,3 +307,14 @@ void cuda_bootstrap_amortized_lwe_ciphertext_vector_64( break; } } + +/* + * This cleanup function frees the data for the amortized PBS on GPU in + * pbs_buffer for 32 or 64 bits inputs. + */ +void cleanup_cuda_bootstrap_amortized(void *v_stream, uint32_t gpu_index, + int8_t **pbs_buffer) { + auto stream = static_cast(v_stream); + // Free memory + cuda_drop_async(*pbs_buffer, stream, gpu_index); +} diff --git a/src/bootstrap_amortized.cuh b/src/bootstrap_amortized.cuh index 81ace4848..1d6872bf1 100644 --- a/src/bootstrap_amortized.cuh +++ b/src/bootstrap_amortized.cuh @@ -213,27 +213,100 @@ __global__ void device_bootstrap_amortized( glwe_dimension); } +template +__host__ __device__ int +get_buffer_size_full_sm_bootstrap_amortized(uint32_t polynomial_size, + uint32_t glwe_dimension) { + return sizeof(Torus) * polynomial_size * (glwe_dimension + 1) + // accumulator + sizeof(Torus) * polynomial_size * + (glwe_dimension + 1) + // accumulator rotated + sizeof(double2) * polynomial_size / 2 * + (glwe_dimension + 1) + // accumulator fft + sizeof(double2) * polynomial_size / 2 * + (glwe_dimension + 1); // calculate buffer fft +} + +template +__host__ __device__ int +get_buffer_size_partial_sm_bootstrap_amortized(uint32_t polynomial_size, + uint32_t glwe_dimension) { + return sizeof(double2) * polynomial_size / 2 * + (glwe_dimension + 1); // calculate buffer fft +} + +template +__host__ __device__ int get_buffer_size_bootstrap_amortized( + uint32_t glwe_dimension, uint32_t polynomial_size, + uint32_t input_lwe_ciphertext_count, uint32_t max_shared_memory) { + + int full_sm = get_buffer_size_full_sm_bootstrap_amortized( + polynomial_size, glwe_dimension); + int partial_sm = get_buffer_size_partial_sm_bootstrap_amortized( + polynomial_size, glwe_dimension); + int partial_dm = full_sm - partial_sm; + int full_dm = full_sm; + int device_mem = 0; + if (max_shared_memory < partial_sm) { + device_mem = full_dm * input_lwe_ciphertext_count; + } else if (max_shared_memory < full_sm) { + device_mem = partial_dm * input_lwe_ciphertext_count; + } + return device_mem; +} + +template +__host__ void scratch_bootstrap_amortized(void *v_stream, uint32_t gpu_index, + int8_t **pbs_buffer, + uint32_t glwe_dimension, + uint32_t polynomial_size, + uint32_t input_lwe_ciphertext_count, + uint32_t max_shared_memory, + bool allocate_gpu_memory) { + cudaSetDevice(gpu_index); + auto stream = static_cast(v_stream); + + int full_sm = get_buffer_size_full_sm_bootstrap_amortized( + polynomial_size, glwe_dimension); + int partial_sm = get_buffer_size_partial_sm_bootstrap_amortized( + polynomial_size, glwe_dimension); + if (max_shared_memory >= partial_sm && max_shared_memory < full_sm) { + cudaFuncSetAttribute(device_bootstrap_amortized, + cudaFuncAttributeMaxDynamicSharedMemorySize, + partial_sm); + cudaFuncSetCacheConfig(device_bootstrap_amortized, + cudaFuncCachePreferShared); + } else if (max_shared_memory >= partial_sm) { + check_cuda_error(cudaFuncSetAttribute( + device_bootstrap_amortized, + cudaFuncAttributeMaxDynamicSharedMemorySize, full_sm)); + check_cuda_error(cudaFuncSetCacheConfig( + device_bootstrap_amortized, + cudaFuncCachePreferShared)); + } + if (allocate_gpu_memory) { + int buffer_size = get_buffer_size_bootstrap_amortized( + glwe_dimension, polynomial_size, input_lwe_ciphertext_count, + max_shared_memory); + *pbs_buffer = (int8_t *)cuda_malloc_async(buffer_size, stream, gpu_index); + check_cuda_error(cudaGetLastError()); + } +} + template __host__ void host_bootstrap_amortized( void *v_stream, uint32_t gpu_index, Torus *lwe_array_out, Torus *lut_vector, Torus *lut_vector_indexes, Torus *lwe_array_in, double2 *bootstrapping_key, - uint32_t glwe_dimension, uint32_t lwe_dimension, uint32_t polynomial_size, - uint32_t base_log, uint32_t level_count, + int8_t *pbs_buffer, uint32_t glwe_dimension, uint32_t lwe_dimension, + uint32_t polynomial_size, uint32_t base_log, uint32_t level_count, uint32_t input_lwe_ciphertext_count, uint32_t num_lut_vectors, uint32_t lwe_idx, uint32_t max_shared_memory) { cudaSetDevice(gpu_index); - int SM_FULL = - sizeof(Torus) * polynomial_size * (glwe_dimension + 1) + // accumulator - sizeof(Torus) * polynomial_size * - (glwe_dimension + 1) + // accumulator rotated - sizeof(double2) * polynomial_size / 2 * - (glwe_dimension + 1) + // accumulator fft - sizeof(double2) * polynomial_size / 2 * - (glwe_dimension + 1); // calculate buffer fft + int SM_FULL = get_buffer_size_full_sm_bootstrap_amortized( + polynomial_size, glwe_dimension); - int SM_PART = sizeof(double2) * polynomial_size / 2 * - (glwe_dimension + 1); // calculate buffer fft + int SM_PART = get_buffer_size_partial_sm_bootstrap_amortized( + polynomial_size, glwe_dimension); int DM_PART = SM_FULL - SM_PART; @@ -241,8 +314,6 @@ __host__ void host_bootstrap_amortized( auto stream = static_cast(v_stream); - int8_t *d_mem; - // Create a 1-dimensional grid of threads // where each block handles 1 sample and each thread // handles opt polynomial coefficients @@ -257,23 +328,15 @@ __host__ void host_bootstrap_amortized( // from one of three templates (no use, partial use or full use // of shared memory) if (max_shared_memory < SM_PART) { - d_mem = (int8_t *)cuda_malloc_async(DM_FULL * input_lwe_ciphertext_count, - stream, gpu_index); device_bootstrap_amortized<<>>( lwe_array_out, lut_vector, lut_vector_indexes, lwe_array_in, - bootstrapping_key, d_mem, glwe_dimension, lwe_dimension, + bootstrapping_key, pbs_buffer, glwe_dimension, lwe_dimension, polynomial_size, base_log, level_count, lwe_idx, DM_FULL); } else if (max_shared_memory < SM_FULL) { - cudaFuncSetAttribute(device_bootstrap_amortized, - cudaFuncAttributeMaxDynamicSharedMemorySize, SM_PART); - cudaFuncSetCacheConfig(device_bootstrap_amortized, - cudaFuncCachePreferShared); - d_mem = (int8_t *)cuda_malloc_async(DM_PART * input_lwe_ciphertext_count, - stream, gpu_index); device_bootstrap_amortized <<>>( lwe_array_out, lut_vector, lut_vector_indexes, lwe_array_in, - bootstrapping_key, d_mem, glwe_dimension, lwe_dimension, + bootstrapping_key, pbs_buffer, glwe_dimension, lwe_dimension, polynomial_size, base_log, level_count, lwe_idx, DM_PART); } else { // For devices with compute capability 7.x a single thread block can @@ -281,22 +344,13 @@ __host__ void host_bootstrap_amortized( // device then has to be allocated dynamically. // For lower compute capabilities, this call // just does nothing and the amount of shared memory used is 48 KB - check_cuda_error(cudaFuncSetAttribute( - device_bootstrap_amortized, - cudaFuncAttributeMaxDynamicSharedMemorySize, SM_FULL)); - check_cuda_error(cudaFuncSetCacheConfig( - device_bootstrap_amortized, - cudaFuncCachePreferShared)); - d_mem = (int8_t *)cuda_malloc_async(0, stream, gpu_index); - device_bootstrap_amortized <<>>( lwe_array_out, lut_vector, lut_vector_indexes, lwe_array_in, - bootstrapping_key, d_mem, glwe_dimension, lwe_dimension, + bootstrapping_key, pbs_buffer, glwe_dimension, lwe_dimension, polynomial_size, base_log, level_count, lwe_idx, 0); } check_cuda_error(cudaGetLastError()); - cuda_drop_async(d_mem, stream, gpu_index); } template diff --git a/src/circuit_bootstrap.cu b/src/circuit_bootstrap.cu index 4c7ac4953..e806cdb0f 100644 --- a/src/circuit_bootstrap.cu +++ b/src/circuit_bootstrap.cu @@ -1,4 +1,99 @@ #include "circuit_bootstrap.cuh" +#include "circuit_bootstrap.h" + +/* + * This scratch function allocates the necessary amount of data on the GPU for + * the circuit bootstrap on 32 bits inputs, into `cbs_buffer`. It also + * configures SM options on the GPU in case FULLSM mode is going to be used. + */ +void scratch_cuda_circuit_bootstrap_32( + void *v_stream, uint32_t gpu_index, int8_t **cbs_buffer, + uint32_t glwe_dimension, uint32_t lwe_dimension, uint32_t polynomial_size, + uint32_t level_count_cbs, uint32_t number_of_inputs, + uint32_t max_shared_memory, bool allocate_gpu_memory) { + + switch (polynomial_size) { + case 512: + scratch_circuit_bootstrap>( + v_stream, gpu_index, cbs_buffer, glwe_dimension, lwe_dimension, + polynomial_size, level_count_cbs, number_of_inputs, max_shared_memory, + allocate_gpu_memory); + break; + case 1024: + scratch_circuit_bootstrap>( + v_stream, gpu_index, cbs_buffer, glwe_dimension, lwe_dimension, + polynomial_size, level_count_cbs, number_of_inputs, max_shared_memory, + allocate_gpu_memory); + break; + case 2048: + scratch_circuit_bootstrap>( + v_stream, gpu_index, cbs_buffer, glwe_dimension, lwe_dimension, + polynomial_size, level_count_cbs, number_of_inputs, max_shared_memory, + allocate_gpu_memory); + break; + case 4096: + scratch_circuit_bootstrap>( + v_stream, gpu_index, cbs_buffer, glwe_dimension, lwe_dimension, + polynomial_size, level_count_cbs, number_of_inputs, max_shared_memory, + allocate_gpu_memory); + break; + case 8192: + scratch_circuit_bootstrap>( + v_stream, gpu_index, cbs_buffer, glwe_dimension, lwe_dimension, + polynomial_size, level_count_cbs, number_of_inputs, max_shared_memory, + allocate_gpu_memory); + break; + default: + break; + } +} + +/* + * This scratch function allocates the necessary amount of data on the GPU for + * the circuit bootstrap on 32 bits inputs, into `cbs_buffer`. It also + * configures SM options on the GPU in case FULLSM mode is going to be used. + */ +void scratch_cuda_circuit_bootstrap_64( + void *v_stream, uint32_t gpu_index, int8_t **cbs_buffer, + uint32_t glwe_dimension, uint32_t lwe_dimension, uint32_t polynomial_size, + uint32_t level_count_cbs, uint32_t number_of_inputs, + uint32_t max_shared_memory, bool allocate_gpu_memory) { + + switch (polynomial_size) { + case 512: + scratch_circuit_bootstrap>( + v_stream, gpu_index, cbs_buffer, glwe_dimension, lwe_dimension, + polynomial_size, level_count_cbs, number_of_inputs, max_shared_memory, + allocate_gpu_memory); + break; + case 1024: + scratch_circuit_bootstrap>( + v_stream, gpu_index, cbs_buffer, glwe_dimension, lwe_dimension, + polynomial_size, level_count_cbs, number_of_inputs, max_shared_memory, + allocate_gpu_memory); + break; + case 2048: + scratch_circuit_bootstrap>( + v_stream, gpu_index, cbs_buffer, glwe_dimension, lwe_dimension, + polynomial_size, level_count_cbs, number_of_inputs, max_shared_memory, + allocate_gpu_memory); + break; + case 4096: + scratch_circuit_bootstrap>( + v_stream, gpu_index, cbs_buffer, glwe_dimension, lwe_dimension, + polynomial_size, level_count_cbs, number_of_inputs, max_shared_memory, + allocate_gpu_memory); + break; + case 8192: + scratch_circuit_bootstrap>( + v_stream, gpu_index, cbs_buffer, glwe_dimension, lwe_dimension, + polynomial_size, level_count_cbs, number_of_inputs, max_shared_memory, + allocate_gpu_memory); + break; + default: + break; + } +} /* * Perform circuit bootstrapping for the batch of 32 bit LWE ciphertexts. @@ -6,13 +101,12 @@ */ void cuda_circuit_bootstrap_32( void *v_stream, uint32_t gpu_index, void *ggsw_out, void *lwe_array_in, - void *fourier_bsk, void *fp_ksk_array, void *lwe_array_in_shifted_buffer, - void *lut_vector, void *lut_vector_indexes, void *lwe_array_out_pbs_buffer, - void *lwe_array_in_fp_ks_buffer, uint32_t delta_log, - uint32_t polynomial_size, uint32_t glwe_dimension, uint32_t lwe_dimension, - uint32_t level_bsk, uint32_t base_log_bsk, uint32_t level_pksk, - uint32_t base_log_pksk, uint32_t level_cbs, uint32_t base_log_cbs, - uint32_t number_of_samples, uint32_t max_shared_memory) { + void *fourier_bsk, void *fp_ksk_array, void *lut_vector_indexes, + int8_t *cbs_buffer, uint32_t delta_log, uint32_t polynomial_size, + uint32_t glwe_dimension, uint32_t lwe_dimension, uint32_t level_bsk, + uint32_t base_log_bsk, uint32_t level_pksk, uint32_t base_log_pksk, + uint32_t level_cbs, uint32_t base_log_cbs, uint32_t number_of_inputs, + uint32_t max_shared_memory) { assert(("Error (GPU circuit bootstrap): polynomial_size should be one of " "512, 1024, 2048, 4096, 8192", polynomial_size == 512 || polynomial_size == 1024 || @@ -28,61 +122,51 @@ void cuda_circuit_bootstrap_32( "equal to the " "number of streaming multiprocessors on the device divided by 8 * " "level_count_bsk", - number_of_samples <= number_of_sm / 4. / 2. / level_bsk)); + number_of_inputs <= number_of_sm / 4. / 2. / level_bsk)); switch (polynomial_size) { case 512: host_circuit_bootstrap>( v_stream, gpu_index, (uint32_t *)ggsw_out, (uint32_t *)lwe_array_in, (double2 *)fourier_bsk, (uint32_t *)fp_ksk_array, - (uint32_t *)lwe_array_in_shifted_buffer, (uint32_t *)lut_vector, - (uint32_t *)lut_vector_indexes, (uint32_t *)lwe_array_out_pbs_buffer, - (uint32_t *)lwe_array_in_fp_ks_buffer, delta_log, polynomial_size, + (uint32_t *)lut_vector_indexes, cbs_buffer, delta_log, polynomial_size, glwe_dimension, lwe_dimension, level_bsk, base_log_bsk, level_pksk, - base_log_pksk, level_cbs, base_log_cbs, number_of_samples, + base_log_pksk, level_cbs, base_log_cbs, number_of_inputs, max_shared_memory); break; case 1024: host_circuit_bootstrap>( v_stream, gpu_index, (uint32_t *)ggsw_out, (uint32_t *)lwe_array_in, (double2 *)fourier_bsk, (uint32_t *)fp_ksk_array, - (uint32_t *)lwe_array_in_shifted_buffer, (uint32_t *)lut_vector, - (uint32_t *)lut_vector_indexes, (uint32_t *)lwe_array_out_pbs_buffer, - (uint32_t *)lwe_array_in_fp_ks_buffer, delta_log, polynomial_size, + (uint32_t *)lut_vector_indexes, cbs_buffer, delta_log, polynomial_size, glwe_dimension, lwe_dimension, level_bsk, base_log_bsk, level_pksk, - base_log_pksk, level_cbs, base_log_cbs, number_of_samples, + base_log_pksk, level_cbs, base_log_cbs, number_of_inputs, max_shared_memory); break; case 2048: host_circuit_bootstrap>( v_stream, gpu_index, (uint32_t *)ggsw_out, (uint32_t *)lwe_array_in, (double2 *)fourier_bsk, (uint32_t *)fp_ksk_array, - (uint32_t *)lwe_array_in_shifted_buffer, (uint32_t *)lut_vector, - (uint32_t *)lut_vector_indexes, (uint32_t *)lwe_array_out_pbs_buffer, - (uint32_t *)lwe_array_in_fp_ks_buffer, delta_log, polynomial_size, + (uint32_t *)lut_vector_indexes, cbs_buffer, delta_log, polynomial_size, glwe_dimension, lwe_dimension, level_bsk, base_log_bsk, level_pksk, - base_log_pksk, level_cbs, base_log_cbs, number_of_samples, + base_log_pksk, level_cbs, base_log_cbs, number_of_inputs, max_shared_memory); break; case 4096: host_circuit_bootstrap>( v_stream, gpu_index, (uint32_t *)ggsw_out, (uint32_t *)lwe_array_in, (double2 *)fourier_bsk, (uint32_t *)fp_ksk_array, - (uint32_t *)lwe_array_in_shifted_buffer, (uint32_t *)lut_vector, - (uint32_t *)lut_vector_indexes, (uint32_t *)lwe_array_out_pbs_buffer, - (uint32_t *)lwe_array_in_fp_ks_buffer, delta_log, polynomial_size, + (uint32_t *)lut_vector_indexes, cbs_buffer, delta_log, polynomial_size, glwe_dimension, lwe_dimension, level_bsk, base_log_bsk, level_pksk, - base_log_pksk, level_cbs, base_log_cbs, number_of_samples, + base_log_pksk, level_cbs, base_log_cbs, number_of_inputs, max_shared_memory); break; case 8192: host_circuit_bootstrap>( v_stream, gpu_index, (uint32_t *)ggsw_out, (uint32_t *)lwe_array_in, (double2 *)fourier_bsk, (uint32_t *)fp_ksk_array, - (uint32_t *)lwe_array_in_shifted_buffer, (uint32_t *)lut_vector, - (uint32_t *)lut_vector_indexes, (uint32_t *)lwe_array_out_pbs_buffer, - (uint32_t *)lwe_array_in_fp_ks_buffer, delta_log, polynomial_size, + (uint32_t *)lut_vector_indexes, cbs_buffer, delta_log, polynomial_size, glwe_dimension, lwe_dimension, level_bsk, base_log_bsk, level_pksk, - base_log_pksk, level_cbs, base_log_cbs, number_of_samples, + base_log_pksk, level_cbs, base_log_cbs, number_of_inputs, max_shared_memory); break; default: @@ -96,31 +180,20 @@ void cuda_circuit_bootstrap_32( * launch * - `gpu_index` is the index of the GPU to be used in the kernel launch * - 'ggsw_out' output batch of ggsw with size: - * 'number_of_samples' * 'level_cbs' * ('glwe_dimension' + 1)^2 * + * 'number_of_inputs' * 'level_cbs' * ('glwe_dimension' + 1)^2 * * polynomial_size * sizeof(u64) * - 'lwe_array_in' input batch of lwe ciphertexts, with size: - * 'number_of_samples' * '(lwe_dimension' + 1) * sizeof(u64) + * 'number_of_inputs' * '(lwe_dimension' + 1) * sizeof(u64) * - 'fourier_bsk' bootstrapping key in fourier domain with size: * 'lwe_dimension' * 'level_bsk' * ('glwe_dimension' + 1)^2 * * 'polynomial_size' / 2 * sizeof(double2) * - 'fp_ksk_array' batch of fp-keyswitch keys with size: * ('polynomial_size' + 1) * 'level_pksk' * ('glwe_dimension' + 1)^2 * * 'polynomial_size' * sizeof(u64) - * The following 5 parameters are used during calculations, they are not actual - * inputs of the function, they are just allocated memory for calculation + * - 'cbs_buffer': buffer used during calculations, it is not an actual + * inputs of the function, just allocated memory for calculation * process, like this, memory can be allocated once and can be used as much * as needed for different calls of circuit_bootstrap function - * - 'lwe_array_in_shifted_buffer' with size: - * 'number_of_samples' * 'level_cbs' * ('lwe_dimension' + 1) * sizeof(u64) - * - 'lut_vector' with size: - * 'level_cbs' * ('glwe_dimension' + 1) * 'polynomial_size' * sizeof(u64) - * - 'lut_vector_indexes' stores the index corresponding to which test - * vector to use - * - 'lwe_array_out_pbs_buffer' with size - * 'number_of_samples' * 'level_cbs' * ('polynomial_size' + 1) * sizeof(u64) - * - 'lwe_array_in_fp_ks_buffer' with size - * 'number_of_samples' * 'level_cbs' * ('glwe_dimension' + 1) * - * ('polynomial_size' + 1) * sizeof(u64) * * This function calls a wrapper to a device kernel that performs the * circuit bootstrap. The kernel is templatized based on integer discretization @@ -128,13 +201,12 @@ void cuda_circuit_bootstrap_32( */ void cuda_circuit_bootstrap_64( void *v_stream, uint32_t gpu_index, void *ggsw_out, void *lwe_array_in, - void *fourier_bsk, void *fp_ksk_array, void *lwe_array_in_shifted_buffer, - void *lut_vector, void *lut_vector_indexes, void *lwe_array_out_pbs_buffer, - void *lwe_array_in_fp_ks_buffer, uint32_t delta_log, - uint32_t polynomial_size, uint32_t glwe_dimension, uint32_t lwe_dimension, - uint32_t level_bsk, uint32_t base_log_bsk, uint32_t level_pksk, - uint32_t base_log_pksk, uint32_t level_cbs, uint32_t base_log_cbs, - uint32_t number_of_samples, uint32_t max_shared_memory) { + void *fourier_bsk, void *fp_ksk_array, void *lut_vector_indexes, + int8_t *cbs_buffer, uint32_t delta_log, uint32_t polynomial_size, + uint32_t glwe_dimension, uint32_t lwe_dimension, uint32_t level_bsk, + uint32_t base_log_bsk, uint32_t level_pksk, uint32_t base_log_pksk, + uint32_t level_cbs, uint32_t base_log_cbs, uint32_t number_of_inputs, + uint32_t max_shared_memory) { assert(("Error (GPU circuit bootstrap): polynomial_size should be one of " "512, 1024, 2048, 4096, 8192", polynomial_size == 512 || polynomial_size == 1024 || @@ -150,65 +222,66 @@ void cuda_circuit_bootstrap_64( "equal to the " "number of streaming multiprocessors on the device divided by 8 * " "level_count_bsk", - number_of_samples <= number_of_sm / 4. / 2. / level_bsk)); + number_of_inputs <= number_of_sm / 4. / 2. / level_bsk)); // The number of samples should be lower than the number of streaming switch (polynomial_size) { case 512: host_circuit_bootstrap>( v_stream, gpu_index, (uint64_t *)ggsw_out, (uint64_t *)lwe_array_in, (double2 *)fourier_bsk, (uint64_t *)fp_ksk_array, - (uint64_t *)lwe_array_in_shifted_buffer, (uint64_t *)lut_vector, - (uint64_t *)lut_vector_indexes, (uint64_t *)lwe_array_out_pbs_buffer, - (uint64_t *)lwe_array_in_fp_ks_buffer, delta_log, polynomial_size, + (uint64_t *)lut_vector_indexes, cbs_buffer, delta_log, polynomial_size, glwe_dimension, lwe_dimension, level_bsk, base_log_bsk, level_pksk, - base_log_pksk, level_cbs, base_log_cbs, number_of_samples, + base_log_pksk, level_cbs, base_log_cbs, number_of_inputs, max_shared_memory); break; case 1024: host_circuit_bootstrap>( v_stream, gpu_index, (uint64_t *)ggsw_out, (uint64_t *)lwe_array_in, (double2 *)fourier_bsk, (uint64_t *)fp_ksk_array, - (uint64_t *)lwe_array_in_shifted_buffer, (uint64_t *)lut_vector, - (uint64_t *)lut_vector_indexes, (uint64_t *)lwe_array_out_pbs_buffer, - (uint64_t *)lwe_array_in_fp_ks_buffer, delta_log, polynomial_size, + (uint64_t *)lut_vector_indexes, cbs_buffer, delta_log, polynomial_size, glwe_dimension, lwe_dimension, level_bsk, base_log_bsk, level_pksk, - base_log_pksk, level_cbs, base_log_cbs, number_of_samples, + base_log_pksk, level_cbs, base_log_cbs, number_of_inputs, max_shared_memory); break; case 2048: host_circuit_bootstrap>( v_stream, gpu_index, (uint64_t *)ggsw_out, (uint64_t *)lwe_array_in, (double2 *)fourier_bsk, (uint64_t *)fp_ksk_array, - (uint64_t *)lwe_array_in_shifted_buffer, (uint64_t *)lut_vector, - (uint64_t *)lut_vector_indexes, (uint64_t *)lwe_array_out_pbs_buffer, - (uint64_t *)lwe_array_in_fp_ks_buffer, delta_log, polynomial_size, + (uint64_t *)lut_vector_indexes, cbs_buffer, delta_log, polynomial_size, glwe_dimension, lwe_dimension, level_bsk, base_log_bsk, level_pksk, - base_log_pksk, level_cbs, base_log_cbs, number_of_samples, + base_log_pksk, level_cbs, base_log_cbs, number_of_inputs, max_shared_memory); break; case 4096: host_circuit_bootstrap>( v_stream, gpu_index, (uint64_t *)ggsw_out, (uint64_t *)lwe_array_in, (double2 *)fourier_bsk, (uint64_t *)fp_ksk_array, - (uint64_t *)lwe_array_in_shifted_buffer, (uint64_t *)lut_vector, - (uint64_t *)lut_vector_indexes, (uint64_t *)lwe_array_out_pbs_buffer, - (uint64_t *)lwe_array_in_fp_ks_buffer, delta_log, polynomial_size, + (uint64_t *)lut_vector_indexes, cbs_buffer, delta_log, polynomial_size, glwe_dimension, lwe_dimension, level_bsk, base_log_bsk, level_pksk, - base_log_pksk, level_cbs, base_log_cbs, number_of_samples, + base_log_pksk, level_cbs, base_log_cbs, number_of_inputs, max_shared_memory); break; case 8192: host_circuit_bootstrap>( v_stream, gpu_index, (uint64_t *)ggsw_out, (uint64_t *)lwe_array_in, (double2 *)fourier_bsk, (uint64_t *)fp_ksk_array, - (uint64_t *)lwe_array_in_shifted_buffer, (uint64_t *)lut_vector, - (uint64_t *)lut_vector_indexes, (uint64_t *)lwe_array_out_pbs_buffer, - (uint64_t *)lwe_array_in_fp_ks_buffer, delta_log, polynomial_size, + (uint64_t *)lut_vector_indexes, cbs_buffer, delta_log, polynomial_size, glwe_dimension, lwe_dimension, level_bsk, base_log_bsk, level_pksk, - base_log_pksk, level_cbs, base_log_cbs, number_of_samples, + base_log_pksk, level_cbs, base_log_cbs, number_of_inputs, max_shared_memory); break; default: break; } } + +/* + * This cleanup function frees the data for the circuit bootstrap on GPU in + * cbs_buffer for 32 or 64 bits inputs. + */ +void cleanup_cuda_circuit_bootstrap(void *v_stream, uint32_t gpu_index, + int8_t **cbs_buffer) { + auto stream = static_cast(v_stream); + // Free memory + cuda_drop_async(*cbs_buffer, stream, gpu_index); +} diff --git a/src/circuit_bootstrap.cuh b/src/circuit_bootstrap.cuh index 3bcd38b84..39daaba16 100644 --- a/src/circuit_bootstrap.cuh +++ b/src/circuit_bootstrap.cuh @@ -1,5 +1,5 @@ -#ifndef CBS_H -#define CBS_H +#ifndef CBS_CUH +#define CBS_CUH #include "bit_extraction.cuh" #include "bootstrap.h" @@ -96,6 +96,51 @@ __global__ void copy_add_lwe_cbs(Torus *lwe_dst, Torus *lwe_src, } } +template +__host__ __device__ int +get_buffer_size_cbs(uint32_t glwe_dimension, uint32_t lwe_dimension, + uint32_t polynomial_size, uint32_t level_count_cbs, + uint32_t number_of_inputs) { + + return number_of_inputs * level_count_cbs * (glwe_dimension + 1) * + (polynomial_size + 1) * + sizeof(Torus) + // lwe_array_in_fp_ks_buffer + number_of_inputs * level_count_cbs * (polynomial_size + 1) * + sizeof(Torus) + // lwe_array_out_pbs_buffer + number_of_inputs * level_count_cbs * (lwe_dimension + 1) * + sizeof(Torus) + // lwe_array_in_shifted_buffer + level_count_cbs * (glwe_dimension + 1) * polynomial_size * + sizeof(Torus); // lut_vector_cbs +} + +template +__host__ void scratch_circuit_bootstrap( + void *v_stream, uint32_t gpu_index, int8_t **cbs_buffer, + uint32_t glwe_dimension, uint32_t lwe_dimension, uint32_t polynomial_size, + uint32_t level_count_cbs, uint32_t number_of_inputs, + uint32_t max_shared_memory, bool allocate_gpu_memory) { + + cudaSetDevice(gpu_index); + auto stream = static_cast(v_stream); + + int pbs_count = number_of_inputs * level_count_cbs; + // allocate and initialize device pointers for circuit bootstrap and vertical + // packing + if (allocate_gpu_memory) { + int buffer_size = + get_buffer_size_cbs(glwe_dimension, lwe_dimension, + polynomial_size, level_count_cbs, + number_of_inputs) + + get_buffer_size_bootstrap_amortized( + glwe_dimension, polynomial_size, pbs_count, max_shared_memory); + *cbs_buffer = (int8_t *)cuda_malloc_async(buffer_size, stream, gpu_index); + } + + scratch_bootstrap_amortized( + v_stream, gpu_index, cbs_buffer, glwe_dimension, polynomial_size, + pbs_count, max_shared_memory, false); +} + /* * Host function for cuda circuit bootstrap. * It executes device functions in specific order and manages @@ -104,24 +149,37 @@ __global__ void copy_add_lwe_cbs(Torus *lwe_dst, Torus *lwe_src, template __host__ void host_circuit_bootstrap( void *v_stream, uint32_t gpu_index, Torus *ggsw_out, Torus *lwe_array_in, - double2 *fourier_bsk, Torus *fp_ksk_array, - Torus *lwe_array_in_shifted_buffer, Torus *lut_vector, - Torus *lut_vector_indexes, Torus *lwe_array_out_pbs_buffer, - Torus *lwe_array_in_fp_ks_buffer, uint32_t delta_log, - uint32_t polynomial_size, uint32_t glwe_dimension, uint32_t lwe_dimension, - uint32_t level_bsk, uint32_t base_log_bsk, uint32_t level_pksk, - uint32_t base_log_pksk, uint32_t level_cbs, uint32_t base_log_cbs, - uint32_t number_of_samples, uint32_t max_shared_memory) { + double2 *fourier_bsk, Torus *fp_ksk_array, Torus *lut_vector_indexes, + int8_t *cbs_buffer, uint32_t delta_log, uint32_t polynomial_size, + uint32_t glwe_dimension, uint32_t lwe_dimension, uint32_t level_bsk, + uint32_t base_log_bsk, uint32_t level_pksk, uint32_t base_log_pksk, + uint32_t level_cbs, uint32_t base_log_cbs, uint32_t number_of_inputs, + uint32_t max_shared_memory) { cudaSetDevice(gpu_index); auto stream = static_cast(v_stream); uint32_t ciphertext_n_bits = sizeof(Torus) * 8; uint32_t lwe_size = lwe_dimension + 1; - int pbs_count = number_of_samples * level_cbs; + int pbs_count = number_of_inputs * level_cbs; - dim3 blocks(level_cbs, number_of_samples, 1); + dim3 blocks(level_cbs, number_of_inputs, 1); int threads = 256; + Torus *lwe_array_in_fp_ks_buffer = (Torus *)cbs_buffer; + Torus *lwe_array_out_pbs_buffer = + (Torus *)lwe_array_in_fp_ks_buffer + + (ptrdiff_t)(number_of_inputs * level_cbs * (glwe_dimension + 1) * + (polynomial_size + 1)); + Torus *lwe_array_in_shifted_buffer = + (Torus *)lwe_array_out_pbs_buffer + + (ptrdiff_t)(number_of_inputs * level_cbs * (polynomial_size + 1)); + Torus *lut_vector = + (Torus *)lwe_array_in_shifted_buffer + + (ptrdiff_t)(number_of_inputs * level_cbs * (lwe_dimension + 1)); + int8_t *pbs_buffer = + (int8_t *)lut_vector + (ptrdiff_t)(level_cbs * (glwe_dimension + 1) * + polynomial_size * sizeof(Torus)); + // Shift message LSB on padding bit, at this point we expect to have messages // with only 1 bit of information shift_lwe_cbs<<>>( @@ -143,7 +201,7 @@ __host__ void host_circuit_bootstrap( // MSB and no bit of padding host_bootstrap_amortized( v_stream, gpu_index, lwe_array_out_pbs_buffer, lut_vector, - lut_vector_indexes, lwe_array_in_shifted_buffer, fourier_bsk, + lut_vector_indexes, lwe_array_in_shifted_buffer, fourier_bsk, pbs_buffer, glwe_dimension, lwe_dimension, polynomial_size, base_log_bsk, level_bsk, pbs_count, level_cbs, 0, max_shared_memory); @@ -161,4 +219,4 @@ __host__ void host_circuit_bootstrap( level_pksk, pbs_count * (glwe_dimension + 1), glwe_dimension + 1); } -#endif // CBS_H +#endif // CBS_CUH diff --git a/src/vertical_packing.cuh b/src/vertical_packing.cuh index a8b6ab4fb..96a83c147 100644 --- a/src/vertical_packing.cuh +++ b/src/vertical_packing.cuh @@ -1,7 +1,6 @@ #ifndef VERTICAL_PACKING_CUH #define VERTICAL_PACKING_CUH -#include "../include/vertical_packing.h" #include "bootstrap.h" #include "complex/operations.cuh" #include "crypto/gadget.cuh" diff --git a/src/wop_bootstrap.cuh b/src/wop_bootstrap.cuh index 6985ce42e..cae28ab5d 100644 --- a/src/wop_bootstrap.cuh +++ b/src/wop_bootstrap.cuh @@ -28,26 +28,17 @@ __global__ void device_build_lut(Torus *lut_out, Torus *lut_in, template __host__ __device__ int -get_buffer_size_cbs_vp(uint32_t glwe_dimension, uint32_t lwe_dimension, - uint32_t polynomial_size, uint32_t level_count_cbs, - uint32_t number_of_inputs, uint32_t tau) { +get_buffer_size_cbs_vp(uint32_t glwe_dimension, uint32_t polynomial_size, + uint32_t level_count_cbs, uint32_t tau, + uint32_t number_of_inputs) { int ggsw_size = level_count_cbs * (glwe_dimension + 1) * (glwe_dimension + 1) * polynomial_size; return number_of_inputs * level_count_cbs * - sizeof(Torus) // lut_vector_indexes - + number_of_inputs * ggsw_size * sizeof(Torus) // ggsw_out - + - number_of_inputs * level_count_cbs * (glwe_dimension + 1) * - (polynomial_size + 1) * sizeof(Torus) // lwe_array_in_fp_ks_buffer - + number_of_inputs * level_count_cbs * (polynomial_size + 1) * - sizeof(Torus) // lwe_array_out_pbs_buffer - + number_of_inputs * level_count_cbs * (lwe_dimension + 1) * - sizeof(Torus) // lwe_array_in_shifted_buffer - + level_count_cbs * (glwe_dimension + 1) * polynomial_size * - sizeof(Torus) // lut_vector_cbs - + tau * (glwe_dimension + 1) * polynomial_size * - sizeof(Torus); // glwe_array_out + sizeof(Torus) + // lut_vector_indexes + number_of_inputs * ggsw_size * sizeof(Torus) + // ggsw_out_cbs + tau * (glwe_dimension + 1) * polynomial_size * + sizeof(Torus); // glwe_array_out_cmux_tree } template @@ -70,9 +61,14 @@ __host__ void scratch_circuit_bootstrap_vertical_packing( // packing if (allocate_gpu_memory) { int buffer_size = - get_buffer_size_cbs_vp(glwe_dimension, lwe_dimension, - polynomial_size, level_count_cbs, - number_of_inputs, tau) + + get_buffer_size_cbs_vp(glwe_dimension, polynomial_size, + level_count_cbs, tau, number_of_inputs) + + get_buffer_size_cbs(glwe_dimension, lwe_dimension, + polynomial_size, level_count_cbs, + number_of_inputs) + + get_buffer_size_bootstrap_amortized( + glwe_dimension, polynomial_size, number_of_inputs * level_count_cbs, + max_shared_memory) + get_buffer_size_cmux_tree(glwe_dimension, polynomial_size, level_count_cbs, r, tau, max_shared_memory) + @@ -96,6 +92,10 @@ __host__ void scratch_circuit_bootstrap_vertical_packing( uint32_t bits = sizeof(Torus) * 8; *cbs_delta_log = (bits - 1); + scratch_circuit_bootstrap( + v_stream, gpu_index, cbs_vp_buffer, glwe_dimension, lwe_dimension, + polynomial_size, level_count_cbs, number_of_inputs, max_shared_memory, + false); scratch_cmux_tree( v_stream, gpu_index, cbs_vp_buffer, glwe_dimension, polynomial_size, level_count_cbs, r, tau, max_shared_memory, false); @@ -107,13 +107,8 @@ __host__ void scratch_circuit_bootstrap_vertical_packing( /* * Cleanup functions free the necessary data on the GPU and on the CPU. * Data that lives on the CPU is prefixed with `h_`. This cleanup function thus - * frees the data for the circuit bootstrap and vertical packing on GPU: - * - ggsw_out - * - lwe_array_in_fp_ks_buffer - * - lwe_array_out_pbs_buffer - * - lwe_array_in_shifted buffer - * - lut_vector_cbs - * - lut_vector_indexes + * frees the data for the circuit bootstrap and vertical packing on GPU + * contained in cbs_vp_buffer */ __host__ void cleanup_circuit_bootstrap_vertical_packing(void *v_stream, uint32_t gpu_index, @@ -139,32 +134,25 @@ __host__ void host_circuit_bootstrap_vertical_packing( int ggsw_size = level_count_cbs * (glwe_dimension + 1) * (glwe_dimension + 1) * polynomial_size; - Torus *lut_vector_indexes = (Torus *)cbs_vp_buffer; - Torus *ggsw_out = (Torus *)lut_vector_indexes + - (ptrdiff_t)(number_of_inputs * level_count_cbs); - Torus *lwe_array_in_fp_ks_buffer = - (Torus *)ggsw_out + (ptrdiff_t)(number_of_inputs * ggsw_size); - Torus *lwe_array_out_pbs_buffer = - (Torus *)lwe_array_in_fp_ks_buffer + - (ptrdiff_t)(number_of_inputs * level_count_cbs * (glwe_dimension + 1) * - (polynomial_size + 1)); - Torus *lwe_array_in_shifted_buffer = - (Torus *)lwe_array_out_pbs_buffer + - (ptrdiff_t)(number_of_inputs * level_count_cbs * (polynomial_size + 1)); - Torus *lut_vector_cbs = - (Torus *)lwe_array_in_shifted_buffer + - (ptrdiff_t)(number_of_inputs * level_count_cbs * (lwe_dimension + 1)); - Torus *glwe_array_out = - (Torus *)lut_vector_cbs + - (ptrdiff_t)(level_count_cbs * (glwe_dimension + 1) * polynomial_size); + Torus *lut_vector_indexes = (Torus *)cbs_vp_buffer; + int8_t *cbs_buffer = + (int8_t *)lut_vector_indexes + + (ptrdiff_t)(number_of_inputs * level_count_cbs * sizeof(Torus)); + int8_t *ggsw_out_cbs = + cbs_buffer + + (ptrdiff_t)(get_buffer_size_cbs(glwe_dimension, lwe_dimension, + polynomial_size, level_count_cbs, + number_of_inputs) + + get_buffer_size_bootstrap_amortized( + glwe_dimension, polynomial_size, + number_of_inputs * level_count_cbs, max_shared_memory)); host_circuit_bootstrap( - v_stream, gpu_index, ggsw_out, lwe_array_in, fourier_bsk, cbs_fpksk, - lwe_array_in_shifted_buffer, lut_vector_cbs, lut_vector_indexes, - lwe_array_out_pbs_buffer, lwe_array_in_fp_ks_buffer, cbs_delta_log, - polynomial_size, glwe_dimension, lwe_dimension, level_count_bsk, - base_log_bsk, level_count_pksk, base_log_pksk, level_count_cbs, - base_log_cbs, number_of_inputs, max_shared_memory); + v_stream, gpu_index, (Torus *)ggsw_out_cbs, lwe_array_in, fourier_bsk, + cbs_fpksk, lut_vector_indexes, cbs_buffer, cbs_delta_log, polynomial_size, + glwe_dimension, lwe_dimension, level_count_bsk, base_log_bsk, + level_count_pksk, base_log_pksk, level_count_cbs, base_log_cbs, + number_of_inputs, max_shared_memory); check_cuda_error(cudaGetLastError()); // number_of_inputs = tau * p is the total number of GGSWs @@ -172,29 +160,33 @@ __host__ void host_circuit_bootstrap_vertical_packing( // lsb GGSW is for the last blind rotation. uint32_t r = number_of_inputs - params::log2_degree; int8_t *cmux_tree_buffer = - (int8_t *)glwe_array_out + - (ptrdiff_t)(tau * (glwe_dimension + 1) * polynomial_size * sizeof(Torus)); + ggsw_out_cbs + (ptrdiff_t)(number_of_inputs * ggsw_size * sizeof(Torus)); + int8_t *glwe_array_out_cmux_tree = + cmux_tree_buffer + (ptrdiff_t)(get_buffer_size_cmux_tree( + glwe_dimension, polynomial_size, level_count_cbs, + r, tau, max_shared_memory)); // CMUX Tree // r = tau * p - log2(N) host_cmux_tree( - v_stream, gpu_index, glwe_array_out, ggsw_out, lut_vector, - cmux_tree_buffer, glwe_dimension, polynomial_size, base_log_cbs, - level_count_cbs, r, tau, max_shared_memory); + v_stream, gpu_index, (Torus *)glwe_array_out_cmux_tree, + (Torus *)ggsw_out_cbs, lut_vector, cmux_tree_buffer, glwe_dimension, + polynomial_size, base_log_cbs, level_count_cbs, r, tau, + max_shared_memory); check_cuda_error(cudaGetLastError()); // Blind rotation + sample extraction // mbr = tau * p - r = log2(N) - Torus *br_ggsw = (Torus *)ggsw_out + + Torus *br_ggsw = (Torus *)ggsw_out_cbs + (ptrdiff_t)(r * level_count_cbs * (glwe_dimension + 1) * (glwe_dimension + 1) * polynomial_size); int8_t *br_se_buffer = - cmux_tree_buffer + (ptrdiff_t)(get_buffer_size_cmux_tree( - glwe_dimension, polynomial_size, level_count_cbs, - r, tau, max_shared_memory)); + glwe_array_out_cmux_tree + + (ptrdiff_t)(tau * (glwe_dimension + 1) * polynomial_size * sizeof(Torus)); host_blind_rotate_and_sample_extraction( - v_stream, gpu_index, lwe_array_out, br_ggsw, glwe_array_out, br_se_buffer, - number_of_inputs - r, tau, glwe_dimension, polynomial_size, base_log_cbs, - level_count_cbs, max_shared_memory); + v_stream, gpu_index, lwe_array_out, br_ggsw, + (Torus *)glwe_array_out_cmux_tree, br_se_buffer, number_of_inputs - r, + tau, glwe_dimension, polynomial_size, base_log_cbs, level_count_cbs, + max_shared_memory); } template @@ -238,9 +230,14 @@ scratch_wop_pbs(void *v_stream, uint32_t gpu_index, int8_t **wop_pbs_buffer, uint32_t r = cbs_vp_number_of_inputs - params::log2_degree; uint32_t mbr_size = cbs_vp_number_of_inputs - r; int buffer_size = - get_buffer_size_cbs_vp(glwe_dimension, lwe_dimension, - polynomial_size, level_count_cbs, - cbs_vp_number_of_inputs, tau) + + get_buffer_size_cbs_vp(glwe_dimension, polynomial_size, + level_count_cbs, tau, + cbs_vp_number_of_inputs) + + get_buffer_size_cbs(glwe_dimension, lwe_dimension, polynomial_size, + level_count_cbs, cbs_vp_number_of_inputs) + + get_buffer_size_bootstrap_amortized( + glwe_dimension, polynomial_size, + cbs_vp_number_of_inputs * level_count_cbs, max_shared_memory) + get_buffer_size_cmux_tree(glwe_dimension, polynomial_size, level_count_cbs, r, tau, max_shared_memory) +