From 3df542c5f813e78a838defd10a3cc1d26636027b Mon Sep 17 00:00:00 2001 From: tmontaigu Date: Wed, 23 Aug 2023 18:39:05 +0000 Subject: [PATCH] chore(bench): bench integer scalar ops up to 256 bits --- tfhe/benches/integer/bench.rs | 149 ++++++++++++++++++---------------- 1 file changed, 79 insertions(+), 70 deletions(-) diff --git a/tfhe/benches/integer/bench.rs b/tfhe/benches/integer/bench.rs index 662f6a7c8..73cbd3b78 100644 --- a/tfhe/benches/integer/bench.rs +++ b/tfhe/benches/integer/bench.rs @@ -15,12 +15,25 @@ use tfhe::integer::keycache::KEY_CACHE; use tfhe::integer::{RadixCiphertext, ServerKey}; use tfhe::keycache::NamedParam; +use tfhe::integer::U256; + #[allow(unused_imports)] use tfhe::shortint::parameters::{ PARAM_MESSAGE_1_CARRY_1_KS_PBS, PARAM_MESSAGE_2_CARRY_2_KS_PBS, PARAM_MESSAGE_3_CARRY_3_KS_PBS, PARAM_MESSAGE_4_CARRY_4_KS_PBS, PARAM_MULTI_BIT_MESSAGE_2_CARRY_2_GROUP_2_KS_PBS, }; +/// The type used to hold scalar values +/// It must be as big as the largest bit size tested +type ScalarType = U256; + +fn gen_random_u256(rng: &mut ThreadRng) -> U256 { + let clearlow = rng.gen::(); + let clearhigh = rng.gen::(); + + tfhe::integer::U256::from((clearlow, clearhigh)) +} + /// An iterator that yields a succession of combinations /// of parameters and a num_block to achieve a certain bit_size ciphertext /// in radix decomposition @@ -113,23 +126,17 @@ fn bench_server_key_binary_function_dirty_inputs( let (cks, sks) = KEY_CACHE.get_from_params(param); let encrypt_two_values = || { - let clearlow = rng.gen::(); - let clearhigh = rng.gen::(); - let clear_0 = tfhe::integer::U256::from((clearlow, clearhigh)); + let clear_0 = gen_random_u256(&mut rng); let mut ct_0 = cks.encrypt_radix(clear_0, num_block); - let clearlow = rng.gen::(); - let clearhigh = rng.gen::(); - let clear_1 = tfhe::integer::U256::from((clearlow, clearhigh)); + let clear_1 = gen_random_u256(&mut rng); let mut ct_1 = cks.encrypt_radix(clear_1, num_block); // Raise the degree, so as to ensure worst case path in operations let mut carry_mod = param.carry_modulus().0; while carry_mod > 0 { // Raise the degree, so as to ensure worst case path in operations - let clearlow = rng.gen::(); - let clearhigh = rng.gen::(); - let clear_2 = tfhe::integer::U256::from((clearlow, clearhigh)); + let clear_2 = gen_random_u256(&mut rng); let ct_2 = cks.encrypt_radix(clear_2, num_block); sks.unchecked_add_assign(&mut ct_0, &ct_2); sks.unchecked_add_assign(&mut ct_1, &ct_2); @@ -187,14 +194,10 @@ fn bench_server_key_binary_function_clean_inputs( let (cks, sks) = KEY_CACHE.get_from_params(param); let encrypt_two_values = || { - let clearlow = rng.gen::(); - let clearhigh = rng.gen::(); - let clear_0 = tfhe::integer::U256::from((clearlow, clearhigh)); + let clear_0 = gen_random_u256(&mut rng); let ct_0 = cks.encrypt_radix(clear_0, num_block); - let clearlow = rng.gen::(); - let clearhigh = rng.gen::(); - let clear_1 = tfhe::integer::U256::from((clearlow, clearhigh)); + let clear_1 = gen_random_u256(&mut rng); let ct_1 = cks.encrypt_radix(clear_1, num_block); (ct_0, ct_1) @@ -248,20 +251,14 @@ fn bench_server_key_unary_function_dirty_inputs( let (cks, sks) = KEY_CACHE.get_from_params(param); let encrypt_one_value = || { - let clearlow = rng.gen::(); - let clearhigh = rng.gen::(); - - let clear_0 = tfhe::integer::U256::from((clearlow, clearhigh)); - + let clear_0 = gen_random_u256(&mut rng); let mut ct_0 = cks.encrypt_radix(clear_0, num_block); // Raise the degree, so as to ensure worst case path in operations let mut carry_mod = param.carry_modulus().0; while carry_mod > 0 { // Raise the degree, so as to ensure worst case path in operations - let clearlow = rng.gen::(); - let clearhigh = rng.gen::(); - let clear_2 = tfhe::integer::U256::from((clearlow, clearhigh)); + let clear_2 = gen_random_u256(&mut rng); let ct_2 = cks.encrypt_radix(clear_2, num_block); sks.unchecked_add_assign(&mut ct_0, &ct_2); @@ -319,10 +316,7 @@ fn bench_server_key_unary_function_clean_inputs( let (cks, sks) = KEY_CACHE.get_from_params(param); let encrypt_one_value = || { - let clearlow = rng.gen::(); - let clearhigh = rng.gen::(); - - let clear_0 = tfhe::integer::U256::from((clearlow, clearhigh)); + let clear_0 = gen_random_u256(&mut rng); cks.encrypt_radix(clear_0, num_block) }; @@ -350,13 +344,15 @@ fn bench_server_key_unary_function_clean_inputs( bench_group.finish() } -fn bench_server_key_binary_scalar_function_dirty_inputs( +fn bench_server_key_binary_scalar_function_dirty_inputs( c: &mut Criterion, bench_name: &str, display_name: &str, binary_op: F, + rng_func: G, ) where - F: Fn(&ServerKey, &mut RadixCiphertext, u64), + F: Fn(&ServerKey, &mut RadixCiphertext, ScalarType), + G: Fn(&mut ThreadRng, usize) -> ScalarType, { let mut bench_group = c.benchmark_group(bench_name); bench_group @@ -367,15 +363,14 @@ fn bench_server_key_binary_scalar_function_dirty_inputs( for (param, num_block, bit_size) in ParamsAndNumBlocksIter::default() { let param_name = param.name(); + let max_value_for_bit_size = ScalarType::MAX >> (ScalarType::BITS as usize - bit_size); + let bench_id = format!("{bench_name}::{param_name}::{bit_size}_bits"); bench_group.bench_function(&bench_id, |b| { let (cks, sks) = KEY_CACHE.get_from_params(param); let encrypt_one_value = || { - let clearlow = rng.gen::(); - let clearhigh = rng.gen::(); - - let clear_0 = tfhe::integer::U256::from((clearlow, clearhigh)); + let clear_0 = gen_random_u256(&mut rng); let mut ct_0 = cks.encrypt_radix(clear_0, num_block); // Raise the degree, so as to ensure worst case path in operations @@ -391,7 +386,7 @@ fn bench_server_key_binary_scalar_function_dirty_inputs( carry_mod -= 1; } - let clear_1 = rng.gen::(); + let clear_1 = rng_func(&mut rng, bit_size) & max_value_for_bit_size; (ct_0, clear_1) }; @@ -426,8 +421,8 @@ fn bench_server_key_binary_scalar_function_clean_inputs( binary_op: F, rng_func: G, ) where - F: Fn(&ServerKey, &mut RadixCiphertext, u64), - G: Fn(&mut ThreadRng, usize) -> u64, + F: Fn(&ServerKey, &mut RadixCiphertext, ScalarType), + G: Fn(&mut ThreadRng, usize) -> ScalarType, { let mut bench_group = c.benchmark_group(bench_name); bench_group @@ -436,24 +431,22 @@ fn bench_server_key_binary_scalar_function_clean_inputs( let mut rng = rand::thread_rng(); for (param, num_block, bit_size) in ParamsAndNumBlocksIter::default() { - if bit_size > 64 { + if bit_size > ScalarType::BITS as usize { break; } let param_name = param.name(); + let max_value_for_bit_size = ScalarType::MAX >> (ScalarType::BITS as usize - bit_size); + let bench_id = format!("{bench_name}::{param_name}::{bit_size}_bits_scalar_{bit_size}"); bench_group.bench_function(&bench_id, |b| { let (cks, sks) = KEY_CACHE.get_from_params(param); let encrypt_one_value = || { - let clearlow = rng.gen::(); - let clearhigh = rng.gen::(); - - let clear_0 = tfhe::integer::U256::from((clearlow, clearhigh)); + let clear_0 = gen_random_u256(&mut rng); let ct_0 = cks.encrypt_radix(clear_0, num_block); - // Avoid overflow issues for u64 where we would take values mod 1 - let clear_1 = (rng_func(&mut rng, bit_size) as u128 % (1u128 << bit_size)) as u64; + let clear_1 = rng_func(&mut rng, bit_size) & max_value_for_bit_size; (ct_0, clear_1) }; @@ -482,18 +475,18 @@ fn bench_server_key_binary_scalar_function_clean_inputs( } // Functions used to apply different way of selecting a scalar based on the context. -fn default_scalar(rng: &mut ThreadRng, _clear_bit_size: usize) -> u64 { - rng.gen::() +fn default_scalar(rng: &mut ThreadRng, _clear_bit_size: usize) -> ScalarType { + gen_random_u256(rng) } -fn shift_scalar(_rng: &mut ThreadRng, _clear_bit_size: usize) -> u64 { +fn shift_scalar(_rng: &mut ThreadRng, _clear_bit_size: usize) -> ScalarType { // Shifting by one is the worst case scenario. - 1 + ScalarType::ONE } -fn mul_scalar(rng: &mut ThreadRng, _clear_bit_size: usize) -> u64 { +fn mul_scalar(rng: &mut ThreadRng, _clear_bit_size: usize) -> ScalarType { loop { - let scalar = rng.gen_range(3u64..=u64::MAX); + let scalar = gen_random_u256(rng); // If scalar is power of two, it is just a shit, which is an happy path. if !scalar.is_power_of_two() { return scalar; @@ -501,11 +494,12 @@ fn mul_scalar(rng: &mut ThreadRng, _clear_bit_size: usize) -> u64 { } } -fn div_scalar(rng: &mut ThreadRng, clear_bit_size: usize) -> u64 { +fn div_scalar(rng: &mut ThreadRng, clear_bit_size: usize) -> ScalarType { loop { - let scalar = rng.gen_range(1..=u64::MAX); - // Avoid overflow issues for u64 where we would take values mod 1 - if (scalar as u128 % (1u128 << clear_bit_size)) != 0 { + let scalar = gen_random_u256(rng); + let max_for_bit_size = ScalarType::MAX >> (ScalarType::BITS as usize - clear_bit_size); + let scalar = scalar & max_for_bit_size; + if scalar != ScalarType::ZERO { return scalar; } } @@ -529,14 +523,10 @@ fn if_then_else_parallelized(c: &mut Criterion) { let (cks, sks) = KEY_CACHE.get_from_params(param); let encrypt_tree_values = || { - let clearlow = rng.gen::(); - let clearhigh = rng.gen::(); - let clear_0 = tfhe::integer::U256::from((clearlow, clearhigh)); + let clear_0 = gen_random_u256(&mut rng); let ct_0 = cks.encrypt_radix(clear_0, num_block); - let clearlow = rng.gen::(); - let clearhigh = rng.gen::(); - let clear_1 = tfhe::integer::U256::from((clearlow, clearhigh)); + let clear_1 = gen_random_u256(&mut rng); let ct_1 = cks.encrypt_radix(clear_1, num_block); let cond = sks.create_trivial_radix(rng.gen_bool(0.5) as u64, num_block); @@ -624,7 +614,7 @@ macro_rules! define_server_key_bench_default_fn ( ); macro_rules! define_server_key_bench_scalar_fn ( - (method_name: $server_key_method:ident, display_name:$name:ident) => { + (method_name: $server_key_method:ident, display_name:$name:ident, rng_func:$($rng_fn:tt)*) => { fn $server_key_method(c: &mut Criterion) { bench_server_key_binary_scalar_function_dirty_inputs( c, @@ -632,7 +622,9 @@ macro_rules! define_server_key_bench_scalar_fn ( stringify!($name), |server_key, lhs, rhs| { server_key.$server_key_method(lhs, rhs); - }) + }, + $($rng_fn)* + ) } } ); @@ -646,7 +638,9 @@ macro_rules! define_server_key_bench_scalar_default_fn ( stringify!($name), |server_key, lhs, rhs| { server_key.$server_key_method(lhs, rhs); - }, $($rng_fn)*) + }, + $($rng_fn)* + ) } } ); @@ -696,21 +690,36 @@ define_server_key_bench_default_fn!( display_name: bitxor ); -define_server_key_bench_scalar_fn!(method_name: smart_scalar_add, display_name: add); -define_server_key_bench_scalar_fn!(method_name: smart_scalar_sub, display_name: sub); -define_server_key_bench_scalar_fn!(method_name: smart_scalar_mul, display_name: mul); +define_server_key_bench_scalar_fn!( + method_name: smart_scalar_add, + display_name: add, + rng_func: default_scalar +); +define_server_key_bench_scalar_fn!( + method_name: smart_scalar_sub, + display_name: sub, + rng_func: default_scalar +); +define_server_key_bench_scalar_fn!( + method_name: smart_scalar_mul, + display_name: mul, + rng_func: mul_scalar +); define_server_key_bench_scalar_fn!( method_name: smart_scalar_add_parallelized, - display_name: add + display_name: add, + rng_func: default_scalar ); define_server_key_bench_scalar_fn!( method_name: smart_scalar_sub_parallelized, - display_name: sub + display_name: sub, + rng_func: default_scalar, ); define_server_key_bench_scalar_fn!( method_name: smart_scalar_mul_parallelized, - display_name: mul + display_name: mul, + rng_func: mul_scalar ); define_server_key_bench_scalar_default_fn!( @@ -810,7 +819,7 @@ define_server_key_bench_scalar_default_fn!( rng_func: default_scalar ); define_server_key_bench_scalar_default_fn!( - method_name: unchecked_small_scalar_mul, + method_name: unchecked_scalar_mul_parallelized, display_name: mul, rng_func: mul_scalar ); @@ -1041,7 +1050,7 @@ criterion_group!( unchecked_scalar_ops, unchecked_scalar_add, unchecked_scalar_sub, - unchecked_small_scalar_mul, + unchecked_scalar_mul_parallelized, unchecked_bitand_parallelized, unchecked_bitor_parallelized, unchecked_bitxor_parallelized,