chore(gpu): refactor the entry points for PBS in the backend

This commit is contained in:
Pedro Alves
2025-08-18 10:57:03 -03:00
committed by Pedro Alves
parent cad4070ebe
commit 57ea3e3e88
10 changed files with 195 additions and 138 deletions

View File

@@ -16,24 +16,27 @@ int32_t cuda_setup_multi_gpu(int device_0_id);
template <typename Torus>
using LweArrayVariant = std::variant<std::vector<Torus *>, Torus *>;
// Macro to define the visitor logic using std::holds_alternative for vectors
#define GET_VARIANT_ELEMENT(variant, index) \
[&] { \
if (std::holds_alternative<std::vector<Torus *>>(variant)) { \
return std::get<std::vector<Torus *>>(variant)[index]; \
} else { \
return std::get<Torus *>(variant); \
} \
}()
// Macro to define the visitor logic using std::holds_alternative for vectors
#define GET_VARIANT_ELEMENT_64BIT(variant, index) \
[&] { \
if (std::holds_alternative<std::vector<uint64_t *>>(variant)) { \
return std::get<std::vector<uint64_t *>>(variant)[index]; \
} else { \
return std::get<uint64_t *>(variant); \
} \
}()
/// get_variant_element() resolves access when the input may be either a single
/// pointer or a vector of pointers. If the variant holds a single pointer, the
/// index is ignored and that pointer is returned; if it holds a vector, the
/// element at `index` is returned.
///
/// This function replaces the previous macro:
/// - Easier to debug and read than a macro
/// - Deduces the pointer type from the variant (no need to name a Torus type
/// explicitly)
/// - Defined in a header, so its eligible for inlining by the optimizer
template <typename Torus>
inline Torus
get_variant_element(const std::variant<std::vector<Torus>, Torus> &variant,
size_t index) {
if (std::holds_alternative<std::vector<Torus>>(variant)) {
return std::get<std::vector<Torus>>(variant)[index];
} else {
return std::get<Torus>(variant);
}
}
int get_active_gpu_count(int num_inputs, int gpu_count);
int get_num_inputs_on_gpu(int total_num_inputs, int gpu_index, int gpu_count);

View File

@@ -6,7 +6,6 @@
#include "integer/radix_ciphertext.h"
#include "keyswitch/keyswitch.h"
#include "pbs/programmable_bootstrap.cuh"
#include "pbs/programmable_bootstrap_128.cuh"
#include "utils/helper_multi_gpu.cuh"
#include <cmath>
#include <functional>
@@ -876,11 +875,11 @@ template <typename InputTorus> struct int_noise_squashing_lut {
get_num_inputs_on_gpu(num_radix_blocks, i, active_gpu_count));
int8_t *gpu_pbs_buffer;
uint64_t size = 0;
execute_scratch_pbs_128(streams[i], gpu_indexes[i], &gpu_pbs_buffer,
params.small_lwe_dimension, params.glwe_dimension,
params.polynomial_size, params.pbs_level,
num_radix_blocks_on_gpu, allocate_gpu_memory,
params.noise_reduction_type, size);
execute_scratch_pbs<__uint128_t>(
streams[i], gpu_indexes[i], &gpu_pbs_buffer, params.glwe_dimension,
params.small_lwe_dimension, params.polynomial_size, params.pbs_level,
params.grouping_factor, num_radix_blocks_on_gpu, params.pbs_type,
allocate_gpu_memory, params.noise_reduction_type, size);
cuda_synchronize_stream(streams[i], gpu_indexes[i]);
if (i == 0) {
size_tracker += size;

View File

@@ -157,12 +157,12 @@ void execute_keyswitch_async(cudaStream_t const *streams,
for (uint i = 0; i < gpu_count; i++) {
int num_samples_on_gpu = get_num_inputs_on_gpu(num_samples, i, gpu_count);
Torus *current_lwe_array_out = GET_VARIANT_ELEMENT(lwe_array_out, i);
Torus *current_lwe_array_out = get_variant_element(lwe_array_out, i);
Torus *current_lwe_output_indexes =
GET_VARIANT_ELEMENT(lwe_output_indexes, i);
Torus *current_lwe_array_in = GET_VARIANT_ELEMENT(lwe_array_in, i);
get_variant_element(lwe_output_indexes, i);
Torus *current_lwe_array_in = get_variant_element(lwe_array_in, i);
Torus *current_lwe_input_indexes =
GET_VARIANT_ELEMENT(lwe_input_indexes, i);
get_variant_element(lwe_input_indexes, i);
// Compute Keyswitch
host_keyswitch_lwe_ciphertext_vector<Torus>(

View File

@@ -202,9 +202,9 @@ __host__ void host_packing_keyswitch_lwe_list_to_glwe(
auto stride_KSK_buffer = glwe_accumulator_size * level_count;
// Shared memory requirement is 4096, 8192, and 16384 bytes respectively for
// 32, 64, and 128-bit Torus elements We want to keep this as a sanity check
uint32_t shared_mem_size = get_shared_mem_size_tgemm<Torus>();
// Shared memory requirement is 4096, 8192, and 16384 bytes respectively for
// 32, 64, and 128-bit Torus elements
// Sanity check: the shared memory size is a constant defined by the algorithm
GPU_ASSERT(shared_mem_size <= 1024 * sizeof(Torus),
"GEMM kernel error: shared memory required might be too large");

View File

@@ -344,7 +344,7 @@ host_integer_decompress(cudaStream_t const *streams,
auto active_gpu_count =
get_active_gpu_count(num_blocks_to_decompress, gpu_count);
if (active_gpu_count == 1) {
execute_pbs_async<Torus>(
execute_pbs_async<Torus, Torus>(
streams, gpu_indexes, active_gpu_count, (Torus *)d_lwe_array_out->ptr,
lut->lwe_indexes_out, lut->lut_vec, lut->lut_indexes_vec,
extracted_lwe, lut->lwe_indexes_in, d_bsks, nullptr, lut->buffer,
@@ -374,7 +374,7 @@ host_integer_decompress(cudaStream_t const *streams,
compression_params.small_lwe_dimension + 1);
/// Apply PBS
execute_pbs_async<Torus>(
execute_pbs_async<Torus, Torus>(
streams, gpu_indexes, active_gpu_count, lwe_after_pbs_vec,
lwe_trivial_indexes_vec, lut->lut_vec, lut->lut_indexes_vec,
lwe_array_in_vec, lwe_trivial_indexes_vec, d_bsks, nullptr,

View File

@@ -558,7 +558,7 @@ __host__ void integer_radix_apply_univariate_lookup_table_kb(
/// Apply PBS to apply a LUT, reduce the noise and go from a small LWE
/// dimension to a big LWE dimension
execute_pbs_async<Torus>(
execute_pbs_async<Torus, Torus>(
streams, gpu_indexes, 1, (Torus *)lwe_array_out->ptr,
lut->lwe_indexes_out, lut->lut_vec, lut->lut_indexes_vec,
lwe_after_ks_vec[0], lwe_trivial_indexes_vec[0], bsks,
@@ -586,7 +586,7 @@ __host__ void integer_radix_apply_univariate_lookup_table_kb(
/// Apply PBS to apply a LUT, reduce the noise and go from a small LWE
/// dimension to a big LWE dimension
execute_pbs_async<Torus>(
execute_pbs_async<Torus, Torus>(
streams, gpu_indexes, active_gpu_count, lwe_after_pbs_vec,
lwe_trivial_indexes_vec, lut->lut_vec, lut->lut_indexes_vec,
lwe_after_ks_vec, lwe_trivial_indexes_vec, bsks, ms_noise_reduction_key,
@@ -665,7 +665,7 @@ __host__ void integer_radix_apply_many_univariate_lookup_table_kb(
/// Apply PBS to apply a LUT, reduce the noise and go from a small LWE
/// dimension to a big LWE dimension
execute_pbs_async<Torus>(
execute_pbs_async<Torus, Torus>(
streams, gpu_indexes, 1, (Torus *)lwe_array_out->ptr,
lut->lwe_indexes_out, lut->lut_vec, lut->lut_indexes_vec,
lwe_after_ks_vec[0], lwe_trivial_indexes_vec[0], bsks,
@@ -693,7 +693,7 @@ __host__ void integer_radix_apply_many_univariate_lookup_table_kb(
/// Apply PBS to apply a LUT, reduce the noise and go from a small LWE
/// dimension to a big LWE dimension
execute_pbs_async<Torus>(
execute_pbs_async<Torus, Torus>(
streams, gpu_indexes, active_gpu_count, lwe_after_pbs_vec,
lwe_trivial_indexes_vec, lut->lut_vec, lut->lut_indexes_vec,
lwe_after_ks_vec, lwe_trivial_indexes_vec, bsks, ms_noise_reduction_key,
@@ -787,7 +787,7 @@ __host__ void integer_radix_apply_bivariate_lookup_table_kb(
/// Apply PBS to apply a LUT, reduce the noise and go from a small LWE
/// dimension to a big LWE dimension
execute_pbs_async<Torus>(
execute_pbs_async<Torus, Torus>(
streams, gpu_indexes, 1, (Torus *)(lwe_array_out->ptr),
lut->lwe_indexes_out, lut->lut_vec, lut->lut_indexes_vec,
lwe_after_ks_vec[0], lwe_trivial_indexes_vec[0], bsks,
@@ -811,7 +811,7 @@ __host__ void integer_radix_apply_bivariate_lookup_table_kb(
/// Apply PBS to apply a LUT, reduce the noise and go from a small LWE
/// dimension to a big LWE dimension
execute_pbs_async<Torus>(
execute_pbs_async<Torus, Torus>(
streams, gpu_indexes, active_gpu_count, lwe_after_pbs_vec,
lwe_trivial_indexes_vec, lut->lut_vec, lut->lut_indexes_vec,
lwe_after_ks_vec, lwe_trivial_indexes_vec, bsks, ms_noise_reduction_key,
@@ -1486,7 +1486,7 @@ void host_full_propagate_inplace(
streams[0], gpu_indexes[0], mem_ptr->tmp_small_lwe_vector, 1, 2,
mem_ptr->tmp_small_lwe_vector, 0, 1);
execute_pbs_async<Torus>(
execute_pbs_async<Torus, Torus>(
streams, gpu_indexes, 1, (Torus *)mem_ptr->tmp_big_lwe_vector->ptr,
mem_ptr->lut->lwe_trivial_indexes, mem_ptr->lut->lut_vec,
mem_ptr->lut->lut_indexes_vec,
@@ -2344,11 +2344,17 @@ __host__ void integer_radix_apply_noise_squashing_kb(
/// Apply PBS to apply a LUT, reduce the noise and go from a small LWE
/// dimension to a big LWE dimension
execute_pbs_128_async<__uint128_t>(
///
/// int_noise_squashing_lut doesn't support a different output or lut
/// indexing than the trivial
execute_pbs_async<uint64_t, __uint128_t>(
streams, gpu_indexes, 1, (__uint128_t *)lwe_array_out->ptr,
lut->lut_vec, lwe_after_ks_vec[0], bsks, ms_noise_reduction_key,
lut->pbs_buffer, small_lwe_dimension, glwe_dimension, polynomial_size,
pbs_base_log, pbs_level, lwe_array_out->num_radix_blocks);
lwe_trivial_indexes_vec[0], lut->lut_vec, lwe_trivial_indexes_vec,
lwe_after_ks_vec[0], lwe_trivial_indexes_vec[0], bsks,
ms_noise_reduction_key, lut->pbs_buffer, glwe_dimension,
small_lwe_dimension, polynomial_size, pbs_base_log, pbs_level,
grouping_factor, lwe_array_out->num_radix_blocks, params.pbs_type, 0,
0);
} else {
/// Make sure all data that should be on GPU 0 is indeed there
cuda_synchronize_stream(streams[0], gpu_indexes[0]);
@@ -2367,11 +2373,15 @@ __host__ void integer_radix_apply_noise_squashing_kb(
ksks, lut->input_big_lwe_dimension, small_lwe_dimension, ks_base_log,
ks_level, lwe_array_out->num_radix_blocks);
execute_pbs_128_async<__uint128_t>(
streams, gpu_indexes, active_gpu_count, lwe_after_pbs_vec, lut->lut_vec,
lwe_after_ks_vec, bsks, ms_noise_reduction_key, lut->pbs_buffer,
small_lwe_dimension, glwe_dimension, polynomial_size, pbs_base_log,
pbs_level, lwe_array_out->num_radix_blocks);
/// int_noise_squashing_lut doesn't support a different output or lut
/// indexing than the trivial
execute_pbs_async<uint64_t, __uint128_t>(
streams, gpu_indexes, active_gpu_count, lwe_after_pbs_vec,
lwe_trivial_indexes_vec, lut->lut_vec, lwe_trivial_indexes_vec,
lwe_after_ks_vec, lwe_trivial_indexes_vec, bsks, ms_noise_reduction_key,
lut->pbs_buffer, glwe_dimension, small_lwe_dimension, polynomial_size,
pbs_base_log, pbs_level, grouping_factor,
lwe_array_out->num_radix_blocks, params.pbs_type, 0, 0);
/// Copy data back to GPU 0 and release vecs
/// In apply noise squashing we always use trivial indexes

View File

@@ -404,7 +404,7 @@ __host__ void host_integer_partial_sum_ciphertexts_vec_kb(
mem_ptr->params.ks_base_log, mem_ptr->params.ks_level,
total_messages);
execute_pbs_async<Torus>(
execute_pbs_async<Torus, Torus>(
streams, gpu_indexes, 1, (Torus *)current_blocks->ptr,
d_pbs_indexes_out, luts_message_carry->lut_vec,
luts_message_carry->lut_indexes_vec, (Torus *)small_lwe_vector->ptr,
@@ -479,7 +479,7 @@ __host__ void host_integer_partial_sum_ciphertexts_vec_kb(
big_lwe_dimension, small_lwe_dimension, mem_ptr->params.ks_base_log,
mem_ptr->params.ks_level, num_radix_blocks);
execute_pbs_async<Torus>(
execute_pbs_async<Torus, Torus>(
streams, gpu_indexes, 1, (Torus *)current_blocks->ptr,
d_pbs_indexes_out, luts_message_carry->lut_vec,
luts_message_carry->lut_indexes_vec, (Torus *)small_lwe_vector->ptr,

View File

@@ -34,7 +34,7 @@ void host_integer_grouped_oprf(
auto lut = mem_ptr->luts;
if (active_gpu_count == 1) {
execute_pbs_async<Torus>(
execute_pbs_async<Torus, Torus>(
streams, gpu_indexes, (uint32_t)1, (Torus *)(radix_lwe_out->ptr),
lut->lwe_indexes_out, lut->lut_vec, lut->lut_indexes_vec,
const_cast<Torus *>(seeded_lwe_input), lut->lwe_indexes_in, bsks,
@@ -60,7 +60,7 @@ void host_integer_grouped_oprf(
active_gpu_count, num_blocks_to_process,
mem_ptr->params.small_lwe_dimension + 1);
execute_pbs_async<Torus>(
execute_pbs_async<Torus, Torus>(
streams, gpu_indexes, active_gpu_count, lwe_after_pbs_vec,
lwe_trivial_indexes_vec, lut->lut_vec, lut->lut_indexes_vec,
lwe_array_in_vec, lwe_trivial_indexes_vec, bsks, ms_noise_reduction_key,

View File

@@ -7,6 +7,7 @@
#include "device.h"
#include "fft/bnsmfft.cuh"
#include "helper_multi_gpu.h"
#include "pbs/pbs_128_utilities.h"
#include "pbs/programmable_bootstrap_multibit.h"
#include "polynomial/polynomial_math.cuh"
@@ -202,15 +203,15 @@ __device__ void mul_ggsw_glwe_in_fourier_domain_2_2_params(
// the buffer in registers to avoid synchronizations and shared memory usage
}
template <typename Torus>
template <typename InputTorus, typename OutputTorus>
void execute_pbs_async(
cudaStream_t const *streams, uint32_t const *gpu_indexes,
uint32_t gpu_count, const LweArrayVariant<Torus> &lwe_array_out,
const LweArrayVariant<Torus> &lwe_output_indexes,
const std::vector<Torus *> lut_vec,
const std::vector<Torus *> lut_indexes_vec,
const LweArrayVariant<Torus> &lwe_array_in,
const LweArrayVariant<Torus> &lwe_input_indexes,
uint32_t gpu_count, const LweArrayVariant<OutputTorus> &lwe_array_out,
const LweArrayVariant<InputTorus> &lwe_output_indexes,
const std::vector<OutputTorus *> lut_vec,
const std::vector<InputTorus *> lut_indexes_vec,
const LweArrayVariant<InputTorus> &lwe_array_in,
const LweArrayVariant<InputTorus> &lwe_input_indexes,
void *const *bootstrapping_keys,
CudaModulusSwitchNoiseReductionKeyFFI const *ms_noise_reduction_key,
std::vector<int8_t *> pbs_buffer, uint32_t glwe_dimension,
@@ -219,8 +220,7 @@ void execute_pbs_async(
uint32_t input_lwe_ciphertext_count, PBS_TYPE pbs_type,
uint32_t num_many_lut, uint32_t lut_stride) {
switch (sizeof(Torus)) {
case sizeof(uint32_t):
if constexpr (std::is_same_v<OutputTorus, uint32_t>) {
// 32 bits
switch (pbs_type) {
case MULTI_BIT:
@@ -238,12 +238,12 @@ void execute_pbs_async(
// Use the macro to get the correct elements for the current iteration
// Handles the case when the input/output are scattered through
// different gpus and when it is not
Torus *current_lwe_array_out = GET_VARIANT_ELEMENT(lwe_array_out, i);
Torus *current_lwe_output_indexes =
GET_VARIANT_ELEMENT(lwe_output_indexes, i);
Torus *current_lwe_array_in = GET_VARIANT_ELEMENT(lwe_array_in, i);
Torus *current_lwe_input_indexes =
GET_VARIANT_ELEMENT(lwe_input_indexes, i);
auto current_lwe_array_out = get_variant_element(lwe_array_out, i);
auto current_lwe_output_indexes =
get_variant_element(lwe_output_indexes, i);
auto current_lwe_array_in = get_variant_element(lwe_array_in, i);
auto current_lwe_input_indexes =
get_variant_element(lwe_input_indexes, i);
cuda_programmable_bootstrap_lwe_ciphertext_vector_32(
streams[i], gpu_indexes[i], current_lwe_array_out,
@@ -257,8 +257,7 @@ void execute_pbs_async(
default:
PANIC("Error: unsupported cuda PBS type.")
}
break;
case sizeof(uint64_t):
} else if constexpr (std::is_same_v<OutputTorus, uint64_t>) {
// 64 bits
switch (pbs_type) {
case MULTI_BIT:
@@ -271,12 +270,12 @@ void execute_pbs_async(
// Use the macro to get the correct elements for the current iteration
// Handles the case when the input/output are scattered through
// different gpus and when it is not
Torus *current_lwe_array_out = GET_VARIANT_ELEMENT(lwe_array_out, i);
Torus *current_lwe_output_indexes =
GET_VARIANT_ELEMENT(lwe_output_indexes, i);
Torus *current_lwe_array_in = GET_VARIANT_ELEMENT(lwe_array_in, i);
Torus *current_lwe_input_indexes =
GET_VARIANT_ELEMENT(lwe_input_indexes, i);
auto current_lwe_array_out = get_variant_element(lwe_array_out, i);
auto current_lwe_output_indexes =
get_variant_element(lwe_output_indexes, i);
auto current_lwe_array_in = get_variant_element(lwe_array_in, i);
auto current_lwe_input_indexes =
get_variant_element(lwe_input_indexes, i);
int gpu_offset =
get_gpu_offset(input_lwe_ciphertext_count, i, gpu_count);
@@ -300,12 +299,12 @@ void execute_pbs_async(
// Use the macro to get the correct elements for the current iteration
// Handles the case when the input/output are scattered through
// different gpus and when it is not
Torus *current_lwe_array_out = GET_VARIANT_ELEMENT(lwe_array_out, i);
Torus *current_lwe_output_indexes =
GET_VARIANT_ELEMENT(lwe_output_indexes, i);
Torus *current_lwe_array_in = GET_VARIANT_ELEMENT(lwe_array_in, i);
Torus *current_lwe_input_indexes =
GET_VARIANT_ELEMENT(lwe_input_indexes, i);
auto current_lwe_array_out = get_variant_element(lwe_array_out, i);
auto current_lwe_output_indexes =
get_variant_element(lwe_output_indexes, i);
auto current_lwe_array_in = get_variant_element(lwe_array_in, i);
auto current_lwe_input_indexes =
get_variant_element(lwe_input_indexes, i);
int gpu_offset =
get_gpu_offset(input_lwe_ciphertext_count, i, gpu_count);
@@ -328,10 +327,81 @@ void execute_pbs_async(
default:
PANIC("Error: unsupported cuda PBS type.")
}
break;
default:
PANIC("Cuda error: unsupported modulus size: only 32 and 64 bit integer "
"moduli are supported.")
} else if constexpr (std::is_same_v<OutputTorus, __uint128_t>) {
// 128 bits
switch (pbs_type) {
case MULTI_BIT:
if (grouping_factor == 0)
PANIC("Multi-bit PBS error: grouping factor should be > 0.")
for (uint i = 0; i < gpu_count; i++) {
int num_inputs_on_gpu =
get_num_inputs_on_gpu(input_lwe_ciphertext_count, i, gpu_count);
// Use the macro to get the correct elements for the current iteration
// Handles the case when the input/output are scattered through
// different gpus and when it is not
auto current_lwe_array_out = get_variant_element(lwe_array_out, i);
auto current_lwe_output_indexes =
get_variant_element(lwe_output_indexes, i);
auto current_lwe_array_in = get_variant_element(lwe_array_in, i);
auto current_lwe_input_indexes =
get_variant_element(lwe_input_indexes, i);
int gpu_offset =
get_gpu_offset(input_lwe_ciphertext_count, i, gpu_count);
auto d_lut_vector_indexes =
lut_indexes_vec[i] + (ptrdiff_t)(gpu_offset);
cuda_multi_bit_programmable_bootstrap_lwe_ciphertext_vector_128(
streams[i], gpu_indexes[i], current_lwe_array_out,
current_lwe_output_indexes, lut_vec[i], d_lut_vector_indexes,
current_lwe_array_in, current_lwe_input_indexes,
bootstrapping_keys[i], pbs_buffer[i], lwe_dimension, glwe_dimension,
polynomial_size, grouping_factor, base_log, level_count,
num_inputs_on_gpu, num_many_lut, lut_stride);
}
break;
case CLASSICAL:
for (uint i = 0; i < gpu_count; i++) {
int num_inputs_on_gpu =
get_num_inputs_on_gpu(input_lwe_ciphertext_count, i, gpu_count);
// Use the macro to get the correct elements for the current iteration
// Handles the case when the input/output are scattered through
// different gpus and when it is not
auto current_lwe_array_out = get_variant_element(lwe_array_out, i);
auto current_lwe_output_indexes =
get_variant_element(lwe_output_indexes, i);
auto current_lwe_array_in = get_variant_element(lwe_array_in, i);
auto current_lwe_input_indexes =
get_variant_element(lwe_input_indexes, i);
int gpu_offset =
get_gpu_offset(input_lwe_ciphertext_count, i, gpu_count);
auto d_lut_vector_indexes =
lut_indexes_vec[i] + (ptrdiff_t)(gpu_offset);
void *zeros = nullptr;
if (ms_noise_reduction_key != nullptr &&
ms_noise_reduction_key->ptr != nullptr)
zeros = ms_noise_reduction_key->ptr[i];
cuda_programmable_bootstrap_lwe_ciphertext_vector_128(
streams[i], gpu_indexes[i], current_lwe_array_out, lut_vec[i],
current_lwe_array_in, bootstrapping_keys[i], ms_noise_reduction_key,
zeros, pbs_buffer[i], lwe_dimension, glwe_dimension,
polynomial_size, base_log, level_count, num_inputs_on_gpu);
}
break;
default:
PANIC("Error: unsupported cuda PBS type.")
}
} else {
static_assert(
std::is_same_v<OutputTorus, uint32_t> ||
std::is_same_v<OutputTorus, uint64_t> ||
std::is_same_v<OutputTorus, __uint128_t>,
"Cuda error: unsupported modulus size: only 32, 64, or 128-bit integer "
"moduli are supported.");
}
}
@@ -344,8 +414,7 @@ void execute_scratch_pbs(cudaStream_t stream, uint32_t gpu_index,
bool allocate_gpu_memory,
PBS_MS_REDUCTION_T noise_reduction_type,
uint64_t &size_tracker) {
switch (sizeof(Torus)) {
case sizeof(uint32_t):
if constexpr (std::is_same_v<Torus, uint32_t>) {
// 32 bits
switch (pbs_type) {
case MULTI_BIT:
@@ -359,8 +428,7 @@ void execute_scratch_pbs(cudaStream_t stream, uint32_t gpu_index,
default:
PANIC("Error: unsupported cuda PBS type.")
}
break;
case sizeof(uint64_t):
} else if constexpr (std::is_same_v<Torus, uint64_t>) {
// 64 bits
switch (pbs_type) {
case MULTI_BIT:
@@ -379,10 +447,32 @@ void execute_scratch_pbs(cudaStream_t stream, uint32_t gpu_index,
default:
PANIC("Error: unsupported cuda PBS type.")
}
break;
default:
PANIC("Cuda error: unsupported modulus size: only 32 and 64 bit integer "
"moduli are supported.")
} else if constexpr (std::is_same_v<Torus, __uint128_t>) {
// 128 bits
switch (pbs_type) {
case MULTI_BIT:
if (grouping_factor == 0)
PANIC("Multi-bit PBS error: grouping factor should be > 0.")
size_tracker =
scratch_cuda_multi_bit_programmable_bootstrap_128_vector_64(
stream, gpu_index, pbs_buffer, glwe_dimension, polynomial_size,
level_count, input_lwe_ciphertext_count, allocate_gpu_memory);
break;
case CLASSICAL:
size_tracker = scratch_cuda_programmable_bootstrap_128(
stream, gpu_index, pbs_buffer, lwe_dimension, glwe_dimension,
polynomial_size, level_count, input_lwe_ciphertext_count,
allocate_gpu_memory, noise_reduction_type);
break;
default:
PANIC("Error: unsupported cuda PBS type.")
}
} else {
static_assert(
std::is_same_v<Torus, uint32_t> || std::is_same_v<Torus, uint64_t> ||
std::is_same_v<Torus, __uint128_t>,
"Cuda error: unsupported modulus size: only 32, 64, or 128-bit integer "
"moduli are supported.");
}
}

View File

@@ -1,45 +0,0 @@
#ifndef CUDA_PROGRAMMABLE_BOOTSTRAP_128_CUH
#define CUDA_PROGRAMMABLE_BOOTSTRAP_128_CUH
#include "pbs/pbs_128_utilities.h"
static void execute_scratch_pbs_128(
void *stream, uint32_t gpu_index, int8_t **pbs_buffer,
uint32_t lwe_dimension, uint32_t glwe_dimension, uint32_t polynomial_size,
uint32_t level_count, uint32_t input_lwe_ciphertext_count,
bool allocate_gpu_memory, PBS_MS_REDUCTION_T noise_reduction_type,
uint64_t &size_tracker_on_gpu) {
// The squash noise function receives as input 64-bit integers
size_tracker_on_gpu = scratch_cuda_programmable_bootstrap_128_vector_64(
stream, gpu_index, pbs_buffer, lwe_dimension, glwe_dimension,
polynomial_size, level_count, input_lwe_ciphertext_count,
allocate_gpu_memory, noise_reduction_type);
}
template <typename Torus>
static void execute_pbs_128_async(
cudaStream_t const *streams, uint32_t const *gpu_indexes,
uint32_t gpu_count, const LweArrayVariant<__uint128_t> &lwe_array_out,
const std::vector<Torus *> lut_vector,
const LweArrayVariant<uint64_t> &lwe_array_in,
void *const *bootstrapping_keys,
CudaModulusSwitchNoiseReductionKeyFFI const *ms_noise_reduction_key,
std::vector<int8_t *> pbs_buffer, uint32_t lwe_dimension,
uint32_t glwe_dimension, uint32_t polynomial_size, uint32_t base_log,
uint32_t level_count, uint32_t num_samples) {
for (uint32_t i = 0; i < gpu_count; i++) {
int num_inputs_on_gpu = get_num_inputs_on_gpu(num_samples, i, gpu_count);
Torus *current_lwe_array_out = GET_VARIANT_ELEMENT(lwe_array_out, i);
uint64_t *current_lwe_array_in = GET_VARIANT_ELEMENT_64BIT(lwe_array_in, i);
void *zeros = nullptr;
if (ms_noise_reduction_key != nullptr)
zeros = ms_noise_reduction_key->ptr[i];
cuda_programmable_bootstrap_lwe_ciphertext_vector_128(
streams[i], gpu_indexes[i], current_lwe_array_out, lut_vector[i],
current_lwe_array_in, bootstrapping_keys[i], ms_noise_reduction_key,
zeros, pbs_buffer[i], lwe_dimension, glwe_dimension, polynomial_size,
base_log, level_count, num_inputs_on_gpu);
}
}
#endif