fix(gpu): refactor the (128-bit and regular) classical PBS entry point to remove the num_samples parameter

- fixes the throughput for those PBSs
- also fixes the throughput benchmark for regular PBSs
This commit is contained in:
Pedro Alves
2025-07-02 09:34:01 -03:00
committed by Pedro Alves
parent d955696fe0
commit 22ddba7145
9 changed files with 14 additions and 25 deletions

View File

@@ -640,7 +640,6 @@ mod cuda {
&cuda_indexes.d_lut,
&cuda_indexes.d_output,
&cuda_indexes.d_input,
LweCiphertextCount(1),
gpu_keys.bsk.as_ref().unwrap(),
&streams,
);
@@ -793,7 +792,6 @@ mod cuda {
&cuda_indexes_vec[i].d_lut,
&cuda_indexes_vec[i].d_output,
&cuda_indexes_vec[i].d_input,
LweCiphertextCount(1),
gpu_keys_vec[i].bsk.as_ref().unwrap(),
local_stream,
);

View File

@@ -302,7 +302,6 @@ mod cuda {
&lwe_ciphertext_in_gpu,
&mut out_pbs_ct_gpu,
&accumulator_gpu,
LweCiphertextCount(1),
gpu_keys.bsk.as_ref().unwrap(),
&streams,
);
@@ -398,12 +397,14 @@ mod cuda {
.zip(accumulators.par_iter())
.zip(local_streams.par_iter())
.for_each(
|((((i, input_ct), output_ct), accumulator), local_stream)| {
|(
(((i, input_batch), output_batch), accumulator),
local_stream,
)| {
cuda_programmable_bootstrap_128_lwe_ciphertext(
input_ct,
output_ct,
input_batch,
output_batch,
accumulator,
LweCiphertextCount(1),
gpu_keys_vec[i].bsk.as_ref().unwrap(),
local_stream,
);

View File

@@ -1031,7 +1031,6 @@ mod cuda {
&cuda_indexes.d_lut,
&cuda_indexes.d_output,
&cuda_indexes.d_input,
LweCiphertextCount(1),
gpu_keys.bsk.as_ref().unwrap(),
&streams,
);
@@ -1113,7 +1112,7 @@ mod cuda {
})
.collect::<Vec<_>>();
let h_indexes = (0..(elements / gpu_count as u64))
let h_indexes = (0..elements_per_stream as u64)
.map(CastFrom::cast_from)
.collect::<Vec<_>>();
let cuda_indexes_vec = (0..gpu_count)
@@ -1157,7 +1156,6 @@ mod cuda {
&cuda_indexes_vec[i].d_lut,
&cuda_indexes_vec[i].d_output,
&cuda_indexes_vec[i].d_input,
LweCiphertextCount(1),
gpu_keys_vec[i].bsk.as_ref().unwrap(),
local_stream,
);

View File

@@ -579,10 +579,12 @@ mod cuda_utils {
let mut d_input = unsafe { CudaVec::<T>::new_async(length, stream, stream_index) };
let mut d_output = unsafe { CudaVec::<T>::new_async(length, stream, stream_index) };
let mut d_lut = unsafe { CudaVec::<T>::new_async(length, stream, stream_index) };
let zeros = vec![T::ZERO; length];
unsafe {
d_input.copy_from_cpu_async(indexes.as_ref(), stream, stream_index);
d_output.copy_from_cpu_async(indexes.as_ref(), stream, stream_index);
d_lut.copy_from_cpu_async(indexes.as_ref(), stream, stream_index);
d_lut.copy_from_cpu_async(zeros.as_ref(), stream, stream_index);
}
stream.synchronize();