chore(bench): new parameters set to run core_crypto bench for docs

This creates extended parameters set to reflect what's displayed
in the documentation.
This commit is contained in:
David Testé
2025-10-27 16:46:49 +01:00
committed by David Testé
parent 70773e442c
commit b0b49ae533
8 changed files with 316 additions and 119 deletions

View File

@@ -11,7 +11,10 @@ on:
options: options:
- classical - classical
- multi_bit - multi_bit
- both - classical + multi_bit
- classical_documentation
- multi_bit_documentation
- classical_documentation + multi_bit_documentation
schedule: schedule:
# Weekly benchmarks will be triggered each Saturday at 5a.m. # Weekly benchmarks will be triggered each Saturday at 5a.m.
@@ -43,8 +46,10 @@ jobs:
- name: Set parameters types - name: Set parameters types
if: github.event_name == 'workflow_dispatch' if: github.event_name == 'workflow_dispatch'
run: | run: |
if [[ "${INPUTS_PARAM_TYPE}" == "both" ]]; then if [[ "${INPUTS_PARAM_TYPE}" == "classical + multi_bit" ]]; then
echo "PARAM_TYPE=[\"classical\", \"multi_bit\"]" >> "${GITHUB_ENV}" echo "PARAM_TYPE=[\"classical\", \"multi_bit\"]" >> "${GITHUB_ENV}"
elif [[ "${INPUTS_PARAM_TYPE}" == "classical_documentation + multi_bit_documentation" ]]; then
echo "PARAM_TYPE=[\"classical_documentation\", \"multi_bit_documentation\"]" >> "${GITHUB_ENV}"
else else
echo "PARAM_TYPE=[\"${INPUTS_PARAM_TYPE}\"]" >> "${GITHUB_ENV}" echo "PARAM_TYPE=[\"${INPUTS_PARAM_TYPE}\"]" >> "${GITHUB_ENV}"
fi fi

View File

@@ -59,7 +59,10 @@ on:
options: options:
- classical - classical
- multi_bit - multi_bit
- both - classical + multi_bit
- classical_documentation
- multi_bit_documentation
- classical_documentation + multi_bit_documentation
permissions: {} permissions: {}

View File

@@ -4,14 +4,13 @@ use benchmark::params::{
benchmark_compression_parameters, benchmark_parameters, multi_bit_benchmark_parameters, benchmark_compression_parameters, benchmark_parameters, multi_bit_benchmark_parameters,
}; };
use benchmark::utilities::{ use benchmark::utilities::{
get_bench_type, throughput_num_threads, write_to_json, BenchmarkType, CryptoParametersRecord, get_bench_type, get_param_type, throughput_num_threads, write_to_json, BenchmarkType,
OperatorType, CryptoParametersRecord, OperatorType, ParamType,
}; };
use criterion::{black_box, Criterion, Throughput}; use criterion::{black_box, Criterion, Throughput};
use itertools::Itertools; use itertools::Itertools;
use rayon::prelude::*; use rayon::prelude::*;
use serde::Serialize; use serde::Serialize;
use std::env;
use tfhe::core_crypto::prelude::*; use tfhe::core_crypto::prelude::*;
// TODO Refactor KS, PBS and KS-PBS benchmarks into a single generic function. // TODO Refactor KS, PBS and KS-PBS benchmarks into a single generic function.
@@ -788,6 +787,13 @@ mod cuda {
cuda_packing_keyswitch(&mut criterion, &benchmark_parameters()); cuda_packing_keyswitch(&mut criterion, &benchmark_parameters());
} }
pub fn cuda_ks_group_documentation() {
let mut criterion: Criterion<_> = (Criterion::default().sample_size(15))
.measurement_time(std::time::Duration::from_secs(60))
.configure_from_args();
cuda_keyswitch(&mut criterion, &benchmark_parameters());
}
pub fn cuda_multi_bit_ks_group() { pub fn cuda_multi_bit_ks_group() {
let mut criterion: Criterion<_> = let mut criterion: Criterion<_> =
(Criterion::default().sample_size(2000)).configure_from_args(); (Criterion::default().sample_size(2000)).configure_from_args();
@@ -798,10 +804,23 @@ mod cuda {
cuda_keyswitch(&mut criterion, &multi_bit_parameters); cuda_keyswitch(&mut criterion, &multi_bit_parameters);
cuda_packing_keyswitch(&mut criterion, &multi_bit_parameters); cuda_packing_keyswitch(&mut criterion, &multi_bit_parameters);
} }
pub fn cuda_multi_bit_ks_group_documentation() {
let mut criterion: Criterion<_> =
(Criterion::default().sample_size(2000)).configure_from_args();
let multi_bit_parameters = multi_bit_benchmark_parameters()
.into_iter()
.map(|(string, params, _)| (string, params))
.collect_vec();
cuda_keyswitch(&mut criterion, &multi_bit_parameters);
}
} }
#[cfg(feature = "gpu")] #[cfg(feature = "gpu")]
use cuda::{cuda_ks_group, cuda_multi_bit_ks_group}; use cuda::{
cuda_ks_group, cuda_ks_group_documentation, cuda_multi_bit_ks_group,
cuda_multi_bit_ks_group_documentation,
};
pub fn ks_group() { pub fn ks_group() {
let mut criterion: Criterion<_> = (Criterion::default() let mut criterion: Criterion<_> = (Criterion::default()
@@ -846,39 +865,32 @@ pub fn packing_ks_group() {
} }
#[cfg(feature = "gpu")] #[cfg(feature = "gpu")]
fn go_through_gpu_bench_groups(val: &str) { fn go_through_gpu_bench_groups() {
match val.to_lowercase().as_str() { match get_param_type() {
"classical" => cuda_ks_group(), ParamType::Classical => cuda_ks_group(),
"multi_bit" => cuda_multi_bit_ks_group(), ParamType::ClassicalDocumentation => cuda_ks_group_documentation(),
_ => panic!("unknown benchmark operations flavor"), ParamType::MultiBit => cuda_multi_bit_ks_group(),
ParamType::MultiBitDocumentation => cuda_multi_bit_ks_group_documentation(),
}; };
} }
#[cfg(not(feature = "gpu"))] #[cfg(not(feature = "gpu"))]
fn go_through_cpu_bench_groups(val: &str) { fn go_through_cpu_bench_groups() {
match val.to_lowercase().as_str() { match get_param_type() {
"classical" => { ParamType::Classical => {
ks_group(); ks_group();
packing_ks_group() packing_ks_group()
} }
"multi_bit" => multi_bit_ks_group(), ParamType::ClassicalDocumentation => ks_group(),
_ => panic!("unknown benchmark operations flavor"), ParamType::MultiBit | ParamType::MultiBitDocumentation => multi_bit_ks_group(),
} }
} }
fn main() { fn main() {
match env::var("__TFHE_RS_PARAM_TYPE") { #[cfg(feature = "gpu")]
Ok(val) => { go_through_gpu_bench_groups();
#[cfg(feature = "gpu")] #[cfg(not(feature = "gpu"))]
go_through_gpu_bench_groups(&val); go_through_cpu_bench_groups();
#[cfg(not(feature = "gpu"))]
go_through_cpu_bench_groups(&val);
}
Err(_) => {
ks_group();
packing_ks_group()
}
};
Criterion::default().configure_from_args().final_summary(); Criterion::default().configure_from_args().final_summary();
} }

View File

@@ -2,13 +2,12 @@ use benchmark::params::{
benchmark_parameters, multi_bit_benchmark_parameters_with_grouping, multi_bit_num_threads, benchmark_parameters, multi_bit_benchmark_parameters_with_grouping, multi_bit_num_threads,
}; };
use benchmark::utilities::{ use benchmark::utilities::{
get_bench_type, throughput_num_threads, write_to_json, BenchmarkType, CryptoParametersRecord, get_bench_type, get_param_type, throughput_num_threads, write_to_json, BenchmarkType,
OperatorType, CryptoParametersRecord, OperatorType, ParamType,
}; };
use criterion::{black_box, Criterion, Throughput}; use criterion::{black_box, Criterion, Throughput};
use rayon::prelude::*; use rayon::prelude::*;
use serde::Serialize; use serde::Serialize;
use std::env;
use tfhe::core_crypto::prelude::*; use tfhe::core_crypto::prelude::*;
// TODO Refactor KS, PBS and KS-PBS benchmarks into a single generic function. // TODO Refactor KS, PBS and KS-PBS benchmarks into a single generic function.
@@ -1147,11 +1146,6 @@ pub fn ks_pbs_group() {
pub fn multi_bit_ks_pbs_group() { pub fn multi_bit_ks_pbs_group() {
let mut criterion: Criterion<_> = (Criterion::default()).configure_from_args(); let mut criterion: Criterion<_> = (Criterion::default()).configure_from_args();
multi_bit_ks_pbs(
&mut criterion,
&multi_bit_benchmark_parameters_with_grouping(),
false,
);
multi_bit_ks_pbs( multi_bit_ks_pbs(
&mut criterion, &mut criterion,
&multi_bit_benchmark_parameters_with_grouping(), &multi_bit_benchmark_parameters_with_grouping(),
@@ -1160,36 +1154,26 @@ pub fn multi_bit_ks_pbs_group() {
} }
#[cfg(feature = "gpu")] #[cfg(feature = "gpu")]
fn go_through_gpu_bench_groups(val: &str) { fn go_through_gpu_bench_groups() {
match val.to_lowercase().as_str() { match get_param_type() {
"classical" => cuda_ks_pbs_group(), ParamType::Classical | ParamType::ClassicalDocumentation => cuda_ks_pbs_group(),
"multi_bit" => cuda_multi_bit_ks_pbs_group(), ParamType::MultiBit | ParamType::MultiBitDocumentation => cuda_multi_bit_ks_pbs_group(),
_ => panic!("unknown benchmark operations flavor"),
}; };
} }
#[cfg(not(feature = "gpu"))] #[cfg(not(feature = "gpu"))]
fn go_through_cpu_bench_groups(val: &str) { fn go_through_cpu_bench_groups() {
match val.to_lowercase().as_str() { match get_param_type() {
"classical" => ks_pbs_group(), ParamType::Classical | ParamType::ClassicalDocumentation => ks_pbs_group(),
"multi_bit" => multi_bit_ks_pbs_group(), ParamType::MultiBit | ParamType::MultiBitDocumentation => multi_bit_ks_pbs_group(),
_ => panic!("unknown benchmark operations flavor"),
} }
} }
fn main() { fn main() {
match env::var("__TFHE_RS_PARAM_TYPE") { #[cfg(feature = "gpu")]
Ok(val) => { go_through_gpu_bench_groups();
#[cfg(feature = "gpu")] #[cfg(not(feature = "gpu"))]
go_through_gpu_bench_groups(&val); go_through_cpu_bench_groups();
#[cfg(not(feature = "gpu"))]
go_through_cpu_bench_groups(&val);
}
Err(_) => {
ks_pbs_group();
multi_bit_ks_pbs_group()
}
};
Criterion::default().configure_from_args().final_summary(); Criterion::default().configure_from_args().final_summary();
} }

View File

@@ -3,13 +3,12 @@ use benchmark::params::{
multi_bit_benchmark_parameters_with_grouping, multi_bit_num_threads, multi_bit_benchmark_parameters_with_grouping, multi_bit_num_threads,
}; };
use benchmark::utilities::{ use benchmark::utilities::{
get_bench_type, throughput_num_threads, write_to_json, BenchmarkType, CryptoParametersRecord, get_bench_type, get_param_type, throughput_num_threads, write_to_json, BenchmarkType,
OperatorType, CryptoParametersRecord, OperatorType, ParamType,
}; };
use criterion::{black_box, Criterion, Throughput}; use criterion::{black_box, Criterion, Throughput};
use rayon::prelude::*; use rayon::prelude::*;
use serde::Serialize; use serde::Serialize;
use std::env;
use tfhe::core_crypto::commons::math::ntt::ntt64::Ntt64; use tfhe::core_crypto::commons::math::ntt::ntt64::Ntt64;
use tfhe::core_crypto::prelude::*; use tfhe::core_crypto::prelude::*;
@@ -1461,6 +1460,11 @@ pub fn pbs_group() {
mem_optimized_batched_pbs(&mut criterion, &benchmark_parameters()); mem_optimized_batched_pbs(&mut criterion, &benchmark_parameters());
} }
pub fn pbs_group_documentation() {
let mut criterion: Criterion<_> = (Criterion::default()).configure_from_args();
mem_optimized_pbs(&mut criterion, &benchmark_parameters());
}
pub fn multi_bit_pbs_group() { pub fn multi_bit_pbs_group() {
let mut criterion: Criterion<_> = (Criterion::default()).configure_from_args(); let mut criterion: Criterion<_> = (Criterion::default()).configure_from_args();
multi_bit_pbs( multi_bit_pbs(
@@ -1475,37 +1479,40 @@ pub fn multi_bit_pbs_group() {
); );
} }
pub fn multi_bit_pbs_group_documentation() {
let mut criterion: Criterion<_> = (Criterion::default()).configure_from_args();
multi_bit_pbs(
&mut criterion,
&multi_bit_benchmark_parameters_with_grouping(),
true,
);
}
#[cfg(feature = "gpu")] #[cfg(feature = "gpu")]
fn go_through_gpu_bench_groups(val: &str) { fn go_through_gpu_bench_groups() {
match val.to_lowercase().as_str() { match get_param_type() {
"classical" => cuda_pbs_group(), ParamType::Classical => cuda_pbs_group(),
"multi_bit" => cuda_multi_bit_pbs_group(), ParamType::ClassicalDocumentation => cuda_pbs_group(),
_ => panic!("unknown benchmark operations flavor"), ParamType::MultiBit => cuda_multi_bit_pbs_group(),
ParamType::MultiBitDocumentation => cuda_multi_bit_pbs_group(),
}; };
} }
#[cfg(not(feature = "gpu"))] #[cfg(not(feature = "gpu"))]
fn go_through_cpu_bench_groups(val: &str) { fn go_through_cpu_bench_groups() {
match val.to_lowercase().as_str() { match get_param_type() {
"classical" => pbs_group(), ParamType::Classical => pbs_group(),
"multi_bit" => multi_bit_pbs_group(), ParamType::ClassicalDocumentation => pbs_group_documentation(),
_ => panic!("unknown benchmark operations flavor"), ParamType::MultiBit => multi_bit_pbs_group(),
ParamType::MultiBitDocumentation => multi_bit_pbs_group_documentation(),
} }
} }
fn main() { fn main() {
match env::var("__TFHE_RS_PARAM_TYPE") { #[cfg(feature = "gpu")]
Ok(val) => { go_through_gpu_bench_groups();
#[cfg(feature = "gpu")] #[cfg(not(feature = "gpu"))]
go_through_gpu_bench_groups(&val); go_through_cpu_bench_groups();
#[cfg(not(feature = "gpu"))]
go_through_cpu_bench_groups(&val);
}
Err(_) => {
pbs_group();
multi_bit_pbs_group()
}
};
Criterion::default().configure_from_args().final_summary(); Criterion::default().configure_from_args().final_summary();
} }

View File

@@ -49,6 +49,17 @@ pub mod shortint_params {
BENCH_PARAM_MESSAGE_4_CARRY_4_KS_PBS_GAUSSIAN_2M128, BENCH_PARAM_MESSAGE_4_CARRY_4_KS_PBS_GAUSSIAN_2M128,
]; ];
pub const SHORTINT_BENCH_PARAMS_TUNIFORM_DOCUMENTATION: [ClassicPBSParameters; 8] = [
BENCH_PARAM_MESSAGE_1_CARRY_1_KS_PBS_TUNIFORM_2M64,
BENCH_PARAM_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M64,
BENCH_PARAM_MESSAGE_3_CARRY_3_KS_PBS_TUNIFORM_2M64,
BENCH_PARAM_MESSAGE_4_CARRY_4_KS_PBS_TUNIFORM_2M64,
BENCH_PARAM_MESSAGE_1_CARRY_1_KS_PBS_TUNIFORM_2M128,
BENCH_PARAM_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M128,
BENCH_PARAM_MESSAGE_3_CARRY_3_KS_PBS_TUNIFORM_2M128,
BENCH_PARAM_MESSAGE_4_CARRY_4_KS_PBS_TUNIFORM_2M128,
];
#[cfg(feature = "gpu")] #[cfg(feature = "gpu")]
pub const SHORTINT_MULTI_BIT_BENCH_PARAMS: [MultiBitPBSParameters; 6] = [ pub const SHORTINT_MULTI_BIT_BENCH_PARAMS: [MultiBitPBSParameters; 6] = [
BENCH_PARAM_GPU_MULTI_BIT_GROUP_4_MESSAGE_1_CARRY_1_KS_PBS_TUNIFORM_2M128, BENCH_PARAM_GPU_MULTI_BIT_GROUP_4_MESSAGE_1_CARRY_1_KS_PBS_TUNIFORM_2M128,
@@ -59,6 +70,34 @@ pub mod shortint_params {
BENCH_PARAM_GPU_MULTI_BIT_GROUP_4_MESSAGE_3_CARRY_3_KS_PBS_GAUSSIAN_2M128, BENCH_PARAM_GPU_MULTI_BIT_GROUP_4_MESSAGE_3_CARRY_3_KS_PBS_GAUSSIAN_2M128,
]; ];
#[cfg(feature = "gpu")]
pub const SHORTINT_MULTI_BIT_BENCH_PARAMS_DOCUMENTATION: [(&str, MultiBitPBSParameters); 6] = [
(
"BENCH_PARAM_GPU_MULTI_BIT_GROUP_4_MESSAGE_1_CARRY_1_KS_PBS_TUNIFORM_2M64",
BENCH_PARAM_GPU_MULTI_BIT_GROUP_4_MESSAGE_1_CARRY_1_KS_PBS_TUNIFORM_2M64,
),
(
"BENCH_PARAM_GPU_MULTI_BIT_GROUP_4_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M64",
BENCH_PARAM_GPU_MULTI_BIT_GROUP_4_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M64,
),
(
"BENCH_PARAM_GPU_MULTI_BIT_GROUP_4_MESSAGE_3_CARRY_3_KS_PBS_TUNIFORM_2M64",
BENCH_PARAM_GPU_MULTI_BIT_GROUP_4_MESSAGE_3_CARRY_3_KS_PBS_TUNIFORM_2M64,
),
(
"BENCH_PARAM_GPU_MULTI_BIT_GROUP_4_MESSAGE_1_CARRY_1_KS_PBS_TUNIFORM_2M128",
BENCH_PARAM_GPU_MULTI_BIT_GROUP_4_MESSAGE_1_CARRY_1_KS_PBS_TUNIFORM_2M128,
),
(
"BENCH_PARAM_GPU_MULTI_BIT_GROUP_4_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M128",
BENCH_PARAM_GPU_MULTI_BIT_GROUP_4_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M128,
),
(
"BENCH_PARAM_GPU_MULTI_BIT_GROUP_4_MESSAGE_3_CARRY_3_KS_PBS_TUNIFORM_2M128",
BENCH_PARAM_GPU_MULTI_BIT_GROUP_4_MESSAGE_3_CARRY_3_KS_PBS_TUNIFORM_2M128,
),
];
#[cfg(not(feature = "gpu"))] #[cfg(not(feature = "gpu"))]
pub const SHORTINT_MULTI_BIT_BENCH_PARAMS: [MultiBitPBSParameters; 6] = [ pub const SHORTINT_MULTI_BIT_BENCH_PARAMS: [MultiBitPBSParameters; 6] = [
BENCH_PARAM_MULTI_BIT_GROUP_2_MESSAGE_1_CARRY_1_KS_PBS_GAUSSIAN_2M128, BENCH_PARAM_MULTI_BIT_GROUP_2_MESSAGE_1_CARRY_1_KS_PBS_GAUSSIAN_2M128,
@@ -69,26 +108,78 @@ pub mod shortint_params {
BENCH_PARAM_MULTI_BIT_GROUP_3_MESSAGE_3_CARRY_3_KS_PBS_GAUSSIAN_2M128, BENCH_PARAM_MULTI_BIT_GROUP_3_MESSAGE_3_CARRY_3_KS_PBS_GAUSSIAN_2M128,
]; ];
#[cfg(not(feature = "gpu"))]
pub const SHORTINT_MULTI_BIT_BENCH_PARAMS_DOCUMENTATION: [(&str, MultiBitPBSParameters); 8] = [
// Message_1_carry_1 2M64 and 2M128 are exactly the same, so we run one variant only
// otherwise we would get a panic due to unicity rules of benchmark IDs/
(
"BENCH_PARAM_MULTI_BIT_GROUP_4_MESSAGE_1_CARRY_1_KS_PBS_TUNIFORM_2M64",
BENCH_PARAM_MULTI_BIT_GROUP_4_MESSAGE_1_CARRY_1_KS_PBS_TUNIFORM_2M64,
),
(
"BENCH_PARAM_MULTI_BIT_GROUP_4_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M64",
BENCH_PARAM_MULTI_BIT_GROUP_4_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M64,
),
(
"BENCH_PARAM_MULTI_BIT_GROUP_4_MESSAGE_3_CARRY_3_KS_PBS_TUNIFORM_2M64",
BENCH_PARAM_MULTI_BIT_GROUP_4_MESSAGE_3_CARRY_3_KS_PBS_TUNIFORM_2M64,
),
(
"BENCH_PARAM_MULTI_BIT_GROUP_4_MESSAGE_4_CARRY_4_KS_PBS_TUNIFORM_2M64",
BENCH_PARAM_MULTI_BIT_GROUP_4_MESSAGE_4_CARRY_4_KS_PBS_TUNIFORM_2M64,
),
(
"BENCH_PARAM_MULTI_BIT_GROUP_4_MESSAGE_1_CARRY_1_KS_PBS_TUNIFORM_2M128",
BENCH_PARAM_MULTI_BIT_GROUP_4_MESSAGE_1_CARRY_1_KS_PBS_TUNIFORM_2M128,
),
(
"BENCH_PARAM_MULTI_BIT_GROUP_4_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M128",
BENCH_PARAM_MULTI_BIT_GROUP_4_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M128,
),
(
"BENCH_PARAM_MULTI_BIT_GROUP_4_MESSAGE_3_CARRY_3_KS_PBS_TUNIFORM_2M128",
BENCH_PARAM_MULTI_BIT_GROUP_4_MESSAGE_3_CARRY_3_KS_PBS_TUNIFORM_2M128,
),
(
"BENCH_PARAM_MULTI_BIT_GROUP_4_MESSAGE_4_CARRY_4_KS_PBS_TUNIFORM_2M128",
BENCH_PARAM_MULTI_BIT_GROUP_4_MESSAGE_4_CARRY_4_KS_PBS_TUNIFORM_2M128,
),
];
#[cfg(feature = "internal-keycache")] #[cfg(feature = "internal-keycache")]
pub mod shortint_params_keycache { pub mod shortint_params_keycache {
use super::*; use super::*;
use crate::utilities::CryptoParametersRecord; use crate::utilities::{get_param_type, CryptoParametersRecord, ParamType};
use tfhe::keycache::NamedParam; use tfhe::keycache::NamedParam;
pub fn benchmark_parameters() -> Vec<(String, CryptoParametersRecord<u64>)> { pub fn benchmark_parameters() -> Vec<(String, CryptoParametersRecord<u64>)> {
match get_parameters_set() { match get_parameters_set() {
ParametersSet::Default => SHORTINT_BENCH_PARAMS_TUNIFORM ParametersSet::Default => {
.iter() let iterator = match get_param_type() {
.chain(SHORTINT_BENCH_PARAMS_GAUSSIAN.iter()) ParamType::ClassicalDocumentation => {
.map(|params| { SHORTINT_BENCH_PARAMS_TUNIFORM_DOCUMENTATION
( .iter()
params.name(), .chain([].iter()) // Use an empty iterator to return the same type
<ClassicPBSParameters as Into<AtomicPatternParameters>>::into(*params) // as the fallback case
}
_ => SHORTINT_BENCH_PARAMS_TUNIFORM
.iter()
.chain(SHORTINT_BENCH_PARAMS_GAUSSIAN.iter()),
};
iterator
.map(|params| {
(
params.name(),
<ClassicPBSParameters as Into<AtomicPatternParameters>>::into(
*params,
)
.to_owned() .to_owned()
.into(), .into(),
) )
}) })
.collect(), .collect()
}
ParametersSet::All => { ParametersSet::All => {
filter_parameters( filter_parameters(
&BENCH_ALL_CLASSIC_PBS_PARAMETERS, &BENCH_ALL_CLASSIC_PBS_PARAMETERS,
@@ -124,18 +215,38 @@ pub mod shortint_params {
pub fn multi_bit_benchmark_parameters( pub fn multi_bit_benchmark_parameters(
) -> Vec<(String, CryptoParametersRecord<u64>, LweBskGroupingFactor)> { ) -> Vec<(String, CryptoParametersRecord<u64>, LweBskGroupingFactor)> {
match get_parameters_set() { match get_parameters_set() {
ParametersSet::Default => SHORTINT_MULTI_BIT_BENCH_PARAMS ParametersSet::Default => match get_param_type() {
.iter() ParamType::MultiBitDocumentation => {
.map(|params| { SHORTINT_MULTI_BIT_BENCH_PARAMS_DOCUMENTATION
( .iter()
params.name(), .map(|(name, params)| {
<MultiBitPBSParameters as Into<AtomicPatternParameters>>::into(*params) (
name.to_string(),
<MultiBitPBSParameters as Into<AtomicPatternParameters>>::into(
*params,
)
.to_owned()
.into(),
params.grouping_factor,
)
})
.collect()
}
_ => SHORTINT_MULTI_BIT_BENCH_PARAMS
.iter()
.map(|params| {
(
params.name(),
<MultiBitPBSParameters as Into<AtomicPatternParameters>>::into(
*params,
)
.to_owned() .to_owned()
.into(), .into(),
params.grouping_factor, params.grouping_factor,
) )
}) })
.collect(), .collect(),
},
ParametersSet::All => { ParametersSet::All => {
let desired_backend = if cfg!(feature = "gpu") { let desired_backend = if cfg!(feature = "gpu") {
DesiredBackend::Gpu DesiredBackend::Gpu
@@ -165,18 +276,38 @@ pub mod shortint_params {
pub fn multi_bit_benchmark_parameters_with_grouping( pub fn multi_bit_benchmark_parameters_with_grouping(
) -> Vec<(String, CryptoParametersRecord<u64>, LweBskGroupingFactor)> { ) -> Vec<(String, CryptoParametersRecord<u64>, LweBskGroupingFactor)> {
match get_parameters_set() { match get_parameters_set() {
ParametersSet::Default => SHORTINT_MULTI_BIT_BENCH_PARAMS ParametersSet::Default => match get_param_type() {
.iter() ParamType::MultiBitDocumentation => {
.map(|params| { SHORTINT_MULTI_BIT_BENCH_PARAMS_DOCUMENTATION
( .iter()
params.name(), .map(|(name, params)| {
<MultiBitPBSParameters as Into<AtomicPatternParameters>>::into(*params) (
name.to_string(),
<MultiBitPBSParameters as Into<AtomicPatternParameters>>::into(
*params,
)
.to_owned()
.into(),
params.grouping_factor,
)
})
.collect()
}
_ => SHORTINT_MULTI_BIT_BENCH_PARAMS
.iter()
.map(|params| {
(
params.name(),
<MultiBitPBSParameters as Into<AtomicPatternParameters>>::into(
*params,
)
.to_owned() .to_owned()
.into(), .into(),
params.grouping_factor, params.grouping_factor,
) )
}) })
.collect(), .collect(),
},
ParametersSet::All => { ParametersSet::All => {
let desired_backend = if cfg!(feature = "gpu") { let desired_backend = if cfg!(feature = "gpu") {
DesiredBackend::Gpu DesiredBackend::Gpu

View File

@@ -19,6 +19,14 @@ pub mod shortint_params_aliases {
V1_5_PARAM_MESSAGE_4_CARRY_4_KS_PBS_GAUSSIAN_2M128; V1_5_PARAM_MESSAGE_4_CARRY_4_KS_PBS_GAUSSIAN_2M128;
// KS PBS TUniform // KS PBS TUniform
pub const BENCH_PARAM_MESSAGE_1_CARRY_1_KS_PBS_TUNIFORM_2M64: ClassicPBSParameters =
V1_5_PARAM_MESSAGE_1_CARRY_1_KS_PBS_TUNIFORM_2M64;
pub const BENCH_PARAM_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M64: ClassicPBSParameters =
V1_5_PARAM_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M64;
pub const BENCH_PARAM_MESSAGE_3_CARRY_3_KS_PBS_TUNIFORM_2M64: ClassicPBSParameters =
V1_5_PARAM_MESSAGE_3_CARRY_3_KS_PBS_TUNIFORM_2M64;
pub const BENCH_PARAM_MESSAGE_4_CARRY_4_KS_PBS_TUNIFORM_2M64: ClassicPBSParameters =
V1_5_PARAM_MESSAGE_4_CARRY_4_KS_PBS_TUNIFORM_2M64;
pub const BENCH_PARAM_MESSAGE_1_CARRY_1_KS_PBS_TUNIFORM_2M128: ClassicPBSParameters = pub const BENCH_PARAM_MESSAGE_1_CARRY_1_KS_PBS_TUNIFORM_2M128: ClassicPBSParameters =
V1_5_PARAM_MESSAGE_1_CARRY_1_KS_PBS_TUNIFORM_2M128; V1_5_PARAM_MESSAGE_1_CARRY_1_KS_PBS_TUNIFORM_2M128;
pub const BENCH_PARAM_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M128: ClassicPBSParameters = pub const BENCH_PARAM_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M128: ClassicPBSParameters =
@@ -82,6 +90,14 @@ pub mod shortint_params_aliases {
MultiBitPBSParameters = MultiBitPBSParameters =
V1_5_PARAM_MULTI_BIT_GROUP_3_MESSAGE_4_CARRY_4_KS_PBS_TUNIFORM_2M128; V1_5_PARAM_MULTI_BIT_GROUP_3_MESSAGE_4_CARRY_4_KS_PBS_TUNIFORM_2M128;
// --- Grouping factor 4 // --- Grouping factor 4
pub const BENCH_PARAM_MULTI_BIT_GROUP_4_MESSAGE_1_CARRY_1_KS_PBS_TUNIFORM_2M64:
MultiBitPBSParameters = V1_5_PARAM_MULTI_BIT_GROUP_4_MESSAGE_1_CARRY_1_KS_PBS_TUNIFORM_2M64;
pub const BENCH_PARAM_MULTI_BIT_GROUP_4_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M64:
MultiBitPBSParameters = V1_5_PARAM_MULTI_BIT_GROUP_4_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M64;
pub const BENCH_PARAM_MULTI_BIT_GROUP_4_MESSAGE_3_CARRY_3_KS_PBS_TUNIFORM_2M64:
MultiBitPBSParameters = V1_5_PARAM_MULTI_BIT_GROUP_4_MESSAGE_3_CARRY_3_KS_PBS_TUNIFORM_2M64;
pub const BENCH_PARAM_MULTI_BIT_GROUP_4_MESSAGE_4_CARRY_4_KS_PBS_TUNIFORM_2M64:
MultiBitPBSParameters = V1_5_PARAM_MULTI_BIT_GROUP_4_MESSAGE_4_CARRY_4_KS_PBS_TUNIFORM_2M64;
pub const BENCH_PARAM_MULTI_BIT_GROUP_4_MESSAGE_1_CARRY_1_KS_PBS_TUNIFORM_2M128: pub const BENCH_PARAM_MULTI_BIT_GROUP_4_MESSAGE_1_CARRY_1_KS_PBS_TUNIFORM_2M128:
MultiBitPBSParameters = MultiBitPBSParameters =
V1_5_PARAM_MULTI_BIT_GROUP_4_MESSAGE_1_CARRY_1_KS_PBS_TUNIFORM_2M128; V1_5_PARAM_MULTI_BIT_GROUP_4_MESSAGE_1_CARRY_1_KS_PBS_TUNIFORM_2M128;
@@ -135,6 +151,18 @@ pub mod shortint_params_aliases {
MultiBitPBSParameters = MultiBitPBSParameters =
V1_5_PARAM_GPU_MULTI_BIT_GROUP_3_MESSAGE_4_CARRY_4_KS_PBS_TUNIFORM_2M128; V1_5_PARAM_GPU_MULTI_BIT_GROUP_3_MESSAGE_4_CARRY_4_KS_PBS_TUNIFORM_2M128;
// --- Grouping factor 4 // --- Grouping factor 4
pub const BENCH_PARAM_GPU_MULTI_BIT_GROUP_4_MESSAGE_1_CARRY_1_KS_PBS_TUNIFORM_2M64:
MultiBitPBSParameters =
V1_5_PARAM_GPU_MULTI_BIT_GROUP_4_MESSAGE_1_CARRY_1_KS_PBS_TUNIFORM_2M64;
pub const BENCH_PARAM_GPU_MULTI_BIT_GROUP_4_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M64:
MultiBitPBSParameters =
V1_5_PARAM_GPU_MULTI_BIT_GROUP_4_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M64;
pub const BENCH_PARAM_GPU_MULTI_BIT_GROUP_4_MESSAGE_3_CARRY_3_KS_PBS_TUNIFORM_2M64:
MultiBitPBSParameters =
V1_5_PARAM_GPU_MULTI_BIT_GROUP_4_MESSAGE_3_CARRY_3_KS_PBS_TUNIFORM_2M64;
pub const BENCH_PARAM_GPU_MULTI_BIT_GROUP_4_MESSAGE_4_CARRY_4_KS_PBS_TUNIFORM_2M64:
MultiBitPBSParameters =
V1_5_PARAM_GPU_MULTI_BIT_GROUP_4_MESSAGE_4_CARRY_4_KS_PBS_TUNIFORM_2M64;
pub const BENCH_PARAM_GPU_MULTI_BIT_GROUP_4_MESSAGE_1_CARRY_1_KS_PBS_TUNIFORM_2M128: pub const BENCH_PARAM_GPU_MULTI_BIT_GROUP_4_MESSAGE_1_CARRY_1_KS_PBS_TUNIFORM_2M128:
MultiBitPBSParameters = MultiBitPBSParameters =
V1_5_PARAM_GPU_MULTI_BIT_GROUP_4_MESSAGE_1_CARRY_1_KS_PBS_TUNIFORM_2M128; V1_5_PARAM_GPU_MULTI_BIT_GROUP_4_MESSAGE_1_CARRY_1_KS_PBS_TUNIFORM_2M128;

View File

@@ -361,10 +361,10 @@ pub struct EnvConfig {
impl EnvConfig { impl EnvConfig {
pub fn new() -> Self { pub fn new() -> Self {
let is_multi_bit = match env::var("__TFHE_RS_PARAM_TYPE") { let is_multi_bit = matches!(
Ok(val) => val.to_lowercase() == "multi_bit", get_param_type(),
Err(_) => false, ParamType::MultiBit | ParamType::MultiBitDocumentation
}; );
let is_fast_bench = match env::var("__TFHE_RS_FAST_BENCH") { let is_fast_bench = match env::var("__TFHE_RS_FAST_BENCH") {
Ok(val) => val.to_lowercase() == "true", Ok(val) => val.to_lowercase() == "true",
@@ -417,6 +417,33 @@ pub fn get_bench_type() -> &'static BenchmarkType {
BENCH_TYPE.get_or_init(|| BenchmarkType::from_env().unwrap()) BENCH_TYPE.get_or_init(|| BenchmarkType::from_env().unwrap())
} }
pub static PARAM_TYPE: OnceLock<ParamType> = OnceLock::new();
pub enum ParamType {
Classical,
MultiBit,
// Variants dedicated to documentation illustration.
ClassicalDocumentation,
MultiBitDocumentation,
}
impl ParamType {
pub fn from_env() -> Result<Self, String> {
let raw_value = env::var("__TFHE_RS_PARAM_TYPE").unwrap_or("classical".to_string());
match raw_value.to_lowercase().as_str() {
"classical" => Ok(ParamType::Classical),
"multi_bit" => Ok(ParamType::MultiBit),
"classical_documentation" => Ok(ParamType::ClassicalDocumentation),
"multi_bit_documentation" => Ok(ParamType::MultiBitDocumentation),
_ => Err(format!("parameters type '{raw_value}' is not supported")),
}
}
}
pub fn get_param_type() -> &'static ParamType {
PARAM_TYPE.get_or_init(|| ParamType::from_env().unwrap())
}
/// Generate a number of threads to use to saturate current machine for throughput measurements. /// Generate a number of threads to use to saturate current machine for throughput measurements.
pub fn throughput_num_threads(num_block: usize, op_pbs_count: u64) -> u64 { pub fn throughput_num_threads(num_block: usize, op_pbs_count: u64) -> u64 {
let ref_block_count = 32; // Represent a ciphertext of 64 bits for 2_2 parameters set let ref_block_count = 32; // Represent a ciphertext of 64 bits for 2_2 parameters set