mirror of
https://github.com/zama-ai/concrete.git
synced 2026-02-08 19:44:57 -05:00
feat(cuda): Implement Stream-Ordered Memory Allocator for CUDA's CMUX Tree
This commit is contained in:
@@ -321,17 +321,26 @@ void host_cmux_tree(
|
||||
// std::cout << "Applying the FFT on m^tree" << std::endl;
|
||||
double2 *d_ggsw_fft_in;
|
||||
int ggsw_size = r * polynomial_size * (glwe_dimension + 1) * (glwe_dimension + 1) * l_gadget;
|
||||
|
||||
#if (CUDART_VERSION < 11020)
|
||||
checkCudaErrors(cudaMalloc((void **)&d_ggsw_fft_in, ggsw_size * sizeof(double)));
|
||||
#else
|
||||
checkCudaErrors(cudaMallocAsync((void **)&d_ggsw_fft_in, ggsw_size * sizeof(double), *stream));
|
||||
#endif
|
||||
|
||||
batch_fft_ggsw_vector<Torus, STorus, params>(
|
||||
d_ggsw_fft_in, ggsw_in, r, glwe_dimension, polynomial_size, l_gadget);
|
||||
v_stream, d_ggsw_fft_in, ggsw_in, r, glwe_dimension, polynomial_size, l_gadget);
|
||||
|
||||
//////////////////////
|
||||
|
||||
// Allocate global memory in case parameters are too large
|
||||
char *d_mem;
|
||||
if (max_shared_memory < memory_needed_per_block) {
|
||||
#if (CUDART_VERSION < 11020)
|
||||
checkCudaErrors(cudaMalloc((void **) &d_mem, memory_needed_per_block * (1 << (r - 1))));
|
||||
#else
|
||||
checkCudaErrors(cudaMallocAsync((void **) &d_mem, memory_needed_per_block * (1 << (r - 1)), *stream));
|
||||
#endif
|
||||
}else{
|
||||
checkCudaErrors(cudaFuncSetAttribute(
|
||||
device_batch_cmux<Torus, STorus, params, FULLSM>,
|
||||
@@ -345,8 +354,14 @@ void host_cmux_tree(
|
||||
// Allocate buffers
|
||||
int glwe_size = (glwe_dimension + 1) * polynomial_size;
|
||||
Torus *d_buffer1, *d_buffer2;
|
||||
|
||||
#if (CUDART_VERSION < 11020)
|
||||
checkCudaErrors(cudaMalloc((void **)&d_buffer1, num_lut * glwe_size * sizeof(Torus)));
|
||||
checkCudaErrors(cudaMalloc((void **)&d_buffer2, num_lut * glwe_size * sizeof(Torus)));
|
||||
#else
|
||||
checkCudaErrors(cudaMallocAsync((void **)&d_buffer1, num_lut * glwe_size * sizeof(Torus), *stream));
|
||||
checkCudaErrors(cudaMallocAsync((void **)&d_buffer2, num_lut * glwe_size * sizeof(Torus), *stream));
|
||||
#endif
|
||||
checkCudaErrors(cudaMemcpyAsync(
|
||||
d_buffer1, lut_vector,
|
||||
num_lut * glwe_size * sizeof(Torus),
|
||||
@@ -383,20 +398,30 @@ void host_cmux_tree(
|
||||
|
||||
}
|
||||
|
||||
checkCudaErrors(cudaStreamSynchronize(*stream));
|
||||
checkCudaErrors(cudaMemcpy(
|
||||
checkCudaErrors(cudaMemcpyAsync(
|
||||
glwe_out, output,
|
||||
(glwe_dimension+1) * polynomial_size * sizeof(Torus),
|
||||
cudaMemcpyDeviceToDevice));
|
||||
cudaMemcpyDeviceToDevice, *stream));
|
||||
|
||||
checkCudaErrors(cudaDeviceSynchronize());
|
||||
// We only need synchronization to assert that data is in glwe_out before
|
||||
// returning. Memory release can be added to the stream and processed
|
||||
// later.
|
||||
checkCudaErrors(cudaStreamSynchronize(*stream));
|
||||
|
||||
// Free memory
|
||||
#if (CUDART_VERSION < 11020)
|
||||
checkCudaErrors(cudaFree(d_ggsw_fft_in));
|
||||
checkCudaErrors(cudaFree(d_buffer1));
|
||||
checkCudaErrors(cudaFree(d_buffer2));
|
||||
if(max_shared_memory < memory_needed_per_block)
|
||||
checkCudaErrors(cudaFree(d_mem));
|
||||
#else
|
||||
checkCudaErrors(cudaFreeAsync(d_ggsw_fft_in, *stream));
|
||||
checkCudaErrors(cudaFreeAsync(d_buffer1, *stream));
|
||||
checkCudaErrors(cudaFreeAsync(d_buffer2, *stream));
|
||||
if(max_shared_memory < memory_needed_per_block)
|
||||
checkCudaErrors(cudaFreeAsync(d_mem, *stream));
|
||||
#endif
|
||||
|
||||
}
|
||||
|
||||
|
||||
@@ -40,11 +40,15 @@ __global__ void batch_fft_ggsw_vectors(double2 *dest, T *src){
|
||||
* Applies the FFT transform on sequence of GGSW ciphertexts already in the global memory
|
||||
*/
|
||||
template <typename T, typename ST, class params>
|
||||
void batch_fft_ggsw_vector(double2 *dest, T *src,
|
||||
uint32_t r,
|
||||
uint32_t glwe_dim,
|
||||
uint32_t polynomial_size,
|
||||
uint32_t l_gadget) {
|
||||
void batch_fft_ggsw_vector(
|
||||
void *v_stream,
|
||||
double2 *dest, T *src,
|
||||
uint32_t r,
|
||||
uint32_t glwe_dim,
|
||||
uint32_t polynomial_size,
|
||||
uint32_t l_gadget) {
|
||||
|
||||
auto stream = static_cast<cudaStream_t *>(v_stream);
|
||||
|
||||
int shared_memory_size = sizeof(double) * polynomial_size;
|
||||
|
||||
@@ -52,7 +56,8 @@ void batch_fft_ggsw_vector(double2 *dest, T *src,
|
||||
int gridSize = total_polynomials;
|
||||
int blockSize = polynomial_size / params::opt;
|
||||
|
||||
batch_fft_ggsw_vectors<T, ST, params><<<gridSize, blockSize, shared_memory_size>>>(dest, src);
|
||||
batch_fft_ggsw_vectors<T, ST, params><<<gridSize, blockSize, shared_memory_size, *stream>>>(dest,
|
||||
src);
|
||||
checkCudaErrors(cudaGetLastError());
|
||||
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user