diff --git a/tfhe/src/integer/server_key/mod.rs b/tfhe/src/integer/server_key/mod.rs index 97160259a..3fbf7d287 100644 --- a/tfhe/src/integer/server_key/mod.rs +++ b/tfhe/src/integer/server_key/mod.rs @@ -33,6 +33,15 @@ impl From for crate::shortint::ServerKey { } } +/// Compute the [`MaxDegree`] for an integer server key (compressed or uncompressed). This formula +/// provisions a free carry bit. This allows carry propagation between shortint blocks in a +/// [`RadixCiphertext`](`crate::integer::RadixCiphertext`), as that process requires adding a bit of +/// carry from one shortint block to the next, which would overflow and lead to wrong results if we +/// did not provision that carry bit. +fn integer_server_key_max_degree(parameters: crate::shortint::ShortintParameterSet) -> MaxDegree { + MaxDegree((parameters.message_modulus().0 - 1) * parameters.carry_modulus().0 - 1) +} + impl ServerKey { /// Generates a server key. /// @@ -54,13 +63,11 @@ impl ServerKey { { // It should remain just enough space to add a carry let client_key = cks.as_ref(); - let max = (client_key.key.parameters.message_modulus().0 - 1) - * client_key.key.parameters.carry_modulus().0 - - 1; + let max_degree = integer_server_key_max_degree(client_key.key.parameters); let sks = crate::shortint::server_key::ServerKey::new_with_max_degree( &client_key.key, - MaxDegree(max), + max_degree, ); ServerKey { key: sks } @@ -87,10 +94,9 @@ impl ServerKey { mut key: crate::shortint::server_key::ServerKey, ) -> ServerKey { // It should remain just enough space add a carry - let max = - (cks.key.parameters.message_modulus().0 - 1) * cks.key.parameters.carry_modulus().0 - 1; + let max_degree = integer_server_key_max_degree(cks.key.parameters); - key.max_degree = MaxDegree(max); + key.max_degree = max_degree; ServerKey { key } } @@ -111,7 +117,10 @@ pub struct CompressedServerKey { impl CompressedServerKey { pub fn new(client_key: &ClientKey) -> CompressedServerKey { - let key = crate::shortint::CompressedServerKey::new(&client_key.key); + let max_degree = integer_server_key_max_degree(client_key.key.parameters); + + let key = + crate::shortint::CompressedServerKey::new_with_max_degree(&client_key.key, max_degree); Self { key } } } @@ -122,3 +131,46 @@ impl From for ServerKey { Self { key } } } + +#[cfg(test)] +mod test { + use super::*; + use crate::integer::RadixClientKey; + use crate::shortint::parameters::PARAM_MESSAGE_2_CARRY_2; + + /// https://github.com/zama-ai/tfhe-rs/issues/460 + /// Problem with CompressedServerKey degree being set to shortint MaxDegree not accounting for + /// the necessary carry bits for e.g. Radix carry propagation. + #[test] + fn test_compressed_server_key_max_degree() { + let cks = ClientKey::new(crate::shortint::parameters::PARAM_MESSAGE_2_CARRY_2_KS_PBS); + // msg_mod = 4, carry_mod = 4, (msg_mod - 1) * carry_mod = 12; minus 1 => 11 + let expected_max_degree = MaxDegree(11); + + let sks = ServerKey::new(&cks); + assert_eq!(sks.key.max_degree, expected_max_degree); + + let csks = CompressedServerKey::new(&cks); + assert_eq!(csks.key.max_degree, expected_max_degree); + + let decompressed_sks: ServerKey = csks.into(); + assert_eq!(decompressed_sks.key.max_degree, expected_max_degree); + + // Repro case from the user + { + let client_key = RadixClientKey::new(PARAM_MESSAGE_2_CARRY_2, 14); + let compressed_eval_key = CompressedServerKey::new(client_key.as_ref()); + let evaluation_key = ServerKey::from(compressed_eval_key); + let modulus = (client_key.parameters().message_modulus().0 as u128) + .pow(client_key.num_blocks() as u32); + + let mut ct = client_key.encrypt(modulus - 1); + let mut res_ct = ct.clone(); + for _ in 0..5 { + res_ct = evaluation_key.smart_add_parallelized(&mut res_ct, &mut ct); + } + let res = client_key.decrypt::(&res_ct); + assert_eq!(modulus - 6, res); + } + } +} diff --git a/tfhe/src/shortint/server_key/compressed.rs b/tfhe/src/shortint/server_key/compressed.rs index 5709fca3b..097e5df18 100644 --- a/tfhe/src/shortint/server_key/compressed.rs +++ b/tfhe/src/shortint/server_key/compressed.rs @@ -54,4 +54,13 @@ impl CompressedServerKey { engine.new_compressed_server_key(client_key).unwrap() }) } + + /// Generate a compressed server key with a chosen maximum degree + pub fn new_with_max_degree(cks: &ClientKey, max_degree: MaxDegree) -> CompressedServerKey { + ShortintEngine::with_thread_local_mut(|engine| { + engine + .new_compressed_server_key_with_max_degree(cks, max_degree) + .unwrap() + }) + } }