mirror of
https://github.com/zama-ai/concrete.git
synced 2026-02-08 19:44:57 -05:00
fix(cuda): fix pbs for 8192 polynomial_size
This commit is contained in:
committed by
Agnès Leroy
parent
921c0a6306
commit
c1f1b533ea
@@ -118,6 +118,12 @@ void cuda_convert_lwe_bootstrap_key(double2 *dest, ST *src, void *v_stream,
|
||||
case 512:
|
||||
if (shared_memory_size <= cuda_get_max_shared_memory(gpu_index)) {
|
||||
buffer = (double2 *)cuda_malloc_async(0, *stream, gpu_index);
|
||||
checkCudaErrors(cudaFuncSetAttribute(
|
||||
batch_NSMFFT<FFTDegree<Degree<512>, ForwardFFT>, FULLSM>,
|
||||
cudaFuncAttributeMaxDynamicSharedMemorySize, shared_memory_size));
|
||||
checkCudaErrors(cudaFuncSetCacheConfig(
|
||||
batch_NSMFFT<FFTDegree<Degree<512>, ForwardFFT>, FULLSM>,
|
||||
cudaFuncCachePreferShared));
|
||||
batch_NSMFFT<FFTDegree<Degree<512>, ForwardFFT>, FULLSM>
|
||||
<<<gridSize, blockSize, shared_memory_size, *stream>>>(d_bsk, dest,
|
||||
buffer);
|
||||
@@ -131,6 +137,12 @@ void cuda_convert_lwe_bootstrap_key(double2 *dest, ST *src, void *v_stream,
|
||||
case 1024:
|
||||
if (shared_memory_size <= cuda_get_max_shared_memory(gpu_index)) {
|
||||
buffer = (double2 *)cuda_malloc_async(0, *stream, gpu_index);
|
||||
checkCudaErrors(cudaFuncSetAttribute(
|
||||
batch_NSMFFT<FFTDegree<Degree<1024>, ForwardFFT>, FULLSM>,
|
||||
cudaFuncAttributeMaxDynamicSharedMemorySize, shared_memory_size));
|
||||
checkCudaErrors(cudaFuncSetCacheConfig(
|
||||
batch_NSMFFT<FFTDegree<Degree<1024>, ForwardFFT>, FULLSM>,
|
||||
cudaFuncCachePreferShared));
|
||||
batch_NSMFFT<FFTDegree<Degree<1024>, ForwardFFT>, FULLSM>
|
||||
<<<gridSize, blockSize, shared_memory_size, *stream>>>(d_bsk, dest,
|
||||
buffer);
|
||||
@@ -144,6 +156,12 @@ void cuda_convert_lwe_bootstrap_key(double2 *dest, ST *src, void *v_stream,
|
||||
case 2048:
|
||||
if (shared_memory_size <= cuda_get_max_shared_memory(gpu_index)) {
|
||||
buffer = (double2 *)cuda_malloc_async(0, *stream, gpu_index);
|
||||
checkCudaErrors(cudaFuncSetAttribute(
|
||||
batch_NSMFFT<FFTDegree<Degree<2048>, ForwardFFT>, FULLSM>,
|
||||
cudaFuncAttributeMaxDynamicSharedMemorySize, shared_memory_size));
|
||||
checkCudaErrors(cudaFuncSetCacheConfig(
|
||||
batch_NSMFFT<FFTDegree<Degree<2048>, ForwardFFT>, FULLSM>,
|
||||
cudaFuncCachePreferShared));
|
||||
batch_NSMFFT<FFTDegree<Degree<2048>, ForwardFFT>, FULLSM>
|
||||
<<<gridSize, blockSize, shared_memory_size, *stream>>>(d_bsk, dest,
|
||||
buffer);
|
||||
@@ -157,6 +175,12 @@ void cuda_convert_lwe_bootstrap_key(double2 *dest, ST *src, void *v_stream,
|
||||
case 4096:
|
||||
if (shared_memory_size <= cuda_get_max_shared_memory(gpu_index)) {
|
||||
buffer = (double2 *)cuda_malloc_async(0, *stream, gpu_index);
|
||||
checkCudaErrors(cudaFuncSetAttribute(
|
||||
batch_NSMFFT<FFTDegree<Degree<4096>, ForwardFFT>, FULLSM>,
|
||||
cudaFuncAttributeMaxDynamicSharedMemorySize, shared_memory_size));
|
||||
checkCudaErrors(cudaFuncSetCacheConfig(
|
||||
batch_NSMFFT<FFTDegree<Degree<4096>, ForwardFFT>, FULLSM>,
|
||||
cudaFuncCachePreferShared));
|
||||
batch_NSMFFT<FFTDegree<Degree<4096>, ForwardFFT>, FULLSM>
|
||||
<<<gridSize, blockSize, shared_memory_size, *stream>>>(d_bsk, dest,
|
||||
buffer);
|
||||
@@ -170,6 +194,12 @@ void cuda_convert_lwe_bootstrap_key(double2 *dest, ST *src, void *v_stream,
|
||||
case 8192:
|
||||
if (shared_memory_size <= cuda_get_max_shared_memory(gpu_index)) {
|
||||
buffer = (double2 *)cuda_malloc_async(0, *stream, gpu_index);
|
||||
checkCudaErrors(cudaFuncSetAttribute(
|
||||
batch_NSMFFT<FFTDegree<Degree<8192>, ForwardFFT>, FULLSM>,
|
||||
cudaFuncAttributeMaxDynamicSharedMemorySize, shared_memory_size));
|
||||
checkCudaErrors(cudaFuncSetCacheConfig(
|
||||
batch_NSMFFT<FFTDegree<Degree<8192>, ForwardFFT>, FULLSM>,
|
||||
cudaFuncCachePreferShared));
|
||||
batch_NSMFFT<FFTDegree<Degree<8192>, ForwardFFT>, FULLSM>
|
||||
<<<gridSize, blockSize, shared_memory_size, *stream>>>(d_bsk, dest,
|
||||
buffer);
|
||||
|
||||
@@ -46,6 +46,7 @@ template <int degree> __device__ double2 negacyclic_twiddle(int j) {
|
||||
break;
|
||||
case 8192:
|
||||
twid = negTwids13[j];
|
||||
break;
|
||||
default:
|
||||
twid.x = 0;
|
||||
twid.y = 0;
|
||||
@@ -721,12 +722,9 @@ __device__ void correction_inverse_fft_inplace(double2 *x) {
|
||||
template <class params, sharedMemDegree SMD>
|
||||
__global__ void batch_NSMFFT(double2 *d_input, double2 *d_output,
|
||||
double2 *buffer) {
|
||||
double2 *fft = &buffer[blockIdx.x * params::degree / 2];
|
||||
if constexpr (SMD != NOSM) {
|
||||
extern __shared__ double2 sharedMemoryFFT[];
|
||||
fft = sharedMemoryFFT;
|
||||
}
|
||||
|
||||
extern __shared__ double2 sharedMemoryFFT[];
|
||||
double2 *fft = (SMD == NOSM) ? &buffer[blockIdx.x * params::degree / 2]
|
||||
: sharedMemoryFFT;
|
||||
int tid = threadIdx.x;
|
||||
|
||||
#pragma unroll
|
||||
|
||||
Reference in New Issue
Block a user