mirror of
https://github.com/zama-ai/tfhe-rs.git
synced 2026-01-09 22:57:59 -05:00
fix(gpu): refactor the (128-bit and regular) classical PBS entry point to remove the num_samples parameter
- fixes the throughput for those PBSs - also fixes the throughput benchmark for regular PBSs
This commit is contained in:
@@ -640,7 +640,6 @@ mod cuda {
|
||||
&cuda_indexes.d_lut,
|
||||
&cuda_indexes.d_output,
|
||||
&cuda_indexes.d_input,
|
||||
LweCiphertextCount(1),
|
||||
gpu_keys.bsk.as_ref().unwrap(),
|
||||
&streams,
|
||||
);
|
||||
@@ -793,7 +792,6 @@ mod cuda {
|
||||
&cuda_indexes_vec[i].d_lut,
|
||||
&cuda_indexes_vec[i].d_output,
|
||||
&cuda_indexes_vec[i].d_input,
|
||||
LweCiphertextCount(1),
|
||||
gpu_keys_vec[i].bsk.as_ref().unwrap(),
|
||||
local_stream,
|
||||
);
|
||||
|
||||
@@ -302,7 +302,6 @@ mod cuda {
|
||||
&lwe_ciphertext_in_gpu,
|
||||
&mut out_pbs_ct_gpu,
|
||||
&accumulator_gpu,
|
||||
LweCiphertextCount(1),
|
||||
gpu_keys.bsk.as_ref().unwrap(),
|
||||
&streams,
|
||||
);
|
||||
@@ -398,12 +397,14 @@ mod cuda {
|
||||
.zip(accumulators.par_iter())
|
||||
.zip(local_streams.par_iter())
|
||||
.for_each(
|
||||
|((((i, input_ct), output_ct), accumulator), local_stream)| {
|
||||
|(
|
||||
(((i, input_batch), output_batch), accumulator),
|
||||
local_stream,
|
||||
)| {
|
||||
cuda_programmable_bootstrap_128_lwe_ciphertext(
|
||||
input_ct,
|
||||
output_ct,
|
||||
input_batch,
|
||||
output_batch,
|
||||
accumulator,
|
||||
LweCiphertextCount(1),
|
||||
gpu_keys_vec[i].bsk.as_ref().unwrap(),
|
||||
local_stream,
|
||||
);
|
||||
|
||||
@@ -1031,7 +1031,6 @@ mod cuda {
|
||||
&cuda_indexes.d_lut,
|
||||
&cuda_indexes.d_output,
|
||||
&cuda_indexes.d_input,
|
||||
LweCiphertextCount(1),
|
||||
gpu_keys.bsk.as_ref().unwrap(),
|
||||
&streams,
|
||||
);
|
||||
@@ -1113,7 +1112,7 @@ mod cuda {
|
||||
})
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
let h_indexes = (0..(elements / gpu_count as u64))
|
||||
let h_indexes = (0..elements_per_stream as u64)
|
||||
.map(CastFrom::cast_from)
|
||||
.collect::<Vec<_>>();
|
||||
let cuda_indexes_vec = (0..gpu_count)
|
||||
@@ -1157,7 +1156,6 @@ mod cuda {
|
||||
&cuda_indexes_vec[i].d_lut,
|
||||
&cuda_indexes_vec[i].d_output,
|
||||
&cuda_indexes_vec[i].d_input,
|
||||
LweCiphertextCount(1),
|
||||
gpu_keys_vec[i].bsk.as_ref().unwrap(),
|
||||
local_stream,
|
||||
);
|
||||
|
||||
@@ -579,10 +579,12 @@ mod cuda_utils {
|
||||
let mut d_input = unsafe { CudaVec::<T>::new_async(length, stream, stream_index) };
|
||||
let mut d_output = unsafe { CudaVec::<T>::new_async(length, stream, stream_index) };
|
||||
let mut d_lut = unsafe { CudaVec::<T>::new_async(length, stream, stream_index) };
|
||||
let zeros = vec![T::ZERO; length];
|
||||
|
||||
unsafe {
|
||||
d_input.copy_from_cpu_async(indexes.as_ref(), stream, stream_index);
|
||||
d_output.copy_from_cpu_async(indexes.as_ref(), stream, stream_index);
|
||||
d_lut.copy_from_cpu_async(indexes.as_ref(), stream, stream_index);
|
||||
d_lut.copy_from_cpu_async(zeros.as_ref(), stream, stream_index);
|
||||
}
|
||||
stream.synchronize();
|
||||
|
||||
|
||||
@@ -5,7 +5,7 @@ use crate::core_crypto::gpu::vec::CudaVec;
|
||||
use crate::core_crypto::gpu::{
|
||||
programmable_bootstrap_128_async, programmable_bootstrap_async, CudaStreams,
|
||||
};
|
||||
use crate::core_crypto::prelude::{CastInto, LweCiphertextCount, UnsignedTorus};
|
||||
use crate::core_crypto::prelude::{CastInto, UnsignedTorus};
|
||||
|
||||
/// # Safety
|
||||
///
|
||||
@@ -19,7 +19,6 @@ pub unsafe fn cuda_programmable_bootstrap_lwe_ciphertext_async<Scalar>(
|
||||
lut_indexes: &CudaVec<Scalar>,
|
||||
output_indexes: &CudaVec<Scalar>,
|
||||
input_indexes: &CudaVec<Scalar>,
|
||||
num_samples: LweCiphertextCount,
|
||||
bsk: &CudaLweBootstrapKey,
|
||||
streams: &CudaStreams,
|
||||
) where
|
||||
@@ -129,7 +128,7 @@ pub unsafe fn cuda_programmable_bootstrap_lwe_ciphertext_async<Scalar>(
|
||||
|
||||
let lwe_dimension = input.lwe_dimension();
|
||||
let ct_modulus = input.ciphertext_modulus().raw_modulus_float();
|
||||
|
||||
let num_samples = input.lwe_ciphertext_count();
|
||||
programmable_bootstrap_async(
|
||||
streams,
|
||||
&mut output.0.d_vec,
|
||||
@@ -159,7 +158,6 @@ pub unsafe fn cuda_programmable_bootstrap_128_lwe_ciphertext_async<Scalar>(
|
||||
input: &CudaLweCiphertextList<u64>,
|
||||
output: &mut CudaLweCiphertextList<Scalar>,
|
||||
accumulator: &CudaGlweCiphertextList<Scalar>,
|
||||
num_samples: LweCiphertextCount,
|
||||
bsk: &CudaLweBootstrapKey,
|
||||
streams: &CudaStreams,
|
||||
) where
|
||||
@@ -255,6 +253,7 @@ pub unsafe fn cuda_programmable_bootstrap_128_lwe_ciphertext_async<Scalar>(
|
||||
);
|
||||
let lwe_dimension = input.lwe_dimension();
|
||||
let ct_modulus = input.ciphertext_modulus().raw_modulus_float();
|
||||
let num_samples = input.lwe_ciphertext_count();
|
||||
programmable_bootstrap_128_async(
|
||||
streams,
|
||||
&mut output.0.d_vec,
|
||||
@@ -280,7 +279,6 @@ pub fn cuda_programmable_bootstrap_lwe_ciphertext<Scalar>(
|
||||
lut_indexes: &CudaVec<Scalar>,
|
||||
output_indexes: &CudaVec<Scalar>,
|
||||
input_indexes: &CudaVec<Scalar>,
|
||||
num_samples: LweCiphertextCount,
|
||||
bsk: &CudaLweBootstrapKey,
|
||||
streams: &CudaStreams,
|
||||
) where
|
||||
@@ -294,7 +292,6 @@ pub fn cuda_programmable_bootstrap_lwe_ciphertext<Scalar>(
|
||||
lut_indexes,
|
||||
output_indexes,
|
||||
input_indexes,
|
||||
num_samples,
|
||||
bsk,
|
||||
streams,
|
||||
);
|
||||
@@ -318,7 +315,6 @@ pub fn cuda_programmable_bootstrap_128_lwe_ciphertext<Scalar>(
|
||||
input: &CudaLweCiphertextList<u64>,
|
||||
output: &mut CudaLweCiphertextList<Scalar>,
|
||||
accumulator: &CudaGlweCiphertextList<Scalar>,
|
||||
num_samples: LweCiphertextCount,
|
||||
bsk: &CudaLweBootstrapKey,
|
||||
streams: &CudaStreams,
|
||||
) where
|
||||
@@ -329,7 +325,6 @@ pub fn cuda_programmable_bootstrap_128_lwe_ciphertext<Scalar>(
|
||||
input,
|
||||
output,
|
||||
accumulator,
|
||||
num_samples,
|
||||
bsk,
|
||||
streams,
|
||||
);
|
||||
|
||||
@@ -153,7 +153,6 @@ fn lwe_encrypt_pbs_decrypt<
|
||||
&d_test_vector_indexes,
|
||||
&d_output_indexes,
|
||||
&d_input_indexes,
|
||||
LweCiphertextCount(num_blocks),
|
||||
&d_bsk,
|
||||
&stream,
|
||||
);
|
||||
|
||||
@@ -180,12 +180,10 @@ pub fn execute_bootstrap_u128(
|
||||
|
||||
let d_accumulator = CudaGlweCiphertextList::from_glwe_ciphertext(&accumulator, &stream);
|
||||
|
||||
let num_blocks = d_lwe_ciphertext_in.0.lwe_ciphertext_count.0;
|
||||
cuda_programmable_bootstrap_128_lwe_ciphertext(
|
||||
&d_lwe_ciphertext_in,
|
||||
&mut d_out_pbs_ct,
|
||||
&d_accumulator,
|
||||
LweCiphertextCount(num_blocks),
|
||||
&d_bsk,
|
||||
&stream,
|
||||
);
|
||||
|
||||
@@ -181,7 +181,6 @@ where
|
||||
&d_test_vector_indexes,
|
||||
&d_output_indexes,
|
||||
&d_input_indexes,
|
||||
LweCiphertextCount(num_blocks),
|
||||
&d_bsk,
|
||||
&stream,
|
||||
);
|
||||
|
||||
@@ -25,7 +25,7 @@ use crate::core_crypto::commons::numeric::Numeric;
|
||||
use crate::core_crypto::gpu::add_lwe_ciphertext_vector_plaintext_scalar_async;
|
||||
use crate::core_crypto::gpu::glwe_ciphertext_list::CudaGlweCiphertextList;
|
||||
use crate::core_crypto::prelude::CastInto;
|
||||
use crate::integer::gpu::server_key::radix::{CudaLweCiphertextList, LweCiphertextCount};
|
||||
use crate::integer::gpu::server_key::radix::CudaLweCiphertextList;
|
||||
use crate::integer::gpu::CudaVec;
|
||||
use itertools::Itertools;
|
||||
|
||||
@@ -574,7 +574,6 @@ impl CudaServerKey {
|
||||
&d_lut_vector_indexes,
|
||||
&d_output_indexes,
|
||||
&d_input_indexes,
|
||||
LweCiphertextCount(num_ct_blocks),
|
||||
d_bsk,
|
||||
streams,
|
||||
);
|
||||
|
||||
Reference in New Issue
Block a user