chore(gpu): bench KS latency batches

This commit is contained in:
Andrei Stoian
2025-11-06 18:39:20 +01:00
parent c3017341bd
commit 97f8ee3a84

View File

@@ -386,7 +386,7 @@ mod cuda {
.keyswitch_key(ksk_big_to_small)
.build();
let bench_id;
let mut bench_id;
match get_bench_type() {
BenchmarkType::Latency => {
@@ -429,6 +429,72 @@ mod cuda {
})
});
}
for elements_per_stream in (2..=32u64)
{
let plaintext_list = PlaintextList::new(
Scalar::ZERO,
PlaintextCount(elements_per_stream as usize),
);
let mut input_ct_list = LweCiphertextList::new(
Scalar::ZERO,
big_lwe_sk.lwe_dimension().to_lwe_size(),
LweCiphertextCount(elements_per_stream as usize),
params.ciphertext_modulus.unwrap(),
);
encrypt_lwe_ciphertext_list(
&big_lwe_sk,
&mut input_ct_list,
&plaintext_list,
params.lwe_noise_distribution.unwrap(),
&mut encryption_generator,
);
let input_ks_list = LweCiphertextList::from_container(
input_ct_list.into_container(),
big_lwe_sk.lwe_dimension().to_lwe_size(),
params.ciphertext_modulus.unwrap(),
);
let input_ct_list_gpu = CudaLweCiphertextList::from_lwe_ciphertext_list(
&input_ks_list,
&streams,
);
let output_ct_list = LweCiphertextList::new(
Scalar::ZERO,
lwe_sk.lwe_dimension().to_lwe_size(),
LweCiphertextCount(elements_per_stream as usize),
params.ciphertext_modulus.unwrap(),
);
let mut output_ct_list_gpu = CudaLweCiphertextList::from_lwe_ciphertext_list(
&output_ct_list,
&streams,
);
let h_indexes = (0..elements_per_stream)
.map(CastFrom::cast_from)
.collect::<Vec<_>>();
let cuda_indexes_vec = CudaIndexes::new(&h_indexes, &streams, 0);
bench_id = format!("{bench_name}::{elements_per_stream}::{name}");
{
bench_group.bench_function(&bench_id, |b| {
b.iter(|| {
cuda_keyswitch_lwe_ciphertext(
gpu_keys.ksk.as_ref().unwrap(),
&input_ct_list_gpu,
&mut output_ct_list_gpu,
&cuda_indexes_vec.d_input,
&cuda_indexes_vec.d_output,
&streams,
);
black_box(&mut output_ct_list_gpu);
})
});
}
}
}
BenchmarkType::Throughput => {
let gpu_keys_vec = cuda_local_keys_core(&cpu_keys, None);