mirror of
https://github.com/zama-ai/tfhe-rs.git
synced 2026-01-10 07:08:03 -05:00
fix(integer): scalar cmux
Default shortint ops were used, leading to carries being cleaned internally which we did not want
This commit is contained in:
@@ -436,7 +436,7 @@ impl ServerKey {
|
||||
let block = block_condition / 2;
|
||||
let condition = block_condition % 2;
|
||||
if condition == 1 {
|
||||
block
|
||||
block % self.message_modulus().0
|
||||
} else {
|
||||
scalar_block
|
||||
}
|
||||
@@ -449,8 +449,9 @@ impl ServerKey {
|
||||
.par_iter()
|
||||
.zip(luts.par_iter())
|
||||
.map(|(block, lut)| {
|
||||
let mut result_block = self.key.scalar_mul(block, 2);
|
||||
self.key.add_assign(&mut result_block, &condition.0);
|
||||
let mut result_block = self.key.unchecked_scalar_mul(block, 2);
|
||||
self.key
|
||||
.unchecked_add_assign(&mut result_block, &condition.0);
|
||||
self.key.apply_lookup_table_assign(&mut result_block, lut);
|
||||
result_block
|
||||
})
|
||||
|
||||
@@ -13,10 +13,27 @@ use crate::shortint::parameters::*;
|
||||
use rand::Rng;
|
||||
use std::sync::Arc;
|
||||
|
||||
create_parameterized_test!(integer_unchecked_left_scalar_if_then_else);
|
||||
create_parameterized_test!(integer_smart_if_then_else);
|
||||
create_parameterized_test!(integer_default_if_then_else);
|
||||
create_parameterized_test!(integer_default_scalar_if_then_else);
|
||||
|
||||
fn integer_unchecked_left_scalar_if_then_else<P>(param: P)
|
||||
where
|
||||
P: Into<PBSParameters>,
|
||||
{
|
||||
fn func(
|
||||
sks: &ServerKey,
|
||||
cond: &BooleanBlock,
|
||||
lhs: u64,
|
||||
rhs: &RadixCiphertext,
|
||||
) -> RadixCiphertext {
|
||||
sks.if_then_else_parallelized(cond, lhs, rhs)
|
||||
}
|
||||
let executor = CpuFunctionExecutor::new(&func);
|
||||
unchecked_left_scalar_if_then_else_test(param, executor);
|
||||
}
|
||||
|
||||
fn integer_smart_if_then_else<P>(param: P)
|
||||
where
|
||||
P: Into<PBSParameters>,
|
||||
@@ -269,3 +286,48 @@ where
|
||||
assert_eq!(ct_res, ct_res2, "Operation is not deterministic");
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn unchecked_left_scalar_if_then_else_test<P, T>(param: P, mut executor: T)
|
||||
where
|
||||
P: Into<PBSParameters>,
|
||||
T: for<'a> FunctionExecutor<(&'a BooleanBlock, u64, &'a RadixCiphertext), RadixCiphertext>,
|
||||
{
|
||||
let param = param.into();
|
||||
let nb_tests = nb_tests_for_params(param);
|
||||
let (cks, sks) = KEY_CACHE.get_from_params(param, IntegerKeyKind::Radix);
|
||||
let cks = RadixClientKey::from((cks, NB_CTXT));
|
||||
|
||||
let sks = Arc::new(sks);
|
||||
|
||||
let mut rng = rand::thread_rng();
|
||||
|
||||
// message_modulus^vec_length
|
||||
let modulus = cks.parameters().message_modulus().0.pow(NB_CTXT as u32);
|
||||
|
||||
executor.setup(&cks, sks);
|
||||
|
||||
for _ in 0..nb_tests {
|
||||
let clear_0 = rng.gen::<u64>() % modulus;
|
||||
let clear_1 = rng.gen::<u64>() % modulus;
|
||||
let clear_condition = rng.gen_bool(0.5);
|
||||
|
||||
let ctxt_condition = cks.encrypt_bool(clear_condition);
|
||||
let ctxt_rhs = cks.encrypt(clear_1);
|
||||
|
||||
let ct_res = executor.execute((&ctxt_condition, clear_0, &ctxt_rhs));
|
||||
assert_eq!(ct_res.blocks.len(), NB_CTXT);
|
||||
assert!(ct_res.block_carries_are_empty());
|
||||
assert!(ct_res
|
||||
.blocks
|
||||
.iter()
|
||||
.all(|b| b.noise_level() == NoiseLevel::NOMINAL));
|
||||
|
||||
let dec_res: u64 = cks.decrypt(&ct_res);
|
||||
let expected_result = if clear_condition { clear_0 } else { clear_1 };
|
||||
assert_eq!(
|
||||
dec_res, expected_result,
|
||||
"Invalid result for cmux({clear_condition}, {clear_0}, {clear_1})\n\
|
||||
Expected {expected_result} got {dec_res}"
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user