mirror of
https://github.com/zama-ai/tfhe-rs.git
synced 2026-01-09 14:47:56 -05:00
feat(integer): add flip operation
Add the flip(condition: BooleanBlock, a: T, b: T) -> (T, T) operation that homomorphically flip/swap two values if the given encrypted boolean encrypts true
This commit is contained in:
@@ -677,6 +677,114 @@ fn if_then_else_parallelized(c: &mut Criterion) {
|
||||
bench_group.finish()
|
||||
}
|
||||
|
||||
fn flip_parallelized(c: &mut Criterion) {
|
||||
let bench_name = "integer::flip_parallelized";
|
||||
let display_name = "flip";
|
||||
|
||||
let mut bench_group = c.benchmark_group(bench_name);
|
||||
bench_group
|
||||
.sample_size(15)
|
||||
.measurement_time(std::time::Duration::from_secs(60));
|
||||
let mut rng = rand::thread_rng();
|
||||
|
||||
for (param, num_block, bit_size) in ParamsAndNumBlocksIter::default() {
|
||||
let param_name = param.name();
|
||||
|
||||
let bench_id;
|
||||
|
||||
match get_bench_type() {
|
||||
BenchmarkType::Latency => {
|
||||
let bench_data = LazyCell::new(|| {
|
||||
let (cks, sks) = KEY_CACHE.get_from_params(param, IntegerKeyKind::Radix);
|
||||
|
||||
let clear_0 = gen_random_u256(&mut rng);
|
||||
let clear_1 = gen_random_u256(&mut rng);
|
||||
let clear_cond = rng.gen_bool(0.5);
|
||||
|
||||
let true_ct = cks.encrypt_radix(clear_0, num_block);
|
||||
let false_ct = cks.encrypt_radix(clear_1, num_block);
|
||||
let condition = cks.encrypt_bool(clear_cond);
|
||||
|
||||
(sks, condition, true_ct, false_ct)
|
||||
});
|
||||
|
||||
bench_id = format!("{bench_name}::{param_name}::{bit_size}_bits");
|
||||
bench_group.bench_function(&bench_id, |b| {
|
||||
let (sks, condition, true_ct, false_ct) =
|
||||
(&bench_data.0, &bench_data.1, &bench_data.2, &bench_data.3);
|
||||
|
||||
b.iter(|| sks.flip_parallelized(condition, true_ct, false_ct))
|
||||
});
|
||||
}
|
||||
BenchmarkType::Throughput => {
|
||||
let (cks, sks) = KEY_CACHE.get_from_params(param, IntegerKeyKind::Radix);
|
||||
|
||||
// Execute the operation once to know its cost.
|
||||
let clear_0 = gen_random_u256(&mut rng);
|
||||
let true_ct = cks.encrypt_radix(clear_0, num_block);
|
||||
|
||||
let clear_1 = gen_random_u256(&mut rng);
|
||||
let false_ct = cks.encrypt_radix(clear_1, num_block);
|
||||
|
||||
let condition = sks.create_trivial_boolean_block(rng.gen_bool(0.5));
|
||||
|
||||
reset_pbs_count();
|
||||
sks.flip_parallelized(&condition, &true_ct, &false_ct);
|
||||
let pbs_count = max(get_pbs_count(), 1); // Operation might not perform any PBS, so we take 1 as default
|
||||
|
||||
bench_id = format!("{bench_name}::throughput::{param_name}::{bit_size}_bits");
|
||||
bench_group
|
||||
.sample_size(10)
|
||||
.measurement_time(std::time::Duration::from_secs(30));
|
||||
let elements = throughput_num_threads(num_block, pbs_count);
|
||||
bench_group.throughput(Throughput::Elements(elements));
|
||||
bench_group.bench_function(&bench_id, |b| {
|
||||
let setup_encrypted_values = || {
|
||||
let cts_cond = (0..elements)
|
||||
.map(|_| sks.create_trivial_boolean_block(rng.gen_bool(0.5)))
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
let cts_then = (0..elements)
|
||||
.map(|_| cks.encrypt_radix(gen_random_u256(&mut rng), num_block))
|
||||
.collect::<Vec<_>>();
|
||||
let cts_else = (0..elements)
|
||||
.map(|_| cks.encrypt_radix(gen_random_u256(&mut rng), num_block))
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
(cts_cond, cts_then, cts_else)
|
||||
};
|
||||
|
||||
b.iter_batched(
|
||||
setup_encrypted_values,
|
||||
|(cts_cond, cts_then, cts_else)| {
|
||||
cts_cond
|
||||
.par_iter()
|
||||
.zip(cts_then.par_iter())
|
||||
.zip(cts_else.par_iter())
|
||||
.for_each(|((condition, true_ct), false_ct)| {
|
||||
sks.flip_parallelized(condition, true_ct, false_ct);
|
||||
})
|
||||
},
|
||||
criterion::BatchSize::SmallInput,
|
||||
);
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
write_to_json::<u64, _>(
|
||||
&bench_id,
|
||||
param,
|
||||
param.name(),
|
||||
display_name,
|
||||
&OperatorType::Atomic,
|
||||
bit_size as u32,
|
||||
vec![param.message_modulus().0.ilog2(); num_block],
|
||||
);
|
||||
}
|
||||
|
||||
bench_group.finish()
|
||||
}
|
||||
|
||||
fn ciphertexts_sum_parallelized(c: &mut Criterion) {
|
||||
let bench_name = "integer::sum_ciphertexts_parallelized";
|
||||
let display_name = "sum_ctxts";
|
||||
@@ -3350,6 +3458,7 @@ criterion_group!(
|
||||
gt_parallelized,
|
||||
ge_parallelized,
|
||||
if_then_else_parallelized,
|
||||
flip_parallelized,
|
||||
);
|
||||
|
||||
criterion_group!(
|
||||
@@ -3365,6 +3474,7 @@ criterion_group!(
|
||||
eq_parallelized,
|
||||
gt_parallelized,
|
||||
if_then_else_parallelized,
|
||||
flip_parallelized
|
||||
);
|
||||
|
||||
criterion_group!(
|
||||
|
||||
@@ -2,6 +2,7 @@ use crate::integer::block_decomposition::{BlockDecomposer, DecomposableInto};
|
||||
use crate::integer::ciphertext::boolean_value::BooleanBlock;
|
||||
use crate::integer::ciphertext::IntegerRadixCiphertext;
|
||||
use crate::integer::{RadixCiphertext, ServerKey, SignedRadixCiphertext};
|
||||
use crate::shortint::Ciphertext;
|
||||
use rayon::prelude::*;
|
||||
|
||||
pub trait ServerKeyDefaultCMux<TrueCt, FalseCt> {
|
||||
@@ -30,6 +31,13 @@ pub trait ServerKeyDefaultCMux<TrueCt, FalseCt> {
|
||||
) -> Self::Output {
|
||||
self.if_then_else_parallelized(condition, true_ct, false_ct)
|
||||
}
|
||||
|
||||
fn flip_parallelized(
|
||||
&self,
|
||||
condition: &BooleanBlock,
|
||||
true_ct: TrueCt,
|
||||
false_ct: FalseCt,
|
||||
) -> (Self::Output, Self::Output);
|
||||
}
|
||||
|
||||
impl<T> ServerKeyDefaultCMux<&T, &T> for ServerKey
|
||||
@@ -100,6 +108,110 @@ where
|
||||
let [true_ct, false_ct] = ct_refs;
|
||||
self.unchecked_if_then_else_parallelized(condition, true_ct, false_ct)
|
||||
}
|
||||
|
||||
fn flip_parallelized(
|
||||
&self,
|
||||
condition: &BooleanBlock,
|
||||
a: &T,
|
||||
b: &T,
|
||||
) -> (Self::Output, Self::Output) {
|
||||
assert_eq!(
|
||||
a.blocks().len(),
|
||||
b.blocks().len(),
|
||||
"Inputs must have the same number of blocks"
|
||||
);
|
||||
|
||||
// To make use if many_lut, we require 1 bit, 1 more bit is required to pack
|
||||
// the condition. Thus 2 bits of carry are required.
|
||||
//
|
||||
// Otherwise we call if_then_else twice, which is less efficient.
|
||||
if self.carry_modulus().0 < (1 << 2) {
|
||||
return rayon::join(
|
||||
|| self.if_then_else_parallelized(condition, b, a),
|
||||
|| self.if_then_else_parallelized(condition, a, b),
|
||||
);
|
||||
}
|
||||
|
||||
let (a, b) = rayon::join(
|
||||
|| self.clean_for_default_op(a),
|
||||
|| self.clean_for_default_op(b),
|
||||
);
|
||||
|
||||
let zero_out_if_true_fn = |packed| {
|
||||
let condition = (packed / self.message_modulus().0) & 1;
|
||||
let value = packed % self.message_modulus().0;
|
||||
(1 - condition) * value
|
||||
};
|
||||
|
||||
let zero_out_if_false_fn = |packed| {
|
||||
let condition = (packed / self.message_modulus().0) & 1;
|
||||
let value = packed % self.message_modulus().0;
|
||||
condition * value
|
||||
};
|
||||
|
||||
let lut = self
|
||||
.key
|
||||
.generate_many_lookup_table(&[&zero_out_if_true_fn, &zero_out_if_false_fn]);
|
||||
|
||||
let scaled_condition = self
|
||||
.key
|
||||
.unchecked_scalar_mul(&condition.0, self.message_modulus().0 as u8);
|
||||
|
||||
let map_condition_lut_on_blocks =
|
||||
|blocks: &[Ciphertext]| -> (Vec<Ciphertext>, Vec<Ciphertext>) {
|
||||
let mut left = Vec::with_capacity(blocks.len());
|
||||
let mut right = Vec::with_capacity(blocks.len());
|
||||
blocks
|
||||
.par_iter()
|
||||
.map(|block| {
|
||||
let block = self.key.unchecked_add(block, &scaled_condition);
|
||||
let mut resulting_blocks = self.key.apply_many_lookup_table(&block, &lut);
|
||||
|
||||
let second_result = resulting_blocks.pop().unwrap();
|
||||
let first_result = resulting_blocks.pop().unwrap();
|
||||
|
||||
(first_result, second_result)
|
||||
})
|
||||
.unzip_into_vecs(&mut left, &mut right);
|
||||
(left, right)
|
||||
};
|
||||
|
||||
let (
|
||||
(mut a_blocks_if_cond, mut a_blocks_if_not_cond),
|
||||
(b_blocks_if_cond, b_blocks_if_not_cond),
|
||||
) = rayon::join(
|
||||
|| map_condition_lut_on_blocks(a.blocks()),
|
||||
|| map_condition_lut_on_blocks(b.blocks()),
|
||||
);
|
||||
|
||||
let clean_lut = self
|
||||
.key
|
||||
.generate_lookup_table(|x| x % self.message_modulus().0);
|
||||
|
||||
let inplace_add_then_clean_blocks =
|
||||
|lhs_blocks: &mut [Ciphertext], rhs_blocks: &[Ciphertext]| {
|
||||
lhs_blocks
|
||||
.par_iter_mut()
|
||||
.zip(rhs_blocks.par_iter())
|
||||
.for_each(|(lhs, rhs)| {
|
||||
self.key.unchecked_add_assign(lhs, rhs);
|
||||
self.key.apply_lookup_table_assign(lhs, &clean_lut);
|
||||
});
|
||||
};
|
||||
rayon::join(
|
||||
|| {
|
||||
inplace_add_then_clean_blocks(&mut a_blocks_if_cond, &b_blocks_if_not_cond);
|
||||
},
|
||||
|| {
|
||||
inplace_add_then_clean_blocks(&mut a_blocks_if_not_cond, &b_blocks_if_cond);
|
||||
},
|
||||
);
|
||||
|
||||
(
|
||||
T::from_blocks(a_blocks_if_cond),
|
||||
T::from_blocks(a_blocks_if_not_cond),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
impl<Scalar> ServerKeyDefaultCMux<&RadixCiphertext, Scalar> for ServerKey
|
||||
@@ -163,6 +275,15 @@ where
|
||||
|
||||
self.unchecked_scalar_if_then_else_parallelized(condition, true_ct_ref, false_value)
|
||||
}
|
||||
|
||||
fn flip_parallelized(
|
||||
&self,
|
||||
condition: &BooleanBlock,
|
||||
true_ct: &RadixCiphertext,
|
||||
false_ct: Scalar,
|
||||
) -> (Self::Output, Self::Output) {
|
||||
self.scalar_flip_parallelized(condition, true_ct, false_ct)
|
||||
}
|
||||
}
|
||||
|
||||
impl<Scalar> ServerKeyDefaultCMux<Scalar, &RadixCiphertext> for ServerKey
|
||||
@@ -217,6 +338,16 @@ where
|
||||
let inverted_condition = self.boolean_bitnot(condition);
|
||||
self.if_then_else_parallelized(&inverted_condition, false_ct, true_value)
|
||||
}
|
||||
|
||||
fn flip_parallelized(
|
||||
&self,
|
||||
condition: &BooleanBlock,
|
||||
true_value: Scalar,
|
||||
false_ct: &RadixCiphertext,
|
||||
) -> (Self::Output, Self::Output) {
|
||||
let inverted_condition = self.boolean_bitnot(condition);
|
||||
self.flip_parallelized(&inverted_condition, false_ct, true_value)
|
||||
}
|
||||
}
|
||||
|
||||
impl<Scalar> ServerKeyDefaultCMux<&SignedRadixCiphertext, Scalar> for ServerKey
|
||||
@@ -243,6 +374,15 @@ where
|
||||
|
||||
self.unchecked_scalar_if_then_else_parallelized(condition, true_ct_ref, false_value)
|
||||
}
|
||||
|
||||
fn flip_parallelized(
|
||||
&self,
|
||||
condition: &BooleanBlock,
|
||||
true_ct: &SignedRadixCiphertext,
|
||||
false_ct: Scalar,
|
||||
) -> (Self::Output, Self::Output) {
|
||||
self.scalar_flip_parallelized(condition, true_ct, false_ct)
|
||||
}
|
||||
}
|
||||
|
||||
impl<Scalar> ServerKeyDefaultCMux<Scalar, &SignedRadixCiphertext> for ServerKey
|
||||
@@ -297,6 +437,16 @@ where
|
||||
let inverted_condition = self.boolean_bitnot(condition);
|
||||
self.if_then_else_parallelized(&inverted_condition, false_ct, true_value)
|
||||
}
|
||||
|
||||
fn flip_parallelized(
|
||||
&self,
|
||||
condition: &BooleanBlock,
|
||||
true_value: Scalar,
|
||||
false_ct: &SignedRadixCiphertext,
|
||||
) -> (Self::Output, Self::Output) {
|
||||
let inverted_condition = self.boolean_bitnot(condition);
|
||||
self.flip_parallelized(&inverted_condition, false_ct, true_value)
|
||||
}
|
||||
}
|
||||
|
||||
impl ServerKeyDefaultCMux<&BooleanBlock, &BooleanBlock> for ServerKey {
|
||||
@@ -388,6 +538,67 @@ impl ServerKeyDefaultCMux<&BooleanBlock, &BooleanBlock> for ServerKey {
|
||||
|
||||
BooleanBlock::new_unchecked(lhs)
|
||||
}
|
||||
|
||||
fn flip_parallelized(
|
||||
&self,
|
||||
condition: &BooleanBlock,
|
||||
true_ct: &BooleanBlock,
|
||||
false_ct: &BooleanBlock,
|
||||
) -> (Self::Output, Self::Output) {
|
||||
let flip_if_false_fn = |packed| {
|
||||
let condition = (packed / 2) & 1;
|
||||
let value = packed % 2;
|
||||
value * condition
|
||||
};
|
||||
|
||||
let flip_if_true_fn = |packed| {
|
||||
let condition = (packed / 2) & 1;
|
||||
let value = packed % 2;
|
||||
(1 - condition) * value
|
||||
};
|
||||
|
||||
let lut = self
|
||||
.key
|
||||
.generate_many_lookup_table(&[&flip_if_false_fn, &flip_if_true_fn]);
|
||||
|
||||
let scaled_condition = self.key.unchecked_scalar_mul(&condition.0, 2);
|
||||
|
||||
let (vec_a, vec_b) = rayon::join(
|
||||
|| {
|
||||
let block = self.key.unchecked_add(&true_ct.0, &scaled_condition);
|
||||
self.key.apply_many_lookup_table(&block, &lut)
|
||||
},
|
||||
|| {
|
||||
let block = self.key.unchecked_add(&false_ct.0, &scaled_condition);
|
||||
self.key.apply_many_lookup_table(&block, &lut)
|
||||
},
|
||||
);
|
||||
|
||||
let [mut a_if_cond, mut a_if_not_cond] = vec_a.try_into().unwrap();
|
||||
let [b_if_cond, b_if_not_cond] = vec_b.try_into().unwrap();
|
||||
|
||||
self.key
|
||||
.unchecked_add_assign(&mut a_if_cond, &b_if_not_cond);
|
||||
self.key
|
||||
.unchecked_add_assign(&mut a_if_not_cond, &b_if_cond);
|
||||
|
||||
let clean_lut = self.key.generate_lookup_table(|x| x % 2);
|
||||
rayon::join(
|
||||
|| {
|
||||
self.key
|
||||
.apply_lookup_table_assign(&mut a_if_cond, &clean_lut)
|
||||
},
|
||||
|| {
|
||||
self.key
|
||||
.apply_lookup_table_assign(&mut a_if_not_cond, &clean_lut)
|
||||
},
|
||||
);
|
||||
|
||||
(
|
||||
BooleanBlock::new_unchecked(a_if_cond),
|
||||
BooleanBlock::new_unchecked(a_if_not_cond),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
impl ServerKey {
|
||||
@@ -775,4 +986,82 @@ impl ServerKey {
|
||||
);
|
||||
});
|
||||
}
|
||||
|
||||
fn scalar_flip_parallelized<T, Scalar>(
|
||||
&self,
|
||||
condition: &BooleanBlock,
|
||||
a: &T,
|
||||
b: Scalar,
|
||||
) -> (T, T)
|
||||
where
|
||||
Scalar: DecomposableInto<u64>,
|
||||
T: IntegerRadixCiphertext,
|
||||
{
|
||||
let a = self.clean_for_default_op(a);
|
||||
|
||||
// To make use of many_lut, we require 1 bit, 1 more bit is required to pack
|
||||
// the condition. Thus 2 bits of carry are required.
|
||||
//
|
||||
// Otherwise we call if_then_else twice, which is less efficient.
|
||||
if self.carry_modulus().0 < (1 << 2) {
|
||||
let inverted_condition = self.boolean_bitnot(condition);
|
||||
return rayon::join(
|
||||
|| self.unchecked_scalar_if_then_else_parallelized(&inverted_condition, &*a, b),
|
||||
|| self.unchecked_scalar_if_then_else_parallelized(condition, &*a, b),
|
||||
);
|
||||
}
|
||||
|
||||
let n_blocks = a.blocks().len();
|
||||
|
||||
// One of the input is a clear, so we can embed its decomposed value into the LUTs
|
||||
// and by using many_lut we can compute both results at once.
|
||||
let luts = BlockDecomposer::with_block_count(b, self.message_modulus().0.ilog2(), n_blocks)
|
||||
.iter_as::<u64>()
|
||||
.map(|scalar_block| {
|
||||
self.key.generate_many_lookup_table(&[
|
||||
&|packed| {
|
||||
let condition = (packed / self.message_modulus().0) & 1;
|
||||
let value = packed % self.message_modulus().0;
|
||||
if condition == 1 {
|
||||
scalar_block
|
||||
} else {
|
||||
value
|
||||
}
|
||||
},
|
||||
&|packed| {
|
||||
let condition = (packed / self.message_modulus().0) & 1;
|
||||
let value = packed % self.message_modulus().0;
|
||||
if condition == 1 {
|
||||
value
|
||||
} else {
|
||||
scalar_block
|
||||
}
|
||||
},
|
||||
])
|
||||
})
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
let scaled_condition = self
|
||||
.key
|
||||
.unchecked_scalar_mul(&condition.0, self.message_modulus().0 as u8);
|
||||
|
||||
let mut a_blocks = Vec::with_capacity(n_blocks);
|
||||
let mut b_blocks = Vec::with_capacity(n_blocks);
|
||||
|
||||
a.blocks()
|
||||
.par_iter()
|
||||
.zip(luts.par_iter())
|
||||
.map(|(block, lut)| {
|
||||
let block = self.key.unchecked_add(block, &scaled_condition);
|
||||
let mut results = self.key.apply_many_lookup_table(&block, lut);
|
||||
|
||||
let second = results.pop().unwrap();
|
||||
let first = results.pop().unwrap();
|
||||
|
||||
(first, second)
|
||||
})
|
||||
.unzip_into_vecs(&mut a_blocks, &mut b_blocks);
|
||||
|
||||
(T::from_blocks(a_blocks), T::from_blocks(b_blocks))
|
||||
}
|
||||
}
|
||||
|
||||
@@ -38,6 +38,8 @@ pub(crate) mod tests_unsigned;
|
||||
mod vector_comparisons;
|
||||
mod vector_find;
|
||||
|
||||
use std::borrow::Cow;
|
||||
|
||||
use super::ServerKey;
|
||||
use crate::integer::ciphertext::IntegerRadixCiphertext;
|
||||
use crate::integer::RadixCiphertext;
|
||||
@@ -298,4 +300,25 @@ impl ServerKey {
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
/// Cleans the input ct so that it is ready to be used in a default ops
|
||||
///
|
||||
/// Returns a Cow::Owned if a clone was necessary for the cleaning,
|
||||
/// Cow::Borrowed otherwise
|
||||
pub(crate) fn clean_for_default_op<'a, T>(&self, ct: &'a T) -> Cow<'a, T>
|
||||
where
|
||||
T: IntegerRadixCiphertext,
|
||||
{
|
||||
if ct
|
||||
.blocks()
|
||||
.iter()
|
||||
.any(|block| !block.carry_is_empty() || block.noise_level() != NoiseLevel::NOMINAL)
|
||||
{
|
||||
let mut cloned = ct.clone();
|
||||
self.full_propagate_parallelized(&mut cloned);
|
||||
Cow::Owned(cloned)
|
||||
} else {
|
||||
Cow::Borrowed(ct)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -19,6 +19,8 @@ use std::sync::Arc;
|
||||
create_parameterized_test!(integer_signed_unchecked_if_then_else);
|
||||
create_parameterized_test!(integer_signed_default_if_then_else);
|
||||
create_parameterized_test!(integer_signed_default_scalar_if_then_else);
|
||||
create_parameterized_test!(integer_signed_default_flip);
|
||||
create_parameterized_test!(integer_signed_default_left_scalar_flip);
|
||||
|
||||
fn integer_signed_unchecked_if_then_else<P>(param: P)
|
||||
where
|
||||
@@ -52,6 +54,30 @@ where
|
||||
signed_default_scalar_if_then_else_test(param, executor);
|
||||
}
|
||||
|
||||
fn integer_signed_default_flip<P>(param: P)
|
||||
where
|
||||
P: Into<TestParameters>,
|
||||
{
|
||||
let func =
|
||||
|sks: &ServerKey,
|
||||
cond: &BooleanBlock,
|
||||
lhs: &SignedRadixCiphertext,
|
||||
rhs: &SignedRadixCiphertext| { sks.flip_parallelized(cond, lhs, rhs) };
|
||||
let executor = CpuFunctionExecutor::new(&func);
|
||||
signed_default_flip_test(param, executor);
|
||||
}
|
||||
|
||||
fn integer_signed_default_left_scalar_flip<P>(param: P)
|
||||
where
|
||||
P: Into<TestParameters>,
|
||||
{
|
||||
let func = |sks: &ServerKey, cond: &BooleanBlock, lhs: i64, rhs: &SignedRadixCiphertext| {
|
||||
sks.flip_parallelized(cond, lhs, rhs)
|
||||
};
|
||||
let executor = CpuFunctionExecutor::new(&func);
|
||||
signed_default_left_scalar_flip_test(param, executor);
|
||||
}
|
||||
|
||||
pub(crate) fn signed_default_if_then_else_test<P, T>(param: P, mut executor: T)
|
||||
where
|
||||
P: Into<TestParameters>,
|
||||
@@ -246,3 +272,165 @@ where
|
||||
assert_eq!(dec_res, if clear_condition { clear_0 } else { clear_1 });
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn signed_default_flip_test<P, T>(param: P, mut executor: T)
|
||||
where
|
||||
P: Into<TestParameters>,
|
||||
T: for<'a> FunctionExecutor<
|
||||
(
|
||||
&'a BooleanBlock,
|
||||
&'a SignedRadixCiphertext,
|
||||
&'a SignedRadixCiphertext,
|
||||
),
|
||||
(SignedRadixCiphertext, SignedRadixCiphertext),
|
||||
>,
|
||||
{
|
||||
let param = param.into();
|
||||
let nb_tests = nb_tests_for_params(param);
|
||||
let (cks, mut sks) = KEY_CACHE.get_from_params(param, IntegerKeyKind::Radix);
|
||||
let cks = RadixClientKey::from((cks, NB_CTXT));
|
||||
|
||||
sks.set_deterministic_pbs_execution(true);
|
||||
let sks = Arc::new(sks);
|
||||
|
||||
let mut rng = rand::thread_rng();
|
||||
|
||||
// message_modulus^vec_length
|
||||
let modulus = cks.parameters().message_modulus().0.pow(NB_CTXT as u32) as i64 / 2;
|
||||
|
||||
executor.setup(&cks, sks.clone());
|
||||
|
||||
fn clear_flip(clear_condition: bool, clear_0: i64, clear_1: i64) -> (i64, i64) {
|
||||
if clear_condition {
|
||||
(clear_1, clear_0)
|
||||
} else {
|
||||
(clear_0, clear_1)
|
||||
}
|
||||
}
|
||||
|
||||
for _ in 0..nb_tests {
|
||||
let clear_0 = rng.gen::<i64>() % modulus;
|
||||
let clear_1 = rng.gen::<i64>() % modulus;
|
||||
let clear_condition = rng.gen_bool(0.5);
|
||||
|
||||
let mut ctxt_0 = cks.encrypt_signed(clear_0);
|
||||
let mut ctxt_1 = cks.encrypt_signed(clear_1);
|
||||
let ctxt_condition = cks.encrypt_bool(clear_condition);
|
||||
|
||||
let (a, b) = executor.execute((&ctxt_condition, &ctxt_0, &ctxt_1));
|
||||
assert!(a.block_carries_are_empty());
|
||||
assert!(b.block_carries_are_empty());
|
||||
|
||||
let dec_a: i64 = cks.decrypt_signed(&a);
|
||||
let dec_b: i64 = cks.decrypt_signed(&b);
|
||||
let expected = clear_flip(clear_condition, clear_0, clear_1);
|
||||
assert_eq!(
|
||||
(dec_a, dec_b),
|
||||
expected,
|
||||
"Invalid result for flip({clear_condition}, {clear_0}, {clear_1})\n\
|
||||
Expected {expected:?} got ({dec_a}, {dec_b})",
|
||||
);
|
||||
|
||||
let (a2, b2) = executor.execute((&ctxt_condition, &ctxt_0, &ctxt_1));
|
||||
assert_eq!(a, a2, "Operation is not deterministic");
|
||||
assert_eq!(b, b2, "Operation is not deterministic");
|
||||
|
||||
let clear_2 = rng.gen::<i64>() % modulus;
|
||||
let clear_3 = rng.gen::<i64>() % modulus;
|
||||
|
||||
let ctxt_2 = cks.encrypt_signed(clear_2);
|
||||
let ctxt_3 = cks.encrypt_signed(clear_3);
|
||||
|
||||
// Add to have non empty carries
|
||||
sks.unchecked_add_assign(&mut ctxt_0, &ctxt_2);
|
||||
sks.unchecked_add_assign(&mut ctxt_1, &ctxt_3);
|
||||
assert!(!ctxt_0.block_carries_are_empty());
|
||||
assert!(!ctxt_1.block_carries_are_empty());
|
||||
let clear_0 = signed_add_under_modulus(clear_0, clear_2, modulus);
|
||||
let clear_1 = signed_add_under_modulus(clear_1, clear_3, modulus);
|
||||
|
||||
let (a, b) = executor.execute((&ctxt_condition, &ctxt_0, &ctxt_1));
|
||||
assert!(a.block_carries_are_empty());
|
||||
assert!(b.block_carries_are_empty());
|
||||
|
||||
let dec_a: i64 = cks.decrypt_signed(&a);
|
||||
let dec_b: i64 = cks.decrypt_signed(&b);
|
||||
let expected = clear_flip(clear_condition, clear_0, clear_1);
|
||||
assert_eq!(
|
||||
(dec_a, dec_b),
|
||||
expected,
|
||||
"Invalid result for flip({clear_condition}, {clear_0}, {clear_1})\n\
|
||||
Expected {expected:?} got ({dec_a}, {dec_b})",
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn signed_default_left_scalar_flip_test<P, T>(param: P, mut executor: T)
|
||||
where
|
||||
P: Into<TestParameters>,
|
||||
T: for<'a> FunctionExecutor<
|
||||
(&'a BooleanBlock, i64, &'a SignedRadixCiphertext),
|
||||
(SignedRadixCiphertext, SignedRadixCiphertext),
|
||||
>,
|
||||
{
|
||||
let param = param.into();
|
||||
let nb_tests = nb_tests_for_params(param);
|
||||
let (cks, mut sks) = KEY_CACHE.get_from_params(param, IntegerKeyKind::Radix);
|
||||
let cks = RadixClientKey::from((cks, NB_CTXT));
|
||||
|
||||
sks.set_deterministic_pbs_execution(true);
|
||||
|
||||
let sks = Arc::new(sks);
|
||||
|
||||
let mut rng = rand::thread_rng();
|
||||
|
||||
// message_modulus^vec_length
|
||||
let modulus = cks.parameters().message_modulus().0.pow(NB_CTXT as u32) as i64 / 2;
|
||||
|
||||
executor.setup(&cks, sks);
|
||||
|
||||
fn clear_flip(clear_condition: bool, clear_0: i64, clear_1: i64) -> (i64, i64) {
|
||||
if clear_condition {
|
||||
(clear_1, clear_0)
|
||||
} else {
|
||||
(clear_0, clear_1)
|
||||
}
|
||||
}
|
||||
|
||||
for _ in 0..nb_tests {
|
||||
let clear_0 = rng.gen::<i64>() % modulus;
|
||||
let clear_1 = rng.gen::<i64>() % modulus;
|
||||
let clear_condition = rng.gen_bool(0.5);
|
||||
|
||||
let ctxt_condition = cks.encrypt_bool(clear_condition);
|
||||
let ctxt_rhs = cks.encrypt_signed(clear_1);
|
||||
|
||||
let (a, b) = executor.execute((&ctxt_condition, clear_0, &ctxt_rhs));
|
||||
assert_eq!(a.blocks.len(), NB_CTXT);
|
||||
assert_eq!(b.blocks.len(), NB_CTXT);
|
||||
assert!(a.block_carries_are_empty());
|
||||
assert!(b.block_carries_are_empty());
|
||||
assert!(a
|
||||
.blocks
|
||||
.iter()
|
||||
.all(|b| b.noise_level() == NoiseLevel::NOMINAL));
|
||||
assert!(b
|
||||
.blocks
|
||||
.iter()
|
||||
.all(|b| b.noise_level() == NoiseLevel::NOMINAL));
|
||||
|
||||
let dec_a: i64 = cks.decrypt_signed(&a);
|
||||
let dec_b: i64 = cks.decrypt_signed(&b);
|
||||
let expected = clear_flip(clear_condition, clear_0, clear_1);
|
||||
assert_eq!(
|
||||
(dec_a, dec_b),
|
||||
expected,
|
||||
"Invalid result for flip({clear_condition}, {clear_0}, {clear_1})\n\
|
||||
Expected {expected:?} got ({dec_a}, {dec_b})",
|
||||
);
|
||||
|
||||
let (a2, b2) = executor.execute((&ctxt_condition, clear_0, &ctxt_rhs));
|
||||
assert_eq!(a, a2, "Operation is not deterministic");
|
||||
assert_eq!(b, b2, "Operation is not deterministic");
|
||||
}
|
||||
}
|
||||
|
||||
@@ -17,6 +17,8 @@ create_parameterized_test!(integer_unchecked_left_scalar_if_then_else);
|
||||
create_parameterized_test!(integer_smart_if_then_else);
|
||||
create_parameterized_test!(integer_default_if_then_else);
|
||||
create_parameterized_test!(integer_default_scalar_if_then_else);
|
||||
create_parameterized_test!(integer_default_flip);
|
||||
create_parameterized_test!(integer_default_left_scalar_flip);
|
||||
|
||||
fn integer_unchecked_left_scalar_if_then_else<P>(param: P)
|
||||
where
|
||||
@@ -65,6 +67,29 @@ where
|
||||
default_scalar_if_then_else_test(param, executor);
|
||||
}
|
||||
|
||||
fn integer_default_flip<P>(param: P)
|
||||
where
|
||||
P: Into<TestParameters>,
|
||||
{
|
||||
let func = |sks: &ServerKey,
|
||||
cond: &BooleanBlock,
|
||||
lhs: &RadixCiphertext,
|
||||
rhs: &RadixCiphertext| { sks.flip_parallelized(cond, lhs, rhs) };
|
||||
let executor = CpuFunctionExecutor::new(&func);
|
||||
default_flip_test(param, executor);
|
||||
}
|
||||
|
||||
fn integer_default_left_scalar_flip<P>(param: P)
|
||||
where
|
||||
P: Into<TestParameters>,
|
||||
{
|
||||
let func = |sks: &ServerKey, cond: &BooleanBlock, lhs: u64, rhs: &RadixCiphertext| {
|
||||
sks.flip_parallelized(cond, lhs, rhs)
|
||||
};
|
||||
let executor = CpuFunctionExecutor::new(&func);
|
||||
default_left_scalar_flip_test(param, executor);
|
||||
}
|
||||
|
||||
pub(crate) fn smart_if_then_else_test<P, T>(param: P, mut executor: T)
|
||||
where
|
||||
P: Into<TestParameters>,
|
||||
@@ -331,3 +356,161 @@ where
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn default_flip_test<P, T>(param: P, mut executor: T)
|
||||
where
|
||||
P: Into<TestParameters>,
|
||||
T: for<'a> FunctionExecutor<
|
||||
(&'a BooleanBlock, &'a RadixCiphertext, &'a RadixCiphertext),
|
||||
(RadixCiphertext, RadixCiphertext),
|
||||
>,
|
||||
{
|
||||
let param = param.into();
|
||||
let nb_tests = nb_tests_for_params(param);
|
||||
let (cks, mut sks) = KEY_CACHE.get_from_params(param, IntegerKeyKind::Radix);
|
||||
let cks = RadixClientKey::from((cks, NB_CTXT));
|
||||
|
||||
sks.set_deterministic_pbs_execution(true);
|
||||
let sks = Arc::new(sks);
|
||||
|
||||
let mut rng = rand::thread_rng();
|
||||
|
||||
// message_modulus^vec_length
|
||||
let modulus = cks.parameters().message_modulus().0.pow(NB_CTXT as u32);
|
||||
|
||||
executor.setup(&cks, sks.clone());
|
||||
|
||||
fn clear_flip(clear_condition: bool, clear_0: u64, clear_1: u64) -> (u64, u64) {
|
||||
if clear_condition {
|
||||
(clear_1, clear_0)
|
||||
} else {
|
||||
(clear_0, clear_1)
|
||||
}
|
||||
}
|
||||
|
||||
for _ in 0..nb_tests {
|
||||
let clear_0 = rng.gen::<u64>() % modulus;
|
||||
let clear_1 = rng.gen::<u64>() % modulus;
|
||||
let clear_condition = rng.gen_bool(0.5);
|
||||
|
||||
let mut ctxt_0 = cks.encrypt(clear_0);
|
||||
let mut ctxt_1 = cks.encrypt(clear_1);
|
||||
let ctxt_condition = cks.encrypt_bool(clear_condition);
|
||||
|
||||
let (a, b) = executor.execute((&ctxt_condition, &ctxt_0, &ctxt_1));
|
||||
assert!(a.block_carries_are_empty());
|
||||
assert!(b.block_carries_are_empty());
|
||||
|
||||
let dec_a: u64 = cks.decrypt(&a);
|
||||
let dec_b: u64 = cks.decrypt(&b);
|
||||
let expected = clear_flip(clear_condition, clear_0, clear_1);
|
||||
assert_eq!(
|
||||
(dec_a, dec_b),
|
||||
expected,
|
||||
"Invalid result for flip({clear_condition}, {clear_0}, {clear_1})\n\
|
||||
Expected {expected:?} got ({dec_a}, {dec_b})",
|
||||
);
|
||||
|
||||
let (a2, b2) = executor.execute((&ctxt_condition, &ctxt_0, &ctxt_1));
|
||||
assert_eq!(a, a2, "Operation is not deterministic");
|
||||
assert_eq!(b, b2, "Operation is not deterministic");
|
||||
|
||||
let clear_2 = rng.gen::<u64>() % modulus;
|
||||
let clear_3 = rng.gen::<u64>() % modulus;
|
||||
|
||||
let ctxt_2 = cks.encrypt(clear_2);
|
||||
let ctxt_3 = cks.encrypt(clear_3);
|
||||
|
||||
// Add to have non empty carries
|
||||
sks.unchecked_add_assign(&mut ctxt_0, &ctxt_2);
|
||||
sks.unchecked_add_assign(&mut ctxt_1, &ctxt_3);
|
||||
assert!(!ctxt_0.block_carries_are_empty());
|
||||
assert!(!ctxt_1.block_carries_are_empty());
|
||||
let clear_0 = (clear_0 + clear_2) % modulus;
|
||||
let clear_1 = (clear_1 + clear_3) % modulus;
|
||||
|
||||
let (a, b) = executor.execute((&ctxt_condition, &ctxt_0, &ctxt_1));
|
||||
assert!(a.block_carries_are_empty());
|
||||
assert!(b.block_carries_are_empty());
|
||||
|
||||
let dec_a: u64 = cks.decrypt(&a);
|
||||
let dec_b: u64 = cks.decrypt(&b);
|
||||
let expected = clear_flip(clear_condition, clear_0, clear_1);
|
||||
assert_eq!(
|
||||
(dec_a, dec_b),
|
||||
expected,
|
||||
"Invalid result for flip({clear_condition}, {clear_0}, {clear_1})\n\
|
||||
Expected {expected:?} got ({dec_a}, {dec_b})",
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn default_left_scalar_flip_test<P, T>(param: P, mut executor: T)
|
||||
where
|
||||
P: Into<TestParameters>,
|
||||
T: for<'a> FunctionExecutor<
|
||||
(&'a BooleanBlock, u64, &'a RadixCiphertext),
|
||||
(RadixCiphertext, RadixCiphertext),
|
||||
>,
|
||||
{
|
||||
let param = param.into();
|
||||
let nb_tests = nb_tests_for_params(param);
|
||||
let (cks, mut sks) = KEY_CACHE.get_from_params(param, IntegerKeyKind::Radix);
|
||||
let cks = RadixClientKey::from((cks, NB_CTXT));
|
||||
|
||||
sks.set_deterministic_pbs_execution(true);
|
||||
|
||||
let sks = Arc::new(sks);
|
||||
|
||||
let mut rng = rand::thread_rng();
|
||||
|
||||
// message_modulus^vec_length
|
||||
let modulus = cks.parameters().message_modulus().0.pow(NB_CTXT as u32);
|
||||
|
||||
executor.setup(&cks, sks);
|
||||
|
||||
fn clear_flip(clear_condition: bool, clear_0: u64, clear_1: u64) -> (u64, u64) {
|
||||
if clear_condition {
|
||||
(clear_1, clear_0)
|
||||
} else {
|
||||
(clear_0, clear_1)
|
||||
}
|
||||
}
|
||||
|
||||
for _ in 0..nb_tests {
|
||||
let clear_0 = rng.gen::<u64>() % modulus;
|
||||
let clear_1 = rng.gen::<u64>() % modulus;
|
||||
let clear_condition = rng.gen_bool(0.5);
|
||||
|
||||
let ctxt_condition = cks.encrypt_bool(clear_condition);
|
||||
let ctxt_rhs = cks.encrypt(clear_1);
|
||||
|
||||
let (a, b) = executor.execute((&ctxt_condition, clear_0, &ctxt_rhs));
|
||||
assert_eq!(a.blocks.len(), NB_CTXT);
|
||||
assert_eq!(b.blocks.len(), NB_CTXT);
|
||||
assert!(a.block_carries_are_empty());
|
||||
assert!(b.block_carries_are_empty());
|
||||
assert!(a
|
||||
.blocks
|
||||
.iter()
|
||||
.all(|b| b.noise_level() == NoiseLevel::NOMINAL));
|
||||
assert!(b
|
||||
.blocks
|
||||
.iter()
|
||||
.all(|b| b.noise_level() == NoiseLevel::NOMINAL));
|
||||
|
||||
let dec_a: u64 = cks.decrypt(&a);
|
||||
let dec_b: u64 = cks.decrypt(&b);
|
||||
let expected = clear_flip(clear_condition, clear_0, clear_1);
|
||||
assert_eq!(
|
||||
(dec_a, dec_b),
|
||||
expected,
|
||||
"Invalid result for flip({clear_condition}, {clear_0}, {clear_1})\n\
|
||||
Expected {expected:?} got ({dec_a}, {dec_b})",
|
||||
);
|
||||
|
||||
let (a2, b2) = executor.execute((&ctxt_condition, clear_0, &ctxt_rhs));
|
||||
assert_eq!(a, a2, "Operation is not deterministic");
|
||||
assert_eq!(b, b2, "Operation is not deterministic");
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user