mirror of
https://github.com/zama-ai/tfhe-rs.git
synced 2026-01-09 14:47:56 -05:00
chore(bench): code refactor and automation for hlapi
This commit is contained in:
1
.gitignore
vendored
1
.gitignore
vendored
@@ -10,6 +10,7 @@ target/
|
||||
**/*.rmeta
|
||||
**/Cargo.lock
|
||||
**/*.bin
|
||||
**/.DS_Store
|
||||
|
||||
# Some of our bench outputs
|
||||
/tfhe/benchmarks_parameters
|
||||
|
||||
1
tfhe-benchmark/.gitignore
vendored
Normal file
1
tfhe-benchmark/.gitignore
vendored
Normal file
@@ -0,0 +1 @@
|
||||
benchmarks_parameters/*
|
||||
@@ -11,34 +11,42 @@ use tfhe::keycache::NamedParam;
|
||||
use tfhe::named::Named;
|
||||
use tfhe::prelude::*;
|
||||
use tfhe::{
|
||||
ClientKey, CompressedServerKey, FheIntegerType, FheUint10, FheUint12, FheUint128, FheUint14,
|
||||
FheUint16, FheUint2, FheUint32, FheUint4, FheUint6, FheUint64, FheUint8, FheUintId, IntegerId,
|
||||
KVStore,
|
||||
ClientKey, CompressedServerKey, FheIntegerType, FheUint, FheUint10, FheUint12, FheUint128,
|
||||
FheUint14, FheUint16, FheUint2, FheUint32, FheUint4, FheUint6, FheUint64, FheUint8, FheUintId,
|
||||
IntegerId, KVStore,
|
||||
};
|
||||
|
||||
use rayon::prelude::*;
|
||||
|
||||
fn bench_fhe_type<FheType>(
|
||||
trait BenchWait {
|
||||
fn wait_bench(&self);
|
||||
}
|
||||
|
||||
impl<Id: FheUintId> BenchWait for FheUint<Id> {
|
||||
fn wait_bench(&self) {
|
||||
self.wait()
|
||||
}
|
||||
}
|
||||
|
||||
impl<T1: FheWait, T2> BenchWait for (T1, T2) {
|
||||
fn wait_bench(&self) {
|
||||
self.0.wait()
|
||||
}
|
||||
}
|
||||
|
||||
fn bench_fhe_type_op<FheType, F, R>(
|
||||
c: &mut Criterion,
|
||||
client_key: &ClientKey,
|
||||
type_name: &str,
|
||||
bit_size: usize,
|
||||
display_name: &str,
|
||||
func_name: &str,
|
||||
func: F,
|
||||
) where
|
||||
F: Fn(&FheType, &FheType) -> R,
|
||||
R: BenchWait,
|
||||
FheType: FheEncrypt<u128, ClientKey>,
|
||||
FheType: FheWait,
|
||||
for<'a> &'a FheType: Add<&'a FheType, Output = FheType>
|
||||
+ Sub<&'a FheType, Output = FheType>
|
||||
+ Mul<&'a FheType, Output = FheType>
|
||||
+ BitAnd<&'a FheType, Output = FheType>
|
||||
+ BitOr<&'a FheType, Output = FheType>
|
||||
+ BitXor<&'a FheType, Output = FheType>
|
||||
+ Shl<&'a FheType, Output = FheType>
|
||||
+ Shr<&'a FheType, Output = FheType>
|
||||
+ RotateLeft<&'a FheType, Output = FheType>
|
||||
+ RotateRight<&'a FheType, Output = FheType>
|
||||
+ OverflowingAdd<&'a FheType, Output = FheType>
|
||||
+ OverflowingSub<&'a FheType, Output = FheType>,
|
||||
for<'a> FheType: FheMin<&'a FheType, Output = FheType> + FheMax<&'a FheType, Output = FheType>,
|
||||
{
|
||||
let mut bench_group = c.benchmark_group(type_name);
|
||||
let mut bench_prefix = "hlapi".to_string();
|
||||
@@ -71,170 +79,90 @@ fn bench_fhe_type<FheType>(
|
||||
let lhs = FheType::encrypt(rng.gen(), client_key);
|
||||
let rhs = FheType::encrypt(rng.gen(), client_key);
|
||||
|
||||
let mut bench_id;
|
||||
let bench_id = format!("{bench_prefix}::{func_name}::{param_name}::{type_name}");
|
||||
|
||||
bench_id = format!("{bench_prefix}::add::{param_name}::{type_name}");
|
||||
bench_group.bench_function(&bench_id, |b| {
|
||||
b.iter(|| {
|
||||
let res = &lhs + &rhs;
|
||||
res.wait();
|
||||
let res = func(&lhs, &rhs);
|
||||
res.wait_bench();
|
||||
black_box(res)
|
||||
})
|
||||
});
|
||||
write_record(bench_id, "add");
|
||||
|
||||
bench_id = format!("{bench_prefix}::overflowing_add::{param_name}::{type_name}");
|
||||
bench_group.bench_function(&bench_id, |b| {
|
||||
b.iter(|| {
|
||||
let (res, flag) = lhs.overflowing_add(&rhs);
|
||||
res.wait();
|
||||
black_box((res, flag))
|
||||
})
|
||||
});
|
||||
write_record(bench_id, "overflowing_add");
|
||||
|
||||
bench_id = format!("{bench_prefix}::overflowing_sub::{param_name}::{type_name}");
|
||||
bench_group.bench_function(&bench_id, |b| {
|
||||
b.iter(|| {
|
||||
let (res, flag) = lhs.overflowing_sub(&rhs);
|
||||
res.wait();
|
||||
black_box((res, flag))
|
||||
})
|
||||
});
|
||||
write_record(bench_id, "overflowing_sub");
|
||||
|
||||
bench_id = format!("{bench_prefix}::sub::{param_name}::{type_name}");
|
||||
bench_group.bench_function(&bench_id, |b| {
|
||||
b.iter(|| {
|
||||
let res = &lhs - &rhs;
|
||||
res.wait();
|
||||
black_box(res)
|
||||
})
|
||||
});
|
||||
write_record(bench_id, "sub");
|
||||
|
||||
bench_id = format!("{bench_prefix}::mul::{param_name}::{type_name}");
|
||||
bench_group.bench_function(&bench_id, |b| {
|
||||
b.iter(|| {
|
||||
let res = &lhs * &rhs;
|
||||
res.wait();
|
||||
black_box(res)
|
||||
})
|
||||
});
|
||||
write_record(bench_id, "mul");
|
||||
|
||||
bench_id = format!("{bench_prefix}::bitand::{param_name}::{type_name}");
|
||||
bench_group.bench_function(&bench_id, |b| {
|
||||
b.iter(|| {
|
||||
let res = &lhs & &rhs;
|
||||
res.wait();
|
||||
black_box(res)
|
||||
})
|
||||
});
|
||||
write_record(bench_id, "bitand");
|
||||
|
||||
bench_id = format!("{bench_prefix}::bitor::{param_name}::{type_name}");
|
||||
bench_group.bench_function(&bench_id, |b| {
|
||||
b.iter(|| {
|
||||
let res = &lhs | &rhs;
|
||||
res.wait();
|
||||
black_box(res)
|
||||
})
|
||||
});
|
||||
write_record(bench_id, "bitor");
|
||||
|
||||
bench_id = format!("{bench_prefix}::bitxor::{param_name}::{type_name}");
|
||||
bench_group.bench_function(&bench_id, |b| {
|
||||
b.iter(|| {
|
||||
let res = &lhs ^ &rhs;
|
||||
res.wait();
|
||||
black_box(res)
|
||||
})
|
||||
});
|
||||
write_record(bench_id, "bitxor");
|
||||
|
||||
bench_id = format!("{bench_prefix}::left_shift::{param_name}::{type_name}");
|
||||
bench_group.bench_function(&bench_id, |b| {
|
||||
b.iter(|| {
|
||||
let res = &lhs << &rhs;
|
||||
res.wait();
|
||||
black_box(res)
|
||||
})
|
||||
});
|
||||
write_record(bench_id, "left_shift");
|
||||
|
||||
bench_id = format!("{bench_prefix}::right_shift::{param_name}::{type_name}");
|
||||
bench_group.bench_function(&bench_id, |b| {
|
||||
b.iter(|| {
|
||||
let res = &lhs >> &rhs;
|
||||
res.wait();
|
||||
black_box(res)
|
||||
})
|
||||
});
|
||||
write_record(bench_id, "right_shift");
|
||||
|
||||
bench_id = format!("{bench_prefix}::left_rotate::{param_name}::{type_name}");
|
||||
bench_group.bench_function(&bench_id, |b| {
|
||||
b.iter(|| {
|
||||
let res = (&lhs).rotate_left(&rhs);
|
||||
res.wait();
|
||||
black_box(res)
|
||||
})
|
||||
});
|
||||
write_record(bench_id, "left_rotate");
|
||||
|
||||
bench_id = format!("{bench_prefix}::right_rotate::{param_name}::{type_name}");
|
||||
bench_group.bench_function(&bench_id, |b| {
|
||||
b.iter(|| {
|
||||
let res = (&lhs).rotate_right(&rhs);
|
||||
res.wait();
|
||||
black_box(res)
|
||||
})
|
||||
});
|
||||
write_record(bench_id, "right_rotate");
|
||||
|
||||
bench_id = format!("{bench_prefix}::min::{param_name}::{type_name}");
|
||||
bench_group.bench_function(&bench_id, |b| {
|
||||
b.iter(|| {
|
||||
let res = lhs.min(&rhs);
|
||||
res.wait();
|
||||
black_box(res)
|
||||
})
|
||||
});
|
||||
write_record(bench_id, "min");
|
||||
|
||||
bench_id = format!("{bench_prefix}::max::{param_name}::{type_name}");
|
||||
bench_group.bench_function(&bench_id, |b| {
|
||||
b.iter(|| {
|
||||
let res = lhs.max(&rhs);
|
||||
res.wait();
|
||||
black_box(res)
|
||||
})
|
||||
});
|
||||
write_record(bench_id, "max");
|
||||
write_record(bench_id, display_name);
|
||||
}
|
||||
|
||||
macro_rules! bench_type {
|
||||
($fhe_type:ident) => {
|
||||
macro_rules! bench_type_op (
|
||||
(type_name: $fhe_type:ident, display_name: $display_name:literal, operation: $op:ident) => {
|
||||
::paste::paste! {
|
||||
fn [<bench_ $fhe_type:snake>](c: &mut Criterion, cks: &ClientKey) {
|
||||
bench_fhe_type::<$fhe_type>(c, cks, stringify!($fhe_type), $fhe_type::num_bits());
|
||||
fn [<bench_ $fhe_type:snake _ $op>](c: &mut Criterion, cks: &ClientKey) {
|
||||
bench_fhe_type_op::<$fhe_type, _, _>(
|
||||
c,
|
||||
cks,
|
||||
stringify!($fhe_type),
|
||||
$fhe_type::num_bits(),
|
||||
$display_name,
|
||||
stringify!($op),
|
||||
|lhs, rhs| lhs.$op(rhs)
|
||||
);
|
||||
}
|
||||
}
|
||||
};
|
||||
);
|
||||
|
||||
macro_rules! generate_typed_benches {
|
||||
($fhe_type:ident) => {
|
||||
bench_type_op!(type_name: $fhe_type, display_name: "add", operation: add);
|
||||
bench_type_op!(type_name: $fhe_type, display_name: "overflowing_add", operation: overflowing_add);
|
||||
bench_type_op!(type_name: $fhe_type, display_name: "sub", operation: sub);
|
||||
bench_type_op!(type_name: $fhe_type, display_name: "overflowing_sub", operation: overflowing_sub);
|
||||
bench_type_op!(type_name: $fhe_type, display_name: "mul", operation: mul);
|
||||
bench_type_op!(type_name: $fhe_type, display_name: "bitand", operation: bitand);
|
||||
bench_type_op!(type_name: $fhe_type, display_name: "bitor", operation: bitor);
|
||||
bench_type_op!(type_name: $fhe_type, display_name: "bitxor", operation: bitxor);
|
||||
bench_type_op!(type_name: $fhe_type, display_name: "left_shift", operation: shl);
|
||||
bench_type_op!(type_name: $fhe_type, display_name: "right_shift", operation: shr);
|
||||
bench_type_op!(type_name: $fhe_type, display_name: "left_rotate", operation: rotate_left);
|
||||
bench_type_op!(type_name: $fhe_type, display_name: "right_rotate", operation: rotate_right);
|
||||
bench_type_op!(type_name: $fhe_type, display_name: "min", operation: min);
|
||||
bench_type_op!(type_name: $fhe_type, display_name: "max", operation: max);
|
||||
};
|
||||
}
|
||||
|
||||
bench_type!(FheUint2);
|
||||
bench_type!(FheUint4);
|
||||
bench_type!(FheUint6);
|
||||
bench_type!(FheUint8);
|
||||
bench_type!(FheUint10);
|
||||
bench_type!(FheUint12);
|
||||
bench_type!(FheUint14);
|
||||
bench_type!(FheUint16);
|
||||
bench_type!(FheUint32);
|
||||
bench_type!(FheUint64);
|
||||
bench_type!(FheUint128);
|
||||
// Generate benches for all FheUint types
|
||||
generate_typed_benches!(FheUint2);
|
||||
generate_typed_benches!(FheUint4);
|
||||
generate_typed_benches!(FheUint6);
|
||||
generate_typed_benches!(FheUint8);
|
||||
generate_typed_benches!(FheUint10);
|
||||
generate_typed_benches!(FheUint12);
|
||||
generate_typed_benches!(FheUint14);
|
||||
generate_typed_benches!(FheUint16);
|
||||
generate_typed_benches!(FheUint32);
|
||||
generate_typed_benches!(FheUint64);
|
||||
generate_typed_benches!(FheUint128);
|
||||
|
||||
macro_rules! run_benches {
|
||||
($c:expr, $cks:expr, $($fhe_type:ident),+ $(,)?) => {
|
||||
$(
|
||||
::paste::paste! {
|
||||
[<bench_ $fhe_type:snake _add>]($c, $cks);
|
||||
[<bench_ $fhe_type:snake _overflowing_add>]($c, $cks);
|
||||
[<bench_ $fhe_type:snake _sub>]($c, $cks);
|
||||
[<bench_ $fhe_type:snake _overflowing_sub>]($c, $cks);
|
||||
[<bench_ $fhe_type:snake _mul>]($c, $cks);
|
||||
[<bench_ $fhe_type:snake _bitand>]($c, $cks);
|
||||
[<bench_ $fhe_type:snake _bitor>]($c, $cks);
|
||||
[<bench_ $fhe_type:snake _bitxor>]($c, $cks);
|
||||
[<bench_ $fhe_type:snake _shl>]($c, $cks);
|
||||
[<bench_ $fhe_type:snake _shr>]($c, $cks);
|
||||
[<bench_ $fhe_type:snake _rotate_left>]($c, $cks);
|
||||
[<bench_ $fhe_type:snake _rotate_right>]($c, $cks);
|
||||
[<bench_ $fhe_type:snake _min>]($c, $cks);
|
||||
[<bench_ $fhe_type:snake _max>]($c, $cks);
|
||||
}
|
||||
)+
|
||||
};
|
||||
}
|
||||
|
||||
trait TypeDisplay {
|
||||
fn fmt(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
@@ -444,7 +372,7 @@ fn main() {
|
||||
|
||||
match env_config.bit_sizes_set {
|
||||
BitSizesSet::Fast => {
|
||||
bench_fhe_uint64(&mut c, &cks);
|
||||
run_benches!(&mut c, &cks, FheUint64);
|
||||
|
||||
// KVStore Benches
|
||||
if benched_device == tfhe::Device::Cpu {
|
||||
@@ -452,17 +380,11 @@ fn main() {
|
||||
}
|
||||
}
|
||||
_ => {
|
||||
bench_fhe_uint2(&mut c, &cks);
|
||||
bench_fhe_uint4(&mut c, &cks);
|
||||
bench_fhe_uint6(&mut c, &cks);
|
||||
bench_fhe_uint8(&mut c, &cks);
|
||||
bench_fhe_uint10(&mut c, &cks);
|
||||
bench_fhe_uint12(&mut c, &cks);
|
||||
bench_fhe_uint14(&mut c, &cks);
|
||||
bench_fhe_uint16(&mut c, &cks);
|
||||
bench_fhe_uint32(&mut c, &cks);
|
||||
bench_fhe_uint64(&mut c, &cks);
|
||||
bench_fhe_uint128(&mut c, &cks);
|
||||
// Call all benchmarks for all types
|
||||
run_benches!(
|
||||
&mut c, &cks, FheUint2, FheUint4, FheUint6, FheUint8, FheUint10, FheUint12,
|
||||
FheUint14, FheUint16, FheUint32, FheUint64, FheUint128
|
||||
);
|
||||
|
||||
// KVStore Benches
|
||||
if benched_device == tfhe::Device::Cpu {
|
||||
|
||||
Reference in New Issue
Block a user