chore(bench): benchmark shortint ops against multi-bit parameters

This commit is contained in:
David Testé
2023-08-30 10:18:34 +02:00
committed by David Testé
parent bf36316c12
commit 65749cb39b
2 changed files with 162 additions and 91 deletions

View File

@@ -438,6 +438,15 @@ bench_shortint: install_rs_check_toolchain
--bench shortint-bench \
--features=$(TARGET_ARCH_FEATURE),shortint,internal-keycache,$(AVX512_FEATURE) -p tfhe
.PHONY: bench_shortint_multi_bit # Run benchmarks for shortint using multi-bit parameters
bench_shortint_multi_bit: install_rs_check_toolchain
RUSTFLAGS="$(RUSTFLAGS)" __TFHE_RS_BENCH_TYPE=MULTI_BIT \
__TFHE_RS_BENCH_OP_FLAVOR=$(BENCH_OP_FLAVOR) \
cargo $(CARGO_RS_CHECK_TOOLCHAIN) bench \
--bench shortint-bench \
--features=$(TARGET_ARCH_FEATURE),shortint,internal-keycache,$(AVX512_FEATURE) -p tfhe --
.PHONY: bench_boolean # Run benchmarks for boolean
bench_boolean: install_rs_check_toolchain
RUSTFLAGS="$(RUSTFLAGS)" cargo $(CARGO_RS_CHECK_TOOLCHAIN) bench \

View File

@@ -41,20 +41,59 @@ const SERVER_KEY_BENCH_PARAMS_EXTENDED: [ClassicPBSParameters; 15] = [
PARAM_MESSAGE_8_CARRY_0_KS_PBS,
];
const SERVER_KEY_MULTI_BIT_BENCH_PARAMS: [MultiBitPBSParameters; 2] = [
PARAM_MULTI_BIT_MESSAGE_2_CARRY_2_GROUP_2_KS_PBS,
PARAM_MULTI_BIT_MESSAGE_2_CARRY_2_GROUP_3_KS_PBS,
];
const SERVER_KEY_MULTI_BIT_BENCH_PARAMS_EXTENDED: [MultiBitPBSParameters; 6] = [
PARAM_MULTI_BIT_MESSAGE_1_CARRY_1_GROUP_2_KS_PBS,
PARAM_MULTI_BIT_MESSAGE_2_CARRY_2_GROUP_2_KS_PBS,
PARAM_MULTI_BIT_MESSAGE_3_CARRY_3_GROUP_2_KS_PBS,
PARAM_MULTI_BIT_MESSAGE_1_CARRY_1_GROUP_3_KS_PBS,
PARAM_MULTI_BIT_MESSAGE_2_CARRY_2_GROUP_3_KS_PBS,
PARAM_MULTI_BIT_MESSAGE_3_CARRY_3_GROUP_3_KS_PBS,
];
enum BenchParamsSet {
Standard,
Extended,
}
fn benchmark_parameters(params_set: BenchParamsSet) -> Vec<PBSParameters> {
let is_multi_bit = match env::var("__TFHE_RS_BENCH_TYPE") {
Ok(val) => val.to_lowercase() == "multi_bit",
Err(_) => false,
};
if is_multi_bit {
let params = match params_set {
BenchParamsSet::Standard => SERVER_KEY_MULTI_BIT_BENCH_PARAMS.to_vec(),
BenchParamsSet::Extended => SERVER_KEY_MULTI_BIT_BENCH_PARAMS_EXTENDED.to_vec(),
};
params.iter().map(|p| (*p).into()).collect()
} else {
let params = match params_set {
BenchParamsSet::Standard => SERVER_KEY_BENCH_PARAMS.to_vec(),
BenchParamsSet::Extended => SERVER_KEY_BENCH_PARAMS_EXTENDED.to_vec(),
};
params.iter().map(|p| (*p).into()).collect()
}
}
fn bench_server_key_unary_function<F>(
c: &mut Criterion,
bench_name: &str,
display_name: &str,
unary_op: F,
params: &[ClassicPBSParameters],
params_set: BenchParamsSet,
) where
F: Fn(&ServerKey, &mut Ciphertext),
{
let mut bench_group = c.benchmark_group(bench_name);
for param in params.iter() {
let param: PBSParameters = (*param).into();
let keys = KEY_CACHE.get_from_param(param);
for param in benchmark_parameters(params_set).iter() {
let keys = KEY_CACHE.get_from_param(*param);
let (cks, sks) = (keys.client_key(), keys.server_key());
let mut rng = rand::thread_rng();
@@ -74,7 +113,7 @@ fn bench_server_key_unary_function<F>(
write_to_json::<u64, _>(
&bench_id,
param,
*param,
param.name(),
display_name,
&OperatorType::Atomic,
@@ -91,15 +130,14 @@ fn bench_server_key_binary_function<F>(
bench_name: &str,
display_name: &str,
binary_op: F,
params: &[ClassicPBSParameters],
params_set: BenchParamsSet,
) where
F: Fn(&ServerKey, &mut Ciphertext, &mut Ciphertext),
{
let mut bench_group = c.benchmark_group(bench_name);
for param in params.iter() {
let param: PBSParameters = (*param).into();
let keys = KEY_CACHE.get_from_param(param);
for param in benchmark_parameters(params_set).iter() {
let keys = KEY_CACHE.get_from_param(*param);
let (cks, sks) = (keys.client_key(), keys.server_key());
let mut rng = rand::thread_rng();
@@ -121,7 +159,7 @@ fn bench_server_key_binary_function<F>(
write_to_json::<u64, _>(
&bench_id,
param,
*param,
param.name(),
display_name,
&OperatorType::Atomic,
@@ -138,15 +176,14 @@ fn bench_server_key_binary_scalar_function<F>(
bench_name: &str,
display_name: &str,
binary_op: F,
params: &[ClassicPBSParameters],
params_set: BenchParamsSet,
) where
F: Fn(&ServerKey, &mut Ciphertext, u8),
{
let mut bench_group = c.benchmark_group(bench_name);
for param in params {
let param: PBSParameters = (*param).into();
let keys = KEY_CACHE.get_from_param(param);
for param in benchmark_parameters(params_set).iter() {
let keys = KEY_CACHE.get_from_param(*param);
let (cks, sks) = (keys.client_key(), keys.server_key());
let mut rng = rand::thread_rng();
@@ -167,7 +204,7 @@ fn bench_server_key_binary_scalar_function<F>(
write_to_json::<u64, _>(
&bench_id,
param,
*param,
param.name(),
display_name,
&OperatorType::Atomic,
@@ -184,15 +221,14 @@ fn bench_server_key_binary_scalar_division_function<F>(
bench_name: &str,
display_name: &str,
binary_op: F,
params: &[ClassicPBSParameters],
params_set: BenchParamsSet,
) where
F: Fn(&ServerKey, &mut Ciphertext, u8),
{
let mut bench_group = c.benchmark_group(bench_name);
for param in params {
let param: PBSParameters = (*param).into();
let keys = KEY_CACHE.get_from_param(param);
for param in benchmark_parameters(params_set).iter() {
let keys = KEY_CACHE.get_from_param(*param);
let (cks, sks) = (keys.client_key(), keys.server_key());
let mut rng = rand::thread_rng();
@@ -217,7 +253,7 @@ fn bench_server_key_binary_scalar_division_function<F>(
write_to_json::<u64, _>(
&bench_id,
param,
*param,
param.name(),
display_name,
&OperatorType::Atomic,
@@ -229,12 +265,11 @@ fn bench_server_key_binary_scalar_division_function<F>(
bench_group.finish()
}
fn carry_extract(c: &mut Criterion) {
fn carry_extract_bench(c: &mut Criterion, params_set: BenchParamsSet) {
let mut bench_group = c.benchmark_group("carry_extract");
for param in SERVER_KEY_BENCH_PARAMS {
let param: PBSParameters = param.into();
let keys = KEY_CACHE.get_from_param(param);
for param in benchmark_parameters(params_set).iter() {
let keys = KEY_CACHE.get_from_param(*param);
let (cks, sks) = (keys.client_key(), keys.server_key());
let mut rng = rand::thread_rng();
@@ -254,7 +289,7 @@ fn carry_extract(c: &mut Criterion) {
write_to_json::<u64, _>(
&bench_id,
param,
*param,
param.name(),
"carry_extract",
&OperatorType::Atomic,
@@ -266,12 +301,11 @@ fn carry_extract(c: &mut Criterion) {
bench_group.finish()
}
fn programmable_bootstrapping(c: &mut Criterion) {
fn programmable_bootstrapping_bench(c: &mut Criterion, params_set: BenchParamsSet) {
let mut bench_group = c.benchmark_group("programmable_bootstrap");
for param in SERVER_KEY_BENCH_PARAMS {
let param: PBSParameters = param.into();
let keys = KEY_CACHE.get_from_param(param);
for param in benchmark_parameters(params_set).iter() {
let keys = KEY_CACHE.get_from_param(*param);
let (cks, sks) = (keys.client_key(), keys.server_key());
let mut rng = rand::thread_rng();
@@ -294,7 +328,7 @@ fn programmable_bootstrapping(c: &mut Criterion) {
write_to_json::<u64, _>(
&bench_id,
param,
*param,
param.name(),
"pbs",
&OperatorType::Atomic,
@@ -312,9 +346,18 @@ fn server_key_from_compressed_key(c: &mut Criterion) {
.sample_size(10)
.measurement_time(std::time::Duration::from_secs(60));
for param in SERVER_KEY_BENCH_PARAMS {
let param: PBSParameters = param.into();
let keys = KEY_CACHE.get_from_param(param);
let mut params = SERVER_KEY_BENCH_PARAMS_EXTENDED
.iter()
.map(|p| (*p).into())
.collect::<Vec<PBSParameters>>();
let multi_bit_params = SERVER_KEY_MULTI_BIT_BENCH_PARAMS_EXTENDED
.iter()
.map(|p| (*p).into())
.collect::<Vec<PBSParameters>>();
params.extend(&multi_bit_params);
for param in params.iter() {
let keys = KEY_CACHE.get_from_param(*param);
let sks_compressed = CompressedServerKey::new(&keys.client_key());
let bench_id = format!("shortint::uncompress_key::{}", param.name());
@@ -333,7 +376,7 @@ fn server_key_from_compressed_key(c: &mut Criterion) {
write_to_json::<u64, _>(
&bench_id,
param,
*param,
param.name(),
"uncompress_key",
&OperatorType::Atomic,
@@ -374,7 +417,7 @@ fn _bench_wopbs_param_message_8_norm2_5(c: &mut Criterion) {
}
macro_rules! define_server_key_unary_bench_fn (
(method_name:$server_key_method:ident, display_name:$name:ident, $params:expr) => {
(method_name:$server_key_method:ident, display_name:$name:ident, $params_set:expr) => {
fn $server_key_method(c: &mut Criterion) {
bench_server_key_unary_function(
c,
@@ -382,13 +425,13 @@ macro_rules! define_server_key_unary_bench_fn (
stringify!($name),
|server_key, lhs| {
let _ = server_key.$server_key_method(lhs);},
$params)
$params_set)
}
}
);
macro_rules! define_server_key_bench_fn (
(method_name:$server_key_method:ident, display_name:$name:ident, $params:expr) => {
(method_name:$server_key_method:ident, display_name:$name:ident, $params_set:expr) => {
fn $server_key_method(c: &mut Criterion) {
bench_server_key_binary_function(
c,
@@ -396,13 +439,13 @@ macro_rules! define_server_key_bench_fn (
stringify!($name),
|server_key, lhs, rhs| {
let _ = server_key.$server_key_method(lhs, rhs);},
$params)
$params_set)
}
}
);
macro_rules! define_server_key_scalar_bench_fn (
(method_name:$server_key_method:ident, display_name:$name:ident, $params:expr) => {
(method_name:$server_key_method:ident, display_name:$name:ident, $params_set:expr) => {
fn $server_key_method(c: &mut Criterion) {
bench_server_key_binary_scalar_function(
c,
@@ -410,13 +453,13 @@ macro_rules! define_server_key_scalar_bench_fn (
stringify!($name),
|server_key, lhs, rhs| {
let _ = server_key.$server_key_method(lhs, rhs);},
$params)
$params_set)
}
}
);
macro_rules! define_server_key_scalar_div_bench_fn (
(method_name:$server_key_method:ident, display_name:$name:ident, $params:expr) => {
(method_name:$server_key_method:ident, display_name:$name:ident, $params_set:expr) => {
fn $server_key_method(c: &mut Criterion) {
bench_server_key_binary_scalar_division_function(
c,
@@ -424,7 +467,19 @@ macro_rules! define_server_key_scalar_div_bench_fn (
stringify!($name),
|server_key, lhs, rhs| {
let _ = server_key.$server_key_method(lhs, rhs);},
$params)
$params_set)
}
}
);
macro_rules! define_custom_bench_fn (
(function_name:$function:ident, $params_set:expr) => {
fn $function(c: &mut Criterion) {
::paste::paste! {
[<$function _bench>](
c,
$params_set)
}
}
}
);
@@ -432,251 +487,258 @@ macro_rules! define_server_key_scalar_div_bench_fn (
define_server_key_unary_bench_fn!(
method_name: unchecked_neg,
display_name: negation,
&SERVER_KEY_BENCH_PARAMS
BenchParamsSet::Standard
);
define_server_key_bench_fn!(
method_name: unchecked_add,
display_name: add,
&SERVER_KEY_BENCH_PARAMS_EXTENDED
BenchParamsSet::Extended
);
define_server_key_bench_fn!(
method_name: unchecked_sub,
display_name: sub,
&SERVER_KEY_BENCH_PARAMS_EXTENDED
BenchParamsSet::Extended
);
define_server_key_bench_fn!(
method_name: unchecked_mul_lsb,
display_name: mul,
&SERVER_KEY_BENCH_PARAMS_EXTENDED
BenchParamsSet::Extended
);
define_server_key_bench_fn!(
method_name: unchecked_mul_msb,
display_name: mul,
&SERVER_KEY_BENCH_PARAMS
BenchParamsSet::Standard
);
define_server_key_bench_fn!(
method_name: unchecked_div,
display_name: div,
&SERVER_KEY_BENCH_PARAMS_EXTENDED
BenchParamsSet::Extended
);
define_server_key_bench_fn!(
method_name: smart_bitand,
display_name: bitand,
&SERVER_KEY_BENCH_PARAMS
BenchParamsSet::Standard
);
define_server_key_bench_fn!(
method_name: smart_bitor,
display_name: bitor,
&SERVER_KEY_BENCH_PARAMS
BenchParamsSet::Standard
);
define_server_key_bench_fn!(
method_name: smart_bitxor,
display_name: bitxor,
&SERVER_KEY_BENCH_PARAMS
BenchParamsSet::Standard
);
define_server_key_bench_fn!(
method_name: smart_add,
display_name: add,
&SERVER_KEY_BENCH_PARAMS
BenchParamsSet::Standard
);
define_server_key_bench_fn!(
method_name: smart_sub,
display_name: sub,
&SERVER_KEY_BENCH_PARAMS
BenchParamsSet::Standard
);
define_server_key_bench_fn!(
method_name: smart_mul_lsb,
display_name: mul,
&SERVER_KEY_BENCH_PARAMS
BenchParamsSet::Standard
);
define_server_key_bench_fn!(
method_name: bitand,
display_name: bitand,
&SERVER_KEY_BENCH_PARAMS
BenchParamsSet::Standard
);
define_server_key_bench_fn!(
method_name: bitor,
display_name: bitor,
&SERVER_KEY_BENCH_PARAMS
BenchParamsSet::Standard
);
define_server_key_bench_fn!(
method_name: bitxor,
display_name: bitxor,
&SERVER_KEY_BENCH_PARAMS
BenchParamsSet::Standard
);
define_server_key_bench_fn!(
method_name: add,
display_name: add,
&SERVER_KEY_BENCH_PARAMS
BenchParamsSet::Standard
);
define_server_key_bench_fn!(
method_name: sub,
display_name: sub,
&SERVER_KEY_BENCH_PARAMS
BenchParamsSet::Standard
);
define_server_key_bench_fn!(
method_name: mul,
display_name: mul,
&SERVER_KEY_BENCH_PARAMS
BenchParamsSet::Standard
);
define_server_key_bench_fn!(
method_name: div,
display_name: div,
&SERVER_KEY_BENCH_PARAMS
BenchParamsSet::Standard
);
define_server_key_bench_fn!(
method_name: greater,
display_name: greater,
&SERVER_KEY_BENCH_PARAMS
BenchParamsSet::Standard
);
define_server_key_bench_fn!(
method_name: greater_or_equal,
display_name: greater_or_equal,
&SERVER_KEY_BENCH_PARAMS
BenchParamsSet::Standard
);
define_server_key_bench_fn!(
method_name: less,
display_name: less,
&SERVER_KEY_BENCH_PARAMS
BenchParamsSet::Standard
);
define_server_key_bench_fn!(
method_name: less_or_equal,
display_name: less_or_equal,
&SERVER_KEY_BENCH_PARAMS
BenchParamsSet::Standard
);
define_server_key_bench_fn!(
method_name: equal,
display_name: equal,
&SERVER_KEY_BENCH_PARAMS
BenchParamsSet::Standard
);
define_server_key_bench_fn!(
method_name: not_equal,
display_name: not_equal,
&SERVER_KEY_BENCH_PARAMS
BenchParamsSet::Standard
);
define_server_key_unary_bench_fn!(
method_name: neg,
display_name: negation,
&SERVER_KEY_BENCH_PARAMS
BenchParamsSet::Standard
);
define_server_key_bench_fn!(
method_name: unchecked_greater,
display_name: greater_than,
&SERVER_KEY_BENCH_PARAMS
BenchParamsSet::Standard
);
define_server_key_bench_fn!(
method_name: unchecked_less,
display_name: less_than,
&SERVER_KEY_BENCH_PARAMS
BenchParamsSet::Standard
);
define_server_key_bench_fn!(
method_name: unchecked_equal,
display_name: equal,
&SERVER_KEY_BENCH_PARAMS
BenchParamsSet::Standard
);
define_server_key_scalar_bench_fn!(
method_name: unchecked_scalar_add,
display_name: add,
&SERVER_KEY_BENCH_PARAMS_EXTENDED
BenchParamsSet::Extended
);
define_server_key_scalar_bench_fn!(
method_name: unchecked_scalar_sub,
display_name: sub,
&SERVER_KEY_BENCH_PARAMS_EXTENDED
BenchParamsSet::Extended
);
define_server_key_scalar_bench_fn!(
method_name: unchecked_scalar_mul,
display_name: mul,
&SERVER_KEY_BENCH_PARAMS_EXTENDED
BenchParamsSet::Extended
);
define_server_key_scalar_bench_fn!(
method_name: unchecked_scalar_left_shift,
display_name: left_shift,
&SERVER_KEY_BENCH_PARAMS
BenchParamsSet::Standard
);
define_server_key_scalar_bench_fn!(
method_name: unchecked_scalar_right_shift,
display_name: right_shift,
&SERVER_KEY_BENCH_PARAMS
BenchParamsSet::Standard
);
define_server_key_scalar_div_bench_fn!(
method_name: unchecked_scalar_div,
display_name: div,
&SERVER_KEY_BENCH_PARAMS_EXTENDED
BenchParamsSet::Extended
);
define_server_key_scalar_div_bench_fn!(
method_name: unchecked_scalar_mod,
display_name: modulo,
&SERVER_KEY_BENCH_PARAMS
BenchParamsSet::Standard
);
define_server_key_scalar_bench_fn!(
method_name: scalar_add,
display_name: add,
&SERVER_KEY_BENCH_PARAMS
BenchParamsSet::Standard
);
define_server_key_scalar_bench_fn!(
method_name: scalar_sub,
display_name: sub,
&SERVER_KEY_BENCH_PARAMS
BenchParamsSet::Standard
);
define_server_key_scalar_bench_fn!(
method_name: scalar_mul,
display_name: mul,
&SERVER_KEY_BENCH_PARAMS
BenchParamsSet::Standard
);
define_server_key_scalar_bench_fn!(
method_name: scalar_left_shift,
display_name: left_shift,
&SERVER_KEY_BENCH_PARAMS
BenchParamsSet::Standard
);
define_server_key_scalar_bench_fn!(
method_name: scalar_right_shift,
display_name: right_shift,
&SERVER_KEY_BENCH_PARAMS
BenchParamsSet::Standard
);
define_server_key_scalar_div_bench_fn!(
method_name: scalar_div,
display_name: div,
&SERVER_KEY_BENCH_PARAMS
BenchParamsSet::Standard
);
define_server_key_scalar_div_bench_fn!(
method_name: scalar_mod,
display_name: modulo,
&SERVER_KEY_BENCH_PARAMS
BenchParamsSet::Standard
);
define_server_key_scalar_bench_fn!(
method_name: scalar_greater,
display_name: greater,
&SERVER_KEY_BENCH_PARAMS
BenchParamsSet::Standard
);
define_server_key_scalar_bench_fn!(
method_name: scalar_greater_or_equal,
display_name: greater_or_equal,
&SERVER_KEY_BENCH_PARAMS
BenchParamsSet::Standard
);
define_server_key_scalar_bench_fn!(
method_name: scalar_less,
display_name: less,
&SERVER_KEY_BENCH_PARAMS
BenchParamsSet::Standard
);
define_server_key_scalar_bench_fn!(
method_name: scalar_less_or_equal,
display_name: less_or_equal,
&SERVER_KEY_BENCH_PARAMS
BenchParamsSet::Standard
);
define_server_key_scalar_div_bench_fn!(
method_name: scalar_equal,
display_name: equal,
&SERVER_KEY_BENCH_PARAMS
BenchParamsSet::Standard
);
define_server_key_scalar_div_bench_fn!(
method_name: scalar_not_equal,
display_name: not_equal,
&SERVER_KEY_BENCH_PARAMS
BenchParamsSet::Standard
);
define_custom_bench_fn!(function_name: carry_extract, BenchParamsSet::Standard);
define_custom_bench_fn!(
function_name: programmable_bootstrapping,
BenchParamsSet::Standard
);
criterion_group!(