fix(gpu): fix pbs and ks benchmarks

This commit is contained in:
Pedro Alves
2025-05-20 15:31:08 +02:00
committed by Pedro Alves
parent 2571196b41
commit 259d125434
2 changed files with 18 additions and 4 deletions

View File

@@ -8,6 +8,7 @@ use benchmark::utilities::{
OperatorType,
};
use criterion::{black_box, Criterion, Throughput};
use itertools::Itertools;
use rayon::prelude::*;
use serde::Serialize;
use std::env;
@@ -325,6 +326,7 @@ mod cuda {
CudaLocalKeys, OperatorType,
};
use criterion::{black_box, Criterion, Throughput};
use itertools::Itertools;
use rayon::prelude::*;
use serde::Serialize;
use tfhe::core_crypto::gpu::glwe_ciphertext_list::CudaGlweCiphertextList;
@@ -750,8 +752,12 @@ mod cuda {
pub fn cuda_multi_bit_ks_group() {
let mut criterion: Criterion<_> =
(Criterion::default().sample_size(2000)).configure_from_args();
cuda_keyswitch(&mut criterion, &multi_bit_benchmark_parameters());
cuda_packing_keyswitch(&mut criterion, &multi_bit_benchmark_parameters());
let multi_bit_parameters = multi_bit_benchmark_parameters()
.into_iter()
.map(|(string, params, _)| (string, params))
.collect_vec();
cuda_keyswitch(&mut criterion, &multi_bit_parameters);
cuda_packing_keyswitch(&mut criterion, &multi_bit_parameters);
}
}
@@ -769,11 +775,16 @@ pub fn ks_group() {
}
pub fn multi_bit_ks_group() {
let multi_bit_parameters = multi_bit_benchmark_parameters()
.into_iter()
.map(|(string, params, _)| (string, params))
.collect_vec();
let mut criterion: Criterion<_> = (Criterion::default()
.sample_size(15)
.measurement_time(std::time::Duration::from_secs(60)))
.configure_from_args();
keyswitch(&mut criterion, &multi_bit_benchmark_parameters());
keyswitch(&mut criterion, &multi_bit_parameters);
}
pub fn packing_ks_group() {

View File

@@ -121,7 +121,8 @@ pub mod shortint_params {
)]
}
pub fn multi_bit_benchmark_parameters() -> Vec<(String, CryptoParametersRecord<u64>)> {
pub fn multi_bit_benchmark_parameters(
) -> Vec<(String, CryptoParametersRecord<u64>, LweBskGroupingFactor)> {
match get_parameters_set() {
ParametersSet::Default => SHORTINT_MULTI_BIT_BENCH_PARAMS
.iter()
@@ -131,6 +132,7 @@ pub mod shortint_params {
<MultiBitPBSParameters as Into<AtomicPatternParameters>>::into(*params)
.to_owned()
.into(),
params.grouping_factor,
)
})
.collect(),
@@ -152,6 +154,7 @@ pub mod shortint_params {
<MultiBitPBSParameters as Into<AtomicPatternParameters>>::into(*params)
.to_owned()
.into(),
params.grouping_factor,
)
})
.collect()