From 57ea3e3e8894e8fc753c70586482315df84ca0ae Mon Sep 17 00:00:00 2001 From: Pedro Alves Date: Mon, 18 Aug 2025 10:57:03 -0300 Subject: [PATCH] chore(gpu): refactor the entry points for PBS in the backend --- .../cuda/include/helper_multi_gpu.h | 39 ++-- .../cuda/include/integer/integer_utilities.h | 11 +- .../cuda/src/crypto/keyswitch.cuh | 8 +- .../cuda/src/crypto/packing_keyswitch.cuh | 4 +- .../src/integer/compression/compression.cuh | 4 +- .../cuda/src/integer/integer.cuh | 42 +++-- .../cuda/src/integer/multiplication.cuh | 4 +- .../cuda/src/integer/oprf.cuh | 4 +- .../cuda/src/pbs/programmable_bootstrap.cuh | 172 +++++++++++++----- .../src/pbs/programmable_bootstrap_128.cuh | 45 ----- 10 files changed, 195 insertions(+), 138 deletions(-) diff --git a/backends/tfhe-cuda-backend/cuda/include/helper_multi_gpu.h b/backends/tfhe-cuda-backend/cuda/include/helper_multi_gpu.h index 667c5f21d..4cfd4c8cd 100644 --- a/backends/tfhe-cuda-backend/cuda/include/helper_multi_gpu.h +++ b/backends/tfhe-cuda-backend/cuda/include/helper_multi_gpu.h @@ -16,24 +16,27 @@ int32_t cuda_setup_multi_gpu(int device_0_id); template using LweArrayVariant = std::variant, Torus *>; -// Macro to define the visitor logic using std::holds_alternative for vectors -#define GET_VARIANT_ELEMENT(variant, index) \ - [&] { \ - if (std::holds_alternative>(variant)) { \ - return std::get>(variant)[index]; \ - } else { \ - return std::get(variant); \ - } \ - }() -// Macro to define the visitor logic using std::holds_alternative for vectors -#define GET_VARIANT_ELEMENT_64BIT(variant, index) \ - [&] { \ - if (std::holds_alternative>(variant)) { \ - return std::get>(variant)[index]; \ - } else { \ - return std::get(variant); \ - } \ - }() +/// get_variant_element() resolves access when the input may be either a single +/// pointer or a vector of pointers. If the variant holds a single pointer, the +/// index is ignored and that pointer is returned; if it holds a vector, the +/// element at `index` is returned. +/// +/// This function replaces the previous macro: +/// - Easier to debug and read than a macro +/// - Deduces the pointer type from the variant (no need to name a Torus type +/// explicitly) +/// - Defined in a header, so it’s eligible for inlining by the optimizer +template +inline Torus +get_variant_element(const std::variant, Torus> &variant, + size_t index) { + if (std::holds_alternative>(variant)) { + return std::get>(variant)[index]; + } else { + return std::get(variant); + } +} + int get_active_gpu_count(int num_inputs, int gpu_count); int get_num_inputs_on_gpu(int total_num_inputs, int gpu_index, int gpu_count); diff --git a/backends/tfhe-cuda-backend/cuda/include/integer/integer_utilities.h b/backends/tfhe-cuda-backend/cuda/include/integer/integer_utilities.h index 3015cc474..bb4766c59 100644 --- a/backends/tfhe-cuda-backend/cuda/include/integer/integer_utilities.h +++ b/backends/tfhe-cuda-backend/cuda/include/integer/integer_utilities.h @@ -6,7 +6,6 @@ #include "integer/radix_ciphertext.h" #include "keyswitch/keyswitch.h" #include "pbs/programmable_bootstrap.cuh" -#include "pbs/programmable_bootstrap_128.cuh" #include "utils/helper_multi_gpu.cuh" #include #include @@ -876,11 +875,11 @@ template struct int_noise_squashing_lut { get_num_inputs_on_gpu(num_radix_blocks, i, active_gpu_count)); int8_t *gpu_pbs_buffer; uint64_t size = 0; - execute_scratch_pbs_128(streams[i], gpu_indexes[i], &gpu_pbs_buffer, - params.small_lwe_dimension, params.glwe_dimension, - params.polynomial_size, params.pbs_level, - num_radix_blocks_on_gpu, allocate_gpu_memory, - params.noise_reduction_type, size); + execute_scratch_pbs<__uint128_t>( + streams[i], gpu_indexes[i], &gpu_pbs_buffer, params.glwe_dimension, + params.small_lwe_dimension, params.polynomial_size, params.pbs_level, + params.grouping_factor, num_radix_blocks_on_gpu, params.pbs_type, + allocate_gpu_memory, params.noise_reduction_type, size); cuda_synchronize_stream(streams[i], gpu_indexes[i]); if (i == 0) { size_tracker += size; diff --git a/backends/tfhe-cuda-backend/cuda/src/crypto/keyswitch.cuh b/backends/tfhe-cuda-backend/cuda/src/crypto/keyswitch.cuh index 75ef1ad12..c2117f0bb 100644 --- a/backends/tfhe-cuda-backend/cuda/src/crypto/keyswitch.cuh +++ b/backends/tfhe-cuda-backend/cuda/src/crypto/keyswitch.cuh @@ -157,12 +157,12 @@ void execute_keyswitch_async(cudaStream_t const *streams, for (uint i = 0; i < gpu_count; i++) { int num_samples_on_gpu = get_num_inputs_on_gpu(num_samples, i, gpu_count); - Torus *current_lwe_array_out = GET_VARIANT_ELEMENT(lwe_array_out, i); + Torus *current_lwe_array_out = get_variant_element(lwe_array_out, i); Torus *current_lwe_output_indexes = - GET_VARIANT_ELEMENT(lwe_output_indexes, i); - Torus *current_lwe_array_in = GET_VARIANT_ELEMENT(lwe_array_in, i); + get_variant_element(lwe_output_indexes, i); + Torus *current_lwe_array_in = get_variant_element(lwe_array_in, i); Torus *current_lwe_input_indexes = - GET_VARIANT_ELEMENT(lwe_input_indexes, i); + get_variant_element(lwe_input_indexes, i); // Compute Keyswitch host_keyswitch_lwe_ciphertext_vector( diff --git a/backends/tfhe-cuda-backend/cuda/src/crypto/packing_keyswitch.cuh b/backends/tfhe-cuda-backend/cuda/src/crypto/packing_keyswitch.cuh index 328a1a862..5192307dd 100644 --- a/backends/tfhe-cuda-backend/cuda/src/crypto/packing_keyswitch.cuh +++ b/backends/tfhe-cuda-backend/cuda/src/crypto/packing_keyswitch.cuh @@ -202,9 +202,9 @@ __host__ void host_packing_keyswitch_lwe_list_to_glwe( auto stride_KSK_buffer = glwe_accumulator_size * level_count; - // Shared memory requirement is 4096, 8192, and 16384 bytes respectively for - // 32, 64, and 128-bit Torus elements We want to keep this as a sanity check uint32_t shared_mem_size = get_shared_mem_size_tgemm(); + // Shared memory requirement is 4096, 8192, and 16384 bytes respectively for + // 32, 64, and 128-bit Torus elements // Sanity check: the shared memory size is a constant defined by the algorithm GPU_ASSERT(shared_mem_size <= 1024 * sizeof(Torus), "GEMM kernel error: shared memory required might be too large"); diff --git a/backends/tfhe-cuda-backend/cuda/src/integer/compression/compression.cuh b/backends/tfhe-cuda-backend/cuda/src/integer/compression/compression.cuh index d32d07afc..917337b69 100644 --- a/backends/tfhe-cuda-backend/cuda/src/integer/compression/compression.cuh +++ b/backends/tfhe-cuda-backend/cuda/src/integer/compression/compression.cuh @@ -344,7 +344,7 @@ host_integer_decompress(cudaStream_t const *streams, auto active_gpu_count = get_active_gpu_count(num_blocks_to_decompress, gpu_count); if (active_gpu_count == 1) { - execute_pbs_async( + execute_pbs_async( streams, gpu_indexes, active_gpu_count, (Torus *)d_lwe_array_out->ptr, lut->lwe_indexes_out, lut->lut_vec, lut->lut_indexes_vec, extracted_lwe, lut->lwe_indexes_in, d_bsks, nullptr, lut->buffer, @@ -374,7 +374,7 @@ host_integer_decompress(cudaStream_t const *streams, compression_params.small_lwe_dimension + 1); /// Apply PBS - execute_pbs_async( + execute_pbs_async( streams, gpu_indexes, active_gpu_count, lwe_after_pbs_vec, lwe_trivial_indexes_vec, lut->lut_vec, lut->lut_indexes_vec, lwe_array_in_vec, lwe_trivial_indexes_vec, d_bsks, nullptr, diff --git a/backends/tfhe-cuda-backend/cuda/src/integer/integer.cuh b/backends/tfhe-cuda-backend/cuda/src/integer/integer.cuh index 9502d6ba7..d298fd993 100644 --- a/backends/tfhe-cuda-backend/cuda/src/integer/integer.cuh +++ b/backends/tfhe-cuda-backend/cuda/src/integer/integer.cuh @@ -558,7 +558,7 @@ __host__ void integer_radix_apply_univariate_lookup_table_kb( /// Apply PBS to apply a LUT, reduce the noise and go from a small LWE /// dimension to a big LWE dimension - execute_pbs_async( + execute_pbs_async( streams, gpu_indexes, 1, (Torus *)lwe_array_out->ptr, lut->lwe_indexes_out, lut->lut_vec, lut->lut_indexes_vec, lwe_after_ks_vec[0], lwe_trivial_indexes_vec[0], bsks, @@ -586,7 +586,7 @@ __host__ void integer_radix_apply_univariate_lookup_table_kb( /// Apply PBS to apply a LUT, reduce the noise and go from a small LWE /// dimension to a big LWE dimension - execute_pbs_async( + execute_pbs_async( streams, gpu_indexes, active_gpu_count, lwe_after_pbs_vec, lwe_trivial_indexes_vec, lut->lut_vec, lut->lut_indexes_vec, lwe_after_ks_vec, lwe_trivial_indexes_vec, bsks, ms_noise_reduction_key, @@ -665,7 +665,7 @@ __host__ void integer_radix_apply_many_univariate_lookup_table_kb( /// Apply PBS to apply a LUT, reduce the noise and go from a small LWE /// dimension to a big LWE dimension - execute_pbs_async( + execute_pbs_async( streams, gpu_indexes, 1, (Torus *)lwe_array_out->ptr, lut->lwe_indexes_out, lut->lut_vec, lut->lut_indexes_vec, lwe_after_ks_vec[0], lwe_trivial_indexes_vec[0], bsks, @@ -693,7 +693,7 @@ __host__ void integer_radix_apply_many_univariate_lookup_table_kb( /// Apply PBS to apply a LUT, reduce the noise and go from a small LWE /// dimension to a big LWE dimension - execute_pbs_async( + execute_pbs_async( streams, gpu_indexes, active_gpu_count, lwe_after_pbs_vec, lwe_trivial_indexes_vec, lut->lut_vec, lut->lut_indexes_vec, lwe_after_ks_vec, lwe_trivial_indexes_vec, bsks, ms_noise_reduction_key, @@ -787,7 +787,7 @@ __host__ void integer_radix_apply_bivariate_lookup_table_kb( /// Apply PBS to apply a LUT, reduce the noise and go from a small LWE /// dimension to a big LWE dimension - execute_pbs_async( + execute_pbs_async( streams, gpu_indexes, 1, (Torus *)(lwe_array_out->ptr), lut->lwe_indexes_out, lut->lut_vec, lut->lut_indexes_vec, lwe_after_ks_vec[0], lwe_trivial_indexes_vec[0], bsks, @@ -811,7 +811,7 @@ __host__ void integer_radix_apply_bivariate_lookup_table_kb( /// Apply PBS to apply a LUT, reduce the noise and go from a small LWE /// dimension to a big LWE dimension - execute_pbs_async( + execute_pbs_async( streams, gpu_indexes, active_gpu_count, lwe_after_pbs_vec, lwe_trivial_indexes_vec, lut->lut_vec, lut->lut_indexes_vec, lwe_after_ks_vec, lwe_trivial_indexes_vec, bsks, ms_noise_reduction_key, @@ -1486,7 +1486,7 @@ void host_full_propagate_inplace( streams[0], gpu_indexes[0], mem_ptr->tmp_small_lwe_vector, 1, 2, mem_ptr->tmp_small_lwe_vector, 0, 1); - execute_pbs_async( + execute_pbs_async( streams, gpu_indexes, 1, (Torus *)mem_ptr->tmp_big_lwe_vector->ptr, mem_ptr->lut->lwe_trivial_indexes, mem_ptr->lut->lut_vec, mem_ptr->lut->lut_indexes_vec, @@ -2344,11 +2344,17 @@ __host__ void integer_radix_apply_noise_squashing_kb( /// Apply PBS to apply a LUT, reduce the noise and go from a small LWE /// dimension to a big LWE dimension - execute_pbs_128_async<__uint128_t>( + /// + /// int_noise_squashing_lut doesn't support a different output or lut + /// indexing than the trivial + execute_pbs_async( streams, gpu_indexes, 1, (__uint128_t *)lwe_array_out->ptr, - lut->lut_vec, lwe_after_ks_vec[0], bsks, ms_noise_reduction_key, - lut->pbs_buffer, small_lwe_dimension, glwe_dimension, polynomial_size, - pbs_base_log, pbs_level, lwe_array_out->num_radix_blocks); + lwe_trivial_indexes_vec[0], lut->lut_vec, lwe_trivial_indexes_vec, + lwe_after_ks_vec[0], lwe_trivial_indexes_vec[0], bsks, + ms_noise_reduction_key, lut->pbs_buffer, glwe_dimension, + small_lwe_dimension, polynomial_size, pbs_base_log, pbs_level, + grouping_factor, lwe_array_out->num_radix_blocks, params.pbs_type, 0, + 0); } else { /// Make sure all data that should be on GPU 0 is indeed there cuda_synchronize_stream(streams[0], gpu_indexes[0]); @@ -2367,11 +2373,15 @@ __host__ void integer_radix_apply_noise_squashing_kb( ksks, lut->input_big_lwe_dimension, small_lwe_dimension, ks_base_log, ks_level, lwe_array_out->num_radix_blocks); - execute_pbs_128_async<__uint128_t>( - streams, gpu_indexes, active_gpu_count, lwe_after_pbs_vec, lut->lut_vec, - lwe_after_ks_vec, bsks, ms_noise_reduction_key, lut->pbs_buffer, - small_lwe_dimension, glwe_dimension, polynomial_size, pbs_base_log, - pbs_level, lwe_array_out->num_radix_blocks); + /// int_noise_squashing_lut doesn't support a different output or lut + /// indexing than the trivial + execute_pbs_async( + streams, gpu_indexes, active_gpu_count, lwe_after_pbs_vec, + lwe_trivial_indexes_vec, lut->lut_vec, lwe_trivial_indexes_vec, + lwe_after_ks_vec, lwe_trivial_indexes_vec, bsks, ms_noise_reduction_key, + lut->pbs_buffer, glwe_dimension, small_lwe_dimension, polynomial_size, + pbs_base_log, pbs_level, grouping_factor, + lwe_array_out->num_radix_blocks, params.pbs_type, 0, 0); /// Copy data back to GPU 0 and release vecs /// In apply noise squashing we always use trivial indexes diff --git a/backends/tfhe-cuda-backend/cuda/src/integer/multiplication.cuh b/backends/tfhe-cuda-backend/cuda/src/integer/multiplication.cuh index d86a2fbb9..ad50af6e9 100644 --- a/backends/tfhe-cuda-backend/cuda/src/integer/multiplication.cuh +++ b/backends/tfhe-cuda-backend/cuda/src/integer/multiplication.cuh @@ -404,7 +404,7 @@ __host__ void host_integer_partial_sum_ciphertexts_vec_kb( mem_ptr->params.ks_base_log, mem_ptr->params.ks_level, total_messages); - execute_pbs_async( + execute_pbs_async( streams, gpu_indexes, 1, (Torus *)current_blocks->ptr, d_pbs_indexes_out, luts_message_carry->lut_vec, luts_message_carry->lut_indexes_vec, (Torus *)small_lwe_vector->ptr, @@ -479,7 +479,7 @@ __host__ void host_integer_partial_sum_ciphertexts_vec_kb( big_lwe_dimension, small_lwe_dimension, mem_ptr->params.ks_base_log, mem_ptr->params.ks_level, num_radix_blocks); - execute_pbs_async( + execute_pbs_async( streams, gpu_indexes, 1, (Torus *)current_blocks->ptr, d_pbs_indexes_out, luts_message_carry->lut_vec, luts_message_carry->lut_indexes_vec, (Torus *)small_lwe_vector->ptr, diff --git a/backends/tfhe-cuda-backend/cuda/src/integer/oprf.cuh b/backends/tfhe-cuda-backend/cuda/src/integer/oprf.cuh index 9e758cfa0..eb7972017 100644 --- a/backends/tfhe-cuda-backend/cuda/src/integer/oprf.cuh +++ b/backends/tfhe-cuda-backend/cuda/src/integer/oprf.cuh @@ -34,7 +34,7 @@ void host_integer_grouped_oprf( auto lut = mem_ptr->luts; if (active_gpu_count == 1) { - execute_pbs_async( + execute_pbs_async( streams, gpu_indexes, (uint32_t)1, (Torus *)(radix_lwe_out->ptr), lut->lwe_indexes_out, lut->lut_vec, lut->lut_indexes_vec, const_cast(seeded_lwe_input), lut->lwe_indexes_in, bsks, @@ -60,7 +60,7 @@ void host_integer_grouped_oprf( active_gpu_count, num_blocks_to_process, mem_ptr->params.small_lwe_dimension + 1); - execute_pbs_async( + execute_pbs_async( streams, gpu_indexes, active_gpu_count, lwe_after_pbs_vec, lwe_trivial_indexes_vec, lut->lut_vec, lut->lut_indexes_vec, lwe_array_in_vec, lwe_trivial_indexes_vec, bsks, ms_noise_reduction_key, diff --git a/backends/tfhe-cuda-backend/cuda/src/pbs/programmable_bootstrap.cuh b/backends/tfhe-cuda-backend/cuda/src/pbs/programmable_bootstrap.cuh index 092d916a0..52d58f1ff 100644 --- a/backends/tfhe-cuda-backend/cuda/src/pbs/programmable_bootstrap.cuh +++ b/backends/tfhe-cuda-backend/cuda/src/pbs/programmable_bootstrap.cuh @@ -7,6 +7,7 @@ #include "device.h" #include "fft/bnsmfft.cuh" #include "helper_multi_gpu.h" +#include "pbs/pbs_128_utilities.h" #include "pbs/programmable_bootstrap_multibit.h" #include "polynomial/polynomial_math.cuh" @@ -202,15 +203,15 @@ __device__ void mul_ggsw_glwe_in_fourier_domain_2_2_params( // the buffer in registers to avoid synchronizations and shared memory usage } -template +template void execute_pbs_async( cudaStream_t const *streams, uint32_t const *gpu_indexes, - uint32_t gpu_count, const LweArrayVariant &lwe_array_out, - const LweArrayVariant &lwe_output_indexes, - const std::vector lut_vec, - const std::vector lut_indexes_vec, - const LweArrayVariant &lwe_array_in, - const LweArrayVariant &lwe_input_indexes, + uint32_t gpu_count, const LweArrayVariant &lwe_array_out, + const LweArrayVariant &lwe_output_indexes, + const std::vector lut_vec, + const std::vector lut_indexes_vec, + const LweArrayVariant &lwe_array_in, + const LweArrayVariant &lwe_input_indexes, void *const *bootstrapping_keys, CudaModulusSwitchNoiseReductionKeyFFI const *ms_noise_reduction_key, std::vector pbs_buffer, uint32_t glwe_dimension, @@ -219,8 +220,7 @@ void execute_pbs_async( uint32_t input_lwe_ciphertext_count, PBS_TYPE pbs_type, uint32_t num_many_lut, uint32_t lut_stride) { - switch (sizeof(Torus)) { - case sizeof(uint32_t): + if constexpr (std::is_same_v) { // 32 bits switch (pbs_type) { case MULTI_BIT: @@ -238,12 +238,12 @@ void execute_pbs_async( // Use the macro to get the correct elements for the current iteration // Handles the case when the input/output are scattered through // different gpus and when it is not - Torus *current_lwe_array_out = GET_VARIANT_ELEMENT(lwe_array_out, i); - Torus *current_lwe_output_indexes = - GET_VARIANT_ELEMENT(lwe_output_indexes, i); - Torus *current_lwe_array_in = GET_VARIANT_ELEMENT(lwe_array_in, i); - Torus *current_lwe_input_indexes = - GET_VARIANT_ELEMENT(lwe_input_indexes, i); + auto current_lwe_array_out = get_variant_element(lwe_array_out, i); + auto current_lwe_output_indexes = + get_variant_element(lwe_output_indexes, i); + auto current_lwe_array_in = get_variant_element(lwe_array_in, i); + auto current_lwe_input_indexes = + get_variant_element(lwe_input_indexes, i); cuda_programmable_bootstrap_lwe_ciphertext_vector_32( streams[i], gpu_indexes[i], current_lwe_array_out, @@ -257,8 +257,7 @@ void execute_pbs_async( default: PANIC("Error: unsupported cuda PBS type.") } - break; - case sizeof(uint64_t): + } else if constexpr (std::is_same_v) { // 64 bits switch (pbs_type) { case MULTI_BIT: @@ -271,12 +270,12 @@ void execute_pbs_async( // Use the macro to get the correct elements for the current iteration // Handles the case when the input/output are scattered through // different gpus and when it is not - Torus *current_lwe_array_out = GET_VARIANT_ELEMENT(lwe_array_out, i); - Torus *current_lwe_output_indexes = - GET_VARIANT_ELEMENT(lwe_output_indexes, i); - Torus *current_lwe_array_in = GET_VARIANT_ELEMENT(lwe_array_in, i); - Torus *current_lwe_input_indexes = - GET_VARIANT_ELEMENT(lwe_input_indexes, i); + auto current_lwe_array_out = get_variant_element(lwe_array_out, i); + auto current_lwe_output_indexes = + get_variant_element(lwe_output_indexes, i); + auto current_lwe_array_in = get_variant_element(lwe_array_in, i); + auto current_lwe_input_indexes = + get_variant_element(lwe_input_indexes, i); int gpu_offset = get_gpu_offset(input_lwe_ciphertext_count, i, gpu_count); @@ -300,12 +299,12 @@ void execute_pbs_async( // Use the macro to get the correct elements for the current iteration // Handles the case when the input/output are scattered through // different gpus and when it is not - Torus *current_lwe_array_out = GET_VARIANT_ELEMENT(lwe_array_out, i); - Torus *current_lwe_output_indexes = - GET_VARIANT_ELEMENT(lwe_output_indexes, i); - Torus *current_lwe_array_in = GET_VARIANT_ELEMENT(lwe_array_in, i); - Torus *current_lwe_input_indexes = - GET_VARIANT_ELEMENT(lwe_input_indexes, i); + auto current_lwe_array_out = get_variant_element(lwe_array_out, i); + auto current_lwe_output_indexes = + get_variant_element(lwe_output_indexes, i); + auto current_lwe_array_in = get_variant_element(lwe_array_in, i); + auto current_lwe_input_indexes = + get_variant_element(lwe_input_indexes, i); int gpu_offset = get_gpu_offset(input_lwe_ciphertext_count, i, gpu_count); @@ -328,10 +327,81 @@ void execute_pbs_async( default: PANIC("Error: unsupported cuda PBS type.") } - break; - default: - PANIC("Cuda error: unsupported modulus size: only 32 and 64 bit integer " - "moduli are supported.") + } else if constexpr (std::is_same_v) { + // 128 bits + switch (pbs_type) { + case MULTI_BIT: + if (grouping_factor == 0) + PANIC("Multi-bit PBS error: grouping factor should be > 0.") + for (uint i = 0; i < gpu_count; i++) { + int num_inputs_on_gpu = + get_num_inputs_on_gpu(input_lwe_ciphertext_count, i, gpu_count); + + // Use the macro to get the correct elements for the current iteration + // Handles the case when the input/output are scattered through + // different gpus and when it is not + auto current_lwe_array_out = get_variant_element(lwe_array_out, i); + auto current_lwe_output_indexes = + get_variant_element(lwe_output_indexes, i); + auto current_lwe_array_in = get_variant_element(lwe_array_in, i); + auto current_lwe_input_indexes = + get_variant_element(lwe_input_indexes, i); + + int gpu_offset = + get_gpu_offset(input_lwe_ciphertext_count, i, gpu_count); + auto d_lut_vector_indexes = + lut_indexes_vec[i] + (ptrdiff_t)(gpu_offset); + + cuda_multi_bit_programmable_bootstrap_lwe_ciphertext_vector_128( + streams[i], gpu_indexes[i], current_lwe_array_out, + current_lwe_output_indexes, lut_vec[i], d_lut_vector_indexes, + current_lwe_array_in, current_lwe_input_indexes, + bootstrapping_keys[i], pbs_buffer[i], lwe_dimension, glwe_dimension, + polynomial_size, grouping_factor, base_log, level_count, + num_inputs_on_gpu, num_many_lut, lut_stride); + } + break; + case CLASSICAL: + for (uint i = 0; i < gpu_count; i++) { + int num_inputs_on_gpu = + get_num_inputs_on_gpu(input_lwe_ciphertext_count, i, gpu_count); + + // Use the macro to get the correct elements for the current iteration + // Handles the case when the input/output are scattered through + // different gpus and when it is not + auto current_lwe_array_out = get_variant_element(lwe_array_out, i); + auto current_lwe_output_indexes = + get_variant_element(lwe_output_indexes, i); + auto current_lwe_array_in = get_variant_element(lwe_array_in, i); + auto current_lwe_input_indexes = + get_variant_element(lwe_input_indexes, i); + + int gpu_offset = + get_gpu_offset(input_lwe_ciphertext_count, i, gpu_count); + auto d_lut_vector_indexes = + lut_indexes_vec[i] + (ptrdiff_t)(gpu_offset); + + void *zeros = nullptr; + if (ms_noise_reduction_key != nullptr && + ms_noise_reduction_key->ptr != nullptr) + zeros = ms_noise_reduction_key->ptr[i]; + cuda_programmable_bootstrap_lwe_ciphertext_vector_128( + streams[i], gpu_indexes[i], current_lwe_array_out, lut_vec[i], + current_lwe_array_in, bootstrapping_keys[i], ms_noise_reduction_key, + zeros, pbs_buffer[i], lwe_dimension, glwe_dimension, + polynomial_size, base_log, level_count, num_inputs_on_gpu); + } + break; + default: + PANIC("Error: unsupported cuda PBS type.") + } + } else { + static_assert( + std::is_same_v || + std::is_same_v || + std::is_same_v, + "Cuda error: unsupported modulus size: only 32, 64, or 128-bit integer " + "moduli are supported."); } } @@ -344,8 +414,7 @@ void execute_scratch_pbs(cudaStream_t stream, uint32_t gpu_index, bool allocate_gpu_memory, PBS_MS_REDUCTION_T noise_reduction_type, uint64_t &size_tracker) { - switch (sizeof(Torus)) { - case sizeof(uint32_t): + if constexpr (std::is_same_v) { // 32 bits switch (pbs_type) { case MULTI_BIT: @@ -359,8 +428,7 @@ void execute_scratch_pbs(cudaStream_t stream, uint32_t gpu_index, default: PANIC("Error: unsupported cuda PBS type.") } - break; - case sizeof(uint64_t): + } else if constexpr (std::is_same_v) { // 64 bits switch (pbs_type) { case MULTI_BIT: @@ -379,10 +447,32 @@ void execute_scratch_pbs(cudaStream_t stream, uint32_t gpu_index, default: PANIC("Error: unsupported cuda PBS type.") } - break; - default: - PANIC("Cuda error: unsupported modulus size: only 32 and 64 bit integer " - "moduli are supported.") + } else if constexpr (std::is_same_v) { + // 128 bits + switch (pbs_type) { + case MULTI_BIT: + if (grouping_factor == 0) + PANIC("Multi-bit PBS error: grouping factor should be > 0.") + size_tracker = + scratch_cuda_multi_bit_programmable_bootstrap_128_vector_64( + stream, gpu_index, pbs_buffer, glwe_dimension, polynomial_size, + level_count, input_lwe_ciphertext_count, allocate_gpu_memory); + break; + case CLASSICAL: + size_tracker = scratch_cuda_programmable_bootstrap_128( + stream, gpu_index, pbs_buffer, lwe_dimension, glwe_dimension, + polynomial_size, level_count, input_lwe_ciphertext_count, + allocate_gpu_memory, noise_reduction_type); + break; + default: + PANIC("Error: unsupported cuda PBS type.") + } + } else { + static_assert( + std::is_same_v || std::is_same_v || + std::is_same_v, + "Cuda error: unsupported modulus size: only 32, 64, or 128-bit integer " + "moduli are supported."); } } diff --git a/backends/tfhe-cuda-backend/cuda/src/pbs/programmable_bootstrap_128.cuh b/backends/tfhe-cuda-backend/cuda/src/pbs/programmable_bootstrap_128.cuh index 3540dbe84..e69de29bb 100644 --- a/backends/tfhe-cuda-backend/cuda/src/pbs/programmable_bootstrap_128.cuh +++ b/backends/tfhe-cuda-backend/cuda/src/pbs/programmable_bootstrap_128.cuh @@ -1,45 +0,0 @@ -#ifndef CUDA_PROGRAMMABLE_BOOTSTRAP_128_CUH -#define CUDA_PROGRAMMABLE_BOOTSTRAP_128_CUH -#include "pbs/pbs_128_utilities.h" - -static void execute_scratch_pbs_128( - void *stream, uint32_t gpu_index, int8_t **pbs_buffer, - uint32_t lwe_dimension, uint32_t glwe_dimension, uint32_t polynomial_size, - uint32_t level_count, uint32_t input_lwe_ciphertext_count, - bool allocate_gpu_memory, PBS_MS_REDUCTION_T noise_reduction_type, - uint64_t &size_tracker_on_gpu) { - // The squash noise function receives as input 64-bit integers - size_tracker_on_gpu = scratch_cuda_programmable_bootstrap_128_vector_64( - stream, gpu_index, pbs_buffer, lwe_dimension, glwe_dimension, - polynomial_size, level_count, input_lwe_ciphertext_count, - allocate_gpu_memory, noise_reduction_type); -} -template -static void execute_pbs_128_async( - cudaStream_t const *streams, uint32_t const *gpu_indexes, - uint32_t gpu_count, const LweArrayVariant<__uint128_t> &lwe_array_out, - const std::vector lut_vector, - const LweArrayVariant &lwe_array_in, - void *const *bootstrapping_keys, - CudaModulusSwitchNoiseReductionKeyFFI const *ms_noise_reduction_key, - std::vector pbs_buffer, uint32_t lwe_dimension, - uint32_t glwe_dimension, uint32_t polynomial_size, uint32_t base_log, - uint32_t level_count, uint32_t num_samples) { - - for (uint32_t i = 0; i < gpu_count; i++) { - int num_inputs_on_gpu = get_num_inputs_on_gpu(num_samples, i, gpu_count); - - Torus *current_lwe_array_out = GET_VARIANT_ELEMENT(lwe_array_out, i); - uint64_t *current_lwe_array_in = GET_VARIANT_ELEMENT_64BIT(lwe_array_in, i); - void *zeros = nullptr; - if (ms_noise_reduction_key != nullptr) - zeros = ms_noise_reduction_key->ptr[i]; - - cuda_programmable_bootstrap_lwe_ciphertext_vector_128( - streams[i], gpu_indexes[i], current_lwe_array_out, lut_vector[i], - current_lwe_array_in, bootstrapping_keys[i], ms_noise_reduction_key, - zeros, pbs_buffer[i], lwe_dimension, glwe_dimension, polynomial_size, - base_log, level_count, num_inputs_on_gpu); - } -} -#endif