diff --git a/backends/tfhe-cuda-backend/cuda/src/pbs/programmable_bootstrap_multibit.cu b/backends/tfhe-cuda-backend/cuda/src/pbs/programmable_bootstrap_multibit.cu index 4ca6ab5b9..a1d1dc5af 100644 --- a/backends/tfhe-cuda-backend/cuda/src/pbs/programmable_bootstrap_multibit.cu +++ b/backends/tfhe-cuda-backend/cuda/src/pbs/programmable_bootstrap_multibit.cu @@ -398,20 +398,32 @@ uint64_t scratch_cuda_multi_bit_programmable_bootstrap_64( uint32_t polynomial_size, uint32_t level_count, uint32_t input_lwe_ciphertext_count, bool allocate_gpu_memory) { + bool supports_cg = + supports_cooperative_groups_on_multibit_programmable_bootstrap( + glwe_dimension, polynomial_size, level_count, + input_lwe_ciphertext_count, cuda_get_max_shared_memory(gpu_index)); #if (CUDA_ARCH >= 900) - if (has_support_to_cuda_programmable_bootstrap_tbc_multi_bit( + // On H100s we should be using TBC until num_samples < num_sms / 2. + // After that we switch to CG until not supported anymore. + // At this point we return to TBC. + int num_sms = 0; + check_cuda_error(cudaDeviceGetAttribute( + &num_sms, cudaDevAttrMultiProcessorCount, gpu_index)); + + bool supports_tbc = + has_support_to_cuda_programmable_bootstrap_tbc_multi_bit( input_lwe_ciphertext_count, glwe_dimension, polynomial_size, - level_count, cuda_get_max_shared_memory(gpu_index))) + level_count, cuda_get_max_shared_memory(gpu_index)); + + if (supports_tbc && + !(input_lwe_ciphertext_count > num_sms / 2 && supports_cg)) return scratch_cuda_tbc_multi_bit_programmable_bootstrap( stream, gpu_index, (pbs_buffer **)buffer, glwe_dimension, polynomial_size, level_count, input_lwe_ciphertext_count, allocate_gpu_memory); else #endif - if (supports_cooperative_groups_on_multibit_programmable_bootstrap< - uint64_t>(glwe_dimension, polynomial_size, level_count, - input_lwe_ciphertext_count, - cuda_get_max_shared_memory(gpu_index))) + if (supports_cg) return scratch_cuda_cg_multi_bit_programmable_bootstrap( stream, gpu_index, (pbs_buffer **)buffer, glwe_dimension, polynomial_size, level_count,