From 0b58741fd457abfd08698ba54a7905ae374786c4 Mon Sep 17 00:00:00 2001 From: Pedro Alves Date: Mon, 7 Nov 2022 12:43:09 -0300 Subject: [PATCH] feat(cuda): Refactor the amortized PBS to use asynchronous allocation. --- src/bootstrap_amortized.cuh | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/src/bootstrap_amortized.cuh b/src/bootstrap_amortized.cuh index 51d9ff9ff..f67f44d12 100644 --- a/src/bootstrap_amortized.cuh +++ b/src/bootstrap_amortized.cuh @@ -10,6 +10,7 @@ #include "cooperative_groups.h" #include "../include/helper_cuda.h" +#include "device.h" #include "bootstrap.h" #include "complex/operations.cuh" #include "crypto/gadget.cuh" @@ -337,8 +338,7 @@ __host__ void host_bootstrap_amortized( // from one of three templates (no use, partial use or full use // of shared memory) if (max_shared_memory < SM_PART) { - checkCudaErrors( - cudaMalloc((void **)&d_mem, DM_FULL * input_lwe_ciphertext_count)); + d_mem = (char*) cuda_malloc_async(DM_FULL * input_lwe_ciphertext_count, v_stream); device_bootstrap_amortized<<>>( lwe_array_out, lut_vector, lut_vector_indexes, lwe_array_in, bootstrapping_key, d_mem, input_lwe_dimension, polynomial_size, @@ -348,8 +348,7 @@ __host__ void host_bootstrap_amortized( cudaFuncAttributeMaxDynamicSharedMemorySize, SM_PART); cudaFuncSetCacheConfig(device_bootstrap_amortized, cudaFuncCachePreferShared); - checkCudaErrors( - cudaMalloc((void **)&d_mem, DM_PART * input_lwe_ciphertext_count)); + d_mem = (char*) cuda_malloc_async(DM_PART * input_lwe_ciphertext_count, v_stream); device_bootstrap_amortized <<>>( lwe_array_out, lut_vector, lut_vector_indexes, lwe_array_in, @@ -367,7 +366,7 @@ __host__ void host_bootstrap_amortized( checkCudaErrors(cudaFuncSetCacheConfig( device_bootstrap_amortized, cudaFuncCachePreferShared)); - checkCudaErrors(cudaMalloc((void **)&d_mem, 0)); + d_mem = (char*) cuda_malloc_async(0, v_stream); device_bootstrap_amortized <<>>( @@ -378,7 +377,7 @@ __host__ void host_bootstrap_amortized( // Synchronize the streams before copying the result to lwe_array_out at the // right place cudaStreamSynchronize(*stream); - cudaFree(d_mem); + cuda_drop_async(d_mem, v_stream); } template