mirror of
https://github.com/zama-ai/tfhe-rs.git
synced 2026-01-08 22:28:01 -05:00
fix(gpu): refactor the (128-bit and regular) classical PBS entry point to remove the num_samples parameter
- fixes the throughput for those PBSs - also fixes the throughput benchmark for regular PBSs
This commit is contained in:
@@ -640,7 +640,6 @@ mod cuda {
|
||||
&cuda_indexes.d_lut,
|
||||
&cuda_indexes.d_output,
|
||||
&cuda_indexes.d_input,
|
||||
LweCiphertextCount(1),
|
||||
gpu_keys.bsk.as_ref().unwrap(),
|
||||
&streams,
|
||||
);
|
||||
@@ -793,7 +792,6 @@ mod cuda {
|
||||
&cuda_indexes_vec[i].d_lut,
|
||||
&cuda_indexes_vec[i].d_output,
|
||||
&cuda_indexes_vec[i].d_input,
|
||||
LweCiphertextCount(1),
|
||||
gpu_keys_vec[i].bsk.as_ref().unwrap(),
|
||||
local_stream,
|
||||
);
|
||||
|
||||
@@ -302,7 +302,6 @@ mod cuda {
|
||||
&lwe_ciphertext_in_gpu,
|
||||
&mut out_pbs_ct_gpu,
|
||||
&accumulator_gpu,
|
||||
LweCiphertextCount(1),
|
||||
gpu_keys.bsk.as_ref().unwrap(),
|
||||
&streams,
|
||||
);
|
||||
@@ -398,12 +397,14 @@ mod cuda {
|
||||
.zip(accumulators.par_iter())
|
||||
.zip(local_streams.par_iter())
|
||||
.for_each(
|
||||
|((((i, input_ct), output_ct), accumulator), local_stream)| {
|
||||
|(
|
||||
(((i, input_batch), output_batch), accumulator),
|
||||
local_stream,
|
||||
)| {
|
||||
cuda_programmable_bootstrap_128_lwe_ciphertext(
|
||||
input_ct,
|
||||
output_ct,
|
||||
input_batch,
|
||||
output_batch,
|
||||
accumulator,
|
||||
LweCiphertextCount(1),
|
||||
gpu_keys_vec[i].bsk.as_ref().unwrap(),
|
||||
local_stream,
|
||||
);
|
||||
|
||||
@@ -1031,7 +1031,6 @@ mod cuda {
|
||||
&cuda_indexes.d_lut,
|
||||
&cuda_indexes.d_output,
|
||||
&cuda_indexes.d_input,
|
||||
LweCiphertextCount(1),
|
||||
gpu_keys.bsk.as_ref().unwrap(),
|
||||
&streams,
|
||||
);
|
||||
@@ -1113,7 +1112,7 @@ mod cuda {
|
||||
})
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
let h_indexes = (0..(elements / gpu_count as u64))
|
||||
let h_indexes = (0..elements_per_stream as u64)
|
||||
.map(CastFrom::cast_from)
|
||||
.collect::<Vec<_>>();
|
||||
let cuda_indexes_vec = (0..gpu_count)
|
||||
@@ -1157,7 +1156,6 @@ mod cuda {
|
||||
&cuda_indexes_vec[i].d_lut,
|
||||
&cuda_indexes_vec[i].d_output,
|
||||
&cuda_indexes_vec[i].d_input,
|
||||
LweCiphertextCount(1),
|
||||
gpu_keys_vec[i].bsk.as_ref().unwrap(),
|
||||
local_stream,
|
||||
);
|
||||
|
||||
@@ -579,10 +579,12 @@ mod cuda_utils {
|
||||
let mut d_input = unsafe { CudaVec::<T>::new_async(length, stream, stream_index) };
|
||||
let mut d_output = unsafe { CudaVec::<T>::new_async(length, stream, stream_index) };
|
||||
let mut d_lut = unsafe { CudaVec::<T>::new_async(length, stream, stream_index) };
|
||||
let zeros = vec![T::ZERO; length];
|
||||
|
||||
unsafe {
|
||||
d_input.copy_from_cpu_async(indexes.as_ref(), stream, stream_index);
|
||||
d_output.copy_from_cpu_async(indexes.as_ref(), stream, stream_index);
|
||||
d_lut.copy_from_cpu_async(indexes.as_ref(), stream, stream_index);
|
||||
d_lut.copy_from_cpu_async(zeros.as_ref(), stream, stream_index);
|
||||
}
|
||||
stream.synchronize();
|
||||
|
||||
|
||||
Reference in New Issue
Block a user