chore(bench): bench integer scalar ops up to 256 bits

This commit is contained in:
tmontaigu
2023-08-23 18:39:05 +00:00
parent 37623eedf3
commit 3df542c5f8

View File

@@ -15,12 +15,25 @@ use tfhe::integer::keycache::KEY_CACHE;
use tfhe::integer::{RadixCiphertext, ServerKey};
use tfhe::keycache::NamedParam;
use tfhe::integer::U256;
#[allow(unused_imports)]
use tfhe::shortint::parameters::{
PARAM_MESSAGE_1_CARRY_1_KS_PBS, PARAM_MESSAGE_2_CARRY_2_KS_PBS, PARAM_MESSAGE_3_CARRY_3_KS_PBS,
PARAM_MESSAGE_4_CARRY_4_KS_PBS, PARAM_MULTI_BIT_MESSAGE_2_CARRY_2_GROUP_2_KS_PBS,
};
/// The type used to hold scalar values
/// It must be as big as the largest bit size tested
type ScalarType = U256;
fn gen_random_u256(rng: &mut ThreadRng) -> U256 {
let clearlow = rng.gen::<u128>();
let clearhigh = rng.gen::<u128>();
tfhe::integer::U256::from((clearlow, clearhigh))
}
/// An iterator that yields a succession of combinations
/// of parameters and a num_block to achieve a certain bit_size ciphertext
/// in radix decomposition
@@ -113,23 +126,17 @@ fn bench_server_key_binary_function_dirty_inputs<F>(
let (cks, sks) = KEY_CACHE.get_from_params(param);
let encrypt_two_values = || {
let clearlow = rng.gen::<u128>();
let clearhigh = rng.gen::<u128>();
let clear_0 = tfhe::integer::U256::from((clearlow, clearhigh));
let clear_0 = gen_random_u256(&mut rng);
let mut ct_0 = cks.encrypt_radix(clear_0, num_block);
let clearlow = rng.gen::<u128>();
let clearhigh = rng.gen::<u128>();
let clear_1 = tfhe::integer::U256::from((clearlow, clearhigh));
let clear_1 = gen_random_u256(&mut rng);
let mut ct_1 = cks.encrypt_radix(clear_1, num_block);
// Raise the degree, so as to ensure worst case path in operations
let mut carry_mod = param.carry_modulus().0;
while carry_mod > 0 {
// Raise the degree, so as to ensure worst case path in operations
let clearlow = rng.gen::<u128>();
let clearhigh = rng.gen::<u128>();
let clear_2 = tfhe::integer::U256::from((clearlow, clearhigh));
let clear_2 = gen_random_u256(&mut rng);
let ct_2 = cks.encrypt_radix(clear_2, num_block);
sks.unchecked_add_assign(&mut ct_0, &ct_2);
sks.unchecked_add_assign(&mut ct_1, &ct_2);
@@ -187,14 +194,10 @@ fn bench_server_key_binary_function_clean_inputs<F>(
let (cks, sks) = KEY_CACHE.get_from_params(param);
let encrypt_two_values = || {
let clearlow = rng.gen::<u128>();
let clearhigh = rng.gen::<u128>();
let clear_0 = tfhe::integer::U256::from((clearlow, clearhigh));
let clear_0 = gen_random_u256(&mut rng);
let ct_0 = cks.encrypt_radix(clear_0, num_block);
let clearlow = rng.gen::<u128>();
let clearhigh = rng.gen::<u128>();
let clear_1 = tfhe::integer::U256::from((clearlow, clearhigh));
let clear_1 = gen_random_u256(&mut rng);
let ct_1 = cks.encrypt_radix(clear_1, num_block);
(ct_0, ct_1)
@@ -248,20 +251,14 @@ fn bench_server_key_unary_function_dirty_inputs<F>(
let (cks, sks) = KEY_CACHE.get_from_params(param);
let encrypt_one_value = || {
let clearlow = rng.gen::<u128>();
let clearhigh = rng.gen::<u128>();
let clear_0 = tfhe::integer::U256::from((clearlow, clearhigh));
let clear_0 = gen_random_u256(&mut rng);
let mut ct_0 = cks.encrypt_radix(clear_0, num_block);
// Raise the degree, so as to ensure worst case path in operations
let mut carry_mod = param.carry_modulus().0;
while carry_mod > 0 {
// Raise the degree, so as to ensure worst case path in operations
let clearlow = rng.gen::<u128>();
let clearhigh = rng.gen::<u128>();
let clear_2 = tfhe::integer::U256::from((clearlow, clearhigh));
let clear_2 = gen_random_u256(&mut rng);
let ct_2 = cks.encrypt_radix(clear_2, num_block);
sks.unchecked_add_assign(&mut ct_0, &ct_2);
@@ -319,10 +316,7 @@ fn bench_server_key_unary_function_clean_inputs<F>(
let (cks, sks) = KEY_CACHE.get_from_params(param);
let encrypt_one_value = || {
let clearlow = rng.gen::<u128>();
let clearhigh = rng.gen::<u128>();
let clear_0 = tfhe::integer::U256::from((clearlow, clearhigh));
let clear_0 = gen_random_u256(&mut rng);
cks.encrypt_radix(clear_0, num_block)
};
@@ -350,13 +344,15 @@ fn bench_server_key_unary_function_clean_inputs<F>(
bench_group.finish()
}
fn bench_server_key_binary_scalar_function_dirty_inputs<F>(
fn bench_server_key_binary_scalar_function_dirty_inputs<F, G>(
c: &mut Criterion,
bench_name: &str,
display_name: &str,
binary_op: F,
rng_func: G,
) where
F: Fn(&ServerKey, &mut RadixCiphertext, u64),
F: Fn(&ServerKey, &mut RadixCiphertext, ScalarType),
G: Fn(&mut ThreadRng, usize) -> ScalarType,
{
let mut bench_group = c.benchmark_group(bench_name);
bench_group
@@ -367,15 +363,14 @@ fn bench_server_key_binary_scalar_function_dirty_inputs<F>(
for (param, num_block, bit_size) in ParamsAndNumBlocksIter::default() {
let param_name = param.name();
let max_value_for_bit_size = ScalarType::MAX >> (ScalarType::BITS as usize - bit_size);
let bench_id = format!("{bench_name}::{param_name}::{bit_size}_bits");
bench_group.bench_function(&bench_id, |b| {
let (cks, sks) = KEY_CACHE.get_from_params(param);
let encrypt_one_value = || {
let clearlow = rng.gen::<u128>();
let clearhigh = rng.gen::<u128>();
let clear_0 = tfhe::integer::U256::from((clearlow, clearhigh));
let clear_0 = gen_random_u256(&mut rng);
let mut ct_0 = cks.encrypt_radix(clear_0, num_block);
// Raise the degree, so as to ensure worst case path in operations
@@ -391,7 +386,7 @@ fn bench_server_key_binary_scalar_function_dirty_inputs<F>(
carry_mod -= 1;
}
let clear_1 = rng.gen::<u64>();
let clear_1 = rng_func(&mut rng, bit_size) & max_value_for_bit_size;
(ct_0, clear_1)
};
@@ -426,8 +421,8 @@ fn bench_server_key_binary_scalar_function_clean_inputs<F, G>(
binary_op: F,
rng_func: G,
) where
F: Fn(&ServerKey, &mut RadixCiphertext, u64),
G: Fn(&mut ThreadRng, usize) -> u64,
F: Fn(&ServerKey, &mut RadixCiphertext, ScalarType),
G: Fn(&mut ThreadRng, usize) -> ScalarType,
{
let mut bench_group = c.benchmark_group(bench_name);
bench_group
@@ -436,24 +431,22 @@ fn bench_server_key_binary_scalar_function_clean_inputs<F, G>(
let mut rng = rand::thread_rng();
for (param, num_block, bit_size) in ParamsAndNumBlocksIter::default() {
if bit_size > 64 {
if bit_size > ScalarType::BITS as usize {
break;
}
let param_name = param.name();
let max_value_for_bit_size = ScalarType::MAX >> (ScalarType::BITS as usize - bit_size);
let bench_id = format!("{bench_name}::{param_name}::{bit_size}_bits_scalar_{bit_size}");
bench_group.bench_function(&bench_id, |b| {
let (cks, sks) = KEY_CACHE.get_from_params(param);
let encrypt_one_value = || {
let clearlow = rng.gen::<u128>();
let clearhigh = rng.gen::<u128>();
let clear_0 = tfhe::integer::U256::from((clearlow, clearhigh));
let clear_0 = gen_random_u256(&mut rng);
let ct_0 = cks.encrypt_radix(clear_0, num_block);
// Avoid overflow issues for u64 where we would take values mod 1
let clear_1 = (rng_func(&mut rng, bit_size) as u128 % (1u128 << bit_size)) as u64;
let clear_1 = rng_func(&mut rng, bit_size) & max_value_for_bit_size;
(ct_0, clear_1)
};
@@ -482,18 +475,18 @@ fn bench_server_key_binary_scalar_function_clean_inputs<F, G>(
}
// Functions used to apply different way of selecting a scalar based on the context.
fn default_scalar(rng: &mut ThreadRng, _clear_bit_size: usize) -> u64 {
rng.gen::<u64>()
fn default_scalar(rng: &mut ThreadRng, _clear_bit_size: usize) -> ScalarType {
gen_random_u256(rng)
}
fn shift_scalar(_rng: &mut ThreadRng, _clear_bit_size: usize) -> u64 {
fn shift_scalar(_rng: &mut ThreadRng, _clear_bit_size: usize) -> ScalarType {
// Shifting by one is the worst case scenario.
1
ScalarType::ONE
}
fn mul_scalar(rng: &mut ThreadRng, _clear_bit_size: usize) -> u64 {
fn mul_scalar(rng: &mut ThreadRng, _clear_bit_size: usize) -> ScalarType {
loop {
let scalar = rng.gen_range(3u64..=u64::MAX);
let scalar = gen_random_u256(rng);
// If scalar is power of two, it is just a shit, which is an happy path.
if !scalar.is_power_of_two() {
return scalar;
@@ -501,11 +494,12 @@ fn mul_scalar(rng: &mut ThreadRng, _clear_bit_size: usize) -> u64 {
}
}
fn div_scalar(rng: &mut ThreadRng, clear_bit_size: usize) -> u64 {
fn div_scalar(rng: &mut ThreadRng, clear_bit_size: usize) -> ScalarType {
loop {
let scalar = rng.gen_range(1..=u64::MAX);
// Avoid overflow issues for u64 where we would take values mod 1
if (scalar as u128 % (1u128 << clear_bit_size)) != 0 {
let scalar = gen_random_u256(rng);
let max_for_bit_size = ScalarType::MAX >> (ScalarType::BITS as usize - clear_bit_size);
let scalar = scalar & max_for_bit_size;
if scalar != ScalarType::ZERO {
return scalar;
}
}
@@ -529,14 +523,10 @@ fn if_then_else_parallelized(c: &mut Criterion) {
let (cks, sks) = KEY_CACHE.get_from_params(param);
let encrypt_tree_values = || {
let clearlow = rng.gen::<u128>();
let clearhigh = rng.gen::<u128>();
let clear_0 = tfhe::integer::U256::from((clearlow, clearhigh));
let clear_0 = gen_random_u256(&mut rng);
let ct_0 = cks.encrypt_radix(clear_0, num_block);
let clearlow = rng.gen::<u128>();
let clearhigh = rng.gen::<u128>();
let clear_1 = tfhe::integer::U256::from((clearlow, clearhigh));
let clear_1 = gen_random_u256(&mut rng);
let ct_1 = cks.encrypt_radix(clear_1, num_block);
let cond = sks.create_trivial_radix(rng.gen_bool(0.5) as u64, num_block);
@@ -624,7 +614,7 @@ macro_rules! define_server_key_bench_default_fn (
);
macro_rules! define_server_key_bench_scalar_fn (
(method_name: $server_key_method:ident, display_name:$name:ident) => {
(method_name: $server_key_method:ident, display_name:$name:ident, rng_func:$($rng_fn:tt)*) => {
fn $server_key_method(c: &mut Criterion) {
bench_server_key_binary_scalar_function_dirty_inputs(
c,
@@ -632,7 +622,9 @@ macro_rules! define_server_key_bench_scalar_fn (
stringify!($name),
|server_key, lhs, rhs| {
server_key.$server_key_method(lhs, rhs);
})
},
$($rng_fn)*
)
}
}
);
@@ -646,7 +638,9 @@ macro_rules! define_server_key_bench_scalar_default_fn (
stringify!($name),
|server_key, lhs, rhs| {
server_key.$server_key_method(lhs, rhs);
}, $($rng_fn)*)
},
$($rng_fn)*
)
}
}
);
@@ -696,21 +690,36 @@ define_server_key_bench_default_fn!(
display_name: bitxor
);
define_server_key_bench_scalar_fn!(method_name: smart_scalar_add, display_name: add);
define_server_key_bench_scalar_fn!(method_name: smart_scalar_sub, display_name: sub);
define_server_key_bench_scalar_fn!(method_name: smart_scalar_mul, display_name: mul);
define_server_key_bench_scalar_fn!(
method_name: smart_scalar_add,
display_name: add,
rng_func: default_scalar
);
define_server_key_bench_scalar_fn!(
method_name: smart_scalar_sub,
display_name: sub,
rng_func: default_scalar
);
define_server_key_bench_scalar_fn!(
method_name: smart_scalar_mul,
display_name: mul,
rng_func: mul_scalar
);
define_server_key_bench_scalar_fn!(
method_name: smart_scalar_add_parallelized,
display_name: add
display_name: add,
rng_func: default_scalar
);
define_server_key_bench_scalar_fn!(
method_name: smart_scalar_sub_parallelized,
display_name: sub
display_name: sub,
rng_func: default_scalar,
);
define_server_key_bench_scalar_fn!(
method_name: smart_scalar_mul_parallelized,
display_name: mul
display_name: mul,
rng_func: mul_scalar
);
define_server_key_bench_scalar_default_fn!(
@@ -810,7 +819,7 @@ define_server_key_bench_scalar_default_fn!(
rng_func: default_scalar
);
define_server_key_bench_scalar_default_fn!(
method_name: unchecked_small_scalar_mul,
method_name: unchecked_scalar_mul_parallelized,
display_name: mul,
rng_func: mul_scalar
);
@@ -1041,7 +1050,7 @@ criterion_group!(
unchecked_scalar_ops,
unchecked_scalar_add,
unchecked_scalar_sub,
unchecked_small_scalar_mul,
unchecked_scalar_mul_parallelized,
unchecked_bitand_parallelized,
unchecked_bitor_parallelized,
unchecked_bitxor_parallelized,