mirror of
https://github.com/zama-ai/tfhe-rs.git
synced 2026-01-09 14:47:56 -05:00
feat(integer): add PublicKey
This commit is contained in:
@@ -11,6 +11,7 @@ use crate::integer::ciphertext::{
|
||||
CompressedCrtCiphertext, CompressedRadixCiphertext, CrtCiphertext, RadixCiphertext,
|
||||
};
|
||||
use crate::integer::client_key::utils::i_crt;
|
||||
use crate::integer::encryption::{encrypt_crt, encrypt_words_radix, ClearText};
|
||||
use crate::shortint::parameters::MessageModulus;
|
||||
use crate::shortint::{
|
||||
Ciphertext as ShortintCiphertext, ClientKey as ShortintClientKey,
|
||||
@@ -19,7 +20,6 @@ use crate::shortint::{
|
||||
use serde::{Deserialize, Serialize};
|
||||
pub use utils::radix_decomposition;
|
||||
|
||||
use crate::integer::U256;
|
||||
pub use crt::CrtClientKey;
|
||||
pub use radix::RadixClientKey;
|
||||
|
||||
@@ -53,46 +53,6 @@ impl AsRef<ClientKey> for ClientKey {
|
||||
}
|
||||
}
|
||||
|
||||
pub trait ClearText {
|
||||
fn as_words(&self) -> &[u64];
|
||||
|
||||
fn as_words_mut(&mut self) -> &mut [u64];
|
||||
}
|
||||
|
||||
impl ClearText for u64 {
|
||||
fn as_words(&self) -> &[u64] {
|
||||
std::slice::from_ref(self)
|
||||
}
|
||||
|
||||
fn as_words_mut(&mut self) -> &mut [u64] {
|
||||
std::slice::from_mut(self)
|
||||
}
|
||||
}
|
||||
|
||||
impl ClearText for u128 {
|
||||
fn as_words(&self) -> &[u64] {
|
||||
let u128_slc = std::slice::from_ref(self);
|
||||
unsafe { std::slice::from_raw_parts(u128_slc.as_ptr() as *const u64, 2) }
|
||||
}
|
||||
|
||||
fn as_words_mut(&mut self) -> &mut [u64] {
|
||||
let u128_slc = std::slice::from_mut(self);
|
||||
unsafe { std::slice::from_raw_parts_mut(u128_slc.as_mut_ptr() as *mut u64, 2) }
|
||||
}
|
||||
}
|
||||
|
||||
impl ClearText for U256 {
|
||||
fn as_words(&self) -> &[u64] {
|
||||
let u128_slc = self.0.as_slice();
|
||||
unsafe { std::slice::from_raw_parts(u128_slc.as_ptr() as *const u64, 4) }
|
||||
}
|
||||
|
||||
fn as_words_mut(&mut self) -> &mut [u64] {
|
||||
let u128_slc = self.0.as_mut_slice();
|
||||
unsafe { std::slice::from_raw_parts_mut(u128_slc.as_mut_ptr() as *mut u64, 4) }
|
||||
}
|
||||
}
|
||||
|
||||
impl ClientKey {
|
||||
/// Creates a Client Key.
|
||||
///
|
||||
@@ -217,40 +177,7 @@ impl ClientKey {
|
||||
F: Fn(&crate::shortint::ClientKey, u64) -> Block,
|
||||
RadixCiphertextType: From<Vec<Block>>,
|
||||
{
|
||||
let mask = (self.key.parameters.message_modulus.0 - 1) as u128;
|
||||
let block_modulus = self.key.parameters.message_modulus.0 as u128;
|
||||
|
||||
let mut blocks = Vec::with_capacity(num_blocks);
|
||||
let mut message_block_iter = message_words.iter().copied();
|
||||
|
||||
let mut source = 0u128; // stores the bits of the word to be encrypted in one of the iteration
|
||||
let mut valid_until_power = 1; // 2^0 = 1, start with nothing valid
|
||||
let mut current_power = 1; // where the next bits to encrypt starts
|
||||
for _ in 0..num_blocks {
|
||||
// Are we going to encrypt bits that are not valid ?
|
||||
// If so, discard already encrypted bits and fetch bits form the input words
|
||||
if (current_power * block_modulus) >= valid_until_power {
|
||||
source /= current_power;
|
||||
valid_until_power /= current_power;
|
||||
|
||||
source += message_block_iter
|
||||
.next()
|
||||
.map(u128::from)
|
||||
.unwrap_or_default()
|
||||
* valid_until_power;
|
||||
|
||||
current_power = 1;
|
||||
valid_until_power <<= 64;
|
||||
}
|
||||
|
||||
let block_value = (source & (mask * current_power)) / current_power;
|
||||
let ct = encrypt_block(&self.key, block_value as u64);
|
||||
blocks.push(ct);
|
||||
|
||||
current_power *= block_modulus;
|
||||
}
|
||||
|
||||
RadixCiphertextType::from(blocks)
|
||||
encrypt_words_radix(&self.key, message_words, num_blocks, encrypt_block)
|
||||
}
|
||||
|
||||
/// Encrypts one block.
|
||||
@@ -544,7 +471,7 @@ impl ClientKey {
|
||||
result % whole_modulus
|
||||
}
|
||||
|
||||
pub fn encrypt_crt_impl<Block, CrtCiphertextType, F>(
|
||||
fn encrypt_crt_impl<Block, CrtCiphertextType, F>(
|
||||
&self,
|
||||
message: u64,
|
||||
base_vec: Vec<u64>,
|
||||
@@ -554,16 +481,6 @@ impl ClientKey {
|
||||
F: Fn(&crate::shortint::ClientKey, u64, MessageModulus) -> Block,
|
||||
CrtCiphertextType: From<(Vec<Block>, Vec<u64>)>,
|
||||
{
|
||||
let mut ctxt_vect = Vec::with_capacity(base_vec.len());
|
||||
|
||||
// Put each decomposition into a new ciphertext
|
||||
for modulus in base_vec.iter().copied() {
|
||||
// encryption
|
||||
let ct = encrypt_block(&self.key, message, MessageModulus(modulus as usize));
|
||||
|
||||
ctxt_vect.push(ct);
|
||||
}
|
||||
|
||||
CrtCiphertextType::from((ctxt_vect, base_vec))
|
||||
encrypt_crt(&self.key, message, base_vec, encrypt_block)
|
||||
}
|
||||
}
|
||||
|
||||
134
tfhe/src/integer/encryption.rs
Normal file
134
tfhe/src/integer/encryption.rs
Normal file
@@ -0,0 +1,134 @@
|
||||
use super::U256;
|
||||
use crate::shortint::parameters::MessageModulus;
|
||||
|
||||
pub trait ClearText {
|
||||
fn as_words(&self) -> &[u64];
|
||||
|
||||
fn as_words_mut(&mut self) -> &mut [u64];
|
||||
}
|
||||
|
||||
impl ClearText for u64 {
|
||||
fn as_words(&self) -> &[u64] {
|
||||
std::slice::from_ref(self)
|
||||
}
|
||||
|
||||
fn as_words_mut(&mut self) -> &mut [u64] {
|
||||
std::slice::from_mut(self)
|
||||
}
|
||||
}
|
||||
|
||||
impl ClearText for u128 {
|
||||
fn as_words(&self) -> &[u64] {
|
||||
let u128_slc = std::slice::from_ref(self);
|
||||
unsafe { std::slice::from_raw_parts(u128_slc.as_ptr() as *const u64, 2) }
|
||||
}
|
||||
|
||||
fn as_words_mut(&mut self) -> &mut [u64] {
|
||||
let u128_slc = std::slice::from_mut(self);
|
||||
unsafe { std::slice::from_raw_parts_mut(u128_slc.as_mut_ptr() as *mut u64, 2) }
|
||||
}
|
||||
}
|
||||
|
||||
impl ClearText for U256 {
|
||||
fn as_words(&self) -> &[u64] {
|
||||
let u128_slc = self.0.as_slice();
|
||||
unsafe { std::slice::from_raw_parts(u128_slc.as_ptr() as *const u64, 4) }
|
||||
}
|
||||
|
||||
fn as_words_mut(&mut self) -> &mut [u64] {
|
||||
let u128_slc = self.0.as_mut_slice();
|
||||
unsafe { std::slice::from_raw_parts_mut(u128_slc.as_mut_ptr() as *mut u64, 4) }
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) trait BlockEncryptionKey {
|
||||
fn parameters(&self) -> &crate::shortint::Parameters;
|
||||
}
|
||||
|
||||
impl BlockEncryptionKey for crate::shortint::ClientKey {
|
||||
fn parameters(&self) -> &crate::shortint::Parameters {
|
||||
&self.parameters
|
||||
}
|
||||
}
|
||||
|
||||
impl BlockEncryptionKey for crate::shortint::PublicKey {
|
||||
fn parameters(&self) -> &crate::shortint::Parameters {
|
||||
&self.parameters
|
||||
}
|
||||
}
|
||||
|
||||
impl BlockEncryptionKey for crate::shortint::CompressedPublicKey {
|
||||
fn parameters(&self) -> &crate::shortint::Parameters {
|
||||
&self.parameters
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn encrypt_words_radix<BlockKey, Block, RadixCiphertextType, F>(
|
||||
encrypting_key: &BlockKey,
|
||||
message_words: &[u64],
|
||||
num_blocks: usize,
|
||||
encrypt_block: F,
|
||||
) -> RadixCiphertextType
|
||||
where
|
||||
BlockKey: BlockEncryptionKey,
|
||||
F: Fn(&BlockKey, u64) -> Block,
|
||||
RadixCiphertextType: From<Vec<Block>>,
|
||||
{
|
||||
let mask = (encrypting_key.parameters().message_modulus.0 - 1) as u128;
|
||||
let block_modulus = encrypting_key.parameters().message_modulus.0 as u128;
|
||||
|
||||
let mut blocks = Vec::with_capacity(num_blocks);
|
||||
let mut message_block_iter = message_words.iter().copied();
|
||||
|
||||
let mut source = 0u128; // stores the bits of the word to be encrypted in one of the iteration
|
||||
let mut valid_until_power = 1; // 2^0 = 1, start with nothing valid
|
||||
let mut current_power = 1; // where the next bits to encrypt starts
|
||||
for _ in 0..num_blocks {
|
||||
// Are we going to encrypt bits that are not valid ?
|
||||
// If so, discard already encrypted bits and fetch bits form the input words
|
||||
if (current_power * block_modulus) >= valid_until_power {
|
||||
source /= current_power;
|
||||
valid_until_power /= current_power;
|
||||
|
||||
source += message_block_iter
|
||||
.next()
|
||||
.map(u128::from)
|
||||
.unwrap_or_default()
|
||||
* valid_until_power;
|
||||
|
||||
current_power = 1;
|
||||
valid_until_power <<= 64;
|
||||
}
|
||||
|
||||
let block_value = (source & (mask * current_power)) / current_power;
|
||||
let ct = encrypt_block(encrypting_key, block_value as u64);
|
||||
blocks.push(ct);
|
||||
|
||||
current_power *= block_modulus;
|
||||
}
|
||||
|
||||
RadixCiphertextType::from(blocks)
|
||||
}
|
||||
|
||||
pub(crate) fn encrypt_crt<BlockKey, Block, CrtCiphertextType, F>(
|
||||
encrypting_key: &BlockKey,
|
||||
message: u64,
|
||||
base_vec: Vec<u64>,
|
||||
encrypt_block: F,
|
||||
) -> CrtCiphertextType
|
||||
where
|
||||
F: Fn(&BlockKey, u64, MessageModulus) -> Block,
|
||||
CrtCiphertextType: From<(Vec<Block>, Vec<u64>)>,
|
||||
{
|
||||
let mut ctxt_vect = Vec::with_capacity(base_vec.len());
|
||||
|
||||
// Put each decomposition into a new ciphertext
|
||||
for modulus in base_vec.iter().copied() {
|
||||
// encryption
|
||||
let ct = encrypt_block(encrypting_key, message, MessageModulus(modulus as usize));
|
||||
|
||||
ctxt_vect.push(ct);
|
||||
}
|
||||
|
||||
CrtCiphertextType::from((ctxt_vect, base_vec))
|
||||
}
|
||||
@@ -50,18 +50,22 @@ extern crate core;
|
||||
#[cfg(test)]
|
||||
#[macro_use]
|
||||
mod tests;
|
||||
mod encryption;
|
||||
|
||||
pub mod ciphertext;
|
||||
pub mod client_key;
|
||||
#[cfg(any(test, feature = "internal-keycache"))]
|
||||
pub mod keycache;
|
||||
pub mod parameters;
|
||||
pub mod public_key;
|
||||
pub mod server_key;
|
||||
pub mod u256;
|
||||
pub mod wopbs;
|
||||
|
||||
pub use ciphertext::{CrtCiphertext, IntegerCiphertext, RadixCiphertext};
|
||||
pub use client_key::{ClientKey, CrtClientKey, RadixClientKey};
|
||||
pub use encryption::ClearText;
|
||||
pub use public_key::{CompressedPublicKey, PublicKey};
|
||||
pub use server_key::{CheckError, ServerKey};
|
||||
pub use u256::U256;
|
||||
|
||||
|
||||
81
tfhe/src/integer/public_key/compressed.rs
Normal file
81
tfhe/src/integer/public_key/compressed.rs
Normal file
@@ -0,0 +1,81 @@
|
||||
use crate::integer::ciphertext::{CrtCiphertext, RadixCiphertext};
|
||||
use crate::integer::client_key::ClientKey;
|
||||
use crate::integer::encryption::{encrypt_crt, encrypt_words_radix, ClearText};
|
||||
use crate::shortint::parameters::MessageModulus;
|
||||
use crate::shortint::CompressedPublicKey as ShortintCompressedPublicKey;
|
||||
|
||||
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
|
||||
pub struct CompressedPublicKey {
|
||||
key: ShortintCompressedPublicKey,
|
||||
}
|
||||
|
||||
impl CompressedPublicKey {
|
||||
pub fn new(client_key: &ClientKey) -> Self {
|
||||
Self {
|
||||
key: ShortintCompressedPublicKey::new(&client_key.key),
|
||||
}
|
||||
}
|
||||
pub fn parameters(&self) -> crate::shortint::Parameters {
|
||||
self.key.parameters
|
||||
}
|
||||
|
||||
pub fn encrypt_radix<T: ClearText>(&self, message: T, num_blocks: usize) -> RadixCiphertext {
|
||||
self.encrypt_words_radix(
|
||||
message.as_words(),
|
||||
num_blocks,
|
||||
crate::shortint::CompressedPublicKey::encrypt,
|
||||
)
|
||||
}
|
||||
|
||||
pub fn encrypt_radix_without_padding(
|
||||
&self,
|
||||
message: u64,
|
||||
num_blocks: usize,
|
||||
) -> RadixCiphertext {
|
||||
self.encrypt_words_radix(
|
||||
message.as_words(),
|
||||
num_blocks,
|
||||
crate::shortint::CompressedPublicKey::encrypt_without_padding,
|
||||
)
|
||||
}
|
||||
|
||||
pub fn encrypt_words_radix<Block, RadixCiphertextType, F>(
|
||||
&self,
|
||||
message_words: &[u64],
|
||||
num_blocks: usize,
|
||||
encrypt_block: F,
|
||||
) -> RadixCiphertextType
|
||||
where
|
||||
F: Fn(&crate::shortint::CompressedPublicKey, u64) -> Block,
|
||||
RadixCiphertextType: From<Vec<Block>>,
|
||||
{
|
||||
encrypt_words_radix(&self.key, message_words, num_blocks, encrypt_block)
|
||||
}
|
||||
|
||||
pub fn encrypt_crt(&self, message: u64, base_vec: Vec<u64>) -> CrtCiphertext {
|
||||
self.encrypt_crt_impl(
|
||||
message,
|
||||
base_vec,
|
||||
crate::shortint::CompressedPublicKey::encrypt_with_message_modulus,
|
||||
)
|
||||
}
|
||||
|
||||
pub fn encrypt_native_crt(&self, message: u64, base_vec: Vec<u64>) -> CrtCiphertext {
|
||||
self.encrypt_crt_impl(message, base_vec, |cks, msg, moduli| {
|
||||
cks.encrypt_native_crt(msg, moduli.0 as u8)
|
||||
})
|
||||
}
|
||||
|
||||
fn encrypt_crt_impl<Block, CrtCiphertextType, F>(
|
||||
&self,
|
||||
message: u64,
|
||||
base_vec: Vec<u64>,
|
||||
encrypt_block: F,
|
||||
) -> CrtCiphertextType
|
||||
where
|
||||
F: Fn(&crate::shortint::CompressedPublicKey, u64, MessageModulus) -> Block,
|
||||
CrtCiphertextType: From<(Vec<Block>, Vec<u64>)>,
|
||||
{
|
||||
encrypt_crt(&self.key, message, base_vec, encrypt_block)
|
||||
}
|
||||
}
|
||||
10
tfhe/src/integer/public_key/mod.rs
Normal file
10
tfhe/src/integer/public_key/mod.rs
Normal file
@@ -0,0 +1,10 @@
|
||||
//! Module with the definition of the encryption PublicKey.
|
||||
|
||||
pub mod compressed;
|
||||
pub mod standard;
|
||||
|
||||
pub use compressed::CompressedPublicKey;
|
||||
pub use standard::PublicKey;
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests;
|
||||
81
tfhe/src/integer/public_key/standard.rs
Normal file
81
tfhe/src/integer/public_key/standard.rs
Normal file
@@ -0,0 +1,81 @@
|
||||
use crate::integer::ciphertext::{CrtCiphertext, RadixCiphertext};
|
||||
use crate::integer::client_key::ClientKey;
|
||||
use crate::integer::encryption::{encrypt_crt, encrypt_words_radix, ClearText};
|
||||
use crate::shortint::parameters::MessageModulus;
|
||||
use crate::shortint::PublicKey as ShortintPublicKey;
|
||||
|
||||
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
|
||||
pub struct PublicKey {
|
||||
key: ShortintPublicKey,
|
||||
}
|
||||
|
||||
impl PublicKey {
|
||||
pub fn new(client_key: &ClientKey) -> Self {
|
||||
Self {
|
||||
key: ShortintPublicKey::new(&client_key.key),
|
||||
}
|
||||
}
|
||||
pub fn parameters(&self) -> crate::shortint::Parameters {
|
||||
self.key.parameters
|
||||
}
|
||||
|
||||
pub fn encrypt_radix<T: ClearText>(&self, message: T, num_blocks: usize) -> RadixCiphertext {
|
||||
self.encrypt_words_radix(
|
||||
message.as_words(),
|
||||
num_blocks,
|
||||
crate::shortint::PublicKey::encrypt,
|
||||
)
|
||||
}
|
||||
|
||||
pub fn encrypt_radix_without_padding(
|
||||
&self,
|
||||
message: u64,
|
||||
num_blocks: usize,
|
||||
) -> RadixCiphertext {
|
||||
self.encrypt_words_radix(
|
||||
message.as_words(),
|
||||
num_blocks,
|
||||
crate::shortint::PublicKey::encrypt_without_padding,
|
||||
)
|
||||
}
|
||||
|
||||
pub fn encrypt_words_radix<Block, RadixCiphertextType, F>(
|
||||
&self,
|
||||
message_words: &[u64],
|
||||
num_blocks: usize,
|
||||
encrypt_block: F,
|
||||
) -> RadixCiphertextType
|
||||
where
|
||||
F: Fn(&crate::shortint::PublicKey, u64) -> Block,
|
||||
RadixCiphertextType: From<Vec<Block>>,
|
||||
{
|
||||
encrypt_words_radix(&self.key, message_words, num_blocks, encrypt_block)
|
||||
}
|
||||
|
||||
pub fn encrypt_crt(&self, message: u64, base_vec: Vec<u64>) -> CrtCiphertext {
|
||||
self.encrypt_crt_impl(
|
||||
message,
|
||||
base_vec,
|
||||
crate::shortint::PublicKey::encrypt_with_message_modulus,
|
||||
)
|
||||
}
|
||||
|
||||
pub fn encrypt_native_crt(&self, message: u64, base_vec: Vec<u64>) -> CrtCiphertext {
|
||||
self.encrypt_crt_impl(message, base_vec, |cks, msg, moduli| {
|
||||
cks.encrypt_native_crt(msg, moduli.0 as u8)
|
||||
})
|
||||
}
|
||||
|
||||
fn encrypt_crt_impl<Block, CrtCiphertextType, F>(
|
||||
&self,
|
||||
message: u64,
|
||||
base_vec: Vec<u64>,
|
||||
encrypt_block: F,
|
||||
) -> CrtCiphertextType
|
||||
where
|
||||
F: Fn(&crate::shortint::PublicKey, u64, MessageModulus) -> Block,
|
||||
CrtCiphertextType: From<(Vec<Block>, Vec<u64>)>,
|
||||
{
|
||||
encrypt_crt(&self.key, message, base_vec, encrypt_block)
|
||||
}
|
||||
}
|
||||
64
tfhe/src/integer/public_key/tests.rs
Normal file
64
tfhe/src/integer/public_key/tests.rs
Normal file
@@ -0,0 +1,64 @@
|
||||
use rand::Rng;
|
||||
|
||||
use crate::integer::{CompressedPublicKey, PublicKey};
|
||||
use crate::shortint::parameters::*;
|
||||
use crate::shortint::Parameters;
|
||||
|
||||
use crate::integer::keycache::KEY_CACHE;
|
||||
|
||||
create_parametrized_test!(radix_encrypt_decrypt_128_bits {
|
||||
PARAM_MESSAGE_1_CARRY_1,
|
||||
PARAM_MESSAGE_2_CARRY_2 /* PARAM_MESSAGE_3_CARRY_3, Skipped as the key requires 32GB
|
||||
* PARAM_MESSAGE_4_CARRY_4, Skipped as the key requires 550GB */
|
||||
});
|
||||
create_parametrized_test!(radix_encrypt_decrypt_compressed_128_bits {
|
||||
PARAM_MESSAGE_1_CARRY_1,
|
||||
PARAM_MESSAGE_2_CARRY_2 /* PARAM_MESSAGE_3_CARRY_3, Skipped as its slow
|
||||
* PARAM_MESSAGE_4_CARRY_4, Skipped as its slow */
|
||||
});
|
||||
|
||||
/// Test that the public key can encrypt a 128 bit number
|
||||
/// in radix decomposition, and that the client key can decrypt it
|
||||
fn radix_encrypt_decrypt_128_bits(param: Parameters) {
|
||||
let (cks, _) = KEY_CACHE.get_from_params(param);
|
||||
let public_key = PublicKey::new(&cks);
|
||||
|
||||
// RNG
|
||||
let mut rng = rand::thread_rng();
|
||||
let num_block = (128f64 / (param.message_modulus.0 as f64).log(2.0)).ceil() as usize;
|
||||
|
||||
let clear = rng.gen::<u128>();
|
||||
|
||||
//encryption
|
||||
let ct = public_key.encrypt_radix(clear, num_block);
|
||||
|
||||
// decryption
|
||||
let mut dec = 0u128;
|
||||
cks.decrypt_radix_into(&ct, &mut dec);
|
||||
|
||||
// assert
|
||||
assert_eq!(clear, dec);
|
||||
}
|
||||
|
||||
fn radix_encrypt_decrypt_compressed_128_bits(param: Parameters) {
|
||||
let (cks, _) = KEY_CACHE.get_from_params(param);
|
||||
println!("Compressed public key gen start");
|
||||
let public_key = CompressedPublicKey::new(&cks);
|
||||
println!("done");
|
||||
|
||||
// RNG
|
||||
let mut rng = rand::thread_rng();
|
||||
let num_block = (128f64 / (param.message_modulus.0 as f64).log(2.0)).ceil() as usize;
|
||||
|
||||
let clear = rng.gen::<u128>();
|
||||
|
||||
//encryption
|
||||
let ct = public_key.encrypt_radix(clear, num_block);
|
||||
|
||||
// decryption
|
||||
let mut dec = 0u128;
|
||||
cks.decrypt_radix_into(&ct, &mut dec);
|
||||
|
||||
// assert
|
||||
assert_eq!(clear, dec);
|
||||
}
|
||||
Reference in New Issue
Block a user