feat(cuda): Implement Stream-Ordered Memory Allocator for CUDA's CMUX Tree

This commit is contained in:
Pedro Alves
2022-09-24 12:10:37 -03:00
committed by Agnès Leroy
parent 3d6524ccf3
commit 1a76cadaa8
2 changed files with 41 additions and 11 deletions

View File

@@ -321,17 +321,26 @@ void host_cmux_tree(
// std::cout << "Applying the FFT on m^tree" << std::endl;
double2 *d_ggsw_fft_in;
int ggsw_size = r * polynomial_size * (glwe_dimension + 1) * (glwe_dimension + 1) * l_gadget;
#if (CUDART_VERSION < 11020)
checkCudaErrors(cudaMalloc((void **)&d_ggsw_fft_in, ggsw_size * sizeof(double)));
#else
checkCudaErrors(cudaMallocAsync((void **)&d_ggsw_fft_in, ggsw_size * sizeof(double), *stream));
#endif
batch_fft_ggsw_vector<Torus, STorus, params>(
d_ggsw_fft_in, ggsw_in, r, glwe_dimension, polynomial_size, l_gadget);
v_stream, d_ggsw_fft_in, ggsw_in, r, glwe_dimension, polynomial_size, l_gadget);
//////////////////////
// Allocate global memory in case parameters are too large
char *d_mem;
if (max_shared_memory < memory_needed_per_block) {
#if (CUDART_VERSION < 11020)
checkCudaErrors(cudaMalloc((void **) &d_mem, memory_needed_per_block * (1 << (r - 1))));
#else
checkCudaErrors(cudaMallocAsync((void **) &d_mem, memory_needed_per_block * (1 << (r - 1)), *stream));
#endif
}else{
checkCudaErrors(cudaFuncSetAttribute(
device_batch_cmux<Torus, STorus, params, FULLSM>,
@@ -345,8 +354,14 @@ void host_cmux_tree(
// Allocate buffers
int glwe_size = (glwe_dimension + 1) * polynomial_size;
Torus *d_buffer1, *d_buffer2;
#if (CUDART_VERSION < 11020)
checkCudaErrors(cudaMalloc((void **)&d_buffer1, num_lut * glwe_size * sizeof(Torus)));
checkCudaErrors(cudaMalloc((void **)&d_buffer2, num_lut * glwe_size * sizeof(Torus)));
#else
checkCudaErrors(cudaMallocAsync((void **)&d_buffer1, num_lut * glwe_size * sizeof(Torus), *stream));
checkCudaErrors(cudaMallocAsync((void **)&d_buffer2, num_lut * glwe_size * sizeof(Torus), *stream));
#endif
checkCudaErrors(cudaMemcpyAsync(
d_buffer1, lut_vector,
num_lut * glwe_size * sizeof(Torus),
@@ -383,20 +398,30 @@ void host_cmux_tree(
}
checkCudaErrors(cudaStreamSynchronize(*stream));
checkCudaErrors(cudaMemcpy(
checkCudaErrors(cudaMemcpyAsync(
glwe_out, output,
(glwe_dimension+1) * polynomial_size * sizeof(Torus),
cudaMemcpyDeviceToDevice));
cudaMemcpyDeviceToDevice, *stream));
checkCudaErrors(cudaDeviceSynchronize());
// We only need synchronization to assert that data is in glwe_out before
// returning. Memory release can be added to the stream and processed
// later.
checkCudaErrors(cudaStreamSynchronize(*stream));
// Free memory
#if (CUDART_VERSION < 11020)
checkCudaErrors(cudaFree(d_ggsw_fft_in));
checkCudaErrors(cudaFree(d_buffer1));
checkCudaErrors(cudaFree(d_buffer2));
if(max_shared_memory < memory_needed_per_block)
checkCudaErrors(cudaFree(d_mem));
#else
checkCudaErrors(cudaFreeAsync(d_ggsw_fft_in, *stream));
checkCudaErrors(cudaFreeAsync(d_buffer1, *stream));
checkCudaErrors(cudaFreeAsync(d_buffer2, *stream));
if(max_shared_memory < memory_needed_per_block)
checkCudaErrors(cudaFreeAsync(d_mem, *stream));
#endif
}

View File

@@ -40,11 +40,15 @@ __global__ void batch_fft_ggsw_vectors(double2 *dest, T *src){
* Applies the FFT transform on sequence of GGSW ciphertexts already in the global memory
*/
template <typename T, typename ST, class params>
void batch_fft_ggsw_vector(double2 *dest, T *src,
uint32_t r,
uint32_t glwe_dim,
uint32_t polynomial_size,
uint32_t l_gadget) {
void batch_fft_ggsw_vector(
void *v_stream,
double2 *dest, T *src,
uint32_t r,
uint32_t glwe_dim,
uint32_t polynomial_size,
uint32_t l_gadget) {
auto stream = static_cast<cudaStream_t *>(v_stream);
int shared_memory_size = sizeof(double) * polynomial_size;
@@ -52,7 +56,8 @@ void batch_fft_ggsw_vector(double2 *dest, T *src,
int gridSize = total_polynomials;
int blockSize = polynomial_size / params::opt;
batch_fft_ggsw_vectors<T, ST, params><<<gridSize, blockSize, shared_memory_size>>>(dest, src);
batch_fft_ggsw_vectors<T, ST, params><<<gridSize, blockSize, shared_memory_size, *stream>>>(dest,
src);
checkCudaErrors(cudaGetLastError());
}