fix(gpu): refactor the (128-bit and regular) classical PBS entry point to remove the num_samples parameter

- fixes the throughput for those PBSs
- also fixes the throughput benchmark for regular PBSs
This commit is contained in:
Pedro Alves
2025-07-02 09:34:01 -03:00
committed by Pedro Alves
parent d955696fe0
commit 22ddba7145
9 changed files with 14 additions and 25 deletions

View File

@@ -640,7 +640,6 @@ mod cuda {
&cuda_indexes.d_lut,
&cuda_indexes.d_output,
&cuda_indexes.d_input,
LweCiphertextCount(1),
gpu_keys.bsk.as_ref().unwrap(),
&streams,
);
@@ -793,7 +792,6 @@ mod cuda {
&cuda_indexes_vec[i].d_lut,
&cuda_indexes_vec[i].d_output,
&cuda_indexes_vec[i].d_input,
LweCiphertextCount(1),
gpu_keys_vec[i].bsk.as_ref().unwrap(),
local_stream,
);

View File

@@ -302,7 +302,6 @@ mod cuda {
&lwe_ciphertext_in_gpu,
&mut out_pbs_ct_gpu,
&accumulator_gpu,
LweCiphertextCount(1),
gpu_keys.bsk.as_ref().unwrap(),
&streams,
);
@@ -398,12 +397,14 @@ mod cuda {
.zip(accumulators.par_iter())
.zip(local_streams.par_iter())
.for_each(
|((((i, input_ct), output_ct), accumulator), local_stream)| {
|(
(((i, input_batch), output_batch), accumulator),
local_stream,
)| {
cuda_programmable_bootstrap_128_lwe_ciphertext(
input_ct,
output_ct,
input_batch,
output_batch,
accumulator,
LweCiphertextCount(1),
gpu_keys_vec[i].bsk.as_ref().unwrap(),
local_stream,
);

View File

@@ -1031,7 +1031,6 @@ mod cuda {
&cuda_indexes.d_lut,
&cuda_indexes.d_output,
&cuda_indexes.d_input,
LweCiphertextCount(1),
gpu_keys.bsk.as_ref().unwrap(),
&streams,
);
@@ -1113,7 +1112,7 @@ mod cuda {
})
.collect::<Vec<_>>();
let h_indexes = (0..(elements / gpu_count as u64))
let h_indexes = (0..elements_per_stream as u64)
.map(CastFrom::cast_from)
.collect::<Vec<_>>();
let cuda_indexes_vec = (0..gpu_count)
@@ -1157,7 +1156,6 @@ mod cuda {
&cuda_indexes_vec[i].d_lut,
&cuda_indexes_vec[i].d_output,
&cuda_indexes_vec[i].d_input,
LweCiphertextCount(1),
gpu_keys_vec[i].bsk.as_ref().unwrap(),
local_stream,
);

View File

@@ -579,10 +579,12 @@ mod cuda_utils {
let mut d_input = unsafe { CudaVec::<T>::new_async(length, stream, stream_index) };
let mut d_output = unsafe { CudaVec::<T>::new_async(length, stream, stream_index) };
let mut d_lut = unsafe { CudaVec::<T>::new_async(length, stream, stream_index) };
let zeros = vec![T::ZERO; length];
unsafe {
d_input.copy_from_cpu_async(indexes.as_ref(), stream, stream_index);
d_output.copy_from_cpu_async(indexes.as_ref(), stream, stream_index);
d_lut.copy_from_cpu_async(indexes.as_ref(), stream, stream_index);
d_lut.copy_from_cpu_async(zeros.as_ref(), stream, stream_index);
}
stream.synchronize();

View File

@@ -5,7 +5,7 @@ use crate::core_crypto::gpu::vec::CudaVec;
use crate::core_crypto::gpu::{
programmable_bootstrap_128_async, programmable_bootstrap_async, CudaStreams,
};
use crate::core_crypto::prelude::{CastInto, LweCiphertextCount, UnsignedTorus};
use crate::core_crypto::prelude::{CastInto, UnsignedTorus};
/// # Safety
///
@@ -19,7 +19,6 @@ pub unsafe fn cuda_programmable_bootstrap_lwe_ciphertext_async<Scalar>(
lut_indexes: &CudaVec<Scalar>,
output_indexes: &CudaVec<Scalar>,
input_indexes: &CudaVec<Scalar>,
num_samples: LweCiphertextCount,
bsk: &CudaLweBootstrapKey,
streams: &CudaStreams,
) where
@@ -129,7 +128,7 @@ pub unsafe fn cuda_programmable_bootstrap_lwe_ciphertext_async<Scalar>(
let lwe_dimension = input.lwe_dimension();
let ct_modulus = input.ciphertext_modulus().raw_modulus_float();
let num_samples = input.lwe_ciphertext_count();
programmable_bootstrap_async(
streams,
&mut output.0.d_vec,
@@ -159,7 +158,6 @@ pub unsafe fn cuda_programmable_bootstrap_128_lwe_ciphertext_async<Scalar>(
input: &CudaLweCiphertextList<u64>,
output: &mut CudaLweCiphertextList<Scalar>,
accumulator: &CudaGlweCiphertextList<Scalar>,
num_samples: LweCiphertextCount,
bsk: &CudaLweBootstrapKey,
streams: &CudaStreams,
) where
@@ -255,6 +253,7 @@ pub unsafe fn cuda_programmable_bootstrap_128_lwe_ciphertext_async<Scalar>(
);
let lwe_dimension = input.lwe_dimension();
let ct_modulus = input.ciphertext_modulus().raw_modulus_float();
let num_samples = input.lwe_ciphertext_count();
programmable_bootstrap_128_async(
streams,
&mut output.0.d_vec,
@@ -280,7 +279,6 @@ pub fn cuda_programmable_bootstrap_lwe_ciphertext<Scalar>(
lut_indexes: &CudaVec<Scalar>,
output_indexes: &CudaVec<Scalar>,
input_indexes: &CudaVec<Scalar>,
num_samples: LweCiphertextCount,
bsk: &CudaLweBootstrapKey,
streams: &CudaStreams,
) where
@@ -294,7 +292,6 @@ pub fn cuda_programmable_bootstrap_lwe_ciphertext<Scalar>(
lut_indexes,
output_indexes,
input_indexes,
num_samples,
bsk,
streams,
);
@@ -318,7 +315,6 @@ pub fn cuda_programmable_bootstrap_128_lwe_ciphertext<Scalar>(
input: &CudaLweCiphertextList<u64>,
output: &mut CudaLweCiphertextList<Scalar>,
accumulator: &CudaGlweCiphertextList<Scalar>,
num_samples: LweCiphertextCount,
bsk: &CudaLweBootstrapKey,
streams: &CudaStreams,
) where
@@ -329,7 +325,6 @@ pub fn cuda_programmable_bootstrap_128_lwe_ciphertext<Scalar>(
input,
output,
accumulator,
num_samples,
bsk,
streams,
);

View File

@@ -153,7 +153,6 @@ fn lwe_encrypt_pbs_decrypt<
&d_test_vector_indexes,
&d_output_indexes,
&d_input_indexes,
LweCiphertextCount(num_blocks),
&d_bsk,
&stream,
);

View File

@@ -180,12 +180,10 @@ pub fn execute_bootstrap_u128(
let d_accumulator = CudaGlweCiphertextList::from_glwe_ciphertext(&accumulator, &stream);
let num_blocks = d_lwe_ciphertext_in.0.lwe_ciphertext_count.0;
cuda_programmable_bootstrap_128_lwe_ciphertext(
&d_lwe_ciphertext_in,
&mut d_out_pbs_ct,
&d_accumulator,
LweCiphertextCount(num_blocks),
&d_bsk,
&stream,
);

View File

@@ -181,7 +181,6 @@ where
&d_test_vector_indexes,
&d_output_indexes,
&d_input_indexes,
LweCiphertextCount(num_blocks),
&d_bsk,
&stream,
);

View File

@@ -25,7 +25,7 @@ use crate::core_crypto::commons::numeric::Numeric;
use crate::core_crypto::gpu::add_lwe_ciphertext_vector_plaintext_scalar_async;
use crate::core_crypto::gpu::glwe_ciphertext_list::CudaGlweCiphertextList;
use crate::core_crypto::prelude::CastInto;
use crate::integer::gpu::server_key::radix::{CudaLweCiphertextList, LweCiphertextCount};
use crate::integer::gpu::server_key::radix::CudaLweCiphertextList;
use crate::integer::gpu::CudaVec;
use itertools::Itertools;
@@ -574,7 +574,6 @@ impl CudaServerKey {
&d_lut_vector_indexes,
&d_output_indexes,
&d_input_indexes,
LweCiphertextCount(num_ct_blocks),
d_bsk,
streams,
);