mirror of
https://github.com/zama-ai/concrete.git
synced 2026-02-08 19:44:57 -05:00
feat(cuda): Refactor the amortized PBS to use asynchronous allocation.
This commit is contained in:
@@ -10,6 +10,7 @@
|
||||
#include "cooperative_groups.h"
|
||||
|
||||
#include "../include/helper_cuda.h"
|
||||
#include "device.h"
|
||||
#include "bootstrap.h"
|
||||
#include "complex/operations.cuh"
|
||||
#include "crypto/gadget.cuh"
|
||||
@@ -337,8 +338,7 @@ __host__ void host_bootstrap_amortized(
|
||||
// from one of three templates (no use, partial use or full use
|
||||
// of shared memory)
|
||||
if (max_shared_memory < SM_PART) {
|
||||
checkCudaErrors(
|
||||
cudaMalloc((void **)&d_mem, DM_FULL * input_lwe_ciphertext_count));
|
||||
d_mem = (char*) cuda_malloc_async(DM_FULL * input_lwe_ciphertext_count, v_stream);
|
||||
device_bootstrap_amortized<Torus, params, NOSM><<<grid, thds, 0, *stream>>>(
|
||||
lwe_array_out, lut_vector, lut_vector_indexes, lwe_array_in,
|
||||
bootstrapping_key, d_mem, input_lwe_dimension, polynomial_size,
|
||||
@@ -348,8 +348,7 @@ __host__ void host_bootstrap_amortized(
|
||||
cudaFuncAttributeMaxDynamicSharedMemorySize, SM_PART);
|
||||
cudaFuncSetCacheConfig(device_bootstrap_amortized<Torus, params, PARTIALSM>,
|
||||
cudaFuncCachePreferShared);
|
||||
checkCudaErrors(
|
||||
cudaMalloc((void **)&d_mem, DM_PART * input_lwe_ciphertext_count));
|
||||
d_mem = (char*) cuda_malloc_async(DM_PART * input_lwe_ciphertext_count, v_stream);
|
||||
device_bootstrap_amortized<Torus, params, PARTIALSM>
|
||||
<<<grid, thds, SM_PART, *stream>>>(
|
||||
lwe_array_out, lut_vector, lut_vector_indexes, lwe_array_in,
|
||||
@@ -367,7 +366,7 @@ __host__ void host_bootstrap_amortized(
|
||||
checkCudaErrors(cudaFuncSetCacheConfig(
|
||||
device_bootstrap_amortized<Torus, params, FULLSM>,
|
||||
cudaFuncCachePreferShared));
|
||||
checkCudaErrors(cudaMalloc((void **)&d_mem, 0));
|
||||
d_mem = (char*) cuda_malloc_async(0, v_stream);
|
||||
|
||||
device_bootstrap_amortized<Torus, params, FULLSM>
|
||||
<<<grid, thds, SM_FULL, *stream>>>(
|
||||
@@ -378,7 +377,7 @@ __host__ void host_bootstrap_amortized(
|
||||
// Synchronize the streams before copying the result to lwe_array_out at the
|
||||
// right place
|
||||
cudaStreamSynchronize(*stream);
|
||||
cudaFree(d_mem);
|
||||
cuda_drop_async(d_mem, v_stream);
|
||||
}
|
||||
|
||||
template <typename Torus, class params>
|
||||
|
||||
Reference in New Issue
Block a user