mirror of
https://github.com/zama-ai/tfhe-rs.git
synced 2026-01-09 14:47:56 -05:00
chore(gpu): bench KS latency batches
This commit is contained in:
@@ -386,7 +386,7 @@ mod cuda {
|
||||
.keyswitch_key(ksk_big_to_small)
|
||||
.build();
|
||||
|
||||
let bench_id;
|
||||
let mut bench_id;
|
||||
|
||||
match get_bench_type() {
|
||||
BenchmarkType::Latency => {
|
||||
@@ -429,6 +429,72 @@ mod cuda {
|
||||
})
|
||||
});
|
||||
}
|
||||
|
||||
|
||||
for elements_per_stream in (2..=32u64)
|
||||
{
|
||||
let plaintext_list = PlaintextList::new(
|
||||
Scalar::ZERO,
|
||||
PlaintextCount(elements_per_stream as usize),
|
||||
);
|
||||
|
||||
let mut input_ct_list = LweCiphertextList::new(
|
||||
Scalar::ZERO,
|
||||
big_lwe_sk.lwe_dimension().to_lwe_size(),
|
||||
LweCiphertextCount(elements_per_stream as usize),
|
||||
params.ciphertext_modulus.unwrap(),
|
||||
);
|
||||
encrypt_lwe_ciphertext_list(
|
||||
&big_lwe_sk,
|
||||
&mut input_ct_list,
|
||||
&plaintext_list,
|
||||
params.lwe_noise_distribution.unwrap(),
|
||||
&mut encryption_generator,
|
||||
);
|
||||
let input_ks_list = LweCiphertextList::from_container(
|
||||
input_ct_list.into_container(),
|
||||
big_lwe_sk.lwe_dimension().to_lwe_size(),
|
||||
params.ciphertext_modulus.unwrap(),
|
||||
);
|
||||
let input_ct_list_gpu = CudaLweCiphertextList::from_lwe_ciphertext_list(
|
||||
&input_ks_list,
|
||||
&streams,
|
||||
);
|
||||
|
||||
let output_ct_list = LweCiphertextList::new(
|
||||
Scalar::ZERO,
|
||||
lwe_sk.lwe_dimension().to_lwe_size(),
|
||||
LweCiphertextCount(elements_per_stream as usize),
|
||||
params.ciphertext_modulus.unwrap(),
|
||||
);
|
||||
let mut output_ct_list_gpu = CudaLweCiphertextList::from_lwe_ciphertext_list(
|
||||
&output_ct_list,
|
||||
&streams,
|
||||
);
|
||||
|
||||
let h_indexes = (0..elements_per_stream)
|
||||
.map(CastFrom::cast_from)
|
||||
.collect::<Vec<_>>();
|
||||
let cuda_indexes_vec = CudaIndexes::new(&h_indexes, &streams, 0);
|
||||
|
||||
bench_id = format!("{bench_name}::{elements_per_stream}::{name}");
|
||||
{
|
||||
bench_group.bench_function(&bench_id, |b| {
|
||||
b.iter(|| {
|
||||
cuda_keyswitch_lwe_ciphertext(
|
||||
gpu_keys.ksk.as_ref().unwrap(),
|
||||
&input_ct_list_gpu,
|
||||
&mut output_ct_list_gpu,
|
||||
&cuda_indexes_vec.d_input,
|
||||
&cuda_indexes_vec.d_output,
|
||||
&streams,
|
||||
);
|
||||
|
||||
black_box(&mut output_ct_list_gpu);
|
||||
})
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
BenchmarkType::Throughput => {
|
||||
let gpu_keys_vec = cuda_local_keys_core(&cpu_keys, None);
|
||||
|
||||
Reference in New Issue
Block a user