From 739db73d468e9da2fd85deefdeca174f1a6a66a6 Mon Sep 17 00:00:00 2001 From: Pedro Alves Date: Thu, 24 Nov 2022 13:31:45 +0100 Subject: [PATCH] feat(cuda): batch_fft_ggsw_vector uses global memory in case there is not enough space in the shared memory --- src/bootstrap_wop.cuh | 12 ++++++------ src/crypto/ggsw.cuh | 41 ++++++++++++++++++++++++++++------------- 2 files changed, 34 insertions(+), 19 deletions(-) diff --git a/src/bootstrap_wop.cuh b/src/bootstrap_wop.cuh index e3c791592..17d0ecfec 100644 --- a/src/bootstrap_wop.cuh +++ b/src/bootstrap_wop.cuh @@ -275,9 +275,9 @@ void host_cmux_tree(void *v_stream, Torus *glwe_array_out, Torus *ggsw_in, double2 *d_ggsw_fft_in = (double2 *)cuda_malloc_async( r * ggsw_size * sizeof(double), *stream, gpu_index); - batch_fft_ggsw_vector(v_stream, d_ggsw_fft_in, ggsw_in, - r, glwe_dimension, - polynomial_size, level_count); + batch_fft_ggsw_vector( + v_stream, d_ggsw_fft_in, ggsw_in, r, glwe_dimension, polynomial_size, + level_count, gpu_index, max_shared_memory); ////////////////////// @@ -653,9 +653,9 @@ void host_blind_rotate_and_sample_extraction( double2 *d_ggsw_fft_in = (double2 *)cuda_malloc_async( mbr_size * ggsw_size * sizeof(double), *stream, gpu_index); - batch_fft_ggsw_vector(v_stream, d_ggsw_fft_in, ggsw_in, - mbr_size, glwe_dimension, - polynomial_size, l_gadget); + batch_fft_ggsw_vector( + v_stream, d_ggsw_fft_in, ggsw_in, mbr_size, glwe_dimension, + polynomial_size, l_gadget, gpu_index, max_shared_memory); checkCudaErrors(cudaGetLastError()); // diff --git a/src/crypto/ggsw.cuh b/src/crypto/ggsw.cuh index 63a9ace05..4f5e70825 100644 --- a/src/crypto/ggsw.cuh +++ b/src/crypto/ggsw.cuh @@ -1,12 +1,17 @@ #ifndef CONCRETE_CORE_GGSW_CUH #define CONCRETE_CORE_GGSW_CUH -template -__global__ void batch_fft_ggsw_vectors(double2 *dest, T *src) { +template +__global__ void device_batch_fft_ggsw_vector(double2 *dest, T *src, + char *device_mem) { extern __shared__ char sharedmem[]; + double2 *selected_memory; - double2 *shared_output = (double2 *)sharedmem; + if constexpr (SMD == FULLSM) + selected_memory = (double2 *)sharedmem; + else + selected_memory = (double2 *)device_mem[blockIdx.x * params::degree]; // Compression int offset = blockIdx.x * blockDim.x; @@ -17,24 +22,24 @@ __global__ void batch_fft_ggsw_vectors(double2 *dest, T *src) { for (int i = 0; i < log_2_opt; i++) { ST x = src[(2 * tid) + params::opt * offset]; ST y = src[(2 * tid + 1) + params::opt * offset]; - shared_output[tid].x = x / (double)std::numeric_limits::max(); - shared_output[tid].y = y / (double)std::numeric_limits::max(); + selected_memory[tid].x = x / (double)std::numeric_limits::max(); + selected_memory[tid].y = y / (double)std::numeric_limits::max(); tid += params::degree / params::opt; } synchronize_threads_in_block(); // Switch to the FFT space - NSMFFT_direct>(shared_output); + NSMFFT_direct>(selected_memory); synchronize_threads_in_block(); - correction_direct_fft_inplace(shared_output); + correction_direct_fft_inplace(selected_memory); synchronize_threads_in_block(); // Write the output to global memory tid = threadIdx.x; #pragma unroll for (int j = 0; j < log_2_opt; j++) { - dest[tid + (params::opt >> 1) * offset] = shared_output[tid]; + dest[tid + (params::opt >> 1) * offset] = selected_memory[tid]; tid += params::degree / params::opt; } } @@ -46,19 +51,29 @@ __global__ void batch_fft_ggsw_vectors(double2 *dest, T *src) { template void batch_fft_ggsw_vector(void *v_stream, double2 *dest, T *src, uint32_t r, uint32_t glwe_dim, uint32_t polynomial_size, - uint32_t level_count) { + uint32_t level_count, uint32_t gpu_index, + uint32_t max_shared_memory) { auto stream = static_cast(v_stream); int shared_memory_size = sizeof(double) * polynomial_size; int gridSize = r * (glwe_dim + 1) * (glwe_dim + 1) * level_count; - ; int blockSize = polynomial_size / params::opt; - batch_fft_ggsw_vectors - <<>>(dest, src); - checkCudaErrors(cudaGetLastError()); + char *d_mem; + if (max_shared_memory < shared_memory_size) { + d_mem = (char *)cuda_malloc_async(shared_memory_size, *stream, gpu_index); + device_batch_fft_ggsw_vector + <<>>(dest, src, d_mem); + checkCudaErrors(cudaGetLastError()); + cuda_drop_async(d_mem, *stream, gpu_index); + } else { + device_batch_fft_ggsw_vector + <<>>(dest, src, + d_mem); + checkCudaErrors(cudaGetLastError()); + } } #endif // CONCRETE_CORE_GGSW_CUH