mirror of
https://github.com/zama-ai/tfhe-rs.git
synced 2026-01-10 07:08:03 -05:00
chore(gpu): switches from the TBC PBS to the other variants for many inputs
This commit is contained in:
@@ -398,20 +398,32 @@ uint64_t scratch_cuda_multi_bit_programmable_bootstrap_64(
|
||||
uint32_t polynomial_size, uint32_t level_count,
|
||||
uint32_t input_lwe_ciphertext_count, bool allocate_gpu_memory) {
|
||||
|
||||
bool supports_cg =
|
||||
supports_cooperative_groups_on_multibit_programmable_bootstrap<uint64_t>(
|
||||
glwe_dimension, polynomial_size, level_count,
|
||||
input_lwe_ciphertext_count, cuda_get_max_shared_memory(gpu_index));
|
||||
#if (CUDA_ARCH >= 900)
|
||||
if (has_support_to_cuda_programmable_bootstrap_tbc_multi_bit<uint64_t>(
|
||||
// On H100s we should be using TBC until num_samples < num_sms / 2.
|
||||
// After that we switch to CG until not supported anymore.
|
||||
// At this point we return to TBC.
|
||||
int num_sms = 0;
|
||||
check_cuda_error(cudaDeviceGetAttribute(
|
||||
&num_sms, cudaDevAttrMultiProcessorCount, gpu_index));
|
||||
|
||||
bool supports_tbc =
|
||||
has_support_to_cuda_programmable_bootstrap_tbc_multi_bit<uint64_t>(
|
||||
input_lwe_ciphertext_count, glwe_dimension, polynomial_size,
|
||||
level_count, cuda_get_max_shared_memory(gpu_index)))
|
||||
level_count, cuda_get_max_shared_memory(gpu_index));
|
||||
|
||||
if (supports_tbc &&
|
||||
!(input_lwe_ciphertext_count > num_sms / 2 && supports_cg))
|
||||
return scratch_cuda_tbc_multi_bit_programmable_bootstrap<uint64_t>(
|
||||
stream, gpu_index, (pbs_buffer<uint64_t, MULTI_BIT> **)buffer,
|
||||
glwe_dimension, polynomial_size, level_count,
|
||||
input_lwe_ciphertext_count, allocate_gpu_memory);
|
||||
else
|
||||
#endif
|
||||
if (supports_cooperative_groups_on_multibit_programmable_bootstrap<
|
||||
uint64_t>(glwe_dimension, polynomial_size, level_count,
|
||||
input_lwe_ciphertext_count,
|
||||
cuda_get_max_shared_memory(gpu_index)))
|
||||
if (supports_cg)
|
||||
return scratch_cuda_cg_multi_bit_programmable_bootstrap<uint64_t>(
|
||||
stream, gpu_index, (pbs_buffer<uint64_t, MULTI_BIT> **)buffer,
|
||||
glwe_dimension, polynomial_size, level_count,
|
||||
|
||||
Reference in New Issue
Block a user