From 1a76cadaa8038434a0b73d094e976fc0fd6e445d Mon Sep 17 00:00:00 2001 From: Pedro Alves Date: Sat, 24 Sep 2022 12:10:37 -0300 Subject: [PATCH] feat(cuda): Implement Stream-Ordered Memory Allocator for CUDA's CMUX Tree --- src/bootstrap_wop.cuh | 35 ++++++++++++++++++++++++++++++----- src/crypto/ggsw.cuh | 17 +++++++++++------ 2 files changed, 41 insertions(+), 11 deletions(-) diff --git a/src/bootstrap_wop.cuh b/src/bootstrap_wop.cuh index 11fbe31c2..a1aff02f8 100644 --- a/src/bootstrap_wop.cuh +++ b/src/bootstrap_wop.cuh @@ -321,17 +321,26 @@ void host_cmux_tree( // std::cout << "Applying the FFT on m^tree" << std::endl; double2 *d_ggsw_fft_in; int ggsw_size = r * polynomial_size * (glwe_dimension + 1) * (glwe_dimension + 1) * l_gadget; + + #if (CUDART_VERSION < 11020) checkCudaErrors(cudaMalloc((void **)&d_ggsw_fft_in, ggsw_size * sizeof(double))); + #else + checkCudaErrors(cudaMallocAsync((void **)&d_ggsw_fft_in, ggsw_size * sizeof(double), *stream)); + #endif batch_fft_ggsw_vector( - d_ggsw_fft_in, ggsw_in, r, glwe_dimension, polynomial_size, l_gadget); + v_stream, d_ggsw_fft_in, ggsw_in, r, glwe_dimension, polynomial_size, l_gadget); ////////////////////// // Allocate global memory in case parameters are too large char *d_mem; if (max_shared_memory < memory_needed_per_block) { + #if (CUDART_VERSION < 11020) checkCudaErrors(cudaMalloc((void **) &d_mem, memory_needed_per_block * (1 << (r - 1)))); + #else + checkCudaErrors(cudaMallocAsync((void **) &d_mem, memory_needed_per_block * (1 << (r - 1)), *stream)); + #endif }else{ checkCudaErrors(cudaFuncSetAttribute( device_batch_cmux, @@ -345,8 +354,14 @@ void host_cmux_tree( // Allocate buffers int glwe_size = (glwe_dimension + 1) * polynomial_size; Torus *d_buffer1, *d_buffer2; + + #if (CUDART_VERSION < 11020) checkCudaErrors(cudaMalloc((void **)&d_buffer1, num_lut * glwe_size * sizeof(Torus))); checkCudaErrors(cudaMalloc((void **)&d_buffer2, num_lut * glwe_size * sizeof(Torus))); + #else + checkCudaErrors(cudaMallocAsync((void **)&d_buffer1, num_lut * glwe_size * sizeof(Torus), *stream)); + checkCudaErrors(cudaMallocAsync((void **)&d_buffer2, num_lut * glwe_size * sizeof(Torus), *stream)); + #endif checkCudaErrors(cudaMemcpyAsync( d_buffer1, lut_vector, num_lut * glwe_size * sizeof(Torus), @@ -383,20 +398,30 @@ void host_cmux_tree( } - checkCudaErrors(cudaStreamSynchronize(*stream)); - checkCudaErrors(cudaMemcpy( + checkCudaErrors(cudaMemcpyAsync( glwe_out, output, (glwe_dimension+1) * polynomial_size * sizeof(Torus), - cudaMemcpyDeviceToDevice)); + cudaMemcpyDeviceToDevice, *stream)); - checkCudaErrors(cudaDeviceSynchronize()); + // We only need synchronization to assert that data is in glwe_out before + // returning. Memory release can be added to the stream and processed + // later. + checkCudaErrors(cudaStreamSynchronize(*stream)); // Free memory + #if (CUDART_VERSION < 11020) checkCudaErrors(cudaFree(d_ggsw_fft_in)); checkCudaErrors(cudaFree(d_buffer1)); checkCudaErrors(cudaFree(d_buffer2)); if(max_shared_memory < memory_needed_per_block) checkCudaErrors(cudaFree(d_mem)); + #else + checkCudaErrors(cudaFreeAsync(d_ggsw_fft_in, *stream)); + checkCudaErrors(cudaFreeAsync(d_buffer1, *stream)); + checkCudaErrors(cudaFreeAsync(d_buffer2, *stream)); + if(max_shared_memory < memory_needed_per_block) + checkCudaErrors(cudaFreeAsync(d_mem, *stream)); + #endif } diff --git a/src/crypto/ggsw.cuh b/src/crypto/ggsw.cuh index ec39738f1..ca7305bfc 100644 --- a/src/crypto/ggsw.cuh +++ b/src/crypto/ggsw.cuh @@ -40,11 +40,15 @@ __global__ void batch_fft_ggsw_vectors(double2 *dest, T *src){ * Applies the FFT transform on sequence of GGSW ciphertexts already in the global memory */ template -void batch_fft_ggsw_vector(double2 *dest, T *src, - uint32_t r, - uint32_t glwe_dim, - uint32_t polynomial_size, - uint32_t l_gadget) { +void batch_fft_ggsw_vector( + void *v_stream, + double2 *dest, T *src, + uint32_t r, + uint32_t glwe_dim, + uint32_t polynomial_size, + uint32_t l_gadget) { + + auto stream = static_cast(v_stream); int shared_memory_size = sizeof(double) * polynomial_size; @@ -52,7 +56,8 @@ void batch_fft_ggsw_vector(double2 *dest, T *src, int gridSize = total_polynomials; int blockSize = polynomial_size / params::opt; - batch_fft_ggsw_vectors<<>>(dest, src); + batch_fft_ggsw_vectors<<>>(dest, + src); checkCudaErrors(cudaGetLastError()); }