fix(integer): scalar cmux

Default shortint ops were used, leading
to carries being cleaned internally which
we did not want
This commit is contained in:
tmontaigu
2025-03-26 11:48:25 +01:00
parent 78fc99aa79
commit 1e5fb00715
2 changed files with 66 additions and 3 deletions

View File

@@ -436,7 +436,7 @@ impl ServerKey {
let block = block_condition / 2;
let condition = block_condition % 2;
if condition == 1 {
block
block % self.message_modulus().0
} else {
scalar_block
}
@@ -449,8 +449,9 @@ impl ServerKey {
.par_iter()
.zip(luts.par_iter())
.map(|(block, lut)| {
let mut result_block = self.key.scalar_mul(block, 2);
self.key.add_assign(&mut result_block, &condition.0);
let mut result_block = self.key.unchecked_scalar_mul(block, 2);
self.key
.unchecked_add_assign(&mut result_block, &condition.0);
self.key.apply_lookup_table_assign(&mut result_block, lut);
result_block
})

View File

@@ -13,10 +13,27 @@ use crate::shortint::parameters::*;
use rand::Rng;
use std::sync::Arc;
create_parameterized_test!(integer_unchecked_left_scalar_if_then_else);
create_parameterized_test!(integer_smart_if_then_else);
create_parameterized_test!(integer_default_if_then_else);
create_parameterized_test!(integer_default_scalar_if_then_else);
fn integer_unchecked_left_scalar_if_then_else<P>(param: P)
where
P: Into<PBSParameters>,
{
fn func(
sks: &ServerKey,
cond: &BooleanBlock,
lhs: u64,
rhs: &RadixCiphertext,
) -> RadixCiphertext {
sks.if_then_else_parallelized(cond, lhs, rhs)
}
let executor = CpuFunctionExecutor::new(&func);
unchecked_left_scalar_if_then_else_test(param, executor);
}
fn integer_smart_if_then_else<P>(param: P)
where
P: Into<PBSParameters>,
@@ -269,3 +286,48 @@ where
assert_eq!(ct_res, ct_res2, "Operation is not deterministic");
}
}
pub(crate) fn unchecked_left_scalar_if_then_else_test<P, T>(param: P, mut executor: T)
where
P: Into<PBSParameters>,
T: for<'a> FunctionExecutor<(&'a BooleanBlock, u64, &'a RadixCiphertext), RadixCiphertext>,
{
let param = param.into();
let nb_tests = nb_tests_for_params(param);
let (cks, sks) = KEY_CACHE.get_from_params(param, IntegerKeyKind::Radix);
let cks = RadixClientKey::from((cks, NB_CTXT));
let sks = Arc::new(sks);
let mut rng = rand::thread_rng();
// message_modulus^vec_length
let modulus = cks.parameters().message_modulus().0.pow(NB_CTXT as u32);
executor.setup(&cks, sks);
for _ in 0..nb_tests {
let clear_0 = rng.gen::<u64>() % modulus;
let clear_1 = rng.gen::<u64>() % modulus;
let clear_condition = rng.gen_bool(0.5);
let ctxt_condition = cks.encrypt_bool(clear_condition);
let ctxt_rhs = cks.encrypt(clear_1);
let ct_res = executor.execute((&ctxt_condition, clear_0, &ctxt_rhs));
assert_eq!(ct_res.blocks.len(), NB_CTXT);
assert!(ct_res.block_carries_are_empty());
assert!(ct_res
.blocks
.iter()
.all(|b| b.noise_level() == NoiseLevel::NOMINAL));
let dec_res: u64 = cks.decrypt(&ct_res);
let expected_result = if clear_condition { clear_0 } else { clear_1 };
assert_eq!(
dec_res, expected_result,
"Invalid result for cmux({clear_condition}, {clear_0}, {clear_1})\n\
Expected {expected_result} got {dec_res}"
);
}
}