feat(cuda): batch_fft_ggsw_vector uses global memory in case there is not enough space in the shared memory

This commit is contained in:
Pedro Alves
2022-11-24 13:31:45 +01:00
committed by Agnès Leroy
parent 56b986da8b
commit 739db73d46
2 changed files with 34 additions and 19 deletions

View File

@@ -275,9 +275,9 @@ void host_cmux_tree(void *v_stream, Torus *glwe_array_out, Torus *ggsw_in,
double2 *d_ggsw_fft_in = (double2 *)cuda_malloc_async(
r * ggsw_size * sizeof(double), *stream, gpu_index);
batch_fft_ggsw_vector<Torus, STorus, params>(v_stream, d_ggsw_fft_in, ggsw_in,
r, glwe_dimension,
polynomial_size, level_count);
batch_fft_ggsw_vector<Torus, STorus, params>(
v_stream, d_ggsw_fft_in, ggsw_in, r, glwe_dimension, polynomial_size,
level_count, gpu_index, max_shared_memory);
//////////////////////
@@ -653,9 +653,9 @@ void host_blind_rotate_and_sample_extraction(
double2 *d_ggsw_fft_in = (double2 *)cuda_malloc_async(
mbr_size * ggsw_size * sizeof(double), *stream, gpu_index);
batch_fft_ggsw_vector<Torus, STorus, params>(v_stream, d_ggsw_fft_in, ggsw_in,
mbr_size, glwe_dimension,
polynomial_size, l_gadget);
batch_fft_ggsw_vector<Torus, STorus, params>(
v_stream, d_ggsw_fft_in, ggsw_in, mbr_size, glwe_dimension,
polynomial_size, l_gadget, gpu_index, max_shared_memory);
checkCudaErrors(cudaGetLastError());
//

View File

@@ -1,12 +1,17 @@
#ifndef CONCRETE_CORE_GGSW_CUH
#define CONCRETE_CORE_GGSW_CUH
template <typename T, typename ST, class params>
__global__ void batch_fft_ggsw_vectors(double2 *dest, T *src) {
template <typename T, typename ST, class params, sharedMemDegree SMD>
__global__ void device_batch_fft_ggsw_vector(double2 *dest, T *src,
char *device_mem) {
extern __shared__ char sharedmem[];
double2 *selected_memory;
double2 *shared_output = (double2 *)sharedmem;
if constexpr (SMD == FULLSM)
selected_memory = (double2 *)sharedmem;
else
selected_memory = (double2 *)device_mem[blockIdx.x * params::degree];
// Compression
int offset = blockIdx.x * blockDim.x;
@@ -17,24 +22,24 @@ __global__ void batch_fft_ggsw_vectors(double2 *dest, T *src) {
for (int i = 0; i < log_2_opt; i++) {
ST x = src[(2 * tid) + params::opt * offset];
ST y = src[(2 * tid + 1) + params::opt * offset];
shared_output[tid].x = x / (double)std::numeric_limits<T>::max();
shared_output[tid].y = y / (double)std::numeric_limits<T>::max();
selected_memory[tid].x = x / (double)std::numeric_limits<T>::max();
selected_memory[tid].y = y / (double)std::numeric_limits<T>::max();
tid += params::degree / params::opt;
}
synchronize_threads_in_block();
// Switch to the FFT space
NSMFFT_direct<HalfDegree<params>>(shared_output);
NSMFFT_direct<HalfDegree<params>>(selected_memory);
synchronize_threads_in_block();
correction_direct_fft_inplace<params>(shared_output);
correction_direct_fft_inplace<params>(selected_memory);
synchronize_threads_in_block();
// Write the output to global memory
tid = threadIdx.x;
#pragma unroll
for (int j = 0; j < log_2_opt; j++) {
dest[tid + (params::opt >> 1) * offset] = shared_output[tid];
dest[tid + (params::opt >> 1) * offset] = selected_memory[tid];
tid += params::degree / params::opt;
}
}
@@ -46,19 +51,29 @@ __global__ void batch_fft_ggsw_vectors(double2 *dest, T *src) {
template <typename T, typename ST, class params>
void batch_fft_ggsw_vector(void *v_stream, double2 *dest, T *src, uint32_t r,
uint32_t glwe_dim, uint32_t polynomial_size,
uint32_t level_count) {
uint32_t level_count, uint32_t gpu_index,
uint32_t max_shared_memory) {
auto stream = static_cast<cudaStream_t *>(v_stream);
int shared_memory_size = sizeof(double) * polynomial_size;
int gridSize = r * (glwe_dim + 1) * (glwe_dim + 1) * level_count;
;
int blockSize = polynomial_size / params::opt;
batch_fft_ggsw_vectors<T, ST, params>
<<<gridSize, blockSize, shared_memory_size, *stream>>>(dest, src);
checkCudaErrors(cudaGetLastError());
char *d_mem;
if (max_shared_memory < shared_memory_size) {
d_mem = (char *)cuda_malloc_async(shared_memory_size, *stream, gpu_index);
device_batch_fft_ggsw_vector<T, ST, params, NOSM>
<<<gridSize, blockSize, 0, *stream>>>(dest, src, d_mem);
checkCudaErrors(cudaGetLastError());
cuda_drop_async(d_mem, *stream, gpu_index);
} else {
device_batch_fft_ggsw_vector<T, ST, params, FULLSM>
<<<gridSize, blockSize, shared_memory_size, *stream>>>(dest, src,
d_mem);
checkCudaErrors(cudaGetLastError());
}
}
#endif // CONCRETE_CORE_GGSW_CUH