mirror of
https://github.com/zama-ai/tfhe-rs.git
synced 2026-01-07 22:04:10 -05:00
chore(bench): bench integer scalar ops up to 256 bits
This commit is contained in:
@@ -15,12 +15,25 @@ use tfhe::integer::keycache::KEY_CACHE;
|
||||
use tfhe::integer::{RadixCiphertext, ServerKey};
|
||||
use tfhe::keycache::NamedParam;
|
||||
|
||||
use tfhe::integer::U256;
|
||||
|
||||
#[allow(unused_imports)]
|
||||
use tfhe::shortint::parameters::{
|
||||
PARAM_MESSAGE_1_CARRY_1_KS_PBS, PARAM_MESSAGE_2_CARRY_2_KS_PBS, PARAM_MESSAGE_3_CARRY_3_KS_PBS,
|
||||
PARAM_MESSAGE_4_CARRY_4_KS_PBS, PARAM_MULTI_BIT_MESSAGE_2_CARRY_2_GROUP_2_KS_PBS,
|
||||
};
|
||||
|
||||
/// The type used to hold scalar values
|
||||
/// It must be as big as the largest bit size tested
|
||||
type ScalarType = U256;
|
||||
|
||||
fn gen_random_u256(rng: &mut ThreadRng) -> U256 {
|
||||
let clearlow = rng.gen::<u128>();
|
||||
let clearhigh = rng.gen::<u128>();
|
||||
|
||||
tfhe::integer::U256::from((clearlow, clearhigh))
|
||||
}
|
||||
|
||||
/// An iterator that yields a succession of combinations
|
||||
/// of parameters and a num_block to achieve a certain bit_size ciphertext
|
||||
/// in radix decomposition
|
||||
@@ -113,23 +126,17 @@ fn bench_server_key_binary_function_dirty_inputs<F>(
|
||||
let (cks, sks) = KEY_CACHE.get_from_params(param);
|
||||
|
||||
let encrypt_two_values = || {
|
||||
let clearlow = rng.gen::<u128>();
|
||||
let clearhigh = rng.gen::<u128>();
|
||||
let clear_0 = tfhe::integer::U256::from((clearlow, clearhigh));
|
||||
let clear_0 = gen_random_u256(&mut rng);
|
||||
let mut ct_0 = cks.encrypt_radix(clear_0, num_block);
|
||||
|
||||
let clearlow = rng.gen::<u128>();
|
||||
let clearhigh = rng.gen::<u128>();
|
||||
let clear_1 = tfhe::integer::U256::from((clearlow, clearhigh));
|
||||
let clear_1 = gen_random_u256(&mut rng);
|
||||
let mut ct_1 = cks.encrypt_radix(clear_1, num_block);
|
||||
|
||||
// Raise the degree, so as to ensure worst case path in operations
|
||||
let mut carry_mod = param.carry_modulus().0;
|
||||
while carry_mod > 0 {
|
||||
// Raise the degree, so as to ensure worst case path in operations
|
||||
let clearlow = rng.gen::<u128>();
|
||||
let clearhigh = rng.gen::<u128>();
|
||||
let clear_2 = tfhe::integer::U256::from((clearlow, clearhigh));
|
||||
let clear_2 = gen_random_u256(&mut rng);
|
||||
let ct_2 = cks.encrypt_radix(clear_2, num_block);
|
||||
sks.unchecked_add_assign(&mut ct_0, &ct_2);
|
||||
sks.unchecked_add_assign(&mut ct_1, &ct_2);
|
||||
@@ -187,14 +194,10 @@ fn bench_server_key_binary_function_clean_inputs<F>(
|
||||
let (cks, sks) = KEY_CACHE.get_from_params(param);
|
||||
|
||||
let encrypt_two_values = || {
|
||||
let clearlow = rng.gen::<u128>();
|
||||
let clearhigh = rng.gen::<u128>();
|
||||
let clear_0 = tfhe::integer::U256::from((clearlow, clearhigh));
|
||||
let clear_0 = gen_random_u256(&mut rng);
|
||||
let ct_0 = cks.encrypt_radix(clear_0, num_block);
|
||||
|
||||
let clearlow = rng.gen::<u128>();
|
||||
let clearhigh = rng.gen::<u128>();
|
||||
let clear_1 = tfhe::integer::U256::from((clearlow, clearhigh));
|
||||
let clear_1 = gen_random_u256(&mut rng);
|
||||
let ct_1 = cks.encrypt_radix(clear_1, num_block);
|
||||
|
||||
(ct_0, ct_1)
|
||||
@@ -248,20 +251,14 @@ fn bench_server_key_unary_function_dirty_inputs<F>(
|
||||
let (cks, sks) = KEY_CACHE.get_from_params(param);
|
||||
|
||||
let encrypt_one_value = || {
|
||||
let clearlow = rng.gen::<u128>();
|
||||
let clearhigh = rng.gen::<u128>();
|
||||
|
||||
let clear_0 = tfhe::integer::U256::from((clearlow, clearhigh));
|
||||
|
||||
let clear_0 = gen_random_u256(&mut rng);
|
||||
let mut ct_0 = cks.encrypt_radix(clear_0, num_block);
|
||||
|
||||
// Raise the degree, so as to ensure worst case path in operations
|
||||
let mut carry_mod = param.carry_modulus().0;
|
||||
while carry_mod > 0 {
|
||||
// Raise the degree, so as to ensure worst case path in operations
|
||||
let clearlow = rng.gen::<u128>();
|
||||
let clearhigh = rng.gen::<u128>();
|
||||
let clear_2 = tfhe::integer::U256::from((clearlow, clearhigh));
|
||||
let clear_2 = gen_random_u256(&mut rng);
|
||||
let ct_2 = cks.encrypt_radix(clear_2, num_block);
|
||||
sks.unchecked_add_assign(&mut ct_0, &ct_2);
|
||||
|
||||
@@ -319,10 +316,7 @@ fn bench_server_key_unary_function_clean_inputs<F>(
|
||||
let (cks, sks) = KEY_CACHE.get_from_params(param);
|
||||
|
||||
let encrypt_one_value = || {
|
||||
let clearlow = rng.gen::<u128>();
|
||||
let clearhigh = rng.gen::<u128>();
|
||||
|
||||
let clear_0 = tfhe::integer::U256::from((clearlow, clearhigh));
|
||||
let clear_0 = gen_random_u256(&mut rng);
|
||||
|
||||
cks.encrypt_radix(clear_0, num_block)
|
||||
};
|
||||
@@ -350,13 +344,15 @@ fn bench_server_key_unary_function_clean_inputs<F>(
|
||||
bench_group.finish()
|
||||
}
|
||||
|
||||
fn bench_server_key_binary_scalar_function_dirty_inputs<F>(
|
||||
fn bench_server_key_binary_scalar_function_dirty_inputs<F, G>(
|
||||
c: &mut Criterion,
|
||||
bench_name: &str,
|
||||
display_name: &str,
|
||||
binary_op: F,
|
||||
rng_func: G,
|
||||
) where
|
||||
F: Fn(&ServerKey, &mut RadixCiphertext, u64),
|
||||
F: Fn(&ServerKey, &mut RadixCiphertext, ScalarType),
|
||||
G: Fn(&mut ThreadRng, usize) -> ScalarType,
|
||||
{
|
||||
let mut bench_group = c.benchmark_group(bench_name);
|
||||
bench_group
|
||||
@@ -367,15 +363,14 @@ fn bench_server_key_binary_scalar_function_dirty_inputs<F>(
|
||||
for (param, num_block, bit_size) in ParamsAndNumBlocksIter::default() {
|
||||
let param_name = param.name();
|
||||
|
||||
let max_value_for_bit_size = ScalarType::MAX >> (ScalarType::BITS as usize - bit_size);
|
||||
|
||||
let bench_id = format!("{bench_name}::{param_name}::{bit_size}_bits");
|
||||
bench_group.bench_function(&bench_id, |b| {
|
||||
let (cks, sks) = KEY_CACHE.get_from_params(param);
|
||||
|
||||
let encrypt_one_value = || {
|
||||
let clearlow = rng.gen::<u128>();
|
||||
let clearhigh = rng.gen::<u128>();
|
||||
|
||||
let clear_0 = tfhe::integer::U256::from((clearlow, clearhigh));
|
||||
let clear_0 = gen_random_u256(&mut rng);
|
||||
let mut ct_0 = cks.encrypt_radix(clear_0, num_block);
|
||||
|
||||
// Raise the degree, so as to ensure worst case path in operations
|
||||
@@ -391,7 +386,7 @@ fn bench_server_key_binary_scalar_function_dirty_inputs<F>(
|
||||
carry_mod -= 1;
|
||||
}
|
||||
|
||||
let clear_1 = rng.gen::<u64>();
|
||||
let clear_1 = rng_func(&mut rng, bit_size) & max_value_for_bit_size;
|
||||
|
||||
(ct_0, clear_1)
|
||||
};
|
||||
@@ -426,8 +421,8 @@ fn bench_server_key_binary_scalar_function_clean_inputs<F, G>(
|
||||
binary_op: F,
|
||||
rng_func: G,
|
||||
) where
|
||||
F: Fn(&ServerKey, &mut RadixCiphertext, u64),
|
||||
G: Fn(&mut ThreadRng, usize) -> u64,
|
||||
F: Fn(&ServerKey, &mut RadixCiphertext, ScalarType),
|
||||
G: Fn(&mut ThreadRng, usize) -> ScalarType,
|
||||
{
|
||||
let mut bench_group = c.benchmark_group(bench_name);
|
||||
bench_group
|
||||
@@ -436,24 +431,22 @@ fn bench_server_key_binary_scalar_function_clean_inputs<F, G>(
|
||||
let mut rng = rand::thread_rng();
|
||||
|
||||
for (param, num_block, bit_size) in ParamsAndNumBlocksIter::default() {
|
||||
if bit_size > 64 {
|
||||
if bit_size > ScalarType::BITS as usize {
|
||||
break;
|
||||
}
|
||||
let param_name = param.name();
|
||||
|
||||
let max_value_for_bit_size = ScalarType::MAX >> (ScalarType::BITS as usize - bit_size);
|
||||
|
||||
let bench_id = format!("{bench_name}::{param_name}::{bit_size}_bits_scalar_{bit_size}");
|
||||
bench_group.bench_function(&bench_id, |b| {
|
||||
let (cks, sks) = KEY_CACHE.get_from_params(param);
|
||||
|
||||
let encrypt_one_value = || {
|
||||
let clearlow = rng.gen::<u128>();
|
||||
let clearhigh = rng.gen::<u128>();
|
||||
|
||||
let clear_0 = tfhe::integer::U256::from((clearlow, clearhigh));
|
||||
let clear_0 = gen_random_u256(&mut rng);
|
||||
let ct_0 = cks.encrypt_radix(clear_0, num_block);
|
||||
|
||||
// Avoid overflow issues for u64 where we would take values mod 1
|
||||
let clear_1 = (rng_func(&mut rng, bit_size) as u128 % (1u128 << bit_size)) as u64;
|
||||
let clear_1 = rng_func(&mut rng, bit_size) & max_value_for_bit_size;
|
||||
|
||||
(ct_0, clear_1)
|
||||
};
|
||||
@@ -482,18 +475,18 @@ fn bench_server_key_binary_scalar_function_clean_inputs<F, G>(
|
||||
}
|
||||
|
||||
// Functions used to apply different way of selecting a scalar based on the context.
|
||||
fn default_scalar(rng: &mut ThreadRng, _clear_bit_size: usize) -> u64 {
|
||||
rng.gen::<u64>()
|
||||
fn default_scalar(rng: &mut ThreadRng, _clear_bit_size: usize) -> ScalarType {
|
||||
gen_random_u256(rng)
|
||||
}
|
||||
|
||||
fn shift_scalar(_rng: &mut ThreadRng, _clear_bit_size: usize) -> u64 {
|
||||
fn shift_scalar(_rng: &mut ThreadRng, _clear_bit_size: usize) -> ScalarType {
|
||||
// Shifting by one is the worst case scenario.
|
||||
1
|
||||
ScalarType::ONE
|
||||
}
|
||||
|
||||
fn mul_scalar(rng: &mut ThreadRng, _clear_bit_size: usize) -> u64 {
|
||||
fn mul_scalar(rng: &mut ThreadRng, _clear_bit_size: usize) -> ScalarType {
|
||||
loop {
|
||||
let scalar = rng.gen_range(3u64..=u64::MAX);
|
||||
let scalar = gen_random_u256(rng);
|
||||
// If scalar is power of two, it is just a shit, which is an happy path.
|
||||
if !scalar.is_power_of_two() {
|
||||
return scalar;
|
||||
@@ -501,11 +494,12 @@ fn mul_scalar(rng: &mut ThreadRng, _clear_bit_size: usize) -> u64 {
|
||||
}
|
||||
}
|
||||
|
||||
fn div_scalar(rng: &mut ThreadRng, clear_bit_size: usize) -> u64 {
|
||||
fn div_scalar(rng: &mut ThreadRng, clear_bit_size: usize) -> ScalarType {
|
||||
loop {
|
||||
let scalar = rng.gen_range(1..=u64::MAX);
|
||||
// Avoid overflow issues for u64 where we would take values mod 1
|
||||
if (scalar as u128 % (1u128 << clear_bit_size)) != 0 {
|
||||
let scalar = gen_random_u256(rng);
|
||||
let max_for_bit_size = ScalarType::MAX >> (ScalarType::BITS as usize - clear_bit_size);
|
||||
let scalar = scalar & max_for_bit_size;
|
||||
if scalar != ScalarType::ZERO {
|
||||
return scalar;
|
||||
}
|
||||
}
|
||||
@@ -529,14 +523,10 @@ fn if_then_else_parallelized(c: &mut Criterion) {
|
||||
let (cks, sks) = KEY_CACHE.get_from_params(param);
|
||||
|
||||
let encrypt_tree_values = || {
|
||||
let clearlow = rng.gen::<u128>();
|
||||
let clearhigh = rng.gen::<u128>();
|
||||
let clear_0 = tfhe::integer::U256::from((clearlow, clearhigh));
|
||||
let clear_0 = gen_random_u256(&mut rng);
|
||||
let ct_0 = cks.encrypt_radix(clear_0, num_block);
|
||||
|
||||
let clearlow = rng.gen::<u128>();
|
||||
let clearhigh = rng.gen::<u128>();
|
||||
let clear_1 = tfhe::integer::U256::from((clearlow, clearhigh));
|
||||
let clear_1 = gen_random_u256(&mut rng);
|
||||
let ct_1 = cks.encrypt_radix(clear_1, num_block);
|
||||
|
||||
let cond = sks.create_trivial_radix(rng.gen_bool(0.5) as u64, num_block);
|
||||
@@ -624,7 +614,7 @@ macro_rules! define_server_key_bench_default_fn (
|
||||
);
|
||||
|
||||
macro_rules! define_server_key_bench_scalar_fn (
|
||||
(method_name: $server_key_method:ident, display_name:$name:ident) => {
|
||||
(method_name: $server_key_method:ident, display_name:$name:ident, rng_func:$($rng_fn:tt)*) => {
|
||||
fn $server_key_method(c: &mut Criterion) {
|
||||
bench_server_key_binary_scalar_function_dirty_inputs(
|
||||
c,
|
||||
@@ -632,7 +622,9 @@ macro_rules! define_server_key_bench_scalar_fn (
|
||||
stringify!($name),
|
||||
|server_key, lhs, rhs| {
|
||||
server_key.$server_key_method(lhs, rhs);
|
||||
})
|
||||
},
|
||||
$($rng_fn)*
|
||||
)
|
||||
}
|
||||
}
|
||||
);
|
||||
@@ -646,7 +638,9 @@ macro_rules! define_server_key_bench_scalar_default_fn (
|
||||
stringify!($name),
|
||||
|server_key, lhs, rhs| {
|
||||
server_key.$server_key_method(lhs, rhs);
|
||||
}, $($rng_fn)*)
|
||||
},
|
||||
$($rng_fn)*
|
||||
)
|
||||
}
|
||||
}
|
||||
);
|
||||
@@ -696,21 +690,36 @@ define_server_key_bench_default_fn!(
|
||||
display_name: bitxor
|
||||
);
|
||||
|
||||
define_server_key_bench_scalar_fn!(method_name: smart_scalar_add, display_name: add);
|
||||
define_server_key_bench_scalar_fn!(method_name: smart_scalar_sub, display_name: sub);
|
||||
define_server_key_bench_scalar_fn!(method_name: smart_scalar_mul, display_name: mul);
|
||||
define_server_key_bench_scalar_fn!(
|
||||
method_name: smart_scalar_add,
|
||||
display_name: add,
|
||||
rng_func: default_scalar
|
||||
);
|
||||
define_server_key_bench_scalar_fn!(
|
||||
method_name: smart_scalar_sub,
|
||||
display_name: sub,
|
||||
rng_func: default_scalar
|
||||
);
|
||||
define_server_key_bench_scalar_fn!(
|
||||
method_name: smart_scalar_mul,
|
||||
display_name: mul,
|
||||
rng_func: mul_scalar
|
||||
);
|
||||
|
||||
define_server_key_bench_scalar_fn!(
|
||||
method_name: smart_scalar_add_parallelized,
|
||||
display_name: add
|
||||
display_name: add,
|
||||
rng_func: default_scalar
|
||||
);
|
||||
define_server_key_bench_scalar_fn!(
|
||||
method_name: smart_scalar_sub_parallelized,
|
||||
display_name: sub
|
||||
display_name: sub,
|
||||
rng_func: default_scalar,
|
||||
);
|
||||
define_server_key_bench_scalar_fn!(
|
||||
method_name: smart_scalar_mul_parallelized,
|
||||
display_name: mul
|
||||
display_name: mul,
|
||||
rng_func: mul_scalar
|
||||
);
|
||||
|
||||
define_server_key_bench_scalar_default_fn!(
|
||||
@@ -810,7 +819,7 @@ define_server_key_bench_scalar_default_fn!(
|
||||
rng_func: default_scalar
|
||||
);
|
||||
define_server_key_bench_scalar_default_fn!(
|
||||
method_name: unchecked_small_scalar_mul,
|
||||
method_name: unchecked_scalar_mul_parallelized,
|
||||
display_name: mul,
|
||||
rng_func: mul_scalar
|
||||
);
|
||||
@@ -1041,7 +1050,7 @@ criterion_group!(
|
||||
unchecked_scalar_ops,
|
||||
unchecked_scalar_add,
|
||||
unchecked_scalar_sub,
|
||||
unchecked_small_scalar_mul,
|
||||
unchecked_scalar_mul_parallelized,
|
||||
unchecked_bitand_parallelized,
|
||||
unchecked_bitor_parallelized,
|
||||
unchecked_bitxor_parallelized,
|
||||
|
||||
Reference in New Issue
Block a user