mirror of
https://github.com/zama-ai/tfhe-rs.git
synced 2026-01-09 14:47:56 -05:00
refactor(shortint): replace clear_carry by message_extract
This commit is contained in:
@@ -149,7 +149,7 @@ impl KreyviumStreamShortint {
|
||||
.unchecked_add_assign(&mut new_c, c5);
|
||||
self.internal_server_key
|
||||
.unchecked_add_assign(&mut new_c, &temp_b);
|
||||
self.internal_server_key.clear_carry_assign(&mut new_c);
|
||||
self.internal_server_key.message_extract_assign(&mut new_c);
|
||||
new_c
|
||||
},
|
||||
|| {
|
||||
|
||||
@@ -113,7 +113,7 @@ impl TriviumStreamShortint {
|
||||
.unchecked_add_assign(&mut new_a, a5);
|
||||
self.internal_server_key
|
||||
.unchecked_add_assign(&mut new_a, &temp_c);
|
||||
self.internal_server_key.clear_carry_assign(&mut new_a);
|
||||
self.internal_server_key.message_extract_assign(&mut new_a);
|
||||
new_a
|
||||
},
|
||||
|| {
|
||||
@@ -122,7 +122,7 @@ impl TriviumStreamShortint {
|
||||
.unchecked_add_assign(&mut new_b, b5);
|
||||
self.internal_server_key
|
||||
.unchecked_add_assign(&mut new_b, &temp_a);
|
||||
self.internal_server_key.clear_carry_assign(&mut new_b);
|
||||
self.internal_server_key.message_extract_assign(&mut new_b);
|
||||
new_b
|
||||
},
|
||||
)
|
||||
@@ -135,7 +135,7 @@ impl TriviumStreamShortint {
|
||||
.unchecked_add_assign(&mut new_c, c5);
|
||||
self.internal_server_key
|
||||
.unchecked_add_assign(&mut new_c, &temp_b);
|
||||
self.internal_server_key.clear_carry_assign(&mut new_c);
|
||||
self.internal_server_key.message_extract_assign(&mut new_c);
|
||||
new_c
|
||||
},
|
||||
|| {
|
||||
|
||||
@@ -217,7 +217,7 @@ impl ServerKey {
|
||||
(true, true) => (ct1, ct2),
|
||||
(true, false) => {
|
||||
tmp_rhs = ct2.clone();
|
||||
self.key.clear_carry_assign(&mut tmp_rhs);
|
||||
self.key.message_extract_assign(&mut tmp_rhs);
|
||||
(ct1, &tmp_rhs)
|
||||
}
|
||||
(false, true) => {
|
||||
@@ -228,7 +228,7 @@ impl ServerKey {
|
||||
tmp_rhs = ct2.clone();
|
||||
rayon::join(
|
||||
|| self.full_propagate_parallelized(ct1),
|
||||
|| self.key.clear_carry_assign(&mut tmp_rhs),
|
||||
|| self.key.message_extract_assign(&mut tmp_rhs),
|
||||
);
|
||||
(ct1, &tmp_rhs)
|
||||
}
|
||||
|
||||
@@ -397,32 +397,6 @@ impl ShortintEngine {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub(crate) fn clear_carry(
|
||||
&mut self,
|
||||
server_key: &ServerKey,
|
||||
ct: &Ciphertext,
|
||||
) -> EngineResult<Ciphertext> {
|
||||
let mut ct_in = ct.clone();
|
||||
self.clear_carry_assign(server_key, &mut ct_in)?;
|
||||
Ok(ct_in)
|
||||
}
|
||||
|
||||
pub(crate) fn clear_carry_assign(
|
||||
&mut self,
|
||||
server_key: &ServerKey,
|
||||
ct: &mut Ciphertext,
|
||||
) -> EngineResult<()> {
|
||||
match server_key.pbs_order {
|
||||
PBSOrder::KeyswitchBootstrap => {
|
||||
self.keyswitch_bootstrap_assign(server_key, ct)?;
|
||||
}
|
||||
PBSOrder::BootstrapKeyswitch => {
|
||||
self.bootstrap_keyswitch_assign(server_key, ct)?;
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub(crate) fn keyswitch_programmable_bootstrap_assign(
|
||||
&mut self,
|
||||
server_key: &ServerKey,
|
||||
|
||||
@@ -118,13 +118,13 @@ impl ServerKey {
|
||||
let tmp_rhs: Ciphertext;
|
||||
|
||||
if !ct_left.carry_is_empty() {
|
||||
self.clear_carry_assign(ct_left);
|
||||
self.message_extract_assign(ct_left);
|
||||
}
|
||||
|
||||
let rhs = if ct_right.carry_is_empty() {
|
||||
ct_right
|
||||
} else {
|
||||
tmp_rhs = self.clear_carry(ct_right);
|
||||
tmp_rhs = self.message_extract(ct_right);
|
||||
&tmp_rhs
|
||||
};
|
||||
|
||||
|
||||
@@ -114,13 +114,13 @@ impl ServerKey {
|
||||
let tmp_rhs: Ciphertext;
|
||||
|
||||
if !ct_left.carry_is_empty() {
|
||||
self.clear_carry_assign(ct_left);
|
||||
self.message_extract_assign(ct_left);
|
||||
}
|
||||
|
||||
let rhs = if ct_right.carry_is_empty() {
|
||||
ct_right
|
||||
} else {
|
||||
tmp_rhs = self.clear_carry(ct_right);
|
||||
tmp_rhs = self.message_extract(ct_right);
|
||||
&tmp_rhs
|
||||
};
|
||||
|
||||
@@ -538,13 +538,13 @@ impl ServerKey {
|
||||
let tmp_rhs: Ciphertext;
|
||||
|
||||
if !ct_left.carry_is_empty() {
|
||||
self.clear_carry_assign(ct_left);
|
||||
self.message_extract_assign(ct_left);
|
||||
}
|
||||
|
||||
let rhs = if ct_right.carry_is_empty() {
|
||||
ct_right
|
||||
} else {
|
||||
tmp_rhs = self.clear_carry(ct_right);
|
||||
tmp_rhs = self.message_extract(ct_right);
|
||||
&tmp_rhs
|
||||
};
|
||||
|
||||
@@ -966,13 +966,13 @@ impl ServerKey {
|
||||
let tmp_rhs: Ciphertext;
|
||||
|
||||
if !ct_left.carry_is_empty() {
|
||||
self.clear_carry_assign(ct_left);
|
||||
self.message_extract_assign(ct_left);
|
||||
}
|
||||
|
||||
let rhs = if ct_right.carry_is_empty() {
|
||||
ct_right
|
||||
} else {
|
||||
tmp_rhs = self.clear_carry(ct_right);
|
||||
tmp_rhs = self.message_extract(ct_right);
|
||||
&tmp_rhs
|
||||
};
|
||||
|
||||
|
||||
@@ -65,14 +65,14 @@ impl ServerKey {
|
||||
let lhs = if ct_left.carry_is_empty() {
|
||||
ct_left
|
||||
} else {
|
||||
tmp_lhs = self.clear_carry(ct_left);
|
||||
tmp_lhs = self.message_extract(ct_left);
|
||||
&tmp_lhs
|
||||
};
|
||||
|
||||
let rhs = if ct_right.carry_is_empty() {
|
||||
ct_right
|
||||
} else {
|
||||
tmp_rhs = self.clear_carry(ct_right);
|
||||
tmp_rhs = self.message_extract(ct_right);
|
||||
&tmp_rhs
|
||||
};
|
||||
|
||||
@@ -279,14 +279,14 @@ impl ServerKey {
|
||||
let lhs = if ct_left.carry_is_empty() {
|
||||
ct_left
|
||||
} else {
|
||||
tmp_lhs = self.clear_carry(ct_left);
|
||||
tmp_lhs = self.message_extract(ct_left);
|
||||
&tmp_lhs
|
||||
};
|
||||
|
||||
let rhs = if ct_right.carry_is_empty() {
|
||||
ct_right
|
||||
} else {
|
||||
tmp_rhs = self.clear_carry(ct_right);
|
||||
tmp_rhs = self.message_extract(ct_right);
|
||||
&tmp_rhs
|
||||
};
|
||||
|
||||
@@ -505,14 +505,14 @@ impl ServerKey {
|
||||
let lhs = if ct_left.carry_is_empty() {
|
||||
ct_left
|
||||
} else {
|
||||
tmp_lhs = self.clear_carry(ct_left);
|
||||
tmp_lhs = self.message_extract(ct_left);
|
||||
&tmp_lhs
|
||||
};
|
||||
|
||||
let rhs = if ct_right.carry_is_empty() {
|
||||
ct_right
|
||||
} else {
|
||||
tmp_rhs = self.clear_carry(ct_right);
|
||||
tmp_rhs = self.message_extract(ct_right);
|
||||
&tmp_rhs
|
||||
};
|
||||
|
||||
@@ -721,14 +721,14 @@ impl ServerKey {
|
||||
let lhs = if ct_left.carry_is_empty() {
|
||||
ct_left
|
||||
} else {
|
||||
tmp_lhs = self.clear_carry(ct_left);
|
||||
tmp_lhs = self.message_extract(ct_left);
|
||||
&tmp_lhs
|
||||
};
|
||||
|
||||
let rhs = if ct_right.carry_is_empty() {
|
||||
ct_right
|
||||
} else {
|
||||
tmp_rhs = self.clear_carry(ct_right);
|
||||
tmp_rhs = self.message_extract(ct_right);
|
||||
&tmp_rhs
|
||||
};
|
||||
|
||||
@@ -945,14 +945,14 @@ impl ServerKey {
|
||||
let lhs = if ct_left.carry_is_empty() {
|
||||
ct_left
|
||||
} else {
|
||||
tmp_lhs = self.clear_carry(ct_left);
|
||||
tmp_lhs = self.message_extract(ct_left);
|
||||
&tmp_lhs
|
||||
};
|
||||
|
||||
let rhs = if ct_right.carry_is_empty() {
|
||||
ct_right
|
||||
} else {
|
||||
tmp_rhs = self.clear_carry(ct_right);
|
||||
tmp_rhs = self.message_extract(ct_right);
|
||||
&tmp_rhs
|
||||
};
|
||||
|
||||
@@ -1159,14 +1159,14 @@ impl ServerKey {
|
||||
let lhs = if ct_left.carry_is_empty() {
|
||||
ct_left
|
||||
} else {
|
||||
tmp_lhs = self.clear_carry(ct_left);
|
||||
tmp_lhs = self.message_extract(ct_left);
|
||||
&tmp_lhs
|
||||
};
|
||||
|
||||
let rhs = if ct_right.carry_is_empty() {
|
||||
ct_right
|
||||
} else {
|
||||
tmp_rhs = self.clear_carry(ct_right);
|
||||
tmp_rhs = self.message_extract(ct_right);
|
||||
&tmp_rhs
|
||||
};
|
||||
|
||||
|
||||
@@ -123,13 +123,13 @@ impl ServerKey {
|
||||
let tmp_rhs: Ciphertext;
|
||||
|
||||
if !ct_left.carry_is_empty() {
|
||||
self.clear_carry_assign(ct_left);
|
||||
self.message_extract_assign(ct_left);
|
||||
}
|
||||
|
||||
let rhs = if ct_right.carry_is_empty() {
|
||||
ct_right
|
||||
} else {
|
||||
tmp_rhs = self.clear_carry(ct_right);
|
||||
tmp_rhs = self.message_extract(ct_right);
|
||||
&tmp_rhs
|
||||
};
|
||||
|
||||
|
||||
@@ -413,60 +413,6 @@ impl ServerKey {
|
||||
})
|
||||
}
|
||||
|
||||
/// Compute a keyswitch and a bootstrap, returning a new ciphertext with empty
|
||||
/// carry bits.
|
||||
///
|
||||
/// # Example
|
||||
///
|
||||
/// ```rust
|
||||
/// use tfhe::shortint::gen_keys;
|
||||
/// use tfhe::shortint::parameters::PARAM_MESSAGE_2_CARRY_2_KS_PBS;
|
||||
///
|
||||
/// // Generate the client key and the server key:
|
||||
/// let (cks, sks) = gen_keys(PARAM_MESSAGE_2_CARRY_2_KS_PBS);
|
||||
///
|
||||
/// let mut ct1 = cks.encrypt(3);
|
||||
/// // | ct1 |
|
||||
/// // | carry | message |
|
||||
/// // |-------|---------|
|
||||
/// // | 0 0 | 1 1 |
|
||||
/// let mut ct2 = cks.encrypt(2);
|
||||
/// // | ct2 |
|
||||
/// // | carry | message |
|
||||
/// // |-------|---------|
|
||||
/// // | 0 0 | 1 0 |
|
||||
///
|
||||
/// let ct_res = sks.smart_add(&mut ct1, &mut ct2);
|
||||
/// // | ct_res |
|
||||
/// // | carry | message |
|
||||
/// // |-------|---------|
|
||||
/// // | 0 1 | 0 1 |
|
||||
///
|
||||
/// // Get the carry
|
||||
/// let ct_carry = sks.carry_extract(&ct_res);
|
||||
/// let carry = cks.decrypt(&ct_carry);
|
||||
/// assert_eq!(carry, 1);
|
||||
///
|
||||
/// let ct_res = sks.clear_carry(&ct_res);
|
||||
///
|
||||
/// let ct_carry = sks.carry_extract(&ct_res);
|
||||
/// let carry = cks.decrypt(&ct_carry);
|
||||
/// assert_eq!(carry, 0);
|
||||
///
|
||||
/// let clear = cks.decrypt(&ct_res);
|
||||
///
|
||||
/// assert_eq!(clear, (3 + 2) % 4);
|
||||
/// ```
|
||||
pub fn clear_carry(&self, ct_in: &Ciphertext) -> Ciphertext {
|
||||
ShortintEngine::with_thread_local_mut(|engine| engine.clear_carry(self, ct_in).unwrap())
|
||||
}
|
||||
|
||||
pub fn clear_carry_assign(&self, ct_in: &mut Ciphertext) {
|
||||
ShortintEngine::with_thread_local_mut(|engine| {
|
||||
engine.clear_carry_assign(self, ct_in).unwrap()
|
||||
})
|
||||
}
|
||||
|
||||
/// Compute a keyswitch and programmable bootstrap.
|
||||
///
|
||||
/// # Example
|
||||
|
||||
@@ -857,13 +857,13 @@ impl ServerKey {
|
||||
let tmp_rhs: Ciphertext;
|
||||
|
||||
if !ct_left.carry_is_empty() {
|
||||
self.clear_carry_assign(ct_left);
|
||||
self.message_extract_assign(ct_left);
|
||||
}
|
||||
|
||||
let rhs = if ct_right.carry_is_empty() {
|
||||
ct_right
|
||||
} else {
|
||||
tmp_rhs = self.clear_carry(ct_right);
|
||||
tmp_rhs = self.message_extract(ct_right);
|
||||
&tmp_rhs
|
||||
};
|
||||
|
||||
@@ -873,7 +873,7 @@ impl ServerKey {
|
||||
.unchecked_mul_lsb_small_carry_modulus_assign(self, ct_left, rhs)
|
||||
.unwrap()
|
||||
});
|
||||
self.clear_carry_assign(ct_left);
|
||||
self.message_extract_assign(ct_left);
|
||||
} else {
|
||||
ShortintEngine::with_thread_local_mut(|engine| {
|
||||
engine.unchecked_mul_lsb_assign(self, ct_left, rhs).unwrap()
|
||||
@@ -996,13 +996,13 @@ impl ServerKey {
|
||||
let tmp_rhs: Ciphertext;
|
||||
|
||||
if !ct_left.carry_is_empty() {
|
||||
self.clear_carry_assign(ct_left);
|
||||
self.message_extract_assign(ct_left);
|
||||
}
|
||||
|
||||
let rhs = if ct_right.carry_is_empty() {
|
||||
ct_right
|
||||
} else {
|
||||
tmp_rhs = self.clear_carry(ct_right);
|
||||
tmp_rhs = self.message_extract(ct_right);
|
||||
&tmp_rhs
|
||||
};
|
||||
|
||||
|
||||
@@ -113,10 +113,10 @@ impl ServerKey {
|
||||
/// ```
|
||||
pub fn neg_assign(&self, ct: &mut Ciphertext) {
|
||||
if !ct.carry_is_empty() {
|
||||
self.clear_carry_assign(ct);
|
||||
self.message_extract_assign(ct);
|
||||
}
|
||||
self.unchecked_neg_assign(ct);
|
||||
self.clear_carry_assign(ct);
|
||||
self.message_extract_assign(ct);
|
||||
}
|
||||
|
||||
/// Homomorphically negates a message without checks.
|
||||
|
||||
@@ -36,7 +36,7 @@ impl ServerKey {
|
||||
|
||||
pub fn scalar_bitand_assign(&self, lhs: &mut Ciphertext, rhs: u8) {
|
||||
if !lhs.carry_is_empty() {
|
||||
self.clear_carry_assign(lhs);
|
||||
self.message_extract_assign(lhs);
|
||||
}
|
||||
|
||||
self.unchecked_scalar_bitand_assign(lhs, rhs);
|
||||
@@ -103,7 +103,7 @@ impl ServerKey {
|
||||
|
||||
pub fn scalar_bitxor_assign(&self, lhs: &mut Ciphertext, rhs: u8) {
|
||||
if !lhs.carry_is_empty() {
|
||||
self.clear_carry_assign(lhs);
|
||||
self.message_extract_assign(lhs);
|
||||
}
|
||||
|
||||
self.unchecked_scalar_bitxor_assign(lhs, rhs);
|
||||
@@ -169,7 +169,7 @@ impl ServerKey {
|
||||
|
||||
pub fn scalar_bitor_assign(&self, lhs: &mut Ciphertext, rhs: u8) {
|
||||
if !lhs.carry_is_empty() {
|
||||
self.clear_carry_assign(lhs);
|
||||
self.message_extract_assign(lhs);
|
||||
}
|
||||
|
||||
self.unchecked_scalar_bitor_assign(lhs, rhs);
|
||||
|
||||
@@ -77,18 +77,18 @@ impl ServerKey {
|
||||
let tmp_rhs: Ciphertext;
|
||||
|
||||
if !ct_left.carry_is_empty() {
|
||||
self.clear_carry_assign(ct_left);
|
||||
self.message_extract_assign(ct_left);
|
||||
}
|
||||
|
||||
let rhs = if ct_right.carry_is_empty() {
|
||||
ct_right
|
||||
} else {
|
||||
tmp_rhs = self.clear_carry(ct_right);
|
||||
tmp_rhs = self.message_extract(ct_right);
|
||||
&tmp_rhs
|
||||
};
|
||||
|
||||
self.unchecked_sub_assign(ct_left, rhs);
|
||||
self.clear_carry_assign(ct_left);
|
||||
self.message_extract_assign(ct_left);
|
||||
}
|
||||
|
||||
/// Homomorphically subtracts ct_right to ct_left.
|
||||
|
||||
@@ -310,7 +310,7 @@ where
|
||||
let ctxt_0 = cks.encrypt(clear_0);
|
||||
|
||||
// keyswitch and bootstrap
|
||||
let ct_res = sks.clear_carry(&ctxt_0);
|
||||
let ct_res = sks.message_extract(&ctxt_0);
|
||||
|
||||
// decryption of ct_res
|
||||
let dec_res = cks.decrypt(&ct_res);
|
||||
|
||||
Reference in New Issue
Block a user