From e6dfb588db95abe179ec89e4de6579e90cb248b2 Mon Sep 17 00:00:00 2001 From: Agnes Leroy Date: Wed, 15 Feb 2023 09:22:07 +0100 Subject: [PATCH] refactor(cuda): prepare to introduce cmux tree scratch --- include/bootstrap.h | 40 ++-- src/bootstrap_amortized.cuh | 18 +- src/bootstrap_low_latency.cuh | 22 +-- src/crypto/ggsw.cuh | 8 +- src/keyswitch.cuh | 4 +- src/polynomial/polynomial.cuh | 4 +- src/vertical_packing.cuh | 48 ++--- src/wop_bootstrap.cu | 336 +++++++++++++++++++++++----------- src/wop_bootstrap.cuh | 72 ++++---- 9 files changed, 333 insertions(+), 219 deletions(-) diff --git a/include/bootstrap.h b/include/bootstrap.h index f56169036..e157100ec 100644 --- a/include/bootstrap.h +++ b/include/bootstrap.h @@ -106,34 +106,38 @@ void cuda_circuit_bootstrap_64( uint32_t number_of_samples, uint32_t max_shared_memory); void scratch_cuda_circuit_bootstrap_vertical_packing_32( - void *v_stream, uint32_t gpu_index, void **cbs_vp_buffer, + void *v_stream, uint32_t gpu_index, int8_t **cbs_vp_buffer, uint32_t *cbs_delta_log, uint32_t glwe_dimension, uint32_t lwe_dimension, uint32_t polynomial_size, uint32_t level_count_cbs, - uint32_t number_of_inputs, uint32_t tau, bool allocate_gpu_memory); + uint32_t number_of_inputs, uint32_t tau, uint32_t max_shared_memory, + bool allocate_gpu_memory); void scratch_cuda_circuit_bootstrap_vertical_packing_64( - void *v_stream, uint32_t gpu_index, void **cbs_vp_buffer, + void *v_stream, uint32_t gpu_index, int8_t **cbs_vp_buffer, uint32_t *cbs_delta_log, uint32_t glwe_dimension, uint32_t lwe_dimension, uint32_t polynomial_size, uint32_t level_count_cbs, - uint32_t number_of_inputs, uint32_t tau, bool allocate_gpu_memory); + uint32_t number_of_inputs, uint32_t tau, uint32_t max_shared_memory, + bool allocate_gpu_memory); void scratch_cuda_wop_pbs_32( - void *v_stream, uint32_t gpu_index, void **wop_pbs_buffer, + void *v_stream, uint32_t gpu_index, int8_t **wop_pbs_buffer, uint32_t *delta_log, uint32_t *cbs_delta_log, uint32_t glwe_dimension, uint32_t lwe_dimension, uint32_t polynomial_size, uint32_t level_count_cbs, uint32_t number_of_bits_of_message_including_padding, - uint32_t number_of_bits_to_extract, uint32_t number_of_inputs); + uint32_t number_of_bits_to_extract, uint32_t number_of_inputs, + uint32_t max_shared_memory); void scratch_cuda_wop_pbs_64( - void *v_stream, uint32_t gpu_index, void **wop_pbs_buffer, + void *v_stream, uint32_t gpu_index, int8_t **wop_pbs_buffer, uint32_t *delta_log, uint32_t *cbs_delta_log, uint32_t glwe_dimension, uint32_t lwe_dimension, uint32_t polynomial_size, uint32_t level_count_cbs, uint32_t number_of_bits_of_message_including_padding, - uint32_t number_of_bits_to_extract, uint32_t number_of_inputs); + uint32_t number_of_bits_to_extract, uint32_t number_of_inputs, + uint32_t max_shared_memory); void cuda_circuit_bootstrap_vertical_packing_64( void *v_stream, uint32_t gpu_index, void *lwe_array_out, void *lwe_array_in, - void *fourier_bsk, void *cbs_fpksk, void *lut_vector, void *cbs_vp_buffer, + void *fourier_bsk, void *cbs_fpksk, void *lut_vector, int8_t *cbs_vp_buffer, uint32_t cbs_delta_log, uint32_t polynomial_size, uint32_t glwe_dimension, uint32_t lwe_dimension, uint32_t level_count_bsk, uint32_t base_log_bsk, uint32_t level_count_pksk, uint32_t base_log_pksk, uint32_t level_count_cbs, @@ -142,7 +146,7 @@ void cuda_circuit_bootstrap_vertical_packing_64( void cuda_wop_pbs_64(void *v_stream, uint32_t gpu_index, void *lwe_array_out, void *lwe_array_in, void *lut_vector, void *fourier_bsk, - void *ksk, void *cbs_fpksk, void *wop_pbs_buffer, + void *ksk, void *cbs_fpksk, int8_t *wop_pbs_buffer, uint32_t cbs_delta_log, uint32_t glwe_dimension, uint32_t lwe_dimension, uint32_t polynomial_size, uint32_t base_log_bsk, uint32_t level_count_bsk, @@ -153,18 +157,12 @@ void cuda_wop_pbs_64(void *v_stream, uint32_t gpu_index, void *lwe_array_out, uint32_t number_of_bits_to_extract, uint32_t delta_log, uint32_t number_of_inputs, uint32_t max_shared_memory); -void cleanup_cuda_wop_pbs_32(void *v_stream, uint32_t gpu_index, - void **wop_pbs_buffer); -void cleanup_cuda_wop_pbs_64(void *v_stream, uint32_t gpu_index, - void **wop_pbs_buffer); +void cleanup_cuda_wop_pbs(void *v_stream, uint32_t gpu_index, + int8_t **wop_pbs_buffer); -void cleanup_cuda_circuit_bootstrap_vertical_packing_32(void *v_stream, - uint32_t gpu_index, - void **cbs_vp_buffer); - -void cleanup_cuda_circuit_bootstrap_vertical_packing_64(void *v_stream, - uint32_t gpu_index, - void **cbs_vp_buffer); +void cleanup_cuda_circuit_bootstrap_vertical_packing(void *v_stream, + uint32_t gpu_index, + int8_t **cbs_vp_buffer); } #ifdef __CUDACC__ diff --git a/src/bootstrap_amortized.cuh b/src/bootstrap_amortized.cuh index 6162acb6b..81ace4848 100644 --- a/src/bootstrap_amortized.cuh +++ b/src/bootstrap_amortized.cuh @@ -52,15 +52,15 @@ template */ __global__ void device_bootstrap_amortized( Torus *lwe_array_out, Torus *lut_vector, Torus *lut_vector_indexes, - Torus *lwe_array_in, double2 *bootstrapping_key, char *device_mem, + Torus *lwe_array_in, double2 *bootstrapping_key, int8_t *device_mem, uint32_t glwe_dimension, uint32_t lwe_dimension, uint32_t polynomial_size, uint32_t base_log, uint32_t level_count, uint32_t lwe_idx, size_t device_memory_size_per_sample) { // We use shared memory for the polynomials that are used often during the // bootstrap, since shared memory is kept in L1 cache and accessing it is // much faster than global memory - extern __shared__ char sharedmem[]; - char *selected_memory; + extern __shared__ int8_t sharedmem[]; + int8_t *selected_memory; if constexpr (SMD == FULLSM) selected_memory = sharedmem; @@ -241,7 +241,7 @@ __host__ void host_bootstrap_amortized( auto stream = static_cast(v_stream); - char *d_mem; + int8_t *d_mem; // Create a 1-dimensional grid of threads // where each block handles 1 sample and each thread @@ -257,8 +257,8 @@ __host__ void host_bootstrap_amortized( // from one of three templates (no use, partial use or full use // of shared memory) if (max_shared_memory < SM_PART) { - d_mem = (char *)cuda_malloc_async(DM_FULL * input_lwe_ciphertext_count, - stream, gpu_index); + d_mem = (int8_t *)cuda_malloc_async(DM_FULL * input_lwe_ciphertext_count, + stream, gpu_index); device_bootstrap_amortized<<>>( lwe_array_out, lut_vector, lut_vector_indexes, lwe_array_in, bootstrapping_key, d_mem, glwe_dimension, lwe_dimension, @@ -268,8 +268,8 @@ __host__ void host_bootstrap_amortized( cudaFuncAttributeMaxDynamicSharedMemorySize, SM_PART); cudaFuncSetCacheConfig(device_bootstrap_amortized, cudaFuncCachePreferShared); - d_mem = (char *)cuda_malloc_async(DM_PART * input_lwe_ciphertext_count, - stream, gpu_index); + d_mem = (int8_t *)cuda_malloc_async(DM_PART * input_lwe_ciphertext_count, + stream, gpu_index); device_bootstrap_amortized <<>>( lwe_array_out, lut_vector, lut_vector_indexes, lwe_array_in, @@ -287,7 +287,7 @@ __host__ void host_bootstrap_amortized( check_cuda_error(cudaFuncSetCacheConfig( device_bootstrap_amortized, cudaFuncCachePreferShared)); - d_mem = (char *)cuda_malloc_async(0, stream, gpu_index); + d_mem = (int8_t *)cuda_malloc_async(0, stream, gpu_index); device_bootstrap_amortized <<>>( diff --git a/src/bootstrap_low_latency.cuh b/src/bootstrap_low_latency.cuh index dc9dfd7e5..587579741 100644 --- a/src/bootstrap_low_latency.cuh +++ b/src/bootstrap_low_latency.cuh @@ -134,15 +134,15 @@ __global__ void device_bootstrap_low_latency( Torus *lwe_array_out, Torus *lut_vector, Torus *lwe_array_in, double2 *bootstrapping_key, double2 *join_buffer, uint32_t lwe_dimension, uint32_t polynomial_size, uint32_t base_log, uint32_t level_count, - char *device_mem, int device_memory_size_per_block) { + int8_t *device_mem, int device_memory_size_per_block) { grid_group grid = this_grid(); // We use shared memory for the polynomials that are used often during the // bootstrap, since shared memory is kept in L1 cache and accessing it is // much faster than global memory - extern __shared__ char sharedmem[]; - char *selected_memory; + extern __shared__ int8_t sharedmem[]; + int8_t *selected_memory; int block_index = blockIdx.x + blockIdx.y * gridDim.x + blockIdx.z * gridDim.x * gridDim.y; uint32_t glwe_dimension = gridDim.y - 1; @@ -275,7 +275,7 @@ __host__ void host_bootstrap_low_latency( int DM_PART = DM_FULL - SM_PART; - char *d_mem; + int8_t *d_mem; int thds = polynomial_size / params::opt; dim3 grid(level_count, glwe_dimension + 1, input_lwe_ciphertext_count); @@ -295,18 +295,18 @@ __host__ void host_bootstrap_low_latency( if (max_shared_memory < SM_PART) { kernel_args[10] = &DM_FULL; check_cuda_error(cudaGetLastError()); - d_mem = (char *)cuda_malloc_async(DM_FULL * input_lwe_ciphertext_count * - level_count * (glwe_dimension + 1), - stream, gpu_index); + d_mem = (int8_t *)cuda_malloc_async(DM_FULL * input_lwe_ciphertext_count * + level_count * (glwe_dimension + 1), + stream, gpu_index); check_cuda_error(cudaGetLastError()); check_cuda_error(cudaLaunchCooperativeKernel( (void *)device_bootstrap_low_latency, grid, thds, (void **)kernel_args, 0, *stream)); } else if (max_shared_memory < SM_FULL) { kernel_args[10] = &DM_PART; - d_mem = (char *)cuda_malloc_async(DM_PART * input_lwe_ciphertext_count * - level_count * (glwe_dimension + 1), - stream, gpu_index); + d_mem = (int8_t *)cuda_malloc_async(DM_PART * input_lwe_ciphertext_count * + level_count * (glwe_dimension + 1), + stream, gpu_index); check_cuda_error(cudaFuncSetAttribute( device_bootstrap_low_latency, cudaFuncAttributeMaxDynamicSharedMemorySize, SM_PART)); @@ -321,7 +321,7 @@ __host__ void host_bootstrap_low_latency( } else { int DM_NONE = 0; kernel_args[10] = &DM_NONE; - d_mem = (char *)cuda_malloc_async(0, stream, gpu_index); + d_mem = (int8_t *)cuda_malloc_async(0, stream, gpu_index); check_cuda_error(cudaFuncSetAttribute( device_bootstrap_low_latency, cudaFuncAttributeMaxDynamicSharedMemorySize, SM_FULL)); diff --git a/src/crypto/ggsw.cuh b/src/crypto/ggsw.cuh index 5db8bc399..aeb33c8cd 100644 --- a/src/crypto/ggsw.cuh +++ b/src/crypto/ggsw.cuh @@ -6,9 +6,9 @@ template __global__ void device_batch_fft_ggsw_vector(double2 *dest, T *src, - char *device_mem) { + int8_t *device_mem) { - extern __shared__ char sharedmem[]; + extern __shared__ int8_t sharedmem[]; double2 *selected_memory; if constexpr (SMD == FULLSM) @@ -59,9 +59,9 @@ void batch_fft_ggsw_vector(cudaStream_t *stream, double2 *dest, T *src, int gridSize = r * (glwe_dim + 1) * (glwe_dim + 1) * level_count; int blockSize = polynomial_size / params::opt; - char *d_mem; + int8_t *d_mem; if (max_shared_memory < shared_memory_size) { - d_mem = (char *)cuda_malloc_async(shared_memory_size, stream, gpu_index); + d_mem = (int8_t *)cuda_malloc_async(shared_memory_size, stream, gpu_index); device_batch_fft_ggsw_vector <<>>(dest, src, d_mem); check_cuda_error(cudaGetLastError()); diff --git a/src/keyswitch.cuh b/src/keyswitch.cuh index 7dd339c51..42021c9d1 100644 --- a/src/keyswitch.cuh +++ b/src/keyswitch.cuh @@ -44,7 +44,7 @@ fp_keyswitch(Torus *glwe_array_out, Torus *lwe_array_in, Torus *fp_ksk_array, size_t chunk_id = blockIdx.x; size_t ksk_id = ciphertext_id % number_of_keys; - extern __shared__ char sharedmem[]; + extern __shared__ int8_t sharedmem[]; // result accumulator, shared memory is used because of frequent access Torus *local_glwe_chunk = (Torus *)sharedmem; @@ -106,7 +106,7 @@ __global__ void keyswitch(Torus *lwe_array_out, Torus *lwe_array_in, Torus *ksk, int lwe_lower, int lwe_upper, int cutoff) { int tid = threadIdx.x; - extern __shared__ char sharedmem[]; + extern __shared__ int8_t sharedmem[]; Torus *local_lwe_array_out = (Torus *)sharedmem; diff --git a/src/polynomial/polynomial.cuh b/src/polynomial/polynomial.cuh index 3c2c7d1c6..d8380a901 100644 --- a/src/polynomial/polynomial.cuh +++ b/src/polynomial/polynomial.cuh @@ -39,7 +39,7 @@ public: __device__ Polynomial(T *coefficients, uint32_t degree) : coefficients(coefficients), degree(degree) {} - __device__ Polynomial(char *memory, uint32_t degree) + __device__ Polynomial(int8_t *memory, uint32_t degree) : coefficients((T *)memory), degree(degree) {} __host__ void copy_to_host(T *dest) { @@ -49,7 +49,7 @@ public: __device__ T get_coefficient(int i) { return this->coefficients[i]; } - __device__ char *reuse_memory() { return (char *)coefficients; } + __device__ int8_t *reuse_memory() { return (int8_t *)coefficients; } __device__ void copy_coefficients_from(Polynomial &source, int begin_dest = 0, diff --git a/src/vertical_packing.cuh b/src/vertical_packing.cuh index bb9b5a22b..baa5d77cc 100644 --- a/src/vertical_packing.cuh +++ b/src/vertical_packing.cuh @@ -54,7 +54,7 @@ template __device__ void ifft_inplace(double2 *data) { template __device__ void cmux(Torus *glwe_array_out, Torus *glwe_array_in, double2 *ggsw_in, - char *selected_memory, uint32_t output_idx, uint32_t input_idx1, + int8_t *selected_memory, uint32_t output_idx, uint32_t input_idx1, uint32_t input_idx2, uint32_t glwe_dim, uint32_t polynomial_size, uint32_t base_log, uint32_t level_count, uint32_t ggsw_idx) { @@ -211,7 +211,7 @@ __host__ void add_padding_to_lut_async(Torus *lut_out, Torus *lut_in, */ template __global__ void device_batch_cmux(Torus *glwe_array_out, Torus *glwe_array_in, - double2 *ggsw_in, char *device_mem, + double2 *ggsw_in, int8_t *device_mem, size_t device_memory_size_per_block, uint32_t glwe_dim, uint32_t polynomial_size, uint32_t base_log, uint32_t level_count, @@ -228,8 +228,8 @@ __global__ void device_batch_cmux(Torus *glwe_array_out, Torus *glwe_array_in, int input_idx2 = (cmux_idx << 1) + 1; // We use shared memory for intermediate result - extern __shared__ char sharedmem[]; - char *selected_memory; + extern __shared__ int8_t sharedmem[]; + int8_t *selected_memory; if constexpr (SMD == FULLSM) selected_memory = sharedmem; @@ -299,9 +299,9 @@ __host__ void host_cmux_tree(void *v_stream, uint32_t gpu_index, ////////////////////// // Allocate global memory in case parameters are too large - char *d_mem; + int8_t *d_mem; if (max_shared_memory < memory_needed_per_block) { - d_mem = (char *)cuda_malloc_async( + d_mem = (int8_t *)cuda_malloc_async( memory_needed_per_block * (1 << (r - 1)) * tau, stream, gpu_index); } else { check_cuda_error(cudaFuncSetAttribute( @@ -380,7 +380,7 @@ __host__ void host_cmux_tree(void *v_stream, uint32_t gpu_index, * - glwe_dim: This is k. * - polynomial_size: size of the polynomials. This is N. * - base_log: log base used for the gadget matrix - B = 2^base_log (~8) - * - l_gadget: number of decomposition levels in the gadget matrix (~4) + * - level_count: number of decomposition levels in the gadget matrix (~4) * - device_memory_size_per_sample: Amount of (shared/global) memory used for * the accumulators. * - device_mem: An array to be used for the accumulators. Can be in the shared @@ -390,12 +390,12 @@ template __global__ void device_blind_rotation_and_sample_extraction( Torus *lwe_out, Torus *glwe_in, double2 *ggsw_in, // m^BR uint32_t mbr_size, uint32_t glwe_dim, uint32_t polynomial_size, - uint32_t base_log, uint32_t l_gadget, size_t device_memory_size_per_sample, - char *device_mem) { + uint32_t base_log, uint32_t level_count, + size_t device_memory_size_per_sample, int8_t *device_mem) { // We use shared memory for intermediate result - extern __shared__ char sharedmem[]; - char *selected_memory; + extern __shared__ int8_t sharedmem[]; + int8_t *selected_memory; if constexpr (SMD == FULLSM) selected_memory = sharedmem; @@ -433,10 +433,10 @@ __global__ void device_blind_rotation_and_sample_extraction( // ACC = CMUX ( Ci, x^ai * ACC, ACC ) synchronize_threads_in_block(); - cmux(accumulator_c0, accumulator_c0, ggsw_in, - (char *)(accumulator_c0 + 4 * polynomial_size), - 0, 0, 1, glwe_dim, polynomial_size, base_log, - l_gadget, i); + cmux( + accumulator_c0, accumulator_c0, ggsw_in, + (int8_t *)(accumulator_c0 + 4 * polynomial_size), 0, 0, 1, glwe_dim, + polynomial_size, base_log, level_count, i); } synchronize_threads_in_block(); @@ -455,7 +455,7 @@ template __host__ void host_blind_rotate_and_sample_extraction( void *v_stream, uint32_t gpu_index, Torus *lwe_out, Torus *ggsw_in, Torus *lut_vector, uint32_t mbr_size, uint32_t tau, uint32_t glwe_dimension, - uint32_t polynomial_size, uint32_t base_log, uint32_t l_gadget, + uint32_t polynomial_size, uint32_t base_log, uint32_t level_count, uint32_t max_shared_memory) { cudaSetDevice(gpu_index); @@ -474,10 +474,10 @@ __host__ void host_blind_rotate_and_sample_extraction( sizeof(double2) * polynomial_size / 2 + // body_res_fft sizeof(double2) * polynomial_size / 2; // glwe_fft - char *d_mem; + int8_t *d_mem; if (max_shared_memory < memory_needed_per_block) - d_mem = (char *)cuda_malloc_async(memory_needed_per_block * tau, stream, - gpu_index); + d_mem = (int8_t *)cuda_malloc_async(memory_needed_per_block * tau, stream, + gpu_index); else { check_cuda_error(cudaFuncSetAttribute( device_blind_rotation_and_sample_extraction( stream, d_ggsw_fft_in, ggsw_in, mbr_size, glwe_dimension, polynomial_size, - l_gadget, gpu_index, max_shared_memory); + level_count, gpu_index, max_shared_memory); check_cuda_error(cudaGetLastError()); // @@ -509,14 +509,14 @@ __host__ void host_blind_rotate_and_sample_extraction( <<>>(lwe_out, lut_vector, d_ggsw_fft_in, mbr_size, glwe_dimension, // k - polynomial_size, base_log, l_gadget, + polynomial_size, base_log, level_count, memory_needed_per_block, d_mem); else device_blind_rotation_and_sample_extraction <<>>( lwe_out, lut_vector, d_ggsw_fft_in, mbr_size, glwe_dimension, // k - polynomial_size, base_log, l_gadget, memory_needed_per_block, + polynomial_size, base_log, level_count, memory_needed_per_block, d_mem); check_cuda_error(cudaGetLastError()); diff --git a/src/wop_bootstrap.cu b/src/wop_bootstrap.cu index f9afc6fed..12321ae9a 100644 --- a/src/wop_bootstrap.cu +++ b/src/wop_bootstrap.cu @@ -7,15 +7,46 @@ * circuit bootstrap. */ void scratch_cuda_circuit_bootstrap_vertical_packing_32( - void *v_stream, uint32_t gpu_index, void **cbs_vp_buffer, + void *v_stream, uint32_t gpu_index, int8_t **cbs_vp_buffer, uint32_t *cbs_delta_log, uint32_t glwe_dimension, uint32_t lwe_dimension, uint32_t polynomial_size, uint32_t level_count_cbs, - uint32_t number_of_inputs, uint32_t tau, bool allocate_gpu_memory) { + uint32_t number_of_inputs, uint32_t tau, uint32_t max_shared_memory, + bool allocate_gpu_memory) { - scratch_circuit_bootstrap_vertical_packing( - v_stream, gpu_index, (uint32_t **)cbs_vp_buffer, cbs_delta_log, - glwe_dimension, lwe_dimension, polynomial_size, level_count_cbs, - number_of_inputs, tau, allocate_gpu_memory); + switch (polynomial_size) { + case 512: + scratch_circuit_bootstrap_vertical_packing>( + v_stream, gpu_index, cbs_vp_buffer, cbs_delta_log, glwe_dimension, + lwe_dimension, polynomial_size, level_count_cbs, number_of_inputs, tau, + max_shared_memory, allocate_gpu_memory); + break; + case 1024: + scratch_circuit_bootstrap_vertical_packing>( + v_stream, gpu_index, cbs_vp_buffer, cbs_delta_log, glwe_dimension, + lwe_dimension, polynomial_size, level_count_cbs, number_of_inputs, tau, + max_shared_memory, allocate_gpu_memory); + break; + case 2048: + scratch_circuit_bootstrap_vertical_packing>( + v_stream, gpu_index, cbs_vp_buffer, cbs_delta_log, glwe_dimension, + lwe_dimension, polynomial_size, level_count_cbs, number_of_inputs, tau, + max_shared_memory, allocate_gpu_memory); + break; + case 4096: + scratch_circuit_bootstrap_vertical_packing>( + v_stream, gpu_index, cbs_vp_buffer, cbs_delta_log, glwe_dimension, + lwe_dimension, polynomial_size, level_count_cbs, number_of_inputs, tau, + max_shared_memory, allocate_gpu_memory); + break; + case 8192: + scratch_circuit_bootstrap_vertical_packing>( + v_stream, gpu_index, cbs_vp_buffer, cbs_delta_log, glwe_dimension, + lwe_dimension, polynomial_size, level_count_cbs, number_of_inputs, tau, + max_shared_memory, allocate_gpu_memory); + break; + default: + break; + } } /* @@ -25,15 +56,46 @@ void scratch_cuda_circuit_bootstrap_vertical_packing_32( * circuit bootstrap. */ void scratch_cuda_circuit_bootstrap_vertical_packing_64( - void *v_stream, uint32_t gpu_index, void **cbs_vp_buffer, + void *v_stream, uint32_t gpu_index, int8_t **cbs_vp_buffer, uint32_t *cbs_delta_log, uint32_t glwe_dimension, uint32_t lwe_dimension, uint32_t polynomial_size, uint32_t level_count_cbs, - uint32_t number_of_inputs, uint32_t tau, bool allocate_gpu_memory) { + uint32_t number_of_inputs, uint32_t tau, uint32_t max_shared_memory, + bool allocate_gpu_memory) { - scratch_circuit_bootstrap_vertical_packing( - v_stream, gpu_index, (uint64_t **)cbs_vp_buffer, cbs_delta_log, - glwe_dimension, lwe_dimension, polynomial_size, level_count_cbs, - number_of_inputs, tau, allocate_gpu_memory); + switch (polynomial_size) { + case 512: + scratch_circuit_bootstrap_vertical_packing>( + v_stream, gpu_index, cbs_vp_buffer, cbs_delta_log, glwe_dimension, + lwe_dimension, polynomial_size, level_count_cbs, number_of_inputs, tau, + max_shared_memory, allocate_gpu_memory); + break; + case 1024: + scratch_circuit_bootstrap_vertical_packing>( + v_stream, gpu_index, cbs_vp_buffer, cbs_delta_log, glwe_dimension, + lwe_dimension, polynomial_size, level_count_cbs, number_of_inputs, tau, + max_shared_memory, allocate_gpu_memory); + break; + case 2048: + scratch_circuit_bootstrap_vertical_packing>( + v_stream, gpu_index, cbs_vp_buffer, cbs_delta_log, glwe_dimension, + lwe_dimension, polynomial_size, level_count_cbs, number_of_inputs, tau, + max_shared_memory, allocate_gpu_memory); + break; + case 4096: + scratch_circuit_bootstrap_vertical_packing>( + v_stream, gpu_index, cbs_vp_buffer, cbs_delta_log, glwe_dimension, + lwe_dimension, polynomial_size, level_count_cbs, number_of_inputs, tau, + max_shared_memory, allocate_gpu_memory); + break; + case 8192: + scratch_circuit_bootstrap_vertical_packing>( + v_stream, gpu_index, cbs_vp_buffer, cbs_delta_log, glwe_dimension, + lwe_dimension, polynomial_size, level_count_cbs, number_of_inputs, tau, + max_shared_memory, allocate_gpu_memory); + break; + default: + break; + } } /* @@ -43,16 +105,51 @@ void scratch_cuda_circuit_bootstrap_vertical_packing_64( * bootstrap. */ void scratch_cuda_wop_pbs_32( - void *v_stream, uint32_t gpu_index, void **wop_pbs_buffer, + void *v_stream, uint32_t gpu_index, int8_t **wop_pbs_buffer, uint32_t *delta_log, uint32_t *cbs_delta_log, uint32_t glwe_dimension, uint32_t lwe_dimension, uint32_t polynomial_size, uint32_t level_count_cbs, uint32_t number_of_bits_of_message_including_padding, - uint32_t number_of_bits_to_extract, uint32_t number_of_inputs) { - scratch_wop_pbs(v_stream, gpu_index, (uint32_t **)wop_pbs_buffer, - delta_log, cbs_delta_log, glwe_dimension, - lwe_dimension, polynomial_size, level_count_cbs, - number_of_bits_of_message_including_padding, - number_of_bits_to_extract, number_of_inputs); + uint32_t number_of_bits_to_extract, uint32_t number_of_inputs, + uint32_t max_shared_memory) { + switch (polynomial_size) { + case 512: + scratch_wop_pbs>( + v_stream, gpu_index, wop_pbs_buffer, delta_log, cbs_delta_log, + glwe_dimension, lwe_dimension, polynomial_size, level_count_cbs, + number_of_bits_of_message_including_padding, number_of_bits_to_extract, + number_of_inputs, max_shared_memory); + break; + case 1024: + scratch_wop_pbs>( + v_stream, gpu_index, wop_pbs_buffer, delta_log, cbs_delta_log, + glwe_dimension, lwe_dimension, polynomial_size, level_count_cbs, + number_of_bits_of_message_including_padding, number_of_bits_to_extract, + number_of_inputs, max_shared_memory); + break; + case 2048: + scratch_wop_pbs>( + v_stream, gpu_index, wop_pbs_buffer, delta_log, cbs_delta_log, + glwe_dimension, lwe_dimension, polynomial_size, level_count_cbs, + number_of_bits_of_message_including_padding, number_of_bits_to_extract, + number_of_inputs, max_shared_memory); + break; + case 4096: + scratch_wop_pbs>( + v_stream, gpu_index, wop_pbs_buffer, delta_log, cbs_delta_log, + glwe_dimension, lwe_dimension, polynomial_size, level_count_cbs, + number_of_bits_of_message_including_padding, number_of_bits_to_extract, + number_of_inputs, max_shared_memory); + break; + case 8192: + scratch_wop_pbs>( + v_stream, gpu_index, wop_pbs_buffer, delta_log, cbs_delta_log, + glwe_dimension, lwe_dimension, polynomial_size, level_count_cbs, + number_of_bits_of_message_including_padding, number_of_bits_to_extract, + number_of_inputs, max_shared_memory); + break; + default: + break; + } } /* @@ -62,16 +159,51 @@ void scratch_cuda_wop_pbs_32( * bootstrap. */ void scratch_cuda_wop_pbs_64( - void *v_stream, uint32_t gpu_index, void **wop_pbs_buffer, + void *v_stream, uint32_t gpu_index, int8_t **wop_pbs_buffer, uint32_t *delta_log, uint32_t *cbs_delta_log, uint32_t glwe_dimension, uint32_t lwe_dimension, uint32_t polynomial_size, uint32_t level_count_cbs, uint32_t number_of_bits_of_message_including_padding, - uint32_t number_of_bits_to_extract, uint32_t number_of_inputs) { - scratch_wop_pbs(v_stream, gpu_index, (uint64_t **)wop_pbs_buffer, - delta_log, cbs_delta_log, glwe_dimension, - lwe_dimension, polynomial_size, level_count_cbs, - number_of_bits_of_message_including_padding, - number_of_bits_to_extract, number_of_inputs); + uint32_t number_of_bits_to_extract, uint32_t number_of_inputs, + uint32_t max_shared_memory) { + switch (polynomial_size) { + case 512: + scratch_wop_pbs>( + v_stream, gpu_index, wop_pbs_buffer, delta_log, cbs_delta_log, + glwe_dimension, lwe_dimension, polynomial_size, level_count_cbs, + number_of_bits_of_message_including_padding, number_of_bits_to_extract, + number_of_inputs, max_shared_memory); + break; + case 1024: + scratch_wop_pbs>( + v_stream, gpu_index, wop_pbs_buffer, delta_log, cbs_delta_log, + glwe_dimension, lwe_dimension, polynomial_size, level_count_cbs, + number_of_bits_of_message_including_padding, number_of_bits_to_extract, + number_of_inputs, max_shared_memory); + break; + case 2048: + scratch_wop_pbs>( + v_stream, gpu_index, wop_pbs_buffer, delta_log, cbs_delta_log, + glwe_dimension, lwe_dimension, polynomial_size, level_count_cbs, + number_of_bits_of_message_including_padding, number_of_bits_to_extract, + number_of_inputs, max_shared_memory); + break; + case 4096: + scratch_wop_pbs>( + v_stream, gpu_index, wop_pbs_buffer, delta_log, cbs_delta_log, + glwe_dimension, lwe_dimension, polynomial_size, level_count_cbs, + number_of_bits_of_message_including_padding, number_of_bits_to_extract, + number_of_inputs, max_shared_memory); + break; + case 8192: + scratch_wop_pbs>( + v_stream, gpu_index, wop_pbs_buffer, delta_log, cbs_delta_log, + glwe_dimension, lwe_dimension, polynomial_size, level_count_cbs, + number_of_bits_of_message_including_padding, number_of_bits_to_extract, + number_of_inputs, max_shared_memory); + break; + default: + break; + } } /* @@ -104,7 +236,7 @@ void scratch_cuda_wop_pbs_64( */ void cuda_circuit_bootstrap_vertical_packing_64( void *v_stream, uint32_t gpu_index, void *lwe_array_out, void *lwe_array_in, - void *fourier_bsk, void *cbs_fpksk, void *lut_vector, void *cbs_vp_buffer, + void *fourier_bsk, void *cbs_fpksk, void *lut_vector, int8_t *cbs_vp_buffer, uint32_t cbs_delta_log, uint32_t polynomial_size, uint32_t glwe_dimension, uint32_t lwe_dimension, uint32_t level_count_bsk, uint32_t base_log_bsk, uint32_t level_count_pksk, uint32_t base_log_pksk, uint32_t level_count_cbs, @@ -131,51 +263,51 @@ void cuda_circuit_bootstrap_vertical_packing_64( host_circuit_bootstrap_vertical_packing>( v_stream, gpu_index, (uint64_t *)lwe_array_out, (uint64_t *)lwe_array_in, (uint64_t *)lut_vector, - (double2 *)fourier_bsk, (uint64_t *)cbs_fpksk, - (uint64_t *)cbs_vp_buffer, cbs_delta_log, glwe_dimension, lwe_dimension, - polynomial_size, base_log_bsk, level_count_bsk, base_log_pksk, - level_count_pksk, base_log_cbs, level_count_cbs, number_of_inputs, - lut_number, max_shared_memory); + (double2 *)fourier_bsk, (uint64_t *)cbs_fpksk, cbs_vp_buffer, + cbs_delta_log, glwe_dimension, lwe_dimension, polynomial_size, + base_log_bsk, level_count_bsk, base_log_pksk, level_count_pksk, + base_log_cbs, level_count_cbs, number_of_inputs, lut_number, + max_shared_memory); break; case 1024: host_circuit_bootstrap_vertical_packing>( v_stream, gpu_index, (uint64_t *)lwe_array_out, (uint64_t *)lwe_array_in, (uint64_t *)lut_vector, - (double2 *)fourier_bsk, (uint64_t *)cbs_fpksk, - (uint64_t *)cbs_vp_buffer, cbs_delta_log, glwe_dimension, lwe_dimension, - polynomial_size, base_log_bsk, level_count_bsk, base_log_pksk, - level_count_pksk, base_log_cbs, level_count_cbs, number_of_inputs, - lut_number, max_shared_memory); + (double2 *)fourier_bsk, (uint64_t *)cbs_fpksk, cbs_vp_buffer, + cbs_delta_log, glwe_dimension, lwe_dimension, polynomial_size, + base_log_bsk, level_count_bsk, base_log_pksk, level_count_pksk, + base_log_cbs, level_count_cbs, number_of_inputs, lut_number, + max_shared_memory); break; case 2048: host_circuit_bootstrap_vertical_packing>( v_stream, gpu_index, (uint64_t *)lwe_array_out, (uint64_t *)lwe_array_in, (uint64_t *)lut_vector, - (double2 *)fourier_bsk, (uint64_t *)cbs_fpksk, - (uint64_t *)cbs_vp_buffer, cbs_delta_log, glwe_dimension, lwe_dimension, - polynomial_size, base_log_bsk, level_count_bsk, base_log_pksk, - level_count_pksk, base_log_cbs, level_count_cbs, number_of_inputs, - lut_number, max_shared_memory); + (double2 *)fourier_bsk, (uint64_t *)cbs_fpksk, cbs_vp_buffer, + cbs_delta_log, glwe_dimension, lwe_dimension, polynomial_size, + base_log_bsk, level_count_bsk, base_log_pksk, level_count_pksk, + base_log_cbs, level_count_cbs, number_of_inputs, lut_number, + max_shared_memory); break; case 4096: host_circuit_bootstrap_vertical_packing>( v_stream, gpu_index, (uint64_t *)lwe_array_out, (uint64_t *)lwe_array_in, (uint64_t *)lut_vector, - (double2 *)fourier_bsk, (uint64_t *)cbs_fpksk, - (uint64_t *)cbs_vp_buffer, cbs_delta_log, glwe_dimension, lwe_dimension, - polynomial_size, base_log_bsk, level_count_bsk, base_log_pksk, - level_count_pksk, base_log_cbs, level_count_cbs, number_of_inputs, - lut_number, max_shared_memory); + (double2 *)fourier_bsk, (uint64_t *)cbs_fpksk, cbs_vp_buffer, + cbs_delta_log, glwe_dimension, lwe_dimension, polynomial_size, + base_log_bsk, level_count_bsk, base_log_pksk, level_count_pksk, + base_log_cbs, level_count_cbs, number_of_inputs, lut_number, + max_shared_memory); break; case 8192: host_circuit_bootstrap_vertical_packing>( v_stream, gpu_index, (uint64_t *)lwe_array_out, (uint64_t *)lwe_array_in, (uint64_t *)lut_vector, - (double2 *)fourier_bsk, (uint64_t *)cbs_fpksk, - (uint64_t *)cbs_vp_buffer, cbs_delta_log, glwe_dimension, lwe_dimension, - polynomial_size, base_log_bsk, level_count_bsk, base_log_pksk, - level_count_pksk, base_log_cbs, level_count_cbs, number_of_inputs, - lut_number, max_shared_memory); + (double2 *)fourier_bsk, (uint64_t *)cbs_fpksk, cbs_vp_buffer, + cbs_delta_log, glwe_dimension, lwe_dimension, polynomial_size, + base_log_bsk, level_count_bsk, base_log_pksk, level_count_pksk, + base_log_cbs, level_count_cbs, number_of_inputs, lut_number, + max_shared_memory); break; default: break; @@ -219,7 +351,7 @@ void cuda_circuit_bootstrap_vertical_packing_64( */ void cuda_wop_pbs_64(void *v_stream, uint32_t gpu_index, void *lwe_array_out, void *lwe_array_in, void *lut_vector, void *fourier_bsk, - void *ksk, void *cbs_fpksk, void *wop_pbs_buffer, + void *ksk, void *cbs_fpksk, int8_t *wop_pbs_buffer, uint32_t cbs_delta_log, uint32_t glwe_dimension, uint32_t lwe_dimension, uint32_t polynomial_size, uint32_t base_log_bsk, uint32_t level_count_bsk, @@ -251,60 +383,60 @@ void cuda_wop_pbs_64(void *v_stream, uint32_t gpu_index, void *lwe_array_out, v_stream, gpu_index, (uint64_t *)lwe_array_out, (uint64_t *)lwe_array_in, (uint64_t *)lut_vector, (double2 *)fourier_bsk, (uint64_t *)ksk, (uint64_t *)cbs_fpksk, - (uint64_t *)wop_pbs_buffer, cbs_delta_log, glwe_dimension, - lwe_dimension, polynomial_size, base_log_bsk, level_count_bsk, - base_log_ksk, level_count_ksk, base_log_pksk, level_count_pksk, - base_log_cbs, level_count_cbs, - number_of_bits_of_message_including_padding, number_of_bits_to_extract, - delta_log, number_of_inputs, max_shared_memory); + wop_pbs_buffer, cbs_delta_log, glwe_dimension, lwe_dimension, + polynomial_size, base_log_bsk, level_count_bsk, base_log_ksk, + level_count_ksk, base_log_pksk, level_count_pksk, base_log_cbs, + level_count_cbs, number_of_bits_of_message_including_padding, + number_of_bits_to_extract, delta_log, number_of_inputs, + max_shared_memory); break; case 1024: host_wop_pbs>( v_stream, gpu_index, (uint64_t *)lwe_array_out, (uint64_t *)lwe_array_in, (uint64_t *)lut_vector, (double2 *)fourier_bsk, (uint64_t *)ksk, (uint64_t *)cbs_fpksk, - (uint64_t *)wop_pbs_buffer, cbs_delta_log, glwe_dimension, - lwe_dimension, polynomial_size, base_log_bsk, level_count_bsk, - base_log_ksk, level_count_ksk, base_log_pksk, level_count_pksk, - base_log_cbs, level_count_cbs, - number_of_bits_of_message_including_padding, number_of_bits_to_extract, - delta_log, number_of_inputs, max_shared_memory); + wop_pbs_buffer, cbs_delta_log, glwe_dimension, lwe_dimension, + polynomial_size, base_log_bsk, level_count_bsk, base_log_ksk, + level_count_ksk, base_log_pksk, level_count_pksk, base_log_cbs, + level_count_cbs, number_of_bits_of_message_including_padding, + number_of_bits_to_extract, delta_log, number_of_inputs, + max_shared_memory); break; case 2048: host_wop_pbs>( v_stream, gpu_index, (uint64_t *)lwe_array_out, (uint64_t *)lwe_array_in, (uint64_t *)lut_vector, (double2 *)fourier_bsk, (uint64_t *)ksk, (uint64_t *)cbs_fpksk, - (uint64_t *)wop_pbs_buffer, cbs_delta_log, glwe_dimension, - lwe_dimension, polynomial_size, base_log_bsk, level_count_bsk, - base_log_ksk, level_count_ksk, base_log_pksk, level_count_pksk, - base_log_cbs, level_count_cbs, - number_of_bits_of_message_including_padding, number_of_bits_to_extract, - delta_log, number_of_inputs, max_shared_memory); + wop_pbs_buffer, cbs_delta_log, glwe_dimension, lwe_dimension, + polynomial_size, base_log_bsk, level_count_bsk, base_log_ksk, + level_count_ksk, base_log_pksk, level_count_pksk, base_log_cbs, + level_count_cbs, number_of_bits_of_message_including_padding, + number_of_bits_to_extract, delta_log, number_of_inputs, + max_shared_memory); break; case 4096: host_wop_pbs>( v_stream, gpu_index, (uint64_t *)lwe_array_out, (uint64_t *)lwe_array_in, (uint64_t *)lut_vector, (double2 *)fourier_bsk, (uint64_t *)ksk, (uint64_t *)cbs_fpksk, - (uint64_t *)wop_pbs_buffer, cbs_delta_log, glwe_dimension, - lwe_dimension, polynomial_size, base_log_bsk, level_count_bsk, - base_log_ksk, level_count_ksk, base_log_pksk, level_count_pksk, - base_log_cbs, level_count_cbs, - number_of_bits_of_message_including_padding, number_of_bits_to_extract, - delta_log, number_of_inputs, max_shared_memory); + wop_pbs_buffer, cbs_delta_log, glwe_dimension, lwe_dimension, + polynomial_size, base_log_bsk, level_count_bsk, base_log_ksk, + level_count_ksk, base_log_pksk, level_count_pksk, base_log_cbs, + level_count_cbs, number_of_bits_of_message_including_padding, + number_of_bits_to_extract, delta_log, number_of_inputs, + max_shared_memory); break; case 8192: host_wop_pbs>( v_stream, gpu_index, (uint64_t *)lwe_array_out, (uint64_t *)lwe_array_in, (uint64_t *)lut_vector, (double2 *)fourier_bsk, (uint64_t *)ksk, (uint64_t *)cbs_fpksk, - (uint64_t *)wop_pbs_buffer, cbs_delta_log, glwe_dimension, - lwe_dimension, polynomial_size, base_log_bsk, level_count_bsk, - base_log_ksk, level_count_ksk, base_log_pksk, level_count_pksk, - base_log_cbs, level_count_cbs, - number_of_bits_of_message_including_padding, number_of_bits_to_extract, - delta_log, number_of_inputs, max_shared_memory); + wop_pbs_buffer, cbs_delta_log, glwe_dimension, lwe_dimension, + polynomial_size, base_log_bsk, level_count_bsk, base_log_ksk, + level_count_ksk, base_log_pksk, level_count_pksk, base_log_cbs, + level_count_cbs, number_of_bits_of_message_including_padding, + number_of_bits_to_extract, delta_log, number_of_inputs, + max_shared_memory); break; default: break; @@ -313,38 +445,20 @@ void cuda_wop_pbs_64(void *v_stream, uint32_t gpu_index, void *lwe_array_out, /* * This cleanup function frees the data for the wop PBS on GPU in wop_pbs_buffer - * for 32 bits inputs. + * for 32 or 64 bits inputs. */ -void cleanup_cuda_wop_pbs_32(void *v_stream, uint32_t gpu_index, - void **wop_pbs_buffer) { - cleanup_wop_pbs(v_stream, gpu_index, (uint32_t **)wop_pbs_buffer); -} -/* - * This cleanup function frees the data for the wop PBS on GPU in wop_pbs_buffer - * for 64 bits inputs. - */ -void cleanup_cuda_wop_pbs_64(void *v_stream, uint32_t gpu_index, - void **wop_pbs_buffer) { - cleanup_wop_pbs(v_stream, gpu_index, (uint64_t **)wop_pbs_buffer); +void cleanup_cuda_wop_pbs(void *v_stream, uint32_t gpu_index, + int8_t **wop_pbs_buffer) { + cleanup_wop_pbs(v_stream, gpu_index, wop_pbs_buffer); } /* * This cleanup function frees the data for the circuit bootstrap and vertical - * packing on GPU in cbs_vp_buffer for 32 bits inputs. + * packing on GPU in cbs_vp_buffer for 32 or 64 bits inputs. */ -void cleanup_cuda_circuit_bootstrap_vertical_packing_32(void *v_stream, - uint32_t gpu_index, - void **cbs_vp_buffer) { - cleanup_circuit_bootstrap_vertical_packing( - v_stream, gpu_index, (uint32_t **)cbs_vp_buffer); -} -/* - * This cleanup function frees the data for the circuit bootstrap and vertical - * packing on GPU in cbs_vp_buffer for 64 bits inputs. - */ -void cleanup_cuda_circuit_bootstrap_vertical_packing_64(void *v_stream, - uint32_t gpu_index, - void **cbs_vp_buffer) { - cleanup_circuit_bootstrap_vertical_packing( - v_stream, gpu_index, (uint64_t **)cbs_vp_buffer); +void cleanup_cuda_circuit_bootstrap_vertical_packing(void *v_stream, + uint32_t gpu_index, + int8_t **cbs_vp_buffer) { + cleanup_circuit_bootstrap_vertical_packing(v_stream, gpu_index, + cbs_vp_buffer); } diff --git a/src/wop_bootstrap.cuh b/src/wop_bootstrap.cuh index ba1e6ddad..23464ace5 100644 --- a/src/wop_bootstrap.cuh +++ b/src/wop_bootstrap.cuh @@ -50,12 +50,13 @@ get_buffer_size_cbs_vp(uint32_t glwe_dimension, uint32_t lwe_dimension, sizeof(Torus); // glwe_array_out } -template +template __host__ void scratch_circuit_bootstrap_vertical_packing( - void *v_stream, uint32_t gpu_index, Torus **cbs_vp_buffer, + void *v_stream, uint32_t gpu_index, int8_t **cbs_vp_buffer, uint32_t *cbs_delta_log, uint32_t glwe_dimension, uint32_t lwe_dimension, uint32_t polynomial_size, uint32_t level_count_cbs, - uint32_t number_of_inputs, uint32_t tau, bool allocate_gpu_memory) { + uint32_t number_of_inputs, uint32_t tau, uint32_t max_shared_memory, + bool allocate_gpu_memory) { cudaSetDevice(gpu_index); auto stream = static_cast(v_stream); @@ -63,20 +64,22 @@ __host__ void scratch_circuit_bootstrap_vertical_packing( // Allocate lut vector indexes on the CPU first to avoid blocking the stream Torus *h_lut_vector_indexes = (Torus *)malloc(number_of_inputs * level_count_cbs * sizeof(Torus)); + uint32_t r = number_of_inputs - params::log2_degree; // allocate and initialize device pointers for circuit bootstrap and vertical // packing if (allocate_gpu_memory) { int buffer_size = get_buffer_size_cbs_vp( glwe_dimension, lwe_dimension, polynomial_size, level_count_cbs, number_of_inputs, tau); - *cbs_vp_buffer = (Torus *)cuda_malloc_async(buffer_size, stream, gpu_index); + *cbs_vp_buffer = + (int8_t *)cuda_malloc_async(buffer_size, stream, gpu_index); } // indexes of lut vectors for cbs for (uint index = 0; index < level_count_cbs * number_of_inputs; index++) { h_lut_vector_indexes[index] = index % level_count_cbs; } // lut_vector_indexes is the first buffer in the cbs_vp_buffer - cuda_memcpy_async_to_gpu(*cbs_vp_buffer, h_lut_vector_indexes, + cuda_memcpy_async_to_gpu((Torus *)*cbs_vp_buffer, h_lut_vector_indexes, number_of_inputs * level_count_cbs * sizeof(Torus), stream, gpu_index); check_cuda_error(cudaStreamSynchronize(*stream)); @@ -98,10 +101,9 @@ __host__ void scratch_circuit_bootstrap_vertical_packing( * - lut_vector_cbs * - lut_vector_indexes */ -template __host__ void cleanup_circuit_bootstrap_vertical_packing(void *v_stream, uint32_t gpu_index, - Torus **cbs_vp_buffer) { + int8_t **cbs_vp_buffer) { auto stream = static_cast(v_stream); // Free memory @@ -115,7 +117,7 @@ template __host__ void host_circuit_bootstrap_vertical_packing( void *v_stream, uint32_t gpu_index, Torus *lwe_array_out, Torus *lwe_array_in, Torus *lut_vector, double2 *fourier_bsk, - Torus *cbs_fpksk, Torus *cbs_vp_buffer, uint32_t cbs_delta_log, + Torus *cbs_fpksk, int8_t *cbs_vp_buffer, uint32_t cbs_delta_log, uint32_t glwe_dimension, uint32_t lwe_dimension, uint32_t polynomial_size, uint32_t base_log_bsk, uint32_t level_count_bsk, uint32_t base_log_pksk, uint32_t level_count_pksk, uint32_t base_log_cbs, uint32_t level_count_cbs, @@ -192,14 +194,15 @@ get_buffer_size_wop_pbs(uint32_t glwe_dimension, uint32_t lwe_dimension, (number_of_bits_of_message_including_padding) * sizeof(Torus); } -template +template __host__ void -scratch_wop_pbs(void *v_stream, uint32_t gpu_index, Torus **wop_pbs_buffer, +scratch_wop_pbs(void *v_stream, uint32_t gpu_index, int8_t **wop_pbs_buffer, uint32_t *delta_log, uint32_t *cbs_delta_log, uint32_t glwe_dimension, uint32_t lwe_dimension, uint32_t polynomial_size, uint32_t level_count_cbs, uint32_t number_of_bits_of_message_including_padding, - uint32_t number_of_bits_to_extract, uint32_t number_of_inputs) { + uint32_t number_of_bits_to_extract, uint32_t number_of_inputs, + uint32_t max_shared_memory) { cudaSetDevice(gpu_index); auto stream = static_cast(v_stream); @@ -208,33 +211,33 @@ scratch_wop_pbs(void *v_stream, uint32_t gpu_index, Torus **wop_pbs_buffer, glwe_dimension, lwe_dimension, polynomial_size, level_count_cbs, number_of_bits_of_message_including_padding, number_of_bits_to_extract, number_of_inputs); - int buffer_size = - get_buffer_size_cbs_vp( - glwe_dimension, lwe_dimension, polynomial_size, level_count_cbs, - number_of_inputs * number_of_bits_to_extract, number_of_inputs) + - wop_pbs_buffer_size; + uint32_t cbs_vp_number_of_inputs = + number_of_inputs * number_of_bits_to_extract; + uint32_t tau = number_of_inputs; + uint32_t r = cbs_vp_number_of_inputs - params::log2_degree; + int buffer_size = get_buffer_size_cbs_vp( + glwe_dimension, lwe_dimension, polynomial_size, + level_count_cbs, cbs_vp_number_of_inputs, tau) + + wop_pbs_buffer_size; - *wop_pbs_buffer = (Torus *)cuda_malloc_async(buffer_size, stream, gpu_index); + *wop_pbs_buffer = (int8_t *)cuda_malloc_async(buffer_size, stream, gpu_index); // indexes of lut vectors for bit extract Torus h_lut_vector_indexes = 0; // lut_vector_indexes is the first array in the wop_pbs buffer - cuda_memcpy_async_to_gpu(*wop_pbs_buffer, &h_lut_vector_indexes, + cuda_memcpy_async_to_gpu(*wop_pbs_buffer, (int8_t *)&h_lut_vector_indexes, sizeof(Torus), stream, gpu_index); check_cuda_error(cudaGetLastError()); uint32_t ciphertext_total_bits_count = sizeof(Torus) * 8; *delta_log = ciphertext_total_bits_count - number_of_bits_of_message_including_padding; - Torus *cbs_vp_buffer = - *wop_pbs_buffer + - (ptrdiff_t)( - 1 + ((glwe_dimension + 1) * polynomial_size) + (polynomial_size + 1) + - (polynomial_size + 1) + (lwe_dimension + 1) + (polynomial_size + 1) + - (lwe_dimension + 1) * (number_of_bits_of_message_including_padding)); - scratch_circuit_bootstrap_vertical_packing( + int8_t *cbs_vp_buffer = + (int8_t *)*wop_pbs_buffer + (ptrdiff_t)wop_pbs_buffer_size; + scratch_circuit_bootstrap_vertical_packing( v_stream, gpu_index, &cbs_vp_buffer, cbs_delta_log, glwe_dimension, lwe_dimension, polynomial_size, level_count_cbs, - number_of_inputs * number_of_bits_to_extract, number_of_inputs, false); + number_of_inputs * number_of_bits_to_extract, number_of_inputs, + max_shared_memory, false); } /* @@ -242,9 +245,8 @@ scratch_wop_pbs(void *v_stream, uint32_t gpu_index, Torus **wop_pbs_buffer, * Data that lives on the CPU is prefixed with `h_`. This cleanup function thus * frees the data for the wop PBS on GPU in wop_pbs_buffer */ -template __host__ void cleanup_wop_pbs(void *v_stream, uint32_t gpu_index, - Torus **wop_pbs_buffer) { + int8_t **wop_pbs_buffer) { auto stream = static_cast(v_stream); cuda_drop_async(*wop_pbs_buffer, stream, gpu_index); } @@ -253,7 +255,7 @@ template __host__ void host_wop_pbs( void *v_stream, uint32_t gpu_index, Torus *lwe_array_out, Torus *lwe_array_in, Torus *lut_vector, double2 *fourier_bsk, Torus *ksk, - Torus *cbs_fpksk, Torus *wop_pbs_buffer, uint32_t cbs_delta_log, + Torus *cbs_fpksk, int8_t *wop_pbs_buffer, uint32_t cbs_delta_log, uint32_t glwe_dimension, uint32_t lwe_dimension, uint32_t polynomial_size, uint32_t base_log_bsk, uint32_t level_count_bsk, uint32_t base_log_ksk, uint32_t level_count_ksk, uint32_t base_log_pksk, uint32_t level_count_pksk, @@ -284,12 +286,12 @@ __host__ void host_wop_pbs( level_count_ksk, number_of_inputs, max_shared_memory); check_cuda_error(cudaGetLastError()); - Torus *cbs_vp_buffer = - (Torus *)wop_pbs_buffer + - (ptrdiff_t)( - 1 + ((glwe_dimension + 1) * polynomial_size) + (polynomial_size + 1) + - (polynomial_size + 1) + (lwe_dimension + 1) + (polynomial_size + 1) + - (lwe_dimension + 1) * number_of_bits_of_message_including_padding); + int8_t *cbs_vp_buffer = + (int8_t *)wop_pbs_buffer + + (ptrdiff_t)get_buffer_size_wop_pbs( + glwe_dimension, lwe_dimension, polynomial_size, level_count_cbs, + number_of_bits_of_message_including_padding, + number_of_bits_to_extract, number_of_inputs); host_circuit_bootstrap_vertical_packing( v_stream, gpu_index, lwe_array_out, lwe_array_out_bit_extract, lut_vector, fourier_bsk, cbs_fpksk, cbs_vp_buffer, cbs_delta_log, glwe_dimension,