mirror of
https://github.com/zama-ai/concrete.git
synced 2026-02-08 11:35:02 -05:00
feat(cuda): batch_fft_ggsw_vector uses global memory in case there is not enough space in the shared memory
This commit is contained in:
@@ -275,9 +275,9 @@ void host_cmux_tree(void *v_stream, Torus *glwe_array_out, Torus *ggsw_in,
|
||||
double2 *d_ggsw_fft_in = (double2 *)cuda_malloc_async(
|
||||
r * ggsw_size * sizeof(double), *stream, gpu_index);
|
||||
|
||||
batch_fft_ggsw_vector<Torus, STorus, params>(v_stream, d_ggsw_fft_in, ggsw_in,
|
||||
r, glwe_dimension,
|
||||
polynomial_size, level_count);
|
||||
batch_fft_ggsw_vector<Torus, STorus, params>(
|
||||
v_stream, d_ggsw_fft_in, ggsw_in, r, glwe_dimension, polynomial_size,
|
||||
level_count, gpu_index, max_shared_memory);
|
||||
|
||||
//////////////////////
|
||||
|
||||
@@ -653,9 +653,9 @@ void host_blind_rotate_and_sample_extraction(
|
||||
double2 *d_ggsw_fft_in = (double2 *)cuda_malloc_async(
|
||||
mbr_size * ggsw_size * sizeof(double), *stream, gpu_index);
|
||||
|
||||
batch_fft_ggsw_vector<Torus, STorus, params>(v_stream, d_ggsw_fft_in, ggsw_in,
|
||||
mbr_size, glwe_dimension,
|
||||
polynomial_size, l_gadget);
|
||||
batch_fft_ggsw_vector<Torus, STorus, params>(
|
||||
v_stream, d_ggsw_fft_in, ggsw_in, mbr_size, glwe_dimension,
|
||||
polynomial_size, l_gadget, gpu_index, max_shared_memory);
|
||||
checkCudaErrors(cudaGetLastError());
|
||||
|
||||
//
|
||||
|
||||
@@ -1,12 +1,17 @@
|
||||
#ifndef CONCRETE_CORE_GGSW_CUH
|
||||
#define CONCRETE_CORE_GGSW_CUH
|
||||
|
||||
template <typename T, typename ST, class params>
|
||||
__global__ void batch_fft_ggsw_vectors(double2 *dest, T *src) {
|
||||
template <typename T, typename ST, class params, sharedMemDegree SMD>
|
||||
__global__ void device_batch_fft_ggsw_vector(double2 *dest, T *src,
|
||||
char *device_mem) {
|
||||
|
||||
extern __shared__ char sharedmem[];
|
||||
double2 *selected_memory;
|
||||
|
||||
double2 *shared_output = (double2 *)sharedmem;
|
||||
if constexpr (SMD == FULLSM)
|
||||
selected_memory = (double2 *)sharedmem;
|
||||
else
|
||||
selected_memory = (double2 *)device_mem[blockIdx.x * params::degree];
|
||||
|
||||
// Compression
|
||||
int offset = blockIdx.x * blockDim.x;
|
||||
@@ -17,24 +22,24 @@ __global__ void batch_fft_ggsw_vectors(double2 *dest, T *src) {
|
||||
for (int i = 0; i < log_2_opt; i++) {
|
||||
ST x = src[(2 * tid) + params::opt * offset];
|
||||
ST y = src[(2 * tid + 1) + params::opt * offset];
|
||||
shared_output[tid].x = x / (double)std::numeric_limits<T>::max();
|
||||
shared_output[tid].y = y / (double)std::numeric_limits<T>::max();
|
||||
selected_memory[tid].x = x / (double)std::numeric_limits<T>::max();
|
||||
selected_memory[tid].y = y / (double)std::numeric_limits<T>::max();
|
||||
tid += params::degree / params::opt;
|
||||
}
|
||||
synchronize_threads_in_block();
|
||||
|
||||
// Switch to the FFT space
|
||||
NSMFFT_direct<HalfDegree<params>>(shared_output);
|
||||
NSMFFT_direct<HalfDegree<params>>(selected_memory);
|
||||
synchronize_threads_in_block();
|
||||
|
||||
correction_direct_fft_inplace<params>(shared_output);
|
||||
correction_direct_fft_inplace<params>(selected_memory);
|
||||
synchronize_threads_in_block();
|
||||
|
||||
// Write the output to global memory
|
||||
tid = threadIdx.x;
|
||||
#pragma unroll
|
||||
for (int j = 0; j < log_2_opt; j++) {
|
||||
dest[tid + (params::opt >> 1) * offset] = shared_output[tid];
|
||||
dest[tid + (params::opt >> 1) * offset] = selected_memory[tid];
|
||||
tid += params::degree / params::opt;
|
||||
}
|
||||
}
|
||||
@@ -46,19 +51,29 @@ __global__ void batch_fft_ggsw_vectors(double2 *dest, T *src) {
|
||||
template <typename T, typename ST, class params>
|
||||
void batch_fft_ggsw_vector(void *v_stream, double2 *dest, T *src, uint32_t r,
|
||||
uint32_t glwe_dim, uint32_t polynomial_size,
|
||||
uint32_t level_count) {
|
||||
uint32_t level_count, uint32_t gpu_index,
|
||||
uint32_t max_shared_memory) {
|
||||
|
||||
auto stream = static_cast<cudaStream_t *>(v_stream);
|
||||
|
||||
int shared_memory_size = sizeof(double) * polynomial_size;
|
||||
|
||||
int gridSize = r * (glwe_dim + 1) * (glwe_dim + 1) * level_count;
|
||||
;
|
||||
int blockSize = polynomial_size / params::opt;
|
||||
|
||||
batch_fft_ggsw_vectors<T, ST, params>
|
||||
<<<gridSize, blockSize, shared_memory_size, *stream>>>(dest, src);
|
||||
checkCudaErrors(cudaGetLastError());
|
||||
char *d_mem;
|
||||
if (max_shared_memory < shared_memory_size) {
|
||||
d_mem = (char *)cuda_malloc_async(shared_memory_size, *stream, gpu_index);
|
||||
device_batch_fft_ggsw_vector<T, ST, params, NOSM>
|
||||
<<<gridSize, blockSize, 0, *stream>>>(dest, src, d_mem);
|
||||
checkCudaErrors(cudaGetLastError());
|
||||
cuda_drop_async(d_mem, *stream, gpu_index);
|
||||
} else {
|
||||
device_batch_fft_ggsw_vector<T, ST, params, FULLSM>
|
||||
<<<gridSize, blockSize, shared_memory_size, *stream>>>(dest, src,
|
||||
d_mem);
|
||||
checkCudaErrors(cudaGetLastError());
|
||||
}
|
||||
}
|
||||
|
||||
#endif // CONCRETE_CORE_GGSW_CUH
|
||||
|
||||
Reference in New Issue
Block a user