mirror of
https://github.com/zama-ai/tfhe-rs.git
synced 2026-01-09 14:47:56 -05:00
chore(gpu): refactor the entry points for PBS in the backend
This commit is contained in:
@@ -16,24 +16,27 @@ int32_t cuda_setup_multi_gpu(int device_0_id);
|
||||
template <typename Torus>
|
||||
using LweArrayVariant = std::variant<std::vector<Torus *>, Torus *>;
|
||||
|
||||
// Macro to define the visitor logic using std::holds_alternative for vectors
|
||||
#define GET_VARIANT_ELEMENT(variant, index) \
|
||||
[&] { \
|
||||
if (std::holds_alternative<std::vector<Torus *>>(variant)) { \
|
||||
return std::get<std::vector<Torus *>>(variant)[index]; \
|
||||
} else { \
|
||||
return std::get<Torus *>(variant); \
|
||||
} \
|
||||
}()
|
||||
// Macro to define the visitor logic using std::holds_alternative for vectors
|
||||
#define GET_VARIANT_ELEMENT_64BIT(variant, index) \
|
||||
[&] { \
|
||||
if (std::holds_alternative<std::vector<uint64_t *>>(variant)) { \
|
||||
return std::get<std::vector<uint64_t *>>(variant)[index]; \
|
||||
} else { \
|
||||
return std::get<uint64_t *>(variant); \
|
||||
} \
|
||||
}()
|
||||
/// get_variant_element() resolves access when the input may be either a single
|
||||
/// pointer or a vector of pointers. If the variant holds a single pointer, the
|
||||
/// index is ignored and that pointer is returned; if it holds a vector, the
|
||||
/// element at `index` is returned.
|
||||
///
|
||||
/// This function replaces the previous macro:
|
||||
/// - Easier to debug and read than a macro
|
||||
/// - Deduces the pointer type from the variant (no need to name a Torus type
|
||||
/// explicitly)
|
||||
/// - Defined in a header, so it’s eligible for inlining by the optimizer
|
||||
template <typename Torus>
|
||||
inline Torus
|
||||
get_variant_element(const std::variant<std::vector<Torus>, Torus> &variant,
|
||||
size_t index) {
|
||||
if (std::holds_alternative<std::vector<Torus>>(variant)) {
|
||||
return std::get<std::vector<Torus>>(variant)[index];
|
||||
} else {
|
||||
return std::get<Torus>(variant);
|
||||
}
|
||||
}
|
||||
|
||||
int get_active_gpu_count(int num_inputs, int gpu_count);
|
||||
|
||||
int get_num_inputs_on_gpu(int total_num_inputs, int gpu_index, int gpu_count);
|
||||
|
||||
@@ -6,7 +6,6 @@
|
||||
#include "integer/radix_ciphertext.h"
|
||||
#include "keyswitch/keyswitch.h"
|
||||
#include "pbs/programmable_bootstrap.cuh"
|
||||
#include "pbs/programmable_bootstrap_128.cuh"
|
||||
#include "utils/helper_multi_gpu.cuh"
|
||||
#include <cmath>
|
||||
#include <functional>
|
||||
@@ -876,11 +875,11 @@ template <typename InputTorus> struct int_noise_squashing_lut {
|
||||
get_num_inputs_on_gpu(num_radix_blocks, i, active_gpu_count));
|
||||
int8_t *gpu_pbs_buffer;
|
||||
uint64_t size = 0;
|
||||
execute_scratch_pbs_128(streams[i], gpu_indexes[i], &gpu_pbs_buffer,
|
||||
params.small_lwe_dimension, params.glwe_dimension,
|
||||
params.polynomial_size, params.pbs_level,
|
||||
num_radix_blocks_on_gpu, allocate_gpu_memory,
|
||||
params.noise_reduction_type, size);
|
||||
execute_scratch_pbs<__uint128_t>(
|
||||
streams[i], gpu_indexes[i], &gpu_pbs_buffer, params.glwe_dimension,
|
||||
params.small_lwe_dimension, params.polynomial_size, params.pbs_level,
|
||||
params.grouping_factor, num_radix_blocks_on_gpu, params.pbs_type,
|
||||
allocate_gpu_memory, params.noise_reduction_type, size);
|
||||
cuda_synchronize_stream(streams[i], gpu_indexes[i]);
|
||||
if (i == 0) {
|
||||
size_tracker += size;
|
||||
|
||||
@@ -157,12 +157,12 @@ void execute_keyswitch_async(cudaStream_t const *streams,
|
||||
for (uint i = 0; i < gpu_count; i++) {
|
||||
int num_samples_on_gpu = get_num_inputs_on_gpu(num_samples, i, gpu_count);
|
||||
|
||||
Torus *current_lwe_array_out = GET_VARIANT_ELEMENT(lwe_array_out, i);
|
||||
Torus *current_lwe_array_out = get_variant_element(lwe_array_out, i);
|
||||
Torus *current_lwe_output_indexes =
|
||||
GET_VARIANT_ELEMENT(lwe_output_indexes, i);
|
||||
Torus *current_lwe_array_in = GET_VARIANT_ELEMENT(lwe_array_in, i);
|
||||
get_variant_element(lwe_output_indexes, i);
|
||||
Torus *current_lwe_array_in = get_variant_element(lwe_array_in, i);
|
||||
Torus *current_lwe_input_indexes =
|
||||
GET_VARIANT_ELEMENT(lwe_input_indexes, i);
|
||||
get_variant_element(lwe_input_indexes, i);
|
||||
|
||||
// Compute Keyswitch
|
||||
host_keyswitch_lwe_ciphertext_vector<Torus>(
|
||||
|
||||
@@ -202,9 +202,9 @@ __host__ void host_packing_keyswitch_lwe_list_to_glwe(
|
||||
|
||||
auto stride_KSK_buffer = glwe_accumulator_size * level_count;
|
||||
|
||||
// Shared memory requirement is 4096, 8192, and 16384 bytes respectively for
|
||||
// 32, 64, and 128-bit Torus elements We want to keep this as a sanity check
|
||||
uint32_t shared_mem_size = get_shared_mem_size_tgemm<Torus>();
|
||||
// Shared memory requirement is 4096, 8192, and 16384 bytes respectively for
|
||||
// 32, 64, and 128-bit Torus elements
|
||||
// Sanity check: the shared memory size is a constant defined by the algorithm
|
||||
GPU_ASSERT(shared_mem_size <= 1024 * sizeof(Torus),
|
||||
"GEMM kernel error: shared memory required might be too large");
|
||||
|
||||
@@ -344,7 +344,7 @@ host_integer_decompress(cudaStream_t const *streams,
|
||||
auto active_gpu_count =
|
||||
get_active_gpu_count(num_blocks_to_decompress, gpu_count);
|
||||
if (active_gpu_count == 1) {
|
||||
execute_pbs_async<Torus>(
|
||||
execute_pbs_async<Torus, Torus>(
|
||||
streams, gpu_indexes, active_gpu_count, (Torus *)d_lwe_array_out->ptr,
|
||||
lut->lwe_indexes_out, lut->lut_vec, lut->lut_indexes_vec,
|
||||
extracted_lwe, lut->lwe_indexes_in, d_bsks, nullptr, lut->buffer,
|
||||
@@ -374,7 +374,7 @@ host_integer_decompress(cudaStream_t const *streams,
|
||||
compression_params.small_lwe_dimension + 1);
|
||||
|
||||
/// Apply PBS
|
||||
execute_pbs_async<Torus>(
|
||||
execute_pbs_async<Torus, Torus>(
|
||||
streams, gpu_indexes, active_gpu_count, lwe_after_pbs_vec,
|
||||
lwe_trivial_indexes_vec, lut->lut_vec, lut->lut_indexes_vec,
|
||||
lwe_array_in_vec, lwe_trivial_indexes_vec, d_bsks, nullptr,
|
||||
|
||||
@@ -558,7 +558,7 @@ __host__ void integer_radix_apply_univariate_lookup_table_kb(
|
||||
|
||||
/// Apply PBS to apply a LUT, reduce the noise and go from a small LWE
|
||||
/// dimension to a big LWE dimension
|
||||
execute_pbs_async<Torus>(
|
||||
execute_pbs_async<Torus, Torus>(
|
||||
streams, gpu_indexes, 1, (Torus *)lwe_array_out->ptr,
|
||||
lut->lwe_indexes_out, lut->lut_vec, lut->lut_indexes_vec,
|
||||
lwe_after_ks_vec[0], lwe_trivial_indexes_vec[0], bsks,
|
||||
@@ -586,7 +586,7 @@ __host__ void integer_radix_apply_univariate_lookup_table_kb(
|
||||
|
||||
/// Apply PBS to apply a LUT, reduce the noise and go from a small LWE
|
||||
/// dimension to a big LWE dimension
|
||||
execute_pbs_async<Torus>(
|
||||
execute_pbs_async<Torus, Torus>(
|
||||
streams, gpu_indexes, active_gpu_count, lwe_after_pbs_vec,
|
||||
lwe_trivial_indexes_vec, lut->lut_vec, lut->lut_indexes_vec,
|
||||
lwe_after_ks_vec, lwe_trivial_indexes_vec, bsks, ms_noise_reduction_key,
|
||||
@@ -665,7 +665,7 @@ __host__ void integer_radix_apply_many_univariate_lookup_table_kb(
|
||||
|
||||
/// Apply PBS to apply a LUT, reduce the noise and go from a small LWE
|
||||
/// dimension to a big LWE dimension
|
||||
execute_pbs_async<Torus>(
|
||||
execute_pbs_async<Torus, Torus>(
|
||||
streams, gpu_indexes, 1, (Torus *)lwe_array_out->ptr,
|
||||
lut->lwe_indexes_out, lut->lut_vec, lut->lut_indexes_vec,
|
||||
lwe_after_ks_vec[0], lwe_trivial_indexes_vec[0], bsks,
|
||||
@@ -693,7 +693,7 @@ __host__ void integer_radix_apply_many_univariate_lookup_table_kb(
|
||||
|
||||
/// Apply PBS to apply a LUT, reduce the noise and go from a small LWE
|
||||
/// dimension to a big LWE dimension
|
||||
execute_pbs_async<Torus>(
|
||||
execute_pbs_async<Torus, Torus>(
|
||||
streams, gpu_indexes, active_gpu_count, lwe_after_pbs_vec,
|
||||
lwe_trivial_indexes_vec, lut->lut_vec, lut->lut_indexes_vec,
|
||||
lwe_after_ks_vec, lwe_trivial_indexes_vec, bsks, ms_noise_reduction_key,
|
||||
@@ -787,7 +787,7 @@ __host__ void integer_radix_apply_bivariate_lookup_table_kb(
|
||||
|
||||
/// Apply PBS to apply a LUT, reduce the noise and go from a small LWE
|
||||
/// dimension to a big LWE dimension
|
||||
execute_pbs_async<Torus>(
|
||||
execute_pbs_async<Torus, Torus>(
|
||||
streams, gpu_indexes, 1, (Torus *)(lwe_array_out->ptr),
|
||||
lut->lwe_indexes_out, lut->lut_vec, lut->lut_indexes_vec,
|
||||
lwe_after_ks_vec[0], lwe_trivial_indexes_vec[0], bsks,
|
||||
@@ -811,7 +811,7 @@ __host__ void integer_radix_apply_bivariate_lookup_table_kb(
|
||||
|
||||
/// Apply PBS to apply a LUT, reduce the noise and go from a small LWE
|
||||
/// dimension to a big LWE dimension
|
||||
execute_pbs_async<Torus>(
|
||||
execute_pbs_async<Torus, Torus>(
|
||||
streams, gpu_indexes, active_gpu_count, lwe_after_pbs_vec,
|
||||
lwe_trivial_indexes_vec, lut->lut_vec, lut->lut_indexes_vec,
|
||||
lwe_after_ks_vec, lwe_trivial_indexes_vec, bsks, ms_noise_reduction_key,
|
||||
@@ -1486,7 +1486,7 @@ void host_full_propagate_inplace(
|
||||
streams[0], gpu_indexes[0], mem_ptr->tmp_small_lwe_vector, 1, 2,
|
||||
mem_ptr->tmp_small_lwe_vector, 0, 1);
|
||||
|
||||
execute_pbs_async<Torus>(
|
||||
execute_pbs_async<Torus, Torus>(
|
||||
streams, gpu_indexes, 1, (Torus *)mem_ptr->tmp_big_lwe_vector->ptr,
|
||||
mem_ptr->lut->lwe_trivial_indexes, mem_ptr->lut->lut_vec,
|
||||
mem_ptr->lut->lut_indexes_vec,
|
||||
@@ -2344,11 +2344,17 @@ __host__ void integer_radix_apply_noise_squashing_kb(
|
||||
|
||||
/// Apply PBS to apply a LUT, reduce the noise and go from a small LWE
|
||||
/// dimension to a big LWE dimension
|
||||
execute_pbs_128_async<__uint128_t>(
|
||||
///
|
||||
/// int_noise_squashing_lut doesn't support a different output or lut
|
||||
/// indexing than the trivial
|
||||
execute_pbs_async<uint64_t, __uint128_t>(
|
||||
streams, gpu_indexes, 1, (__uint128_t *)lwe_array_out->ptr,
|
||||
lut->lut_vec, lwe_after_ks_vec[0], bsks, ms_noise_reduction_key,
|
||||
lut->pbs_buffer, small_lwe_dimension, glwe_dimension, polynomial_size,
|
||||
pbs_base_log, pbs_level, lwe_array_out->num_radix_blocks);
|
||||
lwe_trivial_indexes_vec[0], lut->lut_vec, lwe_trivial_indexes_vec,
|
||||
lwe_after_ks_vec[0], lwe_trivial_indexes_vec[0], bsks,
|
||||
ms_noise_reduction_key, lut->pbs_buffer, glwe_dimension,
|
||||
small_lwe_dimension, polynomial_size, pbs_base_log, pbs_level,
|
||||
grouping_factor, lwe_array_out->num_radix_blocks, params.pbs_type, 0,
|
||||
0);
|
||||
} else {
|
||||
/// Make sure all data that should be on GPU 0 is indeed there
|
||||
cuda_synchronize_stream(streams[0], gpu_indexes[0]);
|
||||
@@ -2367,11 +2373,15 @@ __host__ void integer_radix_apply_noise_squashing_kb(
|
||||
ksks, lut->input_big_lwe_dimension, small_lwe_dimension, ks_base_log,
|
||||
ks_level, lwe_array_out->num_radix_blocks);
|
||||
|
||||
execute_pbs_128_async<__uint128_t>(
|
||||
streams, gpu_indexes, active_gpu_count, lwe_after_pbs_vec, lut->lut_vec,
|
||||
lwe_after_ks_vec, bsks, ms_noise_reduction_key, lut->pbs_buffer,
|
||||
small_lwe_dimension, glwe_dimension, polynomial_size, pbs_base_log,
|
||||
pbs_level, lwe_array_out->num_radix_blocks);
|
||||
/// int_noise_squashing_lut doesn't support a different output or lut
|
||||
/// indexing than the trivial
|
||||
execute_pbs_async<uint64_t, __uint128_t>(
|
||||
streams, gpu_indexes, active_gpu_count, lwe_after_pbs_vec,
|
||||
lwe_trivial_indexes_vec, lut->lut_vec, lwe_trivial_indexes_vec,
|
||||
lwe_after_ks_vec, lwe_trivial_indexes_vec, bsks, ms_noise_reduction_key,
|
||||
lut->pbs_buffer, glwe_dimension, small_lwe_dimension, polynomial_size,
|
||||
pbs_base_log, pbs_level, grouping_factor,
|
||||
lwe_array_out->num_radix_blocks, params.pbs_type, 0, 0);
|
||||
|
||||
/// Copy data back to GPU 0 and release vecs
|
||||
/// In apply noise squashing we always use trivial indexes
|
||||
|
||||
@@ -404,7 +404,7 @@ __host__ void host_integer_partial_sum_ciphertexts_vec_kb(
|
||||
mem_ptr->params.ks_base_log, mem_ptr->params.ks_level,
|
||||
total_messages);
|
||||
|
||||
execute_pbs_async<Torus>(
|
||||
execute_pbs_async<Torus, Torus>(
|
||||
streams, gpu_indexes, 1, (Torus *)current_blocks->ptr,
|
||||
d_pbs_indexes_out, luts_message_carry->lut_vec,
|
||||
luts_message_carry->lut_indexes_vec, (Torus *)small_lwe_vector->ptr,
|
||||
@@ -479,7 +479,7 @@ __host__ void host_integer_partial_sum_ciphertexts_vec_kb(
|
||||
big_lwe_dimension, small_lwe_dimension, mem_ptr->params.ks_base_log,
|
||||
mem_ptr->params.ks_level, num_radix_blocks);
|
||||
|
||||
execute_pbs_async<Torus>(
|
||||
execute_pbs_async<Torus, Torus>(
|
||||
streams, gpu_indexes, 1, (Torus *)current_blocks->ptr,
|
||||
d_pbs_indexes_out, luts_message_carry->lut_vec,
|
||||
luts_message_carry->lut_indexes_vec, (Torus *)small_lwe_vector->ptr,
|
||||
|
||||
@@ -34,7 +34,7 @@ void host_integer_grouped_oprf(
|
||||
auto lut = mem_ptr->luts;
|
||||
|
||||
if (active_gpu_count == 1) {
|
||||
execute_pbs_async<Torus>(
|
||||
execute_pbs_async<Torus, Torus>(
|
||||
streams, gpu_indexes, (uint32_t)1, (Torus *)(radix_lwe_out->ptr),
|
||||
lut->lwe_indexes_out, lut->lut_vec, lut->lut_indexes_vec,
|
||||
const_cast<Torus *>(seeded_lwe_input), lut->lwe_indexes_in, bsks,
|
||||
@@ -60,7 +60,7 @@ void host_integer_grouped_oprf(
|
||||
active_gpu_count, num_blocks_to_process,
|
||||
mem_ptr->params.small_lwe_dimension + 1);
|
||||
|
||||
execute_pbs_async<Torus>(
|
||||
execute_pbs_async<Torus, Torus>(
|
||||
streams, gpu_indexes, active_gpu_count, lwe_after_pbs_vec,
|
||||
lwe_trivial_indexes_vec, lut->lut_vec, lut->lut_indexes_vec,
|
||||
lwe_array_in_vec, lwe_trivial_indexes_vec, bsks, ms_noise_reduction_key,
|
||||
|
||||
@@ -7,6 +7,7 @@
|
||||
#include "device.h"
|
||||
#include "fft/bnsmfft.cuh"
|
||||
#include "helper_multi_gpu.h"
|
||||
#include "pbs/pbs_128_utilities.h"
|
||||
#include "pbs/programmable_bootstrap_multibit.h"
|
||||
#include "polynomial/polynomial_math.cuh"
|
||||
|
||||
@@ -202,15 +203,15 @@ __device__ void mul_ggsw_glwe_in_fourier_domain_2_2_params(
|
||||
// the buffer in registers to avoid synchronizations and shared memory usage
|
||||
}
|
||||
|
||||
template <typename Torus>
|
||||
template <typename InputTorus, typename OutputTorus>
|
||||
void execute_pbs_async(
|
||||
cudaStream_t const *streams, uint32_t const *gpu_indexes,
|
||||
uint32_t gpu_count, const LweArrayVariant<Torus> &lwe_array_out,
|
||||
const LweArrayVariant<Torus> &lwe_output_indexes,
|
||||
const std::vector<Torus *> lut_vec,
|
||||
const std::vector<Torus *> lut_indexes_vec,
|
||||
const LweArrayVariant<Torus> &lwe_array_in,
|
||||
const LweArrayVariant<Torus> &lwe_input_indexes,
|
||||
uint32_t gpu_count, const LweArrayVariant<OutputTorus> &lwe_array_out,
|
||||
const LweArrayVariant<InputTorus> &lwe_output_indexes,
|
||||
const std::vector<OutputTorus *> lut_vec,
|
||||
const std::vector<InputTorus *> lut_indexes_vec,
|
||||
const LweArrayVariant<InputTorus> &lwe_array_in,
|
||||
const LweArrayVariant<InputTorus> &lwe_input_indexes,
|
||||
void *const *bootstrapping_keys,
|
||||
CudaModulusSwitchNoiseReductionKeyFFI const *ms_noise_reduction_key,
|
||||
std::vector<int8_t *> pbs_buffer, uint32_t glwe_dimension,
|
||||
@@ -219,8 +220,7 @@ void execute_pbs_async(
|
||||
uint32_t input_lwe_ciphertext_count, PBS_TYPE pbs_type,
|
||||
uint32_t num_many_lut, uint32_t lut_stride) {
|
||||
|
||||
switch (sizeof(Torus)) {
|
||||
case sizeof(uint32_t):
|
||||
if constexpr (std::is_same_v<OutputTorus, uint32_t>) {
|
||||
// 32 bits
|
||||
switch (pbs_type) {
|
||||
case MULTI_BIT:
|
||||
@@ -238,12 +238,12 @@ void execute_pbs_async(
|
||||
// Use the macro to get the correct elements for the current iteration
|
||||
// Handles the case when the input/output are scattered through
|
||||
// different gpus and when it is not
|
||||
Torus *current_lwe_array_out = GET_VARIANT_ELEMENT(lwe_array_out, i);
|
||||
Torus *current_lwe_output_indexes =
|
||||
GET_VARIANT_ELEMENT(lwe_output_indexes, i);
|
||||
Torus *current_lwe_array_in = GET_VARIANT_ELEMENT(lwe_array_in, i);
|
||||
Torus *current_lwe_input_indexes =
|
||||
GET_VARIANT_ELEMENT(lwe_input_indexes, i);
|
||||
auto current_lwe_array_out = get_variant_element(lwe_array_out, i);
|
||||
auto current_lwe_output_indexes =
|
||||
get_variant_element(lwe_output_indexes, i);
|
||||
auto current_lwe_array_in = get_variant_element(lwe_array_in, i);
|
||||
auto current_lwe_input_indexes =
|
||||
get_variant_element(lwe_input_indexes, i);
|
||||
|
||||
cuda_programmable_bootstrap_lwe_ciphertext_vector_32(
|
||||
streams[i], gpu_indexes[i], current_lwe_array_out,
|
||||
@@ -257,8 +257,7 @@ void execute_pbs_async(
|
||||
default:
|
||||
PANIC("Error: unsupported cuda PBS type.")
|
||||
}
|
||||
break;
|
||||
case sizeof(uint64_t):
|
||||
} else if constexpr (std::is_same_v<OutputTorus, uint64_t>) {
|
||||
// 64 bits
|
||||
switch (pbs_type) {
|
||||
case MULTI_BIT:
|
||||
@@ -271,12 +270,12 @@ void execute_pbs_async(
|
||||
// Use the macro to get the correct elements for the current iteration
|
||||
// Handles the case when the input/output are scattered through
|
||||
// different gpus and when it is not
|
||||
Torus *current_lwe_array_out = GET_VARIANT_ELEMENT(lwe_array_out, i);
|
||||
Torus *current_lwe_output_indexes =
|
||||
GET_VARIANT_ELEMENT(lwe_output_indexes, i);
|
||||
Torus *current_lwe_array_in = GET_VARIANT_ELEMENT(lwe_array_in, i);
|
||||
Torus *current_lwe_input_indexes =
|
||||
GET_VARIANT_ELEMENT(lwe_input_indexes, i);
|
||||
auto current_lwe_array_out = get_variant_element(lwe_array_out, i);
|
||||
auto current_lwe_output_indexes =
|
||||
get_variant_element(lwe_output_indexes, i);
|
||||
auto current_lwe_array_in = get_variant_element(lwe_array_in, i);
|
||||
auto current_lwe_input_indexes =
|
||||
get_variant_element(lwe_input_indexes, i);
|
||||
|
||||
int gpu_offset =
|
||||
get_gpu_offset(input_lwe_ciphertext_count, i, gpu_count);
|
||||
@@ -300,12 +299,12 @@ void execute_pbs_async(
|
||||
// Use the macro to get the correct elements for the current iteration
|
||||
// Handles the case when the input/output are scattered through
|
||||
// different gpus and when it is not
|
||||
Torus *current_lwe_array_out = GET_VARIANT_ELEMENT(lwe_array_out, i);
|
||||
Torus *current_lwe_output_indexes =
|
||||
GET_VARIANT_ELEMENT(lwe_output_indexes, i);
|
||||
Torus *current_lwe_array_in = GET_VARIANT_ELEMENT(lwe_array_in, i);
|
||||
Torus *current_lwe_input_indexes =
|
||||
GET_VARIANT_ELEMENT(lwe_input_indexes, i);
|
||||
auto current_lwe_array_out = get_variant_element(lwe_array_out, i);
|
||||
auto current_lwe_output_indexes =
|
||||
get_variant_element(lwe_output_indexes, i);
|
||||
auto current_lwe_array_in = get_variant_element(lwe_array_in, i);
|
||||
auto current_lwe_input_indexes =
|
||||
get_variant_element(lwe_input_indexes, i);
|
||||
|
||||
int gpu_offset =
|
||||
get_gpu_offset(input_lwe_ciphertext_count, i, gpu_count);
|
||||
@@ -328,10 +327,81 @@ void execute_pbs_async(
|
||||
default:
|
||||
PANIC("Error: unsupported cuda PBS type.")
|
||||
}
|
||||
break;
|
||||
default:
|
||||
PANIC("Cuda error: unsupported modulus size: only 32 and 64 bit integer "
|
||||
"moduli are supported.")
|
||||
} else if constexpr (std::is_same_v<OutputTorus, __uint128_t>) {
|
||||
// 128 bits
|
||||
switch (pbs_type) {
|
||||
case MULTI_BIT:
|
||||
if (grouping_factor == 0)
|
||||
PANIC("Multi-bit PBS error: grouping factor should be > 0.")
|
||||
for (uint i = 0; i < gpu_count; i++) {
|
||||
int num_inputs_on_gpu =
|
||||
get_num_inputs_on_gpu(input_lwe_ciphertext_count, i, gpu_count);
|
||||
|
||||
// Use the macro to get the correct elements for the current iteration
|
||||
// Handles the case when the input/output are scattered through
|
||||
// different gpus and when it is not
|
||||
auto current_lwe_array_out = get_variant_element(lwe_array_out, i);
|
||||
auto current_lwe_output_indexes =
|
||||
get_variant_element(lwe_output_indexes, i);
|
||||
auto current_lwe_array_in = get_variant_element(lwe_array_in, i);
|
||||
auto current_lwe_input_indexes =
|
||||
get_variant_element(lwe_input_indexes, i);
|
||||
|
||||
int gpu_offset =
|
||||
get_gpu_offset(input_lwe_ciphertext_count, i, gpu_count);
|
||||
auto d_lut_vector_indexes =
|
||||
lut_indexes_vec[i] + (ptrdiff_t)(gpu_offset);
|
||||
|
||||
cuda_multi_bit_programmable_bootstrap_lwe_ciphertext_vector_128(
|
||||
streams[i], gpu_indexes[i], current_lwe_array_out,
|
||||
current_lwe_output_indexes, lut_vec[i], d_lut_vector_indexes,
|
||||
current_lwe_array_in, current_lwe_input_indexes,
|
||||
bootstrapping_keys[i], pbs_buffer[i], lwe_dimension, glwe_dimension,
|
||||
polynomial_size, grouping_factor, base_log, level_count,
|
||||
num_inputs_on_gpu, num_many_lut, lut_stride);
|
||||
}
|
||||
break;
|
||||
case CLASSICAL:
|
||||
for (uint i = 0; i < gpu_count; i++) {
|
||||
int num_inputs_on_gpu =
|
||||
get_num_inputs_on_gpu(input_lwe_ciphertext_count, i, gpu_count);
|
||||
|
||||
// Use the macro to get the correct elements for the current iteration
|
||||
// Handles the case when the input/output are scattered through
|
||||
// different gpus and when it is not
|
||||
auto current_lwe_array_out = get_variant_element(lwe_array_out, i);
|
||||
auto current_lwe_output_indexes =
|
||||
get_variant_element(lwe_output_indexes, i);
|
||||
auto current_lwe_array_in = get_variant_element(lwe_array_in, i);
|
||||
auto current_lwe_input_indexes =
|
||||
get_variant_element(lwe_input_indexes, i);
|
||||
|
||||
int gpu_offset =
|
||||
get_gpu_offset(input_lwe_ciphertext_count, i, gpu_count);
|
||||
auto d_lut_vector_indexes =
|
||||
lut_indexes_vec[i] + (ptrdiff_t)(gpu_offset);
|
||||
|
||||
void *zeros = nullptr;
|
||||
if (ms_noise_reduction_key != nullptr &&
|
||||
ms_noise_reduction_key->ptr != nullptr)
|
||||
zeros = ms_noise_reduction_key->ptr[i];
|
||||
cuda_programmable_bootstrap_lwe_ciphertext_vector_128(
|
||||
streams[i], gpu_indexes[i], current_lwe_array_out, lut_vec[i],
|
||||
current_lwe_array_in, bootstrapping_keys[i], ms_noise_reduction_key,
|
||||
zeros, pbs_buffer[i], lwe_dimension, glwe_dimension,
|
||||
polynomial_size, base_log, level_count, num_inputs_on_gpu);
|
||||
}
|
||||
break;
|
||||
default:
|
||||
PANIC("Error: unsupported cuda PBS type.")
|
||||
}
|
||||
} else {
|
||||
static_assert(
|
||||
std::is_same_v<OutputTorus, uint32_t> ||
|
||||
std::is_same_v<OutputTorus, uint64_t> ||
|
||||
std::is_same_v<OutputTorus, __uint128_t>,
|
||||
"Cuda error: unsupported modulus size: only 32, 64, or 128-bit integer "
|
||||
"moduli are supported.");
|
||||
}
|
||||
}
|
||||
|
||||
@@ -344,8 +414,7 @@ void execute_scratch_pbs(cudaStream_t stream, uint32_t gpu_index,
|
||||
bool allocate_gpu_memory,
|
||||
PBS_MS_REDUCTION_T noise_reduction_type,
|
||||
uint64_t &size_tracker) {
|
||||
switch (sizeof(Torus)) {
|
||||
case sizeof(uint32_t):
|
||||
if constexpr (std::is_same_v<Torus, uint32_t>) {
|
||||
// 32 bits
|
||||
switch (pbs_type) {
|
||||
case MULTI_BIT:
|
||||
@@ -359,8 +428,7 @@ void execute_scratch_pbs(cudaStream_t stream, uint32_t gpu_index,
|
||||
default:
|
||||
PANIC("Error: unsupported cuda PBS type.")
|
||||
}
|
||||
break;
|
||||
case sizeof(uint64_t):
|
||||
} else if constexpr (std::is_same_v<Torus, uint64_t>) {
|
||||
// 64 bits
|
||||
switch (pbs_type) {
|
||||
case MULTI_BIT:
|
||||
@@ -379,10 +447,32 @@ void execute_scratch_pbs(cudaStream_t stream, uint32_t gpu_index,
|
||||
default:
|
||||
PANIC("Error: unsupported cuda PBS type.")
|
||||
}
|
||||
break;
|
||||
default:
|
||||
PANIC("Cuda error: unsupported modulus size: only 32 and 64 bit integer "
|
||||
"moduli are supported.")
|
||||
} else if constexpr (std::is_same_v<Torus, __uint128_t>) {
|
||||
// 128 bits
|
||||
switch (pbs_type) {
|
||||
case MULTI_BIT:
|
||||
if (grouping_factor == 0)
|
||||
PANIC("Multi-bit PBS error: grouping factor should be > 0.")
|
||||
size_tracker =
|
||||
scratch_cuda_multi_bit_programmable_bootstrap_128_vector_64(
|
||||
stream, gpu_index, pbs_buffer, glwe_dimension, polynomial_size,
|
||||
level_count, input_lwe_ciphertext_count, allocate_gpu_memory);
|
||||
break;
|
||||
case CLASSICAL:
|
||||
size_tracker = scratch_cuda_programmable_bootstrap_128(
|
||||
stream, gpu_index, pbs_buffer, lwe_dimension, glwe_dimension,
|
||||
polynomial_size, level_count, input_lwe_ciphertext_count,
|
||||
allocate_gpu_memory, noise_reduction_type);
|
||||
break;
|
||||
default:
|
||||
PANIC("Error: unsupported cuda PBS type.")
|
||||
}
|
||||
} else {
|
||||
static_assert(
|
||||
std::is_same_v<Torus, uint32_t> || std::is_same_v<Torus, uint64_t> ||
|
||||
std::is_same_v<Torus, __uint128_t>,
|
||||
"Cuda error: unsupported modulus size: only 32, 64, or 128-bit integer "
|
||||
"moduli are supported.");
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -1,45 +0,0 @@
|
||||
#ifndef CUDA_PROGRAMMABLE_BOOTSTRAP_128_CUH
|
||||
#define CUDA_PROGRAMMABLE_BOOTSTRAP_128_CUH
|
||||
#include "pbs/pbs_128_utilities.h"
|
||||
|
||||
static void execute_scratch_pbs_128(
|
||||
void *stream, uint32_t gpu_index, int8_t **pbs_buffer,
|
||||
uint32_t lwe_dimension, uint32_t glwe_dimension, uint32_t polynomial_size,
|
||||
uint32_t level_count, uint32_t input_lwe_ciphertext_count,
|
||||
bool allocate_gpu_memory, PBS_MS_REDUCTION_T noise_reduction_type,
|
||||
uint64_t &size_tracker_on_gpu) {
|
||||
// The squash noise function receives as input 64-bit integers
|
||||
size_tracker_on_gpu = scratch_cuda_programmable_bootstrap_128_vector_64(
|
||||
stream, gpu_index, pbs_buffer, lwe_dimension, glwe_dimension,
|
||||
polynomial_size, level_count, input_lwe_ciphertext_count,
|
||||
allocate_gpu_memory, noise_reduction_type);
|
||||
}
|
||||
template <typename Torus>
|
||||
static void execute_pbs_128_async(
|
||||
cudaStream_t const *streams, uint32_t const *gpu_indexes,
|
||||
uint32_t gpu_count, const LweArrayVariant<__uint128_t> &lwe_array_out,
|
||||
const std::vector<Torus *> lut_vector,
|
||||
const LweArrayVariant<uint64_t> &lwe_array_in,
|
||||
void *const *bootstrapping_keys,
|
||||
CudaModulusSwitchNoiseReductionKeyFFI const *ms_noise_reduction_key,
|
||||
std::vector<int8_t *> pbs_buffer, uint32_t lwe_dimension,
|
||||
uint32_t glwe_dimension, uint32_t polynomial_size, uint32_t base_log,
|
||||
uint32_t level_count, uint32_t num_samples) {
|
||||
|
||||
for (uint32_t i = 0; i < gpu_count; i++) {
|
||||
int num_inputs_on_gpu = get_num_inputs_on_gpu(num_samples, i, gpu_count);
|
||||
|
||||
Torus *current_lwe_array_out = GET_VARIANT_ELEMENT(lwe_array_out, i);
|
||||
uint64_t *current_lwe_array_in = GET_VARIANT_ELEMENT_64BIT(lwe_array_in, i);
|
||||
void *zeros = nullptr;
|
||||
if (ms_noise_reduction_key != nullptr)
|
||||
zeros = ms_noise_reduction_key->ptr[i];
|
||||
|
||||
cuda_programmable_bootstrap_lwe_ciphertext_vector_128(
|
||||
streams[i], gpu_indexes[i], current_lwe_array_out, lut_vector[i],
|
||||
current_lwe_array_in, bootstrapping_keys[i], ms_noise_reduction_key,
|
||||
zeros, pbs_buffer[i], lwe_dimension, glwe_dimension, polynomial_size,
|
||||
base_log, level_count, num_inputs_on_gpu);
|
||||
}
|
||||
}
|
||||
#endif
|
||||
|
||||
Reference in New Issue
Block a user