feat(integer): add flip operation

Add the flip(condition: BooleanBlock, a: T, b: T) -> (T, T)
operation that homomorphically flip/swap two values if the
given encrypted boolean encrypts true
This commit is contained in:
tmontaigu
2025-09-08 10:29:28 +02:00
parent 63e5504c80
commit e8dc403ebd
5 changed files with 793 additions and 0 deletions

View File

@@ -677,6 +677,114 @@ fn if_then_else_parallelized(c: &mut Criterion) {
bench_group.finish()
}
fn flip_parallelized(c: &mut Criterion) {
let bench_name = "integer::flip_parallelized";
let display_name = "flip";
let mut bench_group = c.benchmark_group(bench_name);
bench_group
.sample_size(15)
.measurement_time(std::time::Duration::from_secs(60));
let mut rng = rand::thread_rng();
for (param, num_block, bit_size) in ParamsAndNumBlocksIter::default() {
let param_name = param.name();
let bench_id;
match get_bench_type() {
BenchmarkType::Latency => {
let bench_data = LazyCell::new(|| {
let (cks, sks) = KEY_CACHE.get_from_params(param, IntegerKeyKind::Radix);
let clear_0 = gen_random_u256(&mut rng);
let clear_1 = gen_random_u256(&mut rng);
let clear_cond = rng.gen_bool(0.5);
let true_ct = cks.encrypt_radix(clear_0, num_block);
let false_ct = cks.encrypt_radix(clear_1, num_block);
let condition = cks.encrypt_bool(clear_cond);
(sks, condition, true_ct, false_ct)
});
bench_id = format!("{bench_name}::{param_name}::{bit_size}_bits");
bench_group.bench_function(&bench_id, |b| {
let (sks, condition, true_ct, false_ct) =
(&bench_data.0, &bench_data.1, &bench_data.2, &bench_data.3);
b.iter(|| sks.flip_parallelized(condition, true_ct, false_ct))
});
}
BenchmarkType::Throughput => {
let (cks, sks) = KEY_CACHE.get_from_params(param, IntegerKeyKind::Radix);
// Execute the operation once to know its cost.
let clear_0 = gen_random_u256(&mut rng);
let true_ct = cks.encrypt_radix(clear_0, num_block);
let clear_1 = gen_random_u256(&mut rng);
let false_ct = cks.encrypt_radix(clear_1, num_block);
let condition = sks.create_trivial_boolean_block(rng.gen_bool(0.5));
reset_pbs_count();
sks.flip_parallelized(&condition, &true_ct, &false_ct);
let pbs_count = max(get_pbs_count(), 1); // Operation might not perform any PBS, so we take 1 as default
bench_id = format!("{bench_name}::throughput::{param_name}::{bit_size}_bits");
bench_group
.sample_size(10)
.measurement_time(std::time::Duration::from_secs(30));
let elements = throughput_num_threads(num_block, pbs_count);
bench_group.throughput(Throughput::Elements(elements));
bench_group.bench_function(&bench_id, |b| {
let setup_encrypted_values = || {
let cts_cond = (0..elements)
.map(|_| sks.create_trivial_boolean_block(rng.gen_bool(0.5)))
.collect::<Vec<_>>();
let cts_then = (0..elements)
.map(|_| cks.encrypt_radix(gen_random_u256(&mut rng), num_block))
.collect::<Vec<_>>();
let cts_else = (0..elements)
.map(|_| cks.encrypt_radix(gen_random_u256(&mut rng), num_block))
.collect::<Vec<_>>();
(cts_cond, cts_then, cts_else)
};
b.iter_batched(
setup_encrypted_values,
|(cts_cond, cts_then, cts_else)| {
cts_cond
.par_iter()
.zip(cts_then.par_iter())
.zip(cts_else.par_iter())
.for_each(|((condition, true_ct), false_ct)| {
sks.flip_parallelized(condition, true_ct, false_ct);
})
},
criterion::BatchSize::SmallInput,
);
});
}
}
write_to_json::<u64, _>(
&bench_id,
param,
param.name(),
display_name,
&OperatorType::Atomic,
bit_size as u32,
vec![param.message_modulus().0.ilog2(); num_block],
);
}
bench_group.finish()
}
fn ciphertexts_sum_parallelized(c: &mut Criterion) {
let bench_name = "integer::sum_ciphertexts_parallelized";
let display_name = "sum_ctxts";
@@ -3350,6 +3458,7 @@ criterion_group!(
gt_parallelized,
ge_parallelized,
if_then_else_parallelized,
flip_parallelized,
);
criterion_group!(
@@ -3365,6 +3474,7 @@ criterion_group!(
eq_parallelized,
gt_parallelized,
if_then_else_parallelized,
flip_parallelized
);
criterion_group!(

View File

@@ -2,6 +2,7 @@ use crate::integer::block_decomposition::{BlockDecomposer, DecomposableInto};
use crate::integer::ciphertext::boolean_value::BooleanBlock;
use crate::integer::ciphertext::IntegerRadixCiphertext;
use crate::integer::{RadixCiphertext, ServerKey, SignedRadixCiphertext};
use crate::shortint::Ciphertext;
use rayon::prelude::*;
pub trait ServerKeyDefaultCMux<TrueCt, FalseCt> {
@@ -30,6 +31,13 @@ pub trait ServerKeyDefaultCMux<TrueCt, FalseCt> {
) -> Self::Output {
self.if_then_else_parallelized(condition, true_ct, false_ct)
}
fn flip_parallelized(
&self,
condition: &BooleanBlock,
true_ct: TrueCt,
false_ct: FalseCt,
) -> (Self::Output, Self::Output);
}
impl<T> ServerKeyDefaultCMux<&T, &T> for ServerKey
@@ -100,6 +108,110 @@ where
let [true_ct, false_ct] = ct_refs;
self.unchecked_if_then_else_parallelized(condition, true_ct, false_ct)
}
fn flip_parallelized(
&self,
condition: &BooleanBlock,
a: &T,
b: &T,
) -> (Self::Output, Self::Output) {
assert_eq!(
a.blocks().len(),
b.blocks().len(),
"Inputs must have the same number of blocks"
);
// To make use if many_lut, we require 1 bit, 1 more bit is required to pack
// the condition. Thus 2 bits of carry are required.
//
// Otherwise we call if_then_else twice, which is less efficient.
if self.carry_modulus().0 < (1 << 2) {
return rayon::join(
|| self.if_then_else_parallelized(condition, b, a),
|| self.if_then_else_parallelized(condition, a, b),
);
}
let (a, b) = rayon::join(
|| self.clean_for_default_op(a),
|| self.clean_for_default_op(b),
);
let zero_out_if_true_fn = |packed| {
let condition = (packed / self.message_modulus().0) & 1;
let value = packed % self.message_modulus().0;
(1 - condition) * value
};
let zero_out_if_false_fn = |packed| {
let condition = (packed / self.message_modulus().0) & 1;
let value = packed % self.message_modulus().0;
condition * value
};
let lut = self
.key
.generate_many_lookup_table(&[&zero_out_if_true_fn, &zero_out_if_false_fn]);
let scaled_condition = self
.key
.unchecked_scalar_mul(&condition.0, self.message_modulus().0 as u8);
let map_condition_lut_on_blocks =
|blocks: &[Ciphertext]| -> (Vec<Ciphertext>, Vec<Ciphertext>) {
let mut left = Vec::with_capacity(blocks.len());
let mut right = Vec::with_capacity(blocks.len());
blocks
.par_iter()
.map(|block| {
let block = self.key.unchecked_add(block, &scaled_condition);
let mut resulting_blocks = self.key.apply_many_lookup_table(&block, &lut);
let second_result = resulting_blocks.pop().unwrap();
let first_result = resulting_blocks.pop().unwrap();
(first_result, second_result)
})
.unzip_into_vecs(&mut left, &mut right);
(left, right)
};
let (
(mut a_blocks_if_cond, mut a_blocks_if_not_cond),
(b_blocks_if_cond, b_blocks_if_not_cond),
) = rayon::join(
|| map_condition_lut_on_blocks(a.blocks()),
|| map_condition_lut_on_blocks(b.blocks()),
);
let clean_lut = self
.key
.generate_lookup_table(|x| x % self.message_modulus().0);
let inplace_add_then_clean_blocks =
|lhs_blocks: &mut [Ciphertext], rhs_blocks: &[Ciphertext]| {
lhs_blocks
.par_iter_mut()
.zip(rhs_blocks.par_iter())
.for_each(|(lhs, rhs)| {
self.key.unchecked_add_assign(lhs, rhs);
self.key.apply_lookup_table_assign(lhs, &clean_lut);
});
};
rayon::join(
|| {
inplace_add_then_clean_blocks(&mut a_blocks_if_cond, &b_blocks_if_not_cond);
},
|| {
inplace_add_then_clean_blocks(&mut a_blocks_if_not_cond, &b_blocks_if_cond);
},
);
(
T::from_blocks(a_blocks_if_cond),
T::from_blocks(a_blocks_if_not_cond),
)
}
}
impl<Scalar> ServerKeyDefaultCMux<&RadixCiphertext, Scalar> for ServerKey
@@ -163,6 +275,15 @@ where
self.unchecked_scalar_if_then_else_parallelized(condition, true_ct_ref, false_value)
}
fn flip_parallelized(
&self,
condition: &BooleanBlock,
true_ct: &RadixCiphertext,
false_ct: Scalar,
) -> (Self::Output, Self::Output) {
self.scalar_flip_parallelized(condition, true_ct, false_ct)
}
}
impl<Scalar> ServerKeyDefaultCMux<Scalar, &RadixCiphertext> for ServerKey
@@ -217,6 +338,16 @@ where
let inverted_condition = self.boolean_bitnot(condition);
self.if_then_else_parallelized(&inverted_condition, false_ct, true_value)
}
fn flip_parallelized(
&self,
condition: &BooleanBlock,
true_value: Scalar,
false_ct: &RadixCiphertext,
) -> (Self::Output, Self::Output) {
let inverted_condition = self.boolean_bitnot(condition);
self.flip_parallelized(&inverted_condition, false_ct, true_value)
}
}
impl<Scalar> ServerKeyDefaultCMux<&SignedRadixCiphertext, Scalar> for ServerKey
@@ -243,6 +374,15 @@ where
self.unchecked_scalar_if_then_else_parallelized(condition, true_ct_ref, false_value)
}
fn flip_parallelized(
&self,
condition: &BooleanBlock,
true_ct: &SignedRadixCiphertext,
false_ct: Scalar,
) -> (Self::Output, Self::Output) {
self.scalar_flip_parallelized(condition, true_ct, false_ct)
}
}
impl<Scalar> ServerKeyDefaultCMux<Scalar, &SignedRadixCiphertext> for ServerKey
@@ -297,6 +437,16 @@ where
let inverted_condition = self.boolean_bitnot(condition);
self.if_then_else_parallelized(&inverted_condition, false_ct, true_value)
}
fn flip_parallelized(
&self,
condition: &BooleanBlock,
true_value: Scalar,
false_ct: &SignedRadixCiphertext,
) -> (Self::Output, Self::Output) {
let inverted_condition = self.boolean_bitnot(condition);
self.flip_parallelized(&inverted_condition, false_ct, true_value)
}
}
impl ServerKeyDefaultCMux<&BooleanBlock, &BooleanBlock> for ServerKey {
@@ -388,6 +538,67 @@ impl ServerKeyDefaultCMux<&BooleanBlock, &BooleanBlock> for ServerKey {
BooleanBlock::new_unchecked(lhs)
}
fn flip_parallelized(
&self,
condition: &BooleanBlock,
true_ct: &BooleanBlock,
false_ct: &BooleanBlock,
) -> (Self::Output, Self::Output) {
let flip_if_false_fn = |packed| {
let condition = (packed / 2) & 1;
let value = packed % 2;
value * condition
};
let flip_if_true_fn = |packed| {
let condition = (packed / 2) & 1;
let value = packed % 2;
(1 - condition) * value
};
let lut = self
.key
.generate_many_lookup_table(&[&flip_if_false_fn, &flip_if_true_fn]);
let scaled_condition = self.key.unchecked_scalar_mul(&condition.0, 2);
let (vec_a, vec_b) = rayon::join(
|| {
let block = self.key.unchecked_add(&true_ct.0, &scaled_condition);
self.key.apply_many_lookup_table(&block, &lut)
},
|| {
let block = self.key.unchecked_add(&false_ct.0, &scaled_condition);
self.key.apply_many_lookup_table(&block, &lut)
},
);
let [mut a_if_cond, mut a_if_not_cond] = vec_a.try_into().unwrap();
let [b_if_cond, b_if_not_cond] = vec_b.try_into().unwrap();
self.key
.unchecked_add_assign(&mut a_if_cond, &b_if_not_cond);
self.key
.unchecked_add_assign(&mut a_if_not_cond, &b_if_cond);
let clean_lut = self.key.generate_lookup_table(|x| x % 2);
rayon::join(
|| {
self.key
.apply_lookup_table_assign(&mut a_if_cond, &clean_lut)
},
|| {
self.key
.apply_lookup_table_assign(&mut a_if_not_cond, &clean_lut)
},
);
(
BooleanBlock::new_unchecked(a_if_cond),
BooleanBlock::new_unchecked(a_if_not_cond),
)
}
}
impl ServerKey {
@@ -775,4 +986,82 @@ impl ServerKey {
);
});
}
fn scalar_flip_parallelized<T, Scalar>(
&self,
condition: &BooleanBlock,
a: &T,
b: Scalar,
) -> (T, T)
where
Scalar: DecomposableInto<u64>,
T: IntegerRadixCiphertext,
{
let a = self.clean_for_default_op(a);
// To make use of many_lut, we require 1 bit, 1 more bit is required to pack
// the condition. Thus 2 bits of carry are required.
//
// Otherwise we call if_then_else twice, which is less efficient.
if self.carry_modulus().0 < (1 << 2) {
let inverted_condition = self.boolean_bitnot(condition);
return rayon::join(
|| self.unchecked_scalar_if_then_else_parallelized(&inverted_condition, &*a, b),
|| self.unchecked_scalar_if_then_else_parallelized(condition, &*a, b),
);
}
let n_blocks = a.blocks().len();
// One of the input is a clear, so we can embed its decomposed value into the LUTs
// and by using many_lut we can compute both results at once.
let luts = BlockDecomposer::with_block_count(b, self.message_modulus().0.ilog2(), n_blocks)
.iter_as::<u64>()
.map(|scalar_block| {
self.key.generate_many_lookup_table(&[
&|packed| {
let condition = (packed / self.message_modulus().0) & 1;
let value = packed % self.message_modulus().0;
if condition == 1 {
scalar_block
} else {
value
}
},
&|packed| {
let condition = (packed / self.message_modulus().0) & 1;
let value = packed % self.message_modulus().0;
if condition == 1 {
value
} else {
scalar_block
}
},
])
})
.collect::<Vec<_>>();
let scaled_condition = self
.key
.unchecked_scalar_mul(&condition.0, self.message_modulus().0 as u8);
let mut a_blocks = Vec::with_capacity(n_blocks);
let mut b_blocks = Vec::with_capacity(n_blocks);
a.blocks()
.par_iter()
.zip(luts.par_iter())
.map(|(block, lut)| {
let block = self.key.unchecked_add(block, &scaled_condition);
let mut results = self.key.apply_many_lookup_table(&block, lut);
let second = results.pop().unwrap();
let first = results.pop().unwrap();
(first, second)
})
.unzip_into_vecs(&mut a_blocks, &mut b_blocks);
(T::from_blocks(a_blocks), T::from_blocks(b_blocks))
}
}

View File

@@ -38,6 +38,8 @@ pub(crate) mod tests_unsigned;
mod vector_comparisons;
mod vector_find;
use std::borrow::Cow;
use super::ServerKey;
use crate::integer::ciphertext::IntegerRadixCiphertext;
use crate::integer::RadixCiphertext;
@@ -298,4 +300,25 @@ impl ServerKey {
}
})
}
/// Cleans the input ct so that it is ready to be used in a default ops
///
/// Returns a Cow::Owned if a clone was necessary for the cleaning,
/// Cow::Borrowed otherwise
pub(crate) fn clean_for_default_op<'a, T>(&self, ct: &'a T) -> Cow<'a, T>
where
T: IntegerRadixCiphertext,
{
if ct
.blocks()
.iter()
.any(|block| !block.carry_is_empty() || block.noise_level() != NoiseLevel::NOMINAL)
{
let mut cloned = ct.clone();
self.full_propagate_parallelized(&mut cloned);
Cow::Owned(cloned)
} else {
Cow::Borrowed(ct)
}
}
}

View File

@@ -19,6 +19,8 @@ use std::sync::Arc;
create_parameterized_test!(integer_signed_unchecked_if_then_else);
create_parameterized_test!(integer_signed_default_if_then_else);
create_parameterized_test!(integer_signed_default_scalar_if_then_else);
create_parameterized_test!(integer_signed_default_flip);
create_parameterized_test!(integer_signed_default_left_scalar_flip);
fn integer_signed_unchecked_if_then_else<P>(param: P)
where
@@ -52,6 +54,30 @@ where
signed_default_scalar_if_then_else_test(param, executor);
}
fn integer_signed_default_flip<P>(param: P)
where
P: Into<TestParameters>,
{
let func =
|sks: &ServerKey,
cond: &BooleanBlock,
lhs: &SignedRadixCiphertext,
rhs: &SignedRadixCiphertext| { sks.flip_parallelized(cond, lhs, rhs) };
let executor = CpuFunctionExecutor::new(&func);
signed_default_flip_test(param, executor);
}
fn integer_signed_default_left_scalar_flip<P>(param: P)
where
P: Into<TestParameters>,
{
let func = |sks: &ServerKey, cond: &BooleanBlock, lhs: i64, rhs: &SignedRadixCiphertext| {
sks.flip_parallelized(cond, lhs, rhs)
};
let executor = CpuFunctionExecutor::new(&func);
signed_default_left_scalar_flip_test(param, executor);
}
pub(crate) fn signed_default_if_then_else_test<P, T>(param: P, mut executor: T)
where
P: Into<TestParameters>,
@@ -246,3 +272,165 @@ where
assert_eq!(dec_res, if clear_condition { clear_0 } else { clear_1 });
}
}
pub(crate) fn signed_default_flip_test<P, T>(param: P, mut executor: T)
where
P: Into<TestParameters>,
T: for<'a> FunctionExecutor<
(
&'a BooleanBlock,
&'a SignedRadixCiphertext,
&'a SignedRadixCiphertext,
),
(SignedRadixCiphertext, SignedRadixCiphertext),
>,
{
let param = param.into();
let nb_tests = nb_tests_for_params(param);
let (cks, mut sks) = KEY_CACHE.get_from_params(param, IntegerKeyKind::Radix);
let cks = RadixClientKey::from((cks, NB_CTXT));
sks.set_deterministic_pbs_execution(true);
let sks = Arc::new(sks);
let mut rng = rand::thread_rng();
// message_modulus^vec_length
let modulus = cks.parameters().message_modulus().0.pow(NB_CTXT as u32) as i64 / 2;
executor.setup(&cks, sks.clone());
fn clear_flip(clear_condition: bool, clear_0: i64, clear_1: i64) -> (i64, i64) {
if clear_condition {
(clear_1, clear_0)
} else {
(clear_0, clear_1)
}
}
for _ in 0..nb_tests {
let clear_0 = rng.gen::<i64>() % modulus;
let clear_1 = rng.gen::<i64>() % modulus;
let clear_condition = rng.gen_bool(0.5);
let mut ctxt_0 = cks.encrypt_signed(clear_0);
let mut ctxt_1 = cks.encrypt_signed(clear_1);
let ctxt_condition = cks.encrypt_bool(clear_condition);
let (a, b) = executor.execute((&ctxt_condition, &ctxt_0, &ctxt_1));
assert!(a.block_carries_are_empty());
assert!(b.block_carries_are_empty());
let dec_a: i64 = cks.decrypt_signed(&a);
let dec_b: i64 = cks.decrypt_signed(&b);
let expected = clear_flip(clear_condition, clear_0, clear_1);
assert_eq!(
(dec_a, dec_b),
expected,
"Invalid result for flip({clear_condition}, {clear_0}, {clear_1})\n\
Expected {expected:?} got ({dec_a}, {dec_b})",
);
let (a2, b2) = executor.execute((&ctxt_condition, &ctxt_0, &ctxt_1));
assert_eq!(a, a2, "Operation is not deterministic");
assert_eq!(b, b2, "Operation is not deterministic");
let clear_2 = rng.gen::<i64>() % modulus;
let clear_3 = rng.gen::<i64>() % modulus;
let ctxt_2 = cks.encrypt_signed(clear_2);
let ctxt_3 = cks.encrypt_signed(clear_3);
// Add to have non empty carries
sks.unchecked_add_assign(&mut ctxt_0, &ctxt_2);
sks.unchecked_add_assign(&mut ctxt_1, &ctxt_3);
assert!(!ctxt_0.block_carries_are_empty());
assert!(!ctxt_1.block_carries_are_empty());
let clear_0 = signed_add_under_modulus(clear_0, clear_2, modulus);
let clear_1 = signed_add_under_modulus(clear_1, clear_3, modulus);
let (a, b) = executor.execute((&ctxt_condition, &ctxt_0, &ctxt_1));
assert!(a.block_carries_are_empty());
assert!(b.block_carries_are_empty());
let dec_a: i64 = cks.decrypt_signed(&a);
let dec_b: i64 = cks.decrypt_signed(&b);
let expected = clear_flip(clear_condition, clear_0, clear_1);
assert_eq!(
(dec_a, dec_b),
expected,
"Invalid result for flip({clear_condition}, {clear_0}, {clear_1})\n\
Expected {expected:?} got ({dec_a}, {dec_b})",
);
}
}
pub(crate) fn signed_default_left_scalar_flip_test<P, T>(param: P, mut executor: T)
where
P: Into<TestParameters>,
T: for<'a> FunctionExecutor<
(&'a BooleanBlock, i64, &'a SignedRadixCiphertext),
(SignedRadixCiphertext, SignedRadixCiphertext),
>,
{
let param = param.into();
let nb_tests = nb_tests_for_params(param);
let (cks, mut sks) = KEY_CACHE.get_from_params(param, IntegerKeyKind::Radix);
let cks = RadixClientKey::from((cks, NB_CTXT));
sks.set_deterministic_pbs_execution(true);
let sks = Arc::new(sks);
let mut rng = rand::thread_rng();
// message_modulus^vec_length
let modulus = cks.parameters().message_modulus().0.pow(NB_CTXT as u32) as i64 / 2;
executor.setup(&cks, sks);
fn clear_flip(clear_condition: bool, clear_0: i64, clear_1: i64) -> (i64, i64) {
if clear_condition {
(clear_1, clear_0)
} else {
(clear_0, clear_1)
}
}
for _ in 0..nb_tests {
let clear_0 = rng.gen::<i64>() % modulus;
let clear_1 = rng.gen::<i64>() % modulus;
let clear_condition = rng.gen_bool(0.5);
let ctxt_condition = cks.encrypt_bool(clear_condition);
let ctxt_rhs = cks.encrypt_signed(clear_1);
let (a, b) = executor.execute((&ctxt_condition, clear_0, &ctxt_rhs));
assert_eq!(a.blocks.len(), NB_CTXT);
assert_eq!(b.blocks.len(), NB_CTXT);
assert!(a.block_carries_are_empty());
assert!(b.block_carries_are_empty());
assert!(a
.blocks
.iter()
.all(|b| b.noise_level() == NoiseLevel::NOMINAL));
assert!(b
.blocks
.iter()
.all(|b| b.noise_level() == NoiseLevel::NOMINAL));
let dec_a: i64 = cks.decrypt_signed(&a);
let dec_b: i64 = cks.decrypt_signed(&b);
let expected = clear_flip(clear_condition, clear_0, clear_1);
assert_eq!(
(dec_a, dec_b),
expected,
"Invalid result for flip({clear_condition}, {clear_0}, {clear_1})\n\
Expected {expected:?} got ({dec_a}, {dec_b})",
);
let (a2, b2) = executor.execute((&ctxt_condition, clear_0, &ctxt_rhs));
assert_eq!(a, a2, "Operation is not deterministic");
assert_eq!(b, b2, "Operation is not deterministic");
}
}

View File

@@ -17,6 +17,8 @@ create_parameterized_test!(integer_unchecked_left_scalar_if_then_else);
create_parameterized_test!(integer_smart_if_then_else);
create_parameterized_test!(integer_default_if_then_else);
create_parameterized_test!(integer_default_scalar_if_then_else);
create_parameterized_test!(integer_default_flip);
create_parameterized_test!(integer_default_left_scalar_flip);
fn integer_unchecked_left_scalar_if_then_else<P>(param: P)
where
@@ -65,6 +67,29 @@ where
default_scalar_if_then_else_test(param, executor);
}
fn integer_default_flip<P>(param: P)
where
P: Into<TestParameters>,
{
let func = |sks: &ServerKey,
cond: &BooleanBlock,
lhs: &RadixCiphertext,
rhs: &RadixCiphertext| { sks.flip_parallelized(cond, lhs, rhs) };
let executor = CpuFunctionExecutor::new(&func);
default_flip_test(param, executor);
}
fn integer_default_left_scalar_flip<P>(param: P)
where
P: Into<TestParameters>,
{
let func = |sks: &ServerKey, cond: &BooleanBlock, lhs: u64, rhs: &RadixCiphertext| {
sks.flip_parallelized(cond, lhs, rhs)
};
let executor = CpuFunctionExecutor::new(&func);
default_left_scalar_flip_test(param, executor);
}
pub(crate) fn smart_if_then_else_test<P, T>(param: P, mut executor: T)
where
P: Into<TestParameters>,
@@ -331,3 +356,161 @@ where
);
}
}
pub(crate) fn default_flip_test<P, T>(param: P, mut executor: T)
where
P: Into<TestParameters>,
T: for<'a> FunctionExecutor<
(&'a BooleanBlock, &'a RadixCiphertext, &'a RadixCiphertext),
(RadixCiphertext, RadixCiphertext),
>,
{
let param = param.into();
let nb_tests = nb_tests_for_params(param);
let (cks, mut sks) = KEY_CACHE.get_from_params(param, IntegerKeyKind::Radix);
let cks = RadixClientKey::from((cks, NB_CTXT));
sks.set_deterministic_pbs_execution(true);
let sks = Arc::new(sks);
let mut rng = rand::thread_rng();
// message_modulus^vec_length
let modulus = cks.parameters().message_modulus().0.pow(NB_CTXT as u32);
executor.setup(&cks, sks.clone());
fn clear_flip(clear_condition: bool, clear_0: u64, clear_1: u64) -> (u64, u64) {
if clear_condition {
(clear_1, clear_0)
} else {
(clear_0, clear_1)
}
}
for _ in 0..nb_tests {
let clear_0 = rng.gen::<u64>() % modulus;
let clear_1 = rng.gen::<u64>() % modulus;
let clear_condition = rng.gen_bool(0.5);
let mut ctxt_0 = cks.encrypt(clear_0);
let mut ctxt_1 = cks.encrypt(clear_1);
let ctxt_condition = cks.encrypt_bool(clear_condition);
let (a, b) = executor.execute((&ctxt_condition, &ctxt_0, &ctxt_1));
assert!(a.block_carries_are_empty());
assert!(b.block_carries_are_empty());
let dec_a: u64 = cks.decrypt(&a);
let dec_b: u64 = cks.decrypt(&b);
let expected = clear_flip(clear_condition, clear_0, clear_1);
assert_eq!(
(dec_a, dec_b),
expected,
"Invalid result for flip({clear_condition}, {clear_0}, {clear_1})\n\
Expected {expected:?} got ({dec_a}, {dec_b})",
);
let (a2, b2) = executor.execute((&ctxt_condition, &ctxt_0, &ctxt_1));
assert_eq!(a, a2, "Operation is not deterministic");
assert_eq!(b, b2, "Operation is not deterministic");
let clear_2 = rng.gen::<u64>() % modulus;
let clear_3 = rng.gen::<u64>() % modulus;
let ctxt_2 = cks.encrypt(clear_2);
let ctxt_3 = cks.encrypt(clear_3);
// Add to have non empty carries
sks.unchecked_add_assign(&mut ctxt_0, &ctxt_2);
sks.unchecked_add_assign(&mut ctxt_1, &ctxt_3);
assert!(!ctxt_0.block_carries_are_empty());
assert!(!ctxt_1.block_carries_are_empty());
let clear_0 = (clear_0 + clear_2) % modulus;
let clear_1 = (clear_1 + clear_3) % modulus;
let (a, b) = executor.execute((&ctxt_condition, &ctxt_0, &ctxt_1));
assert!(a.block_carries_are_empty());
assert!(b.block_carries_are_empty());
let dec_a: u64 = cks.decrypt(&a);
let dec_b: u64 = cks.decrypt(&b);
let expected = clear_flip(clear_condition, clear_0, clear_1);
assert_eq!(
(dec_a, dec_b),
expected,
"Invalid result for flip({clear_condition}, {clear_0}, {clear_1})\n\
Expected {expected:?} got ({dec_a}, {dec_b})",
);
}
}
pub(crate) fn default_left_scalar_flip_test<P, T>(param: P, mut executor: T)
where
P: Into<TestParameters>,
T: for<'a> FunctionExecutor<
(&'a BooleanBlock, u64, &'a RadixCiphertext),
(RadixCiphertext, RadixCiphertext),
>,
{
let param = param.into();
let nb_tests = nb_tests_for_params(param);
let (cks, mut sks) = KEY_CACHE.get_from_params(param, IntegerKeyKind::Radix);
let cks = RadixClientKey::from((cks, NB_CTXT));
sks.set_deterministic_pbs_execution(true);
let sks = Arc::new(sks);
let mut rng = rand::thread_rng();
// message_modulus^vec_length
let modulus = cks.parameters().message_modulus().0.pow(NB_CTXT as u32);
executor.setup(&cks, sks);
fn clear_flip(clear_condition: bool, clear_0: u64, clear_1: u64) -> (u64, u64) {
if clear_condition {
(clear_1, clear_0)
} else {
(clear_0, clear_1)
}
}
for _ in 0..nb_tests {
let clear_0 = rng.gen::<u64>() % modulus;
let clear_1 = rng.gen::<u64>() % modulus;
let clear_condition = rng.gen_bool(0.5);
let ctxt_condition = cks.encrypt_bool(clear_condition);
let ctxt_rhs = cks.encrypt(clear_1);
let (a, b) = executor.execute((&ctxt_condition, clear_0, &ctxt_rhs));
assert_eq!(a.blocks.len(), NB_CTXT);
assert_eq!(b.blocks.len(), NB_CTXT);
assert!(a.block_carries_are_empty());
assert!(b.block_carries_are_empty());
assert!(a
.blocks
.iter()
.all(|b| b.noise_level() == NoiseLevel::NOMINAL));
assert!(b
.blocks
.iter()
.all(|b| b.noise_level() == NoiseLevel::NOMINAL));
let dec_a: u64 = cks.decrypt(&a);
let dec_b: u64 = cks.decrypt(&b);
let expected = clear_flip(clear_condition, clear_0, clear_1);
assert_eq!(
(dec_a, dec_b),
expected,
"Invalid result for flip({clear_condition}, {clear_0}, {clear_1})\n\
Expected {expected:?} got ({dec_a}, {dec_b})",
);
let (a2, b2) = executor.execute((&ctxt_condition, clear_0, &ctxt_rhs));
assert_eq!(a, a2, "Operation is not deterministic");
assert_eq!(b, b2, "Operation is not deterministic");
}
}