diff --git a/src/crypto/bootstrapping_key.cuh b/src/crypto/bootstrapping_key.cuh index 540346ac2..21b00141f 100644 --- a/src/crypto/bootstrapping_key.cuh +++ b/src/crypto/bootstrapping_key.cuh @@ -118,6 +118,12 @@ void cuda_convert_lwe_bootstrap_key(double2 *dest, ST *src, void *v_stream, case 512: if (shared_memory_size <= cuda_get_max_shared_memory(gpu_index)) { buffer = (double2 *)cuda_malloc_async(0, *stream, gpu_index); + checkCudaErrors(cudaFuncSetAttribute( + batch_NSMFFT, ForwardFFT>, FULLSM>, + cudaFuncAttributeMaxDynamicSharedMemorySize, shared_memory_size)); + checkCudaErrors(cudaFuncSetCacheConfig( + batch_NSMFFT, ForwardFFT>, FULLSM>, + cudaFuncCachePreferShared)); batch_NSMFFT, ForwardFFT>, FULLSM> <<>>(d_bsk, dest, buffer); @@ -131,6 +137,12 @@ void cuda_convert_lwe_bootstrap_key(double2 *dest, ST *src, void *v_stream, case 1024: if (shared_memory_size <= cuda_get_max_shared_memory(gpu_index)) { buffer = (double2 *)cuda_malloc_async(0, *stream, gpu_index); + checkCudaErrors(cudaFuncSetAttribute( + batch_NSMFFT, ForwardFFT>, FULLSM>, + cudaFuncAttributeMaxDynamicSharedMemorySize, shared_memory_size)); + checkCudaErrors(cudaFuncSetCacheConfig( + batch_NSMFFT, ForwardFFT>, FULLSM>, + cudaFuncCachePreferShared)); batch_NSMFFT, ForwardFFT>, FULLSM> <<>>(d_bsk, dest, buffer); @@ -144,6 +156,12 @@ void cuda_convert_lwe_bootstrap_key(double2 *dest, ST *src, void *v_stream, case 2048: if (shared_memory_size <= cuda_get_max_shared_memory(gpu_index)) { buffer = (double2 *)cuda_malloc_async(0, *stream, gpu_index); + checkCudaErrors(cudaFuncSetAttribute( + batch_NSMFFT, ForwardFFT>, FULLSM>, + cudaFuncAttributeMaxDynamicSharedMemorySize, shared_memory_size)); + checkCudaErrors(cudaFuncSetCacheConfig( + batch_NSMFFT, ForwardFFT>, FULLSM>, + cudaFuncCachePreferShared)); batch_NSMFFT, ForwardFFT>, FULLSM> <<>>(d_bsk, dest, buffer); @@ -157,6 +175,12 @@ void cuda_convert_lwe_bootstrap_key(double2 *dest, ST *src, void *v_stream, case 4096: if (shared_memory_size <= cuda_get_max_shared_memory(gpu_index)) { buffer = (double2 *)cuda_malloc_async(0, *stream, gpu_index); + checkCudaErrors(cudaFuncSetAttribute( + batch_NSMFFT, ForwardFFT>, FULLSM>, + cudaFuncAttributeMaxDynamicSharedMemorySize, shared_memory_size)); + checkCudaErrors(cudaFuncSetCacheConfig( + batch_NSMFFT, ForwardFFT>, FULLSM>, + cudaFuncCachePreferShared)); batch_NSMFFT, ForwardFFT>, FULLSM> <<>>(d_bsk, dest, buffer); @@ -170,6 +194,12 @@ void cuda_convert_lwe_bootstrap_key(double2 *dest, ST *src, void *v_stream, case 8192: if (shared_memory_size <= cuda_get_max_shared_memory(gpu_index)) { buffer = (double2 *)cuda_malloc_async(0, *stream, gpu_index); + checkCudaErrors(cudaFuncSetAttribute( + batch_NSMFFT, ForwardFFT>, FULLSM>, + cudaFuncAttributeMaxDynamicSharedMemorySize, shared_memory_size)); + checkCudaErrors(cudaFuncSetCacheConfig( + batch_NSMFFT, ForwardFFT>, FULLSM>, + cudaFuncCachePreferShared)); batch_NSMFFT, ForwardFFT>, FULLSM> <<>>(d_bsk, dest, buffer); diff --git a/src/fft/bnsmfft.cuh b/src/fft/bnsmfft.cuh index d712902db..dd41ee44e 100644 --- a/src/fft/bnsmfft.cuh +++ b/src/fft/bnsmfft.cuh @@ -46,6 +46,7 @@ template __device__ double2 negacyclic_twiddle(int j) { break; case 8192: twid = negTwids13[j]; + break; default: twid.x = 0; twid.y = 0; @@ -721,12 +722,9 @@ __device__ void correction_inverse_fft_inplace(double2 *x) { template __global__ void batch_NSMFFT(double2 *d_input, double2 *d_output, double2 *buffer) { - double2 *fft = &buffer[blockIdx.x * params::degree / 2]; - if constexpr (SMD != NOSM) { - extern __shared__ double2 sharedMemoryFFT[]; - fft = sharedMemoryFFT; - } - + extern __shared__ double2 sharedMemoryFFT[]; + double2 *fft = (SMD == NOSM) ? &buffer[blockIdx.x * params::degree / 2] + : sharedMemoryFFT; int tid = threadIdx.x; #pragma unroll