diff --git a/tfhe/src/integer/client_key/mod.rs b/tfhe/src/integer/client_key/mod.rs index 7effca02f..945fbfd97 100644 --- a/tfhe/src/integer/client_key/mod.rs +++ b/tfhe/src/integer/client_key/mod.rs @@ -11,6 +11,7 @@ use crate::integer::ciphertext::{ CompressedCrtCiphertext, CompressedRadixCiphertext, CrtCiphertext, RadixCiphertext, }; use crate::integer::client_key::utils::i_crt; +use crate::integer::encryption::{encrypt_crt, encrypt_words_radix, ClearText}; use crate::shortint::parameters::MessageModulus; use crate::shortint::{ Ciphertext as ShortintCiphertext, ClientKey as ShortintClientKey, @@ -19,7 +20,6 @@ use crate::shortint::{ use serde::{Deserialize, Serialize}; pub use utils::radix_decomposition; -use crate::integer::U256; pub use crt::CrtClientKey; pub use radix::RadixClientKey; @@ -53,46 +53,6 @@ impl AsRef for ClientKey { } } -pub trait ClearText { - fn as_words(&self) -> &[u64]; - - fn as_words_mut(&mut self) -> &mut [u64]; -} - -impl ClearText for u64 { - fn as_words(&self) -> &[u64] { - std::slice::from_ref(self) - } - - fn as_words_mut(&mut self) -> &mut [u64] { - std::slice::from_mut(self) - } -} - -impl ClearText for u128 { - fn as_words(&self) -> &[u64] { - let u128_slc = std::slice::from_ref(self); - unsafe { std::slice::from_raw_parts(u128_slc.as_ptr() as *const u64, 2) } - } - - fn as_words_mut(&mut self) -> &mut [u64] { - let u128_slc = std::slice::from_mut(self); - unsafe { std::slice::from_raw_parts_mut(u128_slc.as_mut_ptr() as *mut u64, 2) } - } -} - -impl ClearText for U256 { - fn as_words(&self) -> &[u64] { - let u128_slc = self.0.as_slice(); - unsafe { std::slice::from_raw_parts(u128_slc.as_ptr() as *const u64, 4) } - } - - fn as_words_mut(&mut self) -> &mut [u64] { - let u128_slc = self.0.as_mut_slice(); - unsafe { std::slice::from_raw_parts_mut(u128_slc.as_mut_ptr() as *mut u64, 4) } - } -} - impl ClientKey { /// Creates a Client Key. /// @@ -217,40 +177,7 @@ impl ClientKey { F: Fn(&crate::shortint::ClientKey, u64) -> Block, RadixCiphertextType: From>, { - let mask = (self.key.parameters.message_modulus.0 - 1) as u128; - let block_modulus = self.key.parameters.message_modulus.0 as u128; - - let mut blocks = Vec::with_capacity(num_blocks); - let mut message_block_iter = message_words.iter().copied(); - - let mut source = 0u128; // stores the bits of the word to be encrypted in one of the iteration - let mut valid_until_power = 1; // 2^0 = 1, start with nothing valid - let mut current_power = 1; // where the next bits to encrypt starts - for _ in 0..num_blocks { - // Are we going to encrypt bits that are not valid ? - // If so, discard already encrypted bits and fetch bits form the input words - if (current_power * block_modulus) >= valid_until_power { - source /= current_power; - valid_until_power /= current_power; - - source += message_block_iter - .next() - .map(u128::from) - .unwrap_or_default() - * valid_until_power; - - current_power = 1; - valid_until_power <<= 64; - } - - let block_value = (source & (mask * current_power)) / current_power; - let ct = encrypt_block(&self.key, block_value as u64); - blocks.push(ct); - - current_power *= block_modulus; - } - - RadixCiphertextType::from(blocks) + encrypt_words_radix(&self.key, message_words, num_blocks, encrypt_block) } /// Encrypts one block. @@ -544,7 +471,7 @@ impl ClientKey { result % whole_modulus } - pub fn encrypt_crt_impl( + fn encrypt_crt_impl( &self, message: u64, base_vec: Vec, @@ -554,16 +481,6 @@ impl ClientKey { F: Fn(&crate::shortint::ClientKey, u64, MessageModulus) -> Block, CrtCiphertextType: From<(Vec, Vec)>, { - let mut ctxt_vect = Vec::with_capacity(base_vec.len()); - - // Put each decomposition into a new ciphertext - for modulus in base_vec.iter().copied() { - // encryption - let ct = encrypt_block(&self.key, message, MessageModulus(modulus as usize)); - - ctxt_vect.push(ct); - } - - CrtCiphertextType::from((ctxt_vect, base_vec)) + encrypt_crt(&self.key, message, base_vec, encrypt_block) } } diff --git a/tfhe/src/integer/encryption.rs b/tfhe/src/integer/encryption.rs new file mode 100644 index 000000000..9336e8066 --- /dev/null +++ b/tfhe/src/integer/encryption.rs @@ -0,0 +1,134 @@ +use super::U256; +use crate::shortint::parameters::MessageModulus; + +pub trait ClearText { + fn as_words(&self) -> &[u64]; + + fn as_words_mut(&mut self) -> &mut [u64]; +} + +impl ClearText for u64 { + fn as_words(&self) -> &[u64] { + std::slice::from_ref(self) + } + + fn as_words_mut(&mut self) -> &mut [u64] { + std::slice::from_mut(self) + } +} + +impl ClearText for u128 { + fn as_words(&self) -> &[u64] { + let u128_slc = std::slice::from_ref(self); + unsafe { std::slice::from_raw_parts(u128_slc.as_ptr() as *const u64, 2) } + } + + fn as_words_mut(&mut self) -> &mut [u64] { + let u128_slc = std::slice::from_mut(self); + unsafe { std::slice::from_raw_parts_mut(u128_slc.as_mut_ptr() as *mut u64, 2) } + } +} + +impl ClearText for U256 { + fn as_words(&self) -> &[u64] { + let u128_slc = self.0.as_slice(); + unsafe { std::slice::from_raw_parts(u128_slc.as_ptr() as *const u64, 4) } + } + + fn as_words_mut(&mut self) -> &mut [u64] { + let u128_slc = self.0.as_mut_slice(); + unsafe { std::slice::from_raw_parts_mut(u128_slc.as_mut_ptr() as *mut u64, 4) } + } +} + +pub(crate) trait BlockEncryptionKey { + fn parameters(&self) -> &crate::shortint::Parameters; +} + +impl BlockEncryptionKey for crate::shortint::ClientKey { + fn parameters(&self) -> &crate::shortint::Parameters { + &self.parameters + } +} + +impl BlockEncryptionKey for crate::shortint::PublicKey { + fn parameters(&self) -> &crate::shortint::Parameters { + &self.parameters + } +} + +impl BlockEncryptionKey for crate::shortint::CompressedPublicKey { + fn parameters(&self) -> &crate::shortint::Parameters { + &self.parameters + } +} + +pub(crate) fn encrypt_words_radix( + encrypting_key: &BlockKey, + message_words: &[u64], + num_blocks: usize, + encrypt_block: F, +) -> RadixCiphertextType +where + BlockKey: BlockEncryptionKey, + F: Fn(&BlockKey, u64) -> Block, + RadixCiphertextType: From>, +{ + let mask = (encrypting_key.parameters().message_modulus.0 - 1) as u128; + let block_modulus = encrypting_key.parameters().message_modulus.0 as u128; + + let mut blocks = Vec::with_capacity(num_blocks); + let mut message_block_iter = message_words.iter().copied(); + + let mut source = 0u128; // stores the bits of the word to be encrypted in one of the iteration + let mut valid_until_power = 1; // 2^0 = 1, start with nothing valid + let mut current_power = 1; // where the next bits to encrypt starts + for _ in 0..num_blocks { + // Are we going to encrypt bits that are not valid ? + // If so, discard already encrypted bits and fetch bits form the input words + if (current_power * block_modulus) >= valid_until_power { + source /= current_power; + valid_until_power /= current_power; + + source += message_block_iter + .next() + .map(u128::from) + .unwrap_or_default() + * valid_until_power; + + current_power = 1; + valid_until_power <<= 64; + } + + let block_value = (source & (mask * current_power)) / current_power; + let ct = encrypt_block(encrypting_key, block_value as u64); + blocks.push(ct); + + current_power *= block_modulus; + } + + RadixCiphertextType::from(blocks) +} + +pub(crate) fn encrypt_crt( + encrypting_key: &BlockKey, + message: u64, + base_vec: Vec, + encrypt_block: F, +) -> CrtCiphertextType +where + F: Fn(&BlockKey, u64, MessageModulus) -> Block, + CrtCiphertextType: From<(Vec, Vec)>, +{ + let mut ctxt_vect = Vec::with_capacity(base_vec.len()); + + // Put each decomposition into a new ciphertext + for modulus in base_vec.iter().copied() { + // encryption + let ct = encrypt_block(encrypting_key, message, MessageModulus(modulus as usize)); + + ctxt_vect.push(ct); + } + + CrtCiphertextType::from((ctxt_vect, base_vec)) +} diff --git a/tfhe/src/integer/mod.rs b/tfhe/src/integer/mod.rs index eeb64f0a9..618be93c3 100755 --- a/tfhe/src/integer/mod.rs +++ b/tfhe/src/integer/mod.rs @@ -50,18 +50,22 @@ extern crate core; #[cfg(test)] #[macro_use] mod tests; +mod encryption; pub mod ciphertext; pub mod client_key; #[cfg(any(test, feature = "internal-keycache"))] pub mod keycache; pub mod parameters; +pub mod public_key; pub mod server_key; pub mod u256; pub mod wopbs; pub use ciphertext::{CrtCiphertext, IntegerCiphertext, RadixCiphertext}; pub use client_key::{ClientKey, CrtClientKey, RadixClientKey}; +pub use encryption::ClearText; +pub use public_key::{CompressedPublicKey, PublicKey}; pub use server_key::{CheckError, ServerKey}; pub use u256::U256; diff --git a/tfhe/src/integer/public_key/compressed.rs b/tfhe/src/integer/public_key/compressed.rs new file mode 100644 index 000000000..021c61c9f --- /dev/null +++ b/tfhe/src/integer/public_key/compressed.rs @@ -0,0 +1,81 @@ +use crate::integer::ciphertext::{CrtCiphertext, RadixCiphertext}; +use crate::integer::client_key::ClientKey; +use crate::integer::encryption::{encrypt_crt, encrypt_words_radix, ClearText}; +use crate::shortint::parameters::MessageModulus; +use crate::shortint::CompressedPublicKey as ShortintCompressedPublicKey; + +#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)] +pub struct CompressedPublicKey { + key: ShortintCompressedPublicKey, +} + +impl CompressedPublicKey { + pub fn new(client_key: &ClientKey) -> Self { + Self { + key: ShortintCompressedPublicKey::new(&client_key.key), + } + } + pub fn parameters(&self) -> crate::shortint::Parameters { + self.key.parameters + } + + pub fn encrypt_radix(&self, message: T, num_blocks: usize) -> RadixCiphertext { + self.encrypt_words_radix( + message.as_words(), + num_blocks, + crate::shortint::CompressedPublicKey::encrypt, + ) + } + + pub fn encrypt_radix_without_padding( + &self, + message: u64, + num_blocks: usize, + ) -> RadixCiphertext { + self.encrypt_words_radix( + message.as_words(), + num_blocks, + crate::shortint::CompressedPublicKey::encrypt_without_padding, + ) + } + + pub fn encrypt_words_radix( + &self, + message_words: &[u64], + num_blocks: usize, + encrypt_block: F, + ) -> RadixCiphertextType + where + F: Fn(&crate::shortint::CompressedPublicKey, u64) -> Block, + RadixCiphertextType: From>, + { + encrypt_words_radix(&self.key, message_words, num_blocks, encrypt_block) + } + + pub fn encrypt_crt(&self, message: u64, base_vec: Vec) -> CrtCiphertext { + self.encrypt_crt_impl( + message, + base_vec, + crate::shortint::CompressedPublicKey::encrypt_with_message_modulus, + ) + } + + pub fn encrypt_native_crt(&self, message: u64, base_vec: Vec) -> CrtCiphertext { + self.encrypt_crt_impl(message, base_vec, |cks, msg, moduli| { + cks.encrypt_native_crt(msg, moduli.0 as u8) + }) + } + + fn encrypt_crt_impl( + &self, + message: u64, + base_vec: Vec, + encrypt_block: F, + ) -> CrtCiphertextType + where + F: Fn(&crate::shortint::CompressedPublicKey, u64, MessageModulus) -> Block, + CrtCiphertextType: From<(Vec, Vec)>, + { + encrypt_crt(&self.key, message, base_vec, encrypt_block) + } +} diff --git a/tfhe/src/integer/public_key/mod.rs b/tfhe/src/integer/public_key/mod.rs new file mode 100644 index 000000000..53110536c --- /dev/null +++ b/tfhe/src/integer/public_key/mod.rs @@ -0,0 +1,10 @@ +//! Module with the definition of the encryption PublicKey. + +pub mod compressed; +pub mod standard; + +pub use compressed::CompressedPublicKey; +pub use standard::PublicKey; + +#[cfg(test)] +mod tests; diff --git a/tfhe/src/integer/public_key/standard.rs b/tfhe/src/integer/public_key/standard.rs new file mode 100644 index 000000000..637ff7d47 --- /dev/null +++ b/tfhe/src/integer/public_key/standard.rs @@ -0,0 +1,81 @@ +use crate::integer::ciphertext::{CrtCiphertext, RadixCiphertext}; +use crate::integer::client_key::ClientKey; +use crate::integer::encryption::{encrypt_crt, encrypt_words_radix, ClearText}; +use crate::shortint::parameters::MessageModulus; +use crate::shortint::PublicKey as ShortintPublicKey; + +#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)] +pub struct PublicKey { + key: ShortintPublicKey, +} + +impl PublicKey { + pub fn new(client_key: &ClientKey) -> Self { + Self { + key: ShortintPublicKey::new(&client_key.key), + } + } + pub fn parameters(&self) -> crate::shortint::Parameters { + self.key.parameters + } + + pub fn encrypt_radix(&self, message: T, num_blocks: usize) -> RadixCiphertext { + self.encrypt_words_radix( + message.as_words(), + num_blocks, + crate::shortint::PublicKey::encrypt, + ) + } + + pub fn encrypt_radix_without_padding( + &self, + message: u64, + num_blocks: usize, + ) -> RadixCiphertext { + self.encrypt_words_radix( + message.as_words(), + num_blocks, + crate::shortint::PublicKey::encrypt_without_padding, + ) + } + + pub fn encrypt_words_radix( + &self, + message_words: &[u64], + num_blocks: usize, + encrypt_block: F, + ) -> RadixCiphertextType + where + F: Fn(&crate::shortint::PublicKey, u64) -> Block, + RadixCiphertextType: From>, + { + encrypt_words_radix(&self.key, message_words, num_blocks, encrypt_block) + } + + pub fn encrypt_crt(&self, message: u64, base_vec: Vec) -> CrtCiphertext { + self.encrypt_crt_impl( + message, + base_vec, + crate::shortint::PublicKey::encrypt_with_message_modulus, + ) + } + + pub fn encrypt_native_crt(&self, message: u64, base_vec: Vec) -> CrtCiphertext { + self.encrypt_crt_impl(message, base_vec, |cks, msg, moduli| { + cks.encrypt_native_crt(msg, moduli.0 as u8) + }) + } + + fn encrypt_crt_impl( + &self, + message: u64, + base_vec: Vec, + encrypt_block: F, + ) -> CrtCiphertextType + where + F: Fn(&crate::shortint::PublicKey, u64, MessageModulus) -> Block, + CrtCiphertextType: From<(Vec, Vec)>, + { + encrypt_crt(&self.key, message, base_vec, encrypt_block) + } +} diff --git a/tfhe/src/integer/public_key/tests.rs b/tfhe/src/integer/public_key/tests.rs new file mode 100644 index 000000000..d83e9f467 --- /dev/null +++ b/tfhe/src/integer/public_key/tests.rs @@ -0,0 +1,64 @@ +use rand::Rng; + +use crate::integer::{CompressedPublicKey, PublicKey}; +use crate::shortint::parameters::*; +use crate::shortint::Parameters; + +use crate::integer::keycache::KEY_CACHE; + +create_parametrized_test!(radix_encrypt_decrypt_128_bits { + PARAM_MESSAGE_1_CARRY_1, + PARAM_MESSAGE_2_CARRY_2 /* PARAM_MESSAGE_3_CARRY_3, Skipped as the key requires 32GB + * PARAM_MESSAGE_4_CARRY_4, Skipped as the key requires 550GB */ +}); +create_parametrized_test!(radix_encrypt_decrypt_compressed_128_bits { + PARAM_MESSAGE_1_CARRY_1, + PARAM_MESSAGE_2_CARRY_2 /* PARAM_MESSAGE_3_CARRY_3, Skipped as its slow + * PARAM_MESSAGE_4_CARRY_4, Skipped as its slow */ +}); + +/// Test that the public key can encrypt a 128 bit number +/// in radix decomposition, and that the client key can decrypt it +fn radix_encrypt_decrypt_128_bits(param: Parameters) { + let (cks, _) = KEY_CACHE.get_from_params(param); + let public_key = PublicKey::new(&cks); + + // RNG + let mut rng = rand::thread_rng(); + let num_block = (128f64 / (param.message_modulus.0 as f64).log(2.0)).ceil() as usize; + + let clear = rng.gen::(); + + //encryption + let ct = public_key.encrypt_radix(clear, num_block); + + // decryption + let mut dec = 0u128; + cks.decrypt_radix_into(&ct, &mut dec); + + // assert + assert_eq!(clear, dec); +} + +fn radix_encrypt_decrypt_compressed_128_bits(param: Parameters) { + let (cks, _) = KEY_CACHE.get_from_params(param); + println!("Compressed public key gen start"); + let public_key = CompressedPublicKey::new(&cks); + println!("done"); + + // RNG + let mut rng = rand::thread_rng(); + let num_block = (128f64 / (param.message_modulus.0 as f64).log(2.0)).ceil() as usize; + + let clear = rng.gen::(); + + //encryption + let ct = public_key.encrypt_radix(clear, num_block); + + // decryption + let mut dec = 0u128; + cks.decrypt_radix_into(&ct, &mut dec); + + // assert + assert_eq!(clear, dec); +}