Compare commits

...

11 Commits

Author SHA1 Message Date
Loris
8ab7ba1ba8 new param 2023-03-28 13:10:05 +09:00
Loris
d4864e766a mul natif crt 2023-03-26 17:42:15 +09:00
J-B Orfila
ddde0637c6 WIP 2023-03-26 07:32:23 +02:00
J-B Orfila
fcf92f9df8 WIP 2023-03-26 05:59:05 +02:00
J-B Orfila
05860574a6 WIP 2023-03-26 05:38:33 +02:00
J-B Orfila
8332bd15ae WIP 2023-03-26 05:18:55 +02:00
J-B Orfila
dad57adbf3 WIP 2023-03-25 16:35:46 +01:00
J-B Orfila
cc375eab36 WIP 2023-03-25 08:36:52 +01:00
Loris
954bc4b75a test hybrid 2023-03-16 18:13:31 +01:00
J-B Orfila
5db89e11ad WIP 2023-03-16 11:18:02 +01:00
J-B Orfila
12610ad803 integers: WIP joc 2023-03-10 16:54:45 +01:00
24 changed files with 2182 additions and 40 deletions

View File

@@ -6,10 +6,12 @@ use tfhe::integer::client_key::radix_decomposition;
use tfhe::integer::keycache::KEY_CACHE;
use tfhe::integer::parameters::*;
use tfhe::integer::wopbs::WopbsKey;
use tfhe::integer::{gen_keys, RadixCiphertext, ServerKey};
use tfhe::shortint::keycache::KEY_CACHE_WOPBS;
use tfhe::integer::{gen_keys, IntegerCiphertext, RadixCiphertext, ServerKey};
use tfhe::integer::ciphertext::crt_ciphertext_from_ciphertext;
use tfhe::integer::parameters::parameters_benches_joc::*;
use tfhe::shortint::keycache::{KEY_CACHE_WOPBS, NamedParam};
use tfhe::shortint::parameters::parameters_wopbs_message_carry::get_parameters_from_message_and_carry_wopbs;
use tfhe::shortint::parameters::{get_parameters_from_message_and_carry, DEFAULT_PARAMETERS};
use tfhe::shortint::parameters::{get_parameters_from_message_and_carry, DEFAULT_PARAMETERS, MessageModulus, CarryModulus};
criterion_group!(
to_be_reworked,
@@ -274,13 +276,25 @@ criterion_group!(
criterion_group!(misc, full_propagate,);
criterion_group!(joc,
//joc_radix,
//joc_radix_wopbs,
//joc_crt,
//joc_hybrid_32_bits,
//joc_crt_wopbs,
//joc_native_crt_wopbs,
joc_native_crt_mul_wopbs,
joc_native_crt_add,
);
criterion_main!(
smart_arithmetic_operation,
smart_scalar_arithmetic_operation,
unchecked_arithmetic_operation,
unchecked_scalar_arithmetic_operation,
misc,
to_be_reworked,
// smart_arithmetic_operation,
// smart_scalar_arithmetic_operation,
// unchecked_arithmetic_operation,
// unchecked_scalar_arithmetic_operation,
// misc,
// to_be_reworked,
joc,
);
fn smart_block_mul(c: &mut Criterion) {
@@ -910,3 +924,438 @@ fn concrete_integer_unchecked_clean_carry_crt_32_bits(c: &mut Criterion) {
})
});
}
fn joc_radix(c: &mut Criterion) {
let param_vec = vec![
ID_1_RADIX_16_BITS_16_BLOCKS,
ID_2_RADIX_16_BITS_8_BLOCKS,
ID_4_RADIX_32_BITS_32_BLOCKS,
ID_5_RADIX_32_BITS_16_BLOCKS,
ID_6_RADIX_32_BITS_8_BLOCKS
];
let nb_blocks_vec = [
16,
8,
32,
16,
8
];
for (param, nb_blocks) in param_vec.iter().zip(nb_blocks_vec.iter()) {
let (cks, sks) = KEY_CACHE.get_from_params(*param);
let modulus = param.message_modulus.0.pow(*nb_blocks as u32) as u64;
let group_name = format!("{}", param.name());
let mut group = c.benchmark_group(group_name.clone());
group.sample_size(10);
println!("Chosen Parameter Set: {param:?}");
let clear_0 = 29 % modulus;
// encryption of an integer
let mut ct_zero = cks.encrypt_radix(clear_0, *nb_blocks);
let mut ct_one = cks.encrypt_radix(clear_0, *nb_blocks);
let id = format!("{}_add", group_name.clone());
// add the two ciphertexts
group.bench_function(id, |b| {
b.iter(|| {
sks.unchecked_add(&mut ct_zero, &ct_one);
})
});
let id = format!("{}_mul", group_name.clone());
// add the two ciphertexts
group.bench_function(id, |b| {
b.iter(|| {
sks.unchecked_mul(&mut ct_zero, &ct_one);
})
});
let id = format!("{}_carry_propagate", group_name);
// add the two ciphertexts
group.bench_function(id, |b| {
b.iter(|| {
sks.full_propagate(&mut ct_zero);
})
});
}
}
fn joc_radix_wopbs(c: &mut Criterion) {
let param_vec = vec![
ID_7_RADIX_16_BITS_16_BLOCKS_WOPBS,
ID_8_RADIX_16_BITS_8_BLOCKS_WOPBS
];
let nb_blocks_vec = vec![
16,
8
];
for (param, nb_blocks) in param_vec.iter().zip(nb_blocks_vec.iter()) {
let group_name = format!("{}", param.name());
let mut group = c.benchmark_group(group_name.clone());
group.sample_size(10);
let mut rng = rand::thread_rng();
let (cks, sks) = KEY_CACHE.get_from_params(*param);
let wopbs_key = WopbsKey::new_wopbs_key_only_for_wopbs(&cks, &sks);
let mut msg_space: u64 = param.message_modulus.0 as u64;
for _ in 1..*nb_blocks {
msg_space *= param.message_modulus.0 as u64;
}
let clear1 = rng.gen::<u64>() % msg_space;
let ct1 = cks.encrypt_radix(clear1, *nb_blocks);
let lut = wopbs_key.generate_lut_radix(&ct1, |x| x);
let id = format!("{}_wopbs", group_name.clone());
// add the two ciphertexts
group.bench_function(id, |b| {
b.iter(|| {
let ct_res = wopbs_key.wopbs(&ct1, &lut);
})
});
}
}
fn joc_crt(c: &mut Criterion) {
let param_vec = vec![ID_3_CRT_16_BITS_5_BLOCKS];
let basis_16bits = vec![7,8,9,11,13];
let basis_vec = [basis_16bits];
for (param, basis) in param_vec.iter().zip(basis_vec.iter()) {
let modulus = basis.iter().product::<u64>();
let (cks, sks) = KEY_CACHE.get_from_params(*param);
let group_name = format!("{}", param.name());
let mut group = c.benchmark_group(group_name.clone());
group.sample_size(10);
let mut rng = rand::thread_rng();
let mut clear_0 = rng.gen::<u64>() % modulus;
let clear_1 = rng.gen::<u64>() % modulus;
// encryption of an integer
let mut ct_zero = cks.encrypt_crt(clear_0, basis.to_vec());
let mut ct_one = cks.encrypt_crt(clear_1, basis.to_vec());
let id = format!("{}_add", group_name.clone());
// add the two ciphertexts
group.bench_function(id, |b| {
b.iter(|| {
sks.unchecked_crt_add(&mut ct_zero, &ct_one);
})
});
let id = format!("{}_mul", group_name.clone());
// add the two ciphertexts
group.bench_function(id, |b| {
b.iter(|| {
sks.unchecked_crt_mul(&mut ct_zero, &ct_one);
})
});
let id = format!("{}_carry_propagate", group_name);
// add the two ciphertexts
group.bench_function(id, |b| {
b.iter(|| {
sks.full_extract_message_assign(&mut ct_zero);
})
});
}
}
fn joc_crt_wopbs(c: &mut Criterion) {
let param_vec = vec![
ID_9_CRT_16_BITS_5_BLOCKS_WOPBS,
];
// Define CRT basis, and global modulus
let basis_16bits = vec![7,8,9,11,13];
let basis_vec = [basis_16bits];
for (param, basis) in param_vec.iter().zip(basis_vec.iter()) {
let mut rng = rand::thread_rng();
let msg_space = basis.iter().product::<u64>();
let (cks, sks) = KEY_CACHE.get_from_params(*param);
let wopbs_key = WopbsKey::new_wopbs_key_only_for_wopbs(&cks, &sks);
let group_name = format!("{}", param.name());
let mut group = c.benchmark_group(group_name.clone());
group.sample_size(10);
let clear1 = rng.gen::<u64>() % msg_space;
let ct1 = cks.encrypt_crt(clear1, basis.to_vec());
let lut = wopbs_key.generate_lut_crt(&ct1, |x| x);
let id = format!("{}_crt_wopbs", group_name);
// add the two ciphertexts
group.bench_function(id, |b| {
b.iter(|| {
let ct_res = wopbs_key.wopbs(&ct1, &lut);
})
});
}
}
fn joc_native_crt_wopbs(c: &mut Criterion) {
let param_vec = vec![
ID_10_NATIF_CRT_16_BITS_5_BLOCKS_WOPBS,
//ID_11_NATIF_CRT_32_BITS_6_BLOCKS_WOPBS
];
// Define CRT basis, and global modulus
let basis_16bits = vec![7,8,9,11,13];
//let basis_32bits = vec![43,47,37,49,29,41];
let basis_vec = [
basis_16bits,
//basis_32bits,
];
for (param, basis) in param_vec.iter().zip(basis_vec.iter()) {
let mut rng = rand::thread_rng();
let msg_space = basis.iter().product::<u64>();
let (cks, sks) = KEY_CACHE.get_from_params(*param);
let wopbs_key = WopbsKey::new_wopbs_key_only_for_wopbs(&cks, &sks);
let group_name = format!("{}", param.name());
let mut group = c.benchmark_group(group_name.clone());
group.sample_size(10);
let clear1 = rng.gen::<u64>() % msg_space;
let ct1 = cks.encrypt_native_crt(clear1, basis.to_vec());
let lut = wopbs_key.generate_lut_native_crt(&ct1, |x| x);
let id = format!("{}_native_crt_wopbs", group_name);
group.bench_function(id, |b| {
b.iter(|| {
let ct_res = wopbs_key.wopbs_native_crt(&ct1, &lut);
})
});
}
}
fn joc_native_crt_add(c: &mut Criterion) {
let param_vec = vec![
//ID_10_NATIF_CRT_16_BITS_5_BLOCKS_WOPBS,
//ID_11_NATIF_CRT_32_BITS_6_BLOCKS_WOPBS,
ID_11_BIS_NATIF_CRT_32_BITS_8_BLOCKS_WOPBS,
];
// Define CRT basis, and global modulus
//let basis_16bits = vec![7,8,9,11,13];
//let basis_32bits = vec![43,47,37,49,29,41];
let basis_32bits_bis = vec![3, 11, 13, 19, 23, 29, 31, 32];
let basis_vec = [
//basis_16bits,
//basis_32bits,
basis_32bits_bis,
];
for (param, basis) in param_vec.iter().zip(basis_vec.iter()) {
let mut rng = rand::thread_rng();
let msg_space = basis.iter().product::<u64>();
let (cks, sks) = KEY_CACHE.get_from_params(*param);
let clear1 = rng.gen::<u64>() % msg_space;
let clear0 = rng.gen::<u64>() % msg_space;
let ct1 = cks.encrypt_native_crt(clear1, basis.to_vec());
let ct0 = cks.encrypt_native_crt(clear0, basis.to_vec());
let group_name = format!("{}", param.name());
let mut group = c.benchmark_group(group_name.clone());
group.sample_size(10);
let id = format!("{}_native_crt_add", group_name);
group.bench_function(id, |b| {
b.iter(|| {
let ct_res = sks.unchecked_crt_add(&ct1, &ct0);
})
});
}
}
fn joc_native_crt_mul_wopbs(c: &mut Criterion) {
let param_vec = vec![
//ID_10_NATIF_CRT_16_BITS_5_BLOCKS_WOPBS,
//ID_11_NATIF_CRT_32_BITS_6_BLOCKS_WOPBS,
ID_11_BIS_NATIF_CRT_32_BITS_8_BLOCKS_WOPBS
];
//let basis_16bits = vec![7,8,9,11,13];
//let basis_32bits = vec![43,47,37,49,29,41];
let basis_32bits_bis = vec![3, 11, 13, 19, 23, 29, 31, 32];
let basis_vec = [
//basis_16bits,
//basis_32bits,
basis_32bits_bis,
];
for (param, basis) in param_vec.iter().zip(basis_vec.iter()) {
let mut rng = rand::thread_rng();
let msg_space = basis.iter().product::<u64>();
let (cks, sks) = KEY_CACHE.get_from_params(*param);
let wopbs_key = WopbsKey::new_wopbs_key_only_for_wopbs(&cks, &sks);
let group_name = format!("{}", param.name());
let mut group = c.benchmark_group(group_name.clone());
group.sample_size(10);
let clear1 = rng.gen::<u64>() % msg_space;
let clear2 = rng.gen::<u64>() % msg_space;
let ct1 = cks.encrypt_native_crt(clear1, basis.to_vec());
let ct2 = cks.encrypt_native_crt(clear2, basis.to_vec());
let mut ct_res = ct1.clone();
let mut i = 0;
for ((ct_left, ct_right), res) in ct1.blocks.iter().zip(ct2.blocks.iter()).zip
(ct_res.blocks.iter_mut()) {
let crt_left = crt_ciphertext_from_ciphertext(&ct_left);
let crt_right = crt_ciphertext_from_ciphertext(&ct_right);
let mut crt_res = crt_ciphertext_from_ciphertext(&res);
let lut = wopbs_key.generate_lut_bivariate_native_crt(&crt_left, |x, y|
x * y);
let id = format!("{}_native_crt_wopbs_mul_block_{}", group_name, i);
group.bench_function(id, |b| {
b.iter(|| {
crt_res = wopbs_key.bivariate_wopbs_native_crt(&crt_left, &crt_right, &lut);
})
});
i = i+ 1;
}
}
}
fn joc_hybrid_32_bits(c: &mut Criterion) {
let param = ID_12_HYBRID_CRT_32_bits;
// basis = 2^5 * 3^5* 5^4 * 7^4
let basis_32bits = vec![
32,
243,
625,
2401
];
let modulus_vec = [
8,
3,
5,
7,
];
let nb_blocks_vec = [
4,
5,
4,
4,
];
let message_carry_mod_vec = [
(MessageModulus(8), CarryModulus(8)),
(MessageModulus(8), CarryModulus(8)),
(MessageModulus(8), CarryModulus(8)),
(MessageModulus(8), CarryModulus(8)),
];
let mut i= 0;
for (block_modulus, nb_blocks) in modulus_vec.iter().zip(nb_blocks_vec.iter
()) {
let (mut cks, mut sks) = KEY_CACHE.get_from_params(param);
cks.key.parameters.message_modulus = message_carry_mod_vec[i].0;
cks.key.parameters.carry_modulus = message_carry_mod_vec[i].1;
sks.key.message_modulus = message_carry_mod_vec[i].0;
sks.key.carry_modulus = message_carry_mod_vec[i].1;
let mut msg_space = basis_32bits[i];
let mut rng = rand::thread_rng();
let clear_0 = rng.gen::<u64>() % msg_space;
let clear_1 = rng.gen::<u64>() % msg_space;
let group_name = format!("{}", param.name());
let mut group = c.benchmark_group(group_name.clone());
group.sample_size(10);
// TEST_ADD //
let mut ct_zero_rad = cks.encrypt_radix_with_message_modulus(clear_0, *nb_blocks,
MessageModulus
(*block_modulus));
let mut ct_one_rad = cks.encrypt_radix_with_message_modulus(clear_1, *nb_blocks,
MessageModulus
(*block_modulus));
let id = format!("{}_hybrid_mul_block_{}", group_name, i);
group.bench_function(id, |b| {
b.iter(|| {
let mut ct_res = sks.unchecked_mul(&mut ct_one_rad, &mut ct_zero_rad);
})
});
let id = format!("{}_hybrid_add_block_{}", group_name, i);
group.bench_function(id, |b| {
b.iter(|| {
let mut ct_res = sks.unchecked_add(&mut ct_one_rad, &mut ct_zero_rad);
})
});
let id = format!("{}_hybrid_prop_block_{}", group_name, i);
group.bench_function(id, |b| {
b.iter(|| {
sks.full_propagate(&mut ct_one_rad);
})
});
i = i+1;
}
}

View File

@@ -211,7 +211,7 @@ fn programmable_bootstrapping(c: &mut Criterion) {
let modulus = cks.parameters.message_modulus.0 as u64;
let acc = sks.generate_accumulator(|x| x);
let acc = sks.generate_accumulator(|x| x, );
let clear_0 = rng.gen::<u64>() % modulus;

View File

@@ -34,7 +34,7 @@ pub unsafe extern "C" fn shortint_server_key_generate_pbs_accumulator(
let heap_allocated_accumulator = Box::new(ShortintPBSAccumulator(
server_key
.0
.generate_accumulator(|x: u64| accumulator_callback(x)),
.generate_accumulator(|x: u64| accumulator_callback(x), ),
));
*result = Box::into_raw(heap_allocated_accumulator);

View File

@@ -6,7 +6,7 @@ use serde::{Deserialize, Serialize};
#[derive(Serialize, Clone, Deserialize)]
pub struct RadixCiphertext {
/// The blocks are stored from LSB to MSB
pub(crate) blocks: Vec<ShortintCiphertext>,
pub blocks: Vec<ShortintCiphertext>,
}
pub trait IntegerCiphertext: Clone {
@@ -52,6 +52,13 @@ impl IntegerCiphertext for CrtCiphertext {
/// the same parameters.
#[derive(Serialize, Clone, Deserialize)]
pub struct CrtCiphertext {
pub(crate) blocks: Vec<ShortintCiphertext>,
pub blocks: Vec<ShortintCiphertext>,
pub(crate) moduli: Vec<u64>,
}
pub fn crt_ciphertext_from_ciphertext(ct: &ShortintCiphertext) -> CrtCiphertext{
CrtCiphertext {
blocks: vec![ct.clone(); 1],
moduli: vec![ct.message_modulus.0 as u64; 1],
}
}

View File

@@ -29,7 +29,7 @@ pub use radix::RadixClientKey;
/// use the same crypto parameters.
#[derive(Serialize, Deserialize, PartialEq, Debug, Clone)]
pub struct ClientKey {
pub(crate) key: ShortintClientKey,
pub key: ShortintClientKey,
}
impl From<ShortintClientKey> for ClientKey {
@@ -115,6 +115,23 @@ impl ClientKey {
RadixCiphertext { blocks }
}
pub fn encrypt_radix_with_message_modulus(&self, message: u64, num_blocks: usize,
block_message_modulus: MessageModulus) ->
RadixCiphertext {
let mut blocks = Vec::with_capacity(num_blocks);
let mut tmp_clear = message;
for _ in 0..num_blocks{
let tmp = tmp_clear % (block_message_modulus.0 as u64);
blocks.push(self.key.encrypt_with_message_modulus(tmp, block_message_modulus));
tmp_clear = (tmp_clear - tmp)/ (block_message_modulus.0 as u64);
}
RadixCiphertext { blocks }
}
/// Encrypts an integer in radix decomposition without padding bit
///
/// # Example
@@ -235,6 +252,28 @@ impl ClientKey {
result % whole_modulus
}
pub fn decrypt_radix_with_message_modulus(&self, ctxt: &RadixCiphertext) -> u64 {
let mut result = 0_u64;
let mut shift = 1_u64;
let modulus = ctxt.blocks[0].message_modulus.0 as u64;
for c_i in ctxt.blocks.iter() {
// decrypt the component i of the integer and multiply it by the radix product
let block_value = self.key.decrypt_message_and_carry(c_i).wrapping_mul(shift);
// update the result
result = result.wrapping_add(block_value);
// update the shift for the next iteration
shift = shift.wrapping_mul(modulus);
}
let whole_modulus = modulus.pow(ctxt.blocks.len() as u32);
result % whole_modulus
}
/// Decrypts a ciphertext encrypting an radix integer encrypted without padding
///
/// # Example

View File

@@ -57,6 +57,9 @@ impl RadixClientKey {
self.key.decrypt_radix(ciphertext)
}
/// Returns the parameters used by the client key.
pub fn parameters(&self) -> ShortintParameters {
self.key.parameters()

View File

@@ -1,4 +1,7 @@
#![allow(clippy::excessive_precision)]
pub mod parameters_benches_joc;
pub use crate::shortint::Parameters;
use crate::shortint::parameters::{CarryModulus, MessageModulus};

View File

@@ -0,0 +1,282 @@
pub use crate::core_crypto::commons::dispersion::{DispersionParameter, StandardDev};
pub use crate::core_crypto::commons::parameters::{
DecompositionBaseLog, DecompositionLevelCount, GlweDimension, LweDimension, PolynomialSize,
};
use crate::shortint::parameters::{CarryModulus, MessageModulus};
use crate::shortint::Parameters;
pub const ID_1_RADIX_16_BITS_16_BLOCKS: Parameters = Parameters {
lwe_dimension: LweDimension(615),
glwe_dimension: GlweDimension(4),
polynomial_size: PolynomialSize(512),
lwe_modular_std_dev: StandardDev(0.00009380341682666086),
glwe_modular_std_dev: StandardDev( 0.0000000000000003162026630747649),
pbs_base_log: DecompositionBaseLog(12),
pbs_level: DecompositionLevelCount(3),
ks_level: DecompositionLevelCount(5),
ks_base_log: DecompositionBaseLog(2),
pfks_level: DecompositionLevelCount(0),
pfks_base_log: DecompositionBaseLog(0),
pfks_modular_std_dev: StandardDev(0.0),
cbs_level: DecompositionLevelCount(0),
cbs_base_log: DecompositionBaseLog(0),
message_modulus: MessageModulus(2),
carry_modulus: CarryModulus(2),
};
pub const ID_2_RADIX_16_BITS_8_BLOCKS: Parameters = Parameters {
lwe_dimension: LweDimension(702),
glwe_dimension: GlweDimension(2),
polynomial_size: PolynomialSize(1024),
lwe_modular_std_dev: StandardDev(0.000018916438292526045),
glwe_modular_std_dev: StandardDev( 0.0000000000000003162026630747649),
pbs_base_log: DecompositionBaseLog(9),
pbs_level: DecompositionLevelCount(4),
ks_level: DecompositionLevelCount(7),
ks_base_log: DecompositionBaseLog(2),
pfks_level: DecompositionLevelCount(0),
pfks_base_log: DecompositionBaseLog(0),
pfks_modular_std_dev: StandardDev(0.0),
cbs_level: DecompositionLevelCount(0),
cbs_base_log: DecompositionBaseLog(0),
message_modulus: MessageModulus(4),
carry_modulus: CarryModulus(4),
};
pub const ID_3_CRT_16_BITS_5_BLOCKS: Parameters = Parameters {
lwe_dimension: LweDimension(872),
glwe_dimension: GlweDimension(1),
polynomial_size: PolynomialSize(4096),
lwe_modular_std_dev: StandardDev(0.0000008244869530752798),
glwe_modular_std_dev: StandardDev(0.0000000000000000002168404344971009),
pbs_base_log: DecompositionBaseLog(22),
pbs_level: DecompositionLevelCount(1),
ks_level: DecompositionLevelCount(4),
ks_base_log: DecompositionBaseLog(4),
pfks_level: DecompositionLevelCount(0),
pfks_base_log: DecompositionBaseLog(0),
pfks_modular_std_dev: StandardDev(0.0),
cbs_level: DecompositionLevelCount(0),
cbs_base_log: DecompositionBaseLog(0),
message_modulus: MessageModulus(16),
carry_modulus: CarryModulus(4),
};
pub const ID_4_RADIX_32_BITS_32_BLOCKS: Parameters = Parameters {
lwe_dimension: LweDimension(667),
glwe_dimension: GlweDimension(6),
polynomial_size: PolynomialSize(256),
lwe_modular_std_dev: StandardDev(0.00003604103581022737),
glwe_modular_std_dev: StandardDev(0.000000000003953518398797519),
pbs_base_log: DecompositionBaseLog(18),
pbs_level: DecompositionLevelCount(1),
ks_level: DecompositionLevelCount(3),
ks_base_log: DecompositionBaseLog(4),
pfks_level: DecompositionLevelCount(0),
pfks_base_log: DecompositionBaseLog(0),
pfks_modular_std_dev: StandardDev(0.0),
cbs_level: DecompositionLevelCount(0),
cbs_base_log: DecompositionBaseLog(0),
message_modulus: MessageModulus(2),
carry_modulus: CarryModulus(2),
};
pub const ID_5_RADIX_32_BITS_16_BLOCKS: Parameters = Parameters {
lwe_dimension: LweDimension(784),
glwe_dimension: GlweDimension(2),
polynomial_size: PolynomialSize(1024),
lwe_modular_std_dev: StandardDev(0.000004174399189990001),
glwe_modular_std_dev: StandardDev(0.0000000000000003162026630747649),
pbs_base_log: DecompositionBaseLog(23),
pbs_level: DecompositionLevelCount(1),
ks_level: DecompositionLevelCount(3),
ks_base_log: DecompositionBaseLog(4),
pfks_level: DecompositionLevelCount(0),
pfks_base_log: DecompositionBaseLog(0),
pfks_modular_std_dev: StandardDev(0.0),
cbs_level: DecompositionLevelCount(0),
cbs_base_log: DecompositionBaseLog(0),
message_modulus: MessageModulus(4),
carry_modulus: CarryModulus(4),
};
pub const ID_6_RADIX_32_BITS_8_BLOCKS: Parameters = Parameters {
lwe_dimension: LweDimension(983),
glwe_dimension: GlweDimension(1),
polynomial_size: PolynomialSize(16384),
lwe_modular_std_dev: StandardDev(0.00000010595830454427828),
glwe_modular_std_dev: StandardDev(0.0000000000000000002168404344971009),
pbs_base_log: DecompositionBaseLog(15),
pbs_level: DecompositionLevelCount(2),
ks_level: DecompositionLevelCount(5),
ks_base_log: DecompositionBaseLog(4),
pfks_level: DecompositionLevelCount(0),
pfks_base_log: DecompositionBaseLog(0),
pfks_modular_std_dev: StandardDev(0.0),
cbs_level: DecompositionLevelCount(0),
cbs_base_log: DecompositionBaseLog(0),
message_modulus: MessageModulus(16),
carry_modulus: CarryModulus(16),
};
pub const ID_6_CRT_32_BITS_6_BLOCKS: Parameters = Parameters {
lwe_dimension: LweDimension(983),
glwe_dimension: GlweDimension(1),
polynomial_size: PolynomialSize(16384),
lwe_modular_std_dev: StandardDev(0.00000010595830454427828),
glwe_modular_std_dev: StandardDev(0.0000000000000000002168404344971009),
pbs_base_log: DecompositionBaseLog(15),
pbs_level: DecompositionLevelCount(2),
ks_level: DecompositionLevelCount(5),
ks_base_log: DecompositionBaseLog(4),
pfks_level: DecompositionLevelCount(0),
pfks_base_log: DecompositionBaseLog(0),
pfks_modular_std_dev: StandardDev(0.0),
cbs_level: DecompositionLevelCount(0),
cbs_base_log: DecompositionBaseLog(0),
message_modulus: MessageModulus(64),
carry_modulus: CarryModulus(4),
};
pub const ID_7_RADIX_16_BITS_16_BLOCKS_WOPBS: Parameters = Parameters {
lwe_dimension: LweDimension(549),
glwe_dimension: GlweDimension(2),
polynomial_size: PolynomialSize(1024),
lwe_modular_std_dev: StandardDev(0.0003177104139262535),
glwe_modular_std_dev: StandardDev(0.0000000000000003162026630747649),
pbs_base_log: DecompositionBaseLog(12),
pbs_level: DecompositionLevelCount(3),
ks_level: DecompositionLevelCount(5),
ks_base_log: DecompositionBaseLog(2),
pfks_level: DecompositionLevelCount(2),
pfks_base_log: DecompositionBaseLog(17),
pfks_modular_std_dev: StandardDev(0.0000000000000003162026630747649),
cbs_level: DecompositionLevelCount(1),
cbs_base_log: DecompositionBaseLog(13),
message_modulus: MessageModulus(2),
carry_modulus: CarryModulus(2),
};
pub const ID_8_RADIX_16_BITS_8_BLOCKS_WOPBS: Parameters = Parameters {
lwe_dimension: LweDimension(534),
glwe_dimension: GlweDimension(2),
polynomial_size: PolynomialSize(1024),
lwe_modular_std_dev: StandardDev(0.0004192214045106218),
glwe_modular_std_dev: StandardDev(0.0000000000000003162026630747649),
pbs_base_log: DecompositionBaseLog(12),
pbs_level: DecompositionLevelCount(3),
ks_level: DecompositionLevelCount(5),
ks_base_log: DecompositionBaseLog(2),
pfks_level: DecompositionLevelCount(2),
pfks_base_log: DecompositionBaseLog(17),
pfks_modular_std_dev: StandardDev(0.0000000000000003162026630747649),
cbs_level: DecompositionLevelCount(2),
cbs_base_log: DecompositionBaseLog(9),
message_modulus: MessageModulus(4),
carry_modulus: CarryModulus(4),
};
pub const ID_9_CRT_16_BITS_5_BLOCKS_WOPBS: Parameters = Parameters {
lwe_dimension: LweDimension(538),
glwe_dimension: GlweDimension(4),
polynomial_size: PolynomialSize(1024),
lwe_modular_std_dev: StandardDev(0.00038844554870845634),
glwe_modular_std_dev: StandardDev(0.0000000000000000002168404344971009),
pbs_base_log: DecompositionBaseLog(4),
pbs_level: DecompositionLevelCount(11),
ks_level: DecompositionLevelCount(10),
ks_base_log: DecompositionBaseLog(1),
pfks_level: DecompositionLevelCount(2),
pfks_base_log: DecompositionBaseLog(20),
pfks_modular_std_dev: StandardDev(0.0000000000000000002168404344971009),
cbs_level: DecompositionLevelCount(4),
cbs_base_log: DecompositionBaseLog(7),
message_modulus: MessageModulus(16),
carry_modulus: CarryModulus(4),
};
pub const ID_10_NATIF_CRT_16_BITS_5_BLOCKS_WOPBS: Parameters = Parameters {
lwe_dimension: LweDimension(696),
glwe_dimension: GlweDimension(2),
polynomial_size: PolynomialSize(1024),
lwe_modular_std_dev: StandardDev(0.00002113509320237618),
glwe_modular_std_dev: StandardDev(0.0000000000000003162026630747649),
pbs_base_log: DecompositionBaseLog(9),
pbs_level: DecompositionLevelCount(4),
ks_level: DecompositionLevelCount(7),
ks_base_log: DecompositionBaseLog(2),
pfks_level: DecompositionLevelCount(2),
pfks_base_log: DecompositionBaseLog(17),
pfks_modular_std_dev: StandardDev(0.0000000000000003162026630747649),
cbs_level: DecompositionLevelCount(3),
cbs_base_log: DecompositionBaseLog(7),
message_modulus: MessageModulus(16),
carry_modulus: CarryModulus(1),
};
pub const ID_11_NATIF_CRT_32_BITS_6_BLOCKS_WOPBS: Parameters = Parameters {
lwe_dimension: LweDimension(791),
glwe_dimension: GlweDimension(1),
polynomial_size: PolynomialSize(4096),
lwe_modular_std_dev: StandardDev(0.000003659302213002263),
glwe_modular_std_dev: StandardDev(0.0000000000000000002168404344971009),
pbs_base_log: DecompositionBaseLog(3),
pbs_level: DecompositionLevelCount(14),
ks_level: DecompositionLevelCount(16),
ks_base_log: DecompositionBaseLog(1),
pfks_level: DecompositionLevelCount(2),
pfks_base_log: DecompositionBaseLog(20),
pfks_modular_std_dev: StandardDev(0.0000000000000000002168404344971009),
cbs_level: DecompositionLevelCount(5),
cbs_base_log: DecompositionBaseLog(5),
message_modulus: MessageModulus(64),
carry_modulus: CarryModulus(1),
};
pub const ID_11_BIS_NATIF_CRT_32_BITS_8_BLOCKS_WOPBS: Parameters = Parameters {
lwe_dimension: LweDimension(781),
glwe_dimension: GlweDimension(1),
polynomial_size: PolynomialSize(2048),
lwe_modular_std_dev: StandardDev(0.0000044043577651404615),
glwe_modular_std_dev: StandardDev(0.0000000000000003152931493498455),
pbs_base_log: DecompositionBaseLog(5),
pbs_level: DecompositionLevelCount(8),
ks_level: DecompositionLevelCount(16),
ks_base_log: DecompositionBaseLog(1),
pfks_level: DecompositionLevelCount(3),
pfks_base_log: DecompositionBaseLog(13),
pfks_modular_std_dev: StandardDev(0.0000000000000003152931493498455),
cbs_level: DecompositionLevelCount(4),
cbs_base_log: DecompositionBaseLog(6),
message_modulus: MessageModulus(32),
carry_modulus: CarryModulus(1),
};
pub const ID_12_HYBRID_CRT_32_bits: Parameters = Parameters {
lwe_dimension: LweDimension(838),
glwe_dimension: GlweDimension(1),
polynomial_size: PolynomialSize(4096),
lwe_modular_std_dev: StandardDev(0.0000015398206356719045),
glwe_modular_std_dev: StandardDev(0.0000000000000000002168404344971009),
pbs_base_log: DecompositionBaseLog(15),
pbs_level: DecompositionLevelCount(2),
ks_level: DecompositionLevelCount(5),
ks_base_log: DecompositionBaseLog(3),
pfks_level: DecompositionLevelCount(0),
pfks_base_log: DecompositionBaseLog(0),
pfks_modular_std_dev: StandardDev(0.0),
cbs_level: DecompositionLevelCount(0),
cbs_base_log: DecompositionBaseLog(0),
message_modulus: MessageModulus(8),
carry_modulus: CarryModulus(8),
};
pub const TEST_WOPBS: Parameters = Parameters {
lwe_dimension: LweDimension(10),
glwe_dimension: GlweDimension(1),
polynomial_size: PolynomialSize(1024),
lwe_modular_std_dev: StandardDev(0.0000000000000000000004168323308734758),
glwe_modular_std_dev: StandardDev(0.00000000000000000000000000000000000004905643852600863),
pbs_base_log: DecompositionBaseLog(7),
pbs_level: DecompositionLevelCount(6),
ks_base_log: DecompositionBaseLog(1),
ks_level: DecompositionLevelCount(14),
pfks_level: DecompositionLevelCount(6),
pfks_base_log: DecompositionBaseLog(7),
pfks_modular_std_dev: StandardDev(0.000000000000000000000000000000000000004905643852600863),
cbs_level: DecompositionLevelCount(7),
cbs_base_log: DecompositionBaseLog(4),
message_modulus: MessageModulus(16),
carry_modulus: CarryModulus(1),
};

View File

@@ -83,7 +83,7 @@ impl ServerKey {
let accumulators = basis
.iter()
.copied()
.map(|b| self.key.generate_accumulator(|x| f(x) % b));
.map(|b| self.key.generate_accumulator(|x| f(x) % b, ));
for (block, acc) in ct1.blocks.iter_mut().zip(accumulators) {
self.key

View File

@@ -28,7 +28,12 @@ impl ServerKey {
/// ```
pub fn unchecked_crt_mul_assign(&self, ct_left: &mut CrtCiphertext, ct_right: &CrtCiphertext) {
for (ct_left, ct_right) in ct_left.blocks.iter_mut().zip(ct_right.blocks.iter()) {
self.key.unchecked_mul_lsb_assign(ct_left, ct_right);
if ct_left.message_modulus.0 <= ct_left.carry_modulus.0 {
self.key.unchecked_mul_lsb_assign(ct_left, ct_right);
}
else {
self.key.unchecked_mul_lsb_small_carry_assign(ct_left, ct_right);
}
}
}

View File

@@ -82,7 +82,7 @@ impl ServerKey {
let accumulators = basis
.iter()
.copied()
.map(|b| self.key.generate_accumulator(|x| f(x) % b))
.map(|b| self.key.generate_accumulator(|x| f(x) % b, ))
.collect::<Vec<_>>();
ct1.blocks

View File

@@ -6,6 +6,8 @@ mod crt;
mod crt_parallel;
mod radix;
mod radix_parallel;
#[cfg(test)]
mod tests;
use crate::integer::client_key::ClientKey;
use crate::shortint::server_key::MaxDegree;
@@ -20,7 +22,7 @@ pub use crate::shortint::CheckError;
/// sends it to the server so it can compute homomorphic integer circuits.
#[derive(Serialize, Deserialize, Clone)]
pub struct ServerKey {
pub(crate) key: crate::shortint::ServerKey,
pub key: crate::shortint::ServerKey,
}
impl From<ServerKey> for crate::shortint::ServerKey {

View File

@@ -11,6 +11,7 @@ mod sub;
use super::ServerKey;
use crate::integer::RadixCiphertext;
use crate::shortint::prelude::MessageModulus;
#[cfg(test)]
mod tests;
@@ -44,6 +45,18 @@ impl ServerKey {
RadixCiphertext { blocks: vec_res }
}
pub fn create_trivial_zero_radix_with_message_modulus(&self, num_blocks: usize,
message_modulus: MessageModulus) ->
RadixCiphertext {
let mut vec_res = Vec::with_capacity(num_blocks);
for _ in 0..num_blocks {
vec_res.push(self.key.create_trivial_with_message_modulus(0_u64, message_modulus));
}
RadixCiphertext { blocks: vec_res }
}
/// Propagate the carry of the 'index' block to the next one.
///
/// # Example

View File

@@ -86,15 +86,26 @@ impl ServerKey {
let mut result_lsb = shifted_ct.clone();
let mut result_msb = shifted_ct;
for res_lsb_i in result_lsb.blocks[index..].iter_mut() {
self.key.unchecked_mul_lsb_assign(res_lsb_i, ct2);
}
let len = result_msb.blocks.len() - 1;
for res_msb_i in result_msb.blocks[index..len].iter_mut() {
self.key.unchecked_mul_msb_assign(res_msb_i, ct2);
}
//if ct1.blocks[0].message_modulus.0 <= ct1.blocks[0].carry_modulus.0 {
for res_lsb_i in result_lsb.blocks[index..].iter_mut() {
self.key.unchecked_mul_lsb_assign(res_lsb_i, ct2);
}
let len = result_msb.blocks.len() - 1;
for res_msb_i in result_msb.blocks[index..len].iter_mut() {
self.key.unchecked_mul_msb_assign(res_msb_i, ct2);
}
//}
// else {
// for res_lsb_i in result_lsb.blocks[index..].iter_mut() {
// self.key.unchecked_mul_lsb_small_carry_assign(res_lsb_i, ct2);
// }
//
// let len = result_msb.blocks.len() - 1;
// for res_msb_i in result_msb.blocks[index..len].iter_mut() {
// self.key.unchecked_mul_msb_small_carry_assign(res_msb_i, ct2);
// }
// }
result_msb = self.blockshift(&result_msb, 1);
self.unchecked_add(&result_lsb, &result_msb)
@@ -207,12 +218,13 @@ impl ServerKey {
///
/// The result is returned as a new ciphertext.
pub fn unchecked_mul(&self, ct1: &RadixCiphertext, ct2: &RadixCiphertext) -> RadixCiphertext {
let mut result = self.create_trivial_zero_radix(ct1.blocks.len());
let mut result = self.create_trivial_zero_radix_with_message_modulus(ct1.blocks.len(),
ct1.blocks[0].message_modulus);
for (i, ct2_i) in ct2.blocks.iter().enumerate() {
let tmp = self.unchecked_block_mul(ct1, ct2_i, i);
let mut tmp = self.unchecked_block_mul(ct1, ct2_i, i);
self.unchecked_add_assign(&mut result, &tmp);
self.smart_add_assign(&mut result, &mut tmp);
}
result

View File

@@ -0,0 +1,905 @@
use crate::integer::keycache::{KEY_CACHE, KEY_CACHE_WOPBS};
use rand::Rng;
use crate::integer::ciphertext::crt_ciphertext_from_ciphertext;
use crate::integer::{CrtCiphertext, IntegerCiphertext, RadixCiphertext};
use crate::integer::parameters::PARAM_4_BITS_5_BLOCKS;
use crate::integer::parameters::parameters_benches_joc::*;
use crate::integer::wopbs::WopbsKey;
use crate::shortint::prelude::{CarryModulus, MessageModulus};
/// Number of assert in randomized tests
const NB_TEST: usize = 10;
//// RADIX ///
#[test]
fn joc_radix_add() {
let param_vec = vec![ID_1_RADIX_16_BITS_16_BLOCKS, ID_2_RADIX_16_BITS_8_BLOCKS,
ID_4_RADIX_32_BITS_32_BLOCKS, ID_5_RADIX_32_BITS_16_BLOCKS, ID_6_RADIX_32_BITS_8_BLOCKS];
let nb_blocks_vec = vec![16, 8, 32, 16, 8];
for (param, nb_blocks) in param_vec.iter().zip(nb_blocks_vec.iter()) {
let (cks, sks) = KEY_CACHE.get_from_params(*param);
let mut rng = rand::thread_rng();
// message_modulus^vec_length
let modulus = param.message_modulus.0.pow(*nb_blocks as u32) as u64;
for _ in 0..NB_TEST {
let clear_0 = rng.gen::<u64>() % modulus;
let clear_1 = rng.gen::<u64>() % modulus;
// encryption of an integer
let ct_0 = cks.encrypt_radix(clear_0, *nb_blocks);
// encryption of an inter
let ct_1 = cks.encrypt_radix(clear_1, *nb_blocks);
// add the two ciphertexts
let ct_res = sks.unchecked_add(&ct_0, &ct_1);
// decryption of ct_res
let dec_res = cks.decrypt_radix(&ct_res);
// assert
assert_eq!((clear_0 + clear_1) % modulus, dec_res);
}
}
}
#[test]
fn joc_radix_mul() {
let param_vec = vec![
ID_1_RADIX_16_BITS_16_BLOCKS,
ID_2_RADIX_16_BITS_8_BLOCKS,
ID_4_RADIX_32_BITS_32_BLOCKS, // DOES NOT WORK
ID_5_RADIX_32_BITS_16_BLOCKS,
ID_6_RADIX_32_BITS_8_BLOCKS
];
let nb_blocks_vec = vec![
16,
8,
32,
16,
8
];
for (param, nb_blocks) in param_vec.iter().zip(nb_blocks_vec.iter()) {
let (cks, sks) = KEY_CACHE.get_from_params(*param);
let mut rng = rand::thread_rng();
// message_modulus^vec_length
let modulus = param.message_modulus.0.pow(*nb_blocks as u32) as u64;
println!("MODULUS = {modulus}, nb_blocks = {nb_blocks}");
for _ in 0..10 {
let clear_0 = 2725718330; //rng.gen::<u64>() % modulus;
// let clear_1 = 751081786; // rng.gen::<u64>() % modulus;
let clear_1 = 2;
// encryption of an integer
let mut ct_0 = cks.encrypt_radix(clear_0, *nb_blocks);
// encryption of an inter
let mut ct_1 = cks.encrypt_radix(clear_1, *nb_blocks);
println!("DECRYPTED: ct_0 = {}, ct_1 = {}", cks.decrypt_radix(&ct_0), cks
.decrypt_radix(&ct_1));
// mul the two ciphertexts
let mut ct_res = sks.unchecked_mul(&mut ct_0, &mut ct_1);
// let mut ct_res = sks.unchecked_add(&mut ct_0.clone(), &mut ct_0);
// sks.full_propagate(&mut ct_res);
// decryption of ct_res
let dec_res = cks.decrypt_radix(&ct_res);
// assert
assert_eq!((clear_0 * clear_1) % modulus, dec_res);
}
}
}
#[test]
fn joc_radix_carry_propagate() {
let param_vec = vec![
ID_1_RADIX_16_BITS_16_BLOCKS,
ID_2_RADIX_16_BITS_8_BLOCKS,
ID_4_RADIX_32_BITS_32_BLOCKS,
ID_5_RADIX_32_BITS_16_BLOCKS,
ID_6_RADIX_32_BITS_8_BLOCKS
];
let nb_blocks_vec = vec![
16,
8,
32,
16,
8,
];
for (param, nb_blocks) in param_vec.iter().zip(nb_blocks_vec.iter()) {
let (cks, sks) = KEY_CACHE.get_from_params(*param);
let mut rng = rand::thread_rng();
// message_modulus^vec_length
let modulus = param.message_modulus.0.pow(*nb_blocks as u32) as u64;
for _ in 0..NB_TEST {
let clear_0 = rng.gen::<u64>() % modulus;
// encryption of an integer
let mut ct_0 = cks.encrypt_radix(clear_0, *nb_blocks);
sks.full_propagate(&mut ct_0);
// decryption of ct_res
let dec_res = cks.decrypt_radix(&ct_0);
// assert
assert_eq!(clear_0 % modulus, dec_res);
}
}
}
#[test]
pub fn joc_radix_wopbs() {
let param_vec = vec![
ID_7_RADIX_16_BITS_16_BLOCKS_WOPBS,
ID_8_RADIX_16_BITS_8_BLOCKS_WOPBS
];
let nb_blocks_vec = vec![
16,
8,
];
for (param, nb_blocks) in param_vec.iter().zip(nb_blocks_vec.iter()) {
let mut rng = rand::thread_rng();
let (cks, sks) = KEY_CACHE.get_from_params(*param);
let wopbs_key = WopbsKey::new_wopbs_key_only_for_wopbs(&cks, &sks);
let mut msg_space: u64 = param.message_modulus.0 as u64;
for _ in 1..*nb_blocks {
msg_space *= param.message_modulus.0 as u64;
}
for _ in 0..NB_TEST {
let clear1 = rng.gen::<u64>() % msg_space;
let ct1 = cks.encrypt_radix(clear1, *nb_blocks);
let lut = wopbs_key.generate_lut_radix(&ct1, |x| x);
let ct_res = wopbs_key.wopbs(&ct1, &lut);
let res_wop = cks.decrypt_radix(&ct_res);
assert_eq!(clear1, res_wop);
}
}
}
/// CRT ///
#[test]
fn joc_crt_add() {
let param_vec = vec![ID_3_CRT_16_BITS_5_BLOCKS, ID_6_CRT_32_BITS_6_BLOCKS];
// Define CRT basis, and global modulus
let basis_16bits = vec![7,8,9,11,13];
let basis_32bits = vec![43,47,37,49,29,41];
let basis_vec = [basis_16bits, basis_32bits];
for (param, basis) in param_vec.iter().zip(basis_vec.iter()) {
let modulus = basis.iter().product::<u64>();
let (cks, sks) = KEY_CACHE.get_from_params(*param);
let mut rng = rand::thread_rng();
for _ in 0..NB_TEST {
let mut clear_0 = rng.gen::<u64>() % modulus;
let clear_1 = rng.gen::<u64>() % modulus;
// encryption of an integer
let mut ct_zero = cks.encrypt_crt(clear_0, basis.to_vec());
let mut ct_one = cks.encrypt_crt(clear_1, basis.to_vec());
let dec0 = cks.decrypt_crt(&ct_zero);
let dec1 = cks.decrypt_crt(&ct_one);
assert_eq!(dec0, clear_0);
assert_eq!(dec1, clear_1);
// add the two ciphertexts
let ct_res = sks.unchecked_crt_add(&mut ct_zero, &mut ct_one);
// decryption of ct_res
let dec_res = cks.decrypt_crt(&ct_res);
// assert
clear_0 += clear_1;
assert_eq!(clear_0 % modulus, dec_res % modulus);
}
}
}
#[test]
fn joc_crt_mul() {
let param_vec = vec![ID_3_CRT_16_BITS_5_BLOCKS, ID_6_CRT_32_BITS_6_BLOCKS];
// Define CRT basis, and global modulus
let basis_16bits = vec![7,8,9,11,13];
let basis_32bits = vec![43,47,37,49,29,41];
let basis_vec = [basis_16bits, basis_32bits];
for (param, basis) in param_vec.iter().zip(basis_vec.iter()) {
let modulus = basis.iter().product::<u64>();
let (cks, sks) = KEY_CACHE.get_from_params(*param);
let mut rng = rand::thread_rng();
for _ in 0..NB_TEST {
let mut clear_0 = rng.gen::<u64>() % modulus;
let clear_1 = rng.gen::<u64>() % modulus;
// encryption of an integer
let mut ct_zero = cks.encrypt_crt(clear_0, basis.to_vec());
let mut ct_one = cks.encrypt_crt(clear_1, basis.to_vec());
// add the two ciphertexts
let ct_res = sks.unchecked_crt_mul(&mut ct_zero, &mut ct_one);
// decryption of ct_res
let dec_res = cks.decrypt_crt(&ct_res);
// assert
clear_0 *= clear_1;
assert_eq!(clear_0 % modulus, dec_res % modulus);
}
}
}
#[test]
fn joc_crt_carry_propagate() {
let param_vec = vec![ID_3_CRT_16_BITS_5_BLOCKS, ID_6_CRT_32_BITS_6_BLOCKS];
// Define CRT basis, and global modulus
let basis_16bits = vec![7,8,9,11,13];
let basis_32bits = vec![43,47,37,49,29,41];
let basis_vec = [basis_16bits, basis_32bits];
for (param, basis) in param_vec.iter().zip(basis_vec.iter()) {
let modulus = basis.iter().product::<u64>();
let (cks, sks) = KEY_CACHE.get_from_params(*param);
let mut rng = rand::thread_rng();
for _ in 0..NB_TEST {
let clear_0 = rng.gen::<u64>() % modulus;
// encryption of an integer
let mut ct_zero = cks.encrypt_crt(clear_0, basis.to_vec());
// add the two ciphertexts
sks.full_extract_message_assign(&mut ct_zero);
// decryption of ct_res
let dec_res = cks.decrypt_crt(&ct_zero);
// assert
assert_eq!(clear_0 % modulus, dec_res % modulus);
}
}
}
#[test]
pub fn joc_crt_wopbs() {
let param_vec = vec![
ID_9_CRT_16_BITS_5_BLOCKS_WOPBS,
];
// Define CRT basis, and global modulus
let basis_16bits = vec![7,8,9,11,13];
let basis_vec = [basis_16bits];
for (param, basis) in param_vec.iter().zip(basis_vec.iter()) {
let mut rng = rand::thread_rng();
let msg_space = basis.iter().product::<u64>();
let (cks, sks) = KEY_CACHE.get_from_params(*param);
let wopbs_key = WopbsKey::new_wopbs_key_only_for_wopbs(&cks, &sks);
for _ in 0..NB_TEST {
let clear1 = rng.gen::<u64>() % msg_space;
let ct1 = cks.encrypt_crt(clear1, basis.to_vec());
let lut = wopbs_key.generate_lut_crt(&ct1, |x| x);
let ct_res = wopbs_key.wopbs(&ct1, &lut);
let res_wop = cks.decrypt_crt(&ct_res);
assert_eq!(clear1, res_wop);
}
}
}
//#[test]
pub fn joc_native_crt_wopbs() {
let param_vec = vec![
ID_10_NATIF_CRT_16_BITS_5_BLOCKS_WOPBS,
//ID_11_NATIF_CRT_32_BITS_6_BLOCKS_WOPBS
];
// Define CRT basis, and global modulus
let basis_16bits = vec![7,8,9,11,13];
//let basis_32bits = vec![43,47,37,49,29,41];
let basis_vec = [
basis_16bits,
// basis_32bits,
];
for (param, basis) in param_vec.iter().zip(basis_vec.iter()) {
let mut rng = rand::thread_rng();
let msg_space = basis.iter().product::<u64>();
let (cks, sks) = KEY_CACHE.get_from_params(*param);
let wopbs_key = WopbsKey::new_wopbs_key_only_for_wopbs(&cks, &sks);
for _ in 0..NB_TEST {
let clear1 = rng.gen::<u64>() % msg_space;
let ct1 = cks.encrypt_native_crt(clear1, basis.to_vec());
let lut = wopbs_key.generate_lut_native_crt(&ct1, |x| x);
let ct_res = wopbs_key.wopbs_native_crt(&ct1, &lut);
let res_wop = cks.decrypt_native_crt(&ct_res);
assert_eq!(clear1, res_wop);
}
}
}
#[test]
pub fn joc_native_crt_add() {
let param_vec = vec![
//ID_10_NATIF_CRT_16_BITS_5_BLOCKS_WOPBS,
//ID_11_NATIF_CRT_32_BITS_6_BLOCKS_WOPBS,
ID_11_BIS_NATIF_CRT_32_BITS_8_BLOCKS_WOPBS,
];
// Define CRT basis, and global modulus
//let basis_16bits = vec![7,8,9,11,13];
//let basis_32bits = vec![43,47,37,49,29,41];
let basis_32bits_bis = vec![3, 11, 13, 19, 23, 29, 31, 32];
let basis_vec = [
//basis_16bits,
//basis_32bits,
basis_32bits_bis,
];
for (param, basis) in param_vec.iter().zip(basis_vec.iter()) {
let mut rng = rand::thread_rng();
let msg_space = basis.iter().product::<u64>();
let (cks, sks) = KEY_CACHE.get_from_params(*param);
for _ in 0..NB_TEST {
let clear1 = rng.gen::<u64>() % msg_space;
let clear0 = rng.gen::<u64>() % msg_space;
let ct1 = cks.encrypt_native_crt(clear1, basis.to_vec());
let ct0 = cks.encrypt_native_crt(clear0, basis.to_vec());
let ct_res = sks.unchecked_crt_add(&ct1, &ct0);
let res = cks.decrypt_native_crt(&ct_res);
assert_eq!((clear0 + clear1) % msg_space, res);
}
}
}
#[test]
pub fn joc_native_crt_mul_wopbs() {
let param_vec = vec![
//ID_11_NATIF_CRT_32_BITS_6_BLOCKS_WOPBS,
ID_11_BIS_NATIF_CRT_32_BITS_8_BLOCKS_WOPBS
];
//let basis_32bits = vec![43,47,37,49,29,41];
let basis_32bits_bis = vec![3, 11, 13, 19, 23, 29, 31, 32];
let basis_vec = [
//basis_32bits,
basis_32bits_bis,
];
for (param, basis) in param_vec.iter().zip(basis_vec.iter()) {
let mut rng = rand::thread_rng();
let msg_space = basis.iter().product::<u64>();
let (cks, sks) = KEY_CACHE.get_from_params(*param);
let wopbs_key = WopbsKey::new_wopbs_key_only_for_wopbs(&cks, &sks);
for _ in 0..NB_TEST {
let clear1 = rng.gen::<u64>() % msg_space;
let clear2 = rng.gen::<u64>() % msg_space;
let ct1 = cks.encrypt_native_crt(clear1, basis.to_vec());
let ct2 = cks.encrypt_native_crt(clear2, basis.to_vec());
let mut ct_res = ct1.clone();
for ((ct_left, ct_right), res) in ct1.blocks.iter().zip(ct2.blocks.iter()).zip
(ct_res.blocks.iter_mut()) {
let crt_left = crt_ciphertext_from_ciphertext(&ct_left);
let crt_right = crt_ciphertext_from_ciphertext(&ct_right);
let mut crt_res = crt_ciphertext_from_ciphertext(&res);
let lut = wopbs_key.generate_lut_bivariate_native_crt(&crt_left, |x,y|
x*y);
crt_res = wopbs_key.bivariate_wopbs_native_crt(&crt_left, &crt_right, &lut);
}
let res_wop = cks.decrypt_native_crt(&ct_res);
assert_eq!(clear1, res_wop);
}
}
}
#[test]
pub fn joc_hybrid_32_bits() {
let param = ID_12_HYBRID_CRT_32_bits;
//let param = ID_6_CRT_32_BITS_6_BLOCKS;
// basis = 2^5 * 3^5* 5^4 * 7^4
let basis_32bits = vec![
32,
243,
625,
2401
];
let modulus_vec = [
8,
3,
5,
7,
];
let nb_blocks_vec = [
4,
5,
4,
4,
];
let message_carry_mod_vec = [
(MessageModulus(8), CarryModulus(8)),
(MessageModulus(8), CarryModulus(8)),
(MessageModulus(8), CarryModulus(8)),
(MessageModulus(8), CarryModulus(8)),
];
//println!("Chosen Parameter Set: {param:?}");
for _ in 0..10 {
let mut i= 0;
for (block_modulus, nb_blocks) in modulus_vec.iter().zip(nb_blocks_vec.iter
()) {
let (mut cks, mut sks) = KEY_CACHE.get_from_params(param);
// sks.key.message_modulus = MessageModulus(*block_modulus);
// sks.key.carry_modulus = CarryModulus(*block_modulus);
cks.key.parameters.message_modulus = message_carry_mod_vec[i].0;
cks.key.parameters.carry_modulus = message_carry_mod_vec[i].1;
sks.key.message_modulus = message_carry_mod_vec[i].0;
sks.key.carry_modulus = message_carry_mod_vec[i].1;
let mut msg_space = basis_32bits[i];
i = i+1;
// println!("block_modulus = {block_modulus}");
// println!("msg_space = {msg_space}");
//
// let mut rng = rand::thread_rng();
// let clear_0 = rng.gen::<u64>() % msg_space;
// let mut cpy_clear_0 = clear_0;
// let mut blocks_crt_0 = vec![];
// for _ in 0..*nb_blocks{
// let tmp = cpy_clear_0 % block_modulus;
// blocks_crt_0.push((cks.encrypt_crt(tmp, vec![*block_modulus])).blocks[0].clone());
// cpy_clear_0 = (cpy_clear_0 - tmp)/ block_modulus;
// }
// let clear_1 = rng.gen::<u64>() % msg_space;
// let mut cpy_clear_1 = clear_1;
// let mut blocks_crt_1 = vec![];
// for _ in 0..*nb_blocks{
// let tmp = cpy_clear_1 % block_modulus;
// blocks_crt_1.push((cks.encrypt_crt(tmp, vec![*block_modulus])).blocks[0].clone());
// cpy_clear_1 = (cpy_clear_1 - tmp)/ block_modulus;
// }
let mut rng = rand::thread_rng();
let clear_0 = rng.gen::<u64>() % msg_space;
let clear_1 = rng.gen::<u64>() % msg_space;
println!("clear 0 {:?}", clear_0);
println!("clear 1 {:?}", clear_1);
// TEST_ADD //
let mut ct_zero_rad = cks.encrypt_radix_with_message_modulus(clear_0, *nb_blocks,
MessageModulus
(*block_modulus));
let mut ct_one_rad = cks.encrypt_radix_with_message_modulus(clear_1, *nb_blocks,
MessageModulus
(*block_modulus));
// for (ct0_i, ct1_i) in ct_zero_rad.blocks.iter_mut().zip(ct_one_rad.blocks.iter_mut()) {
// ct0_i.carry_modulus = CarryModulus(ct0_i.message_modulus.0);
// ct1_i.carry_modulus = CarryModulus(ct0_i.message_modulus.0);
//
// }
println!("CT0 Msg modulus = {}, CT0 carry modulus = {}", ct_zero_rad.blocks[0]
.message_modulus
.0.clone(), ct_zero_rad.blocks[0]
.carry_modulus.0);
println!("CT1 Msg modulus = {}, CT1 carry modulus = {}", ct_one_rad.blocks[0]
.message_modulus
.0.clone(), ct_one_rad.blocks[0]
.carry_modulus.0);
//
let result = cks.decrypt_radix_with_message_modulus(&ct_zero_rad);
assert_eq!(result % msg_space, (clear_0) % msg_space);
let result = cks.decrypt_radix_with_message_modulus(&ct_one_rad);
assert_eq!(result % msg_space, (clear_1) % msg_space);
//TEST ADD
let mut ct_res = sks.unchecked_add(&ct_zero_rad, &ct_one_rad);
let mut result = 0_u64;
let mut shift = 1_u64;
let result = cks.decrypt_radix_with_message_modulus(&ct_res);
println!("add");
println!("dec add {:?}", result);
println!("dec add mod {:?}", result% msg_space);
println!("expected {:?}", (clear_0 + clear_1) % msg_space);
assert_eq!(result % msg_space, (clear_0 + clear_1) % msg_space);
println!("-----");
// TEST_CARRY_PROPAGATE //
sks.full_propagate(&mut ct_res);
let result = cks.decrypt_radix_with_message_modulus(&ct_res);
println!("propagate");
println!("dec propagate {:?}", result);
println!("dec propagate mod {:?}", result% msg_space);
assert_eq!(result % msg_space , (clear_0 + clear_1) % msg_space);
println!("expected {:?}", (clear_0 + clear_1) % msg_space);
println!("-----");
let mut ct_res = sks.unchecked_mul(&mut ct_one_rad, &mut ct_zero_rad);
let result = cks.decrypt_radix_with_message_modulus(&ct_res);
println!("mul");
println!("dec mul {:?}", result);
println!("dec mul mod {:?}", result % msg_space);
println!("clear mul {:?}", (clear_0 * clear_1));
println!("clear mul mod {:?}", (clear_0 * clear_1) % msg_space);
println!("info deg: {:?}", ct_res.blocks[0].degree);
println!("info mm : {:?}", ct_res.blocks[0].message_modulus);
println!("info cm : {:?}", ct_res.blocks[0].carry_modulus);
println!("expected {:?}", (clear_0 * clear_1) % msg_space);
assert_eq!(result % msg_space , (clear_0 * clear_1) % msg_space);
println!("-----");
}
}
//println!("it's OK");
panic!()
}
#[test]
pub fn EXPERIMETAL_hybride_32_bits() {
let param = ID_6_CRT_32_BITS_6_BLOCKS;
// basis = 2^5 * 3^5* 5^4 * 7^4
let basis_32bits = vec![
32,
243,
625,
2420
];
let modulus_vec = [
2,
3,
5,
7,
];
let nb_blocks_vec = [
5,
5,
4,
4,
];
//println!("Chosen Parameter Set: {param:?}");
for _ in 0..2 {
let i= 0;
for (block_modulus, nb_blocks) in modulus_vec.iter().zip(nb_blocks_vec.iter
()) {
let (mut cks, mut sks) = KEY_CACHE.get_from_params(param);
// let mut msg_space = *block_modulus;
// for _ in 1..*nb_blocks {
// msg_space *= *block_modulus;
// }
let mut msg_space = basis_32bits[i];
println!("block_modulus = {block_modulus}");
println!("msg_space = {msg_space}");
let mut rng = rand::thread_rng();
let clear_0 = rng.gen::<u64>() % msg_space;
let mut cpy_clear_0 = clear_0;
let mut blocks_crt_0 = vec![];
for _ in 0..*nb_blocks{
let tmp = cpy_clear_0 % block_modulus;
blocks_crt_0.push((cks.encrypt_crt(tmp, vec![*block_modulus])).blocks[0].clone());
cpy_clear_0 = (cpy_clear_0 - tmp)/ block_modulus;
}
let clear_1 = rng.gen::<u64>() % msg_space;
let mut cpy_clear_1 = clear_1;
let mut blocks_crt_1 = vec![];
for _ in 0..*nb_blocks{
let tmp = cpy_clear_1 % block_modulus;
blocks_crt_1.push((cks.encrypt_crt(tmp, vec![*block_modulus])).blocks[0].clone());
cpy_clear_1 = (cpy_clear_1 - tmp)/ block_modulus;
}
println!("clear 0 {:?}", clear_0);
println!("clear 1 {:?}", clear_1);
println!("add {:?}", clear_0 + clear_1);
println!("add mod {:?}", (clear_1 + clear_0) %msg_space );
// TEST_ADD //
let mut ct_zero_rad = RadixCiphertext::from_blocks(blocks_crt_0);
let mut ct_one_rad = RadixCiphertext::from_blocks(blocks_crt_1);
let mut ct_res = sks.unchecked_add(&ct_zero_rad, &ct_one_rad);
let mut result = 0_u64;
let mut shift = 1_u64;
for c_i in ct_res.blocks().iter() {
// decrypt the component i of the integer and multiply it by the radix product
let block_value = cks.key.decrypt_message_and_carry(c_i);
// update the result
result = result.wrapping_add(block_value.wrapping_mul(shift));
// update the shift for the next iteration
shift = shift.wrapping_mul(*block_modulus);
}
println!("add");
println!("dec add {:?}", result);
println!("dec add mod {:?}", result% msg_space);
assert_eq!(result % msg_space, (clear_0 + clear_1) % msg_space);
println!("-----");
// TEST_CARRY_PROPAGATE //
sks.full_propagate(&mut ct_res);
let mut result = 0_u64;
let mut shift = 1_u64;
for c_i in ct_res.blocks().iter() {
// decrypt the component i of the integer and multiply it by the radix product
let block_value = cks.key.decrypt_message_and_carry(c_i);
// update the result
result = result.wrapping_add(block_value.wrapping_mul(shift));
// update the shift for the next iteration
shift = shift.wrapping_mul(*block_modulus);
}
println!("propagate");
println!("dec propagate {:?}", result);
println!("dec propagate mod {:?}", result% msg_space);
assert_eq!(result % msg_space , (clear_0 + clear_1) % msg_space);
println!("-----");
// TEST_MUL //
ct_one_rad.blocks[0].message_modulus.0 = ct_one_rad.blocks[0].message_modulus.0 *2;
ct_one_rad.blocks[0].carry_modulus.0 = ct_one_rad.blocks[0].carry_modulus.0 /2;
ct_zero_rad.blocks[0].message_modulus.0 = ct_zero_rad.blocks[0].message_modulus.0 *2;
ct_zero_rad.blocks[0].carry_modulus.0 = ct_zero_rad.blocks[0].carry_modulus.0 /2;
sks.key.message_modulus.0 = sks.key.message_modulus.0*2;
sks.key.carry_modulus.0 = sks.key.carry_modulus.0/2;
//cks.parameters().carry_modulus.0 = cks.parameters().carry_modulus.0/2;
//cks.parameters().message_modulus.0 = cks.parameters().message_modulus.0/2;
ct_res = sks.unchecked_mul(&mut ct_one_rad, &mut ct_zero_rad);
//sks.full_propagate(&mut ct_res);
let mut result = 0_u64;
let mut shift = 1_u64;
for c_i in ct_res.blocks().iter() {
// decrypt the component i of the integer and multiply it by the radix product
let block_value = cks.key.decrypt_message_and_carry(c_i);
// update the result
result = result.wrapping_add(block_value.wrapping_mul(shift));
// update the shift for the next iteration
shift = shift.wrapping_mul(*block_modulus);
}
println!("mul");
println!("dec mul {:?}", result);
println!("dec mul mod {:?}", result % msg_space);
println!("clear mul {:?}", (clear_0 * clear_1));
println!("clear mul mod {:?}", (clear_0 * clear_1) % msg_space);
println!("info deg: {:?}", ct_res.blocks[0].degree);
println!("info mm : {:?}", ct_res.blocks[0].message_modulus);
println!("info cm : {:?}", ct_res.blocks[0].carry_modulus);
assert_eq!(result % msg_space , (clear_0 * clear_1) % msg_space);
println!("-----");
}
}
println!("it's OK");
panic!()
}
/*
#[test]
pub fn joc_hybride_32_bits() {
let param = ID_5_RADIX_32_BITS_16_BLOCKS;
// basis = 2^5 * 3^5* 5^4 * 7^4
let basis_32bits = vec![32, 243, 625, 2420];
let modulus_vec = [
2,
3,
5,
7,
];
let nb_blocks_vec = [
5,
5,
4,
4,
];
println!("Chosen Parameter Set: {param:?}");
for _ in 0..NB_TEST {
for (block_modulus, nb_blocks) in modulus_vec.iter().zip(nb_blocks_vec.iter()) {
let (cks, sks) = KEY_CACHE.get_from_params(param);
let mut msg_space = *block_modulus;
for _ in 1..*nb_blocks {
msg_space *= *block_modulus;
}
println!("block_modulus = {block_modulus}");
println!("msg_space = {msg_space}");
let mut rng = rand::thread_rng();
let clear_0 = rng.gen::<u64>() % msg_space;
println!("Expected Result = {}", (clear_0*clear_0) % msg_space);
// encryption of an integer using CRT hacking
let mut ct_zero = cks.encrypt_crt(clear_0, vec![*block_modulus; *nb_blocks]);
let mut ct_one = cks.encrypt_crt(clear_0, vec![*block_modulus; *nb_blocks]);
let mut ct_zero_rad = RadixCiphertext::from_blocks(ct_zero.blocks().to_vec());
let ct_one_rad = RadixCiphertext::from_blocks(ct_one.blocks().to_vec());
// Test carry progration`
let mut ct_tmp = sks.unchecked_add(&ct_one_rad, &ct_zero_rad);
let mut result = 0_u64;
let mut shift = 1_u64;
let modulus = ct_one_rad.blocks[0].message_modulus.0 as u64;
for c_i in ct_one_rad.blocks.iter() {
// decrypt the component i of the integer and multiply it by the radix product
let block_value = cks.key.decrypt_message_and_carry(c_i).wrapping_mul(shift);
// update the result
result = result.wrapping_add(block_value);
// update the shift for the next iteration
shift = shift.wrapping_mul(modulus);
}
let dec_res = cks.decrypt_radix(&ct_zero_rad);
println!("FIRST ADD");
assert_eq!(dec_res, (clear_0) % msg_space);
//
// sks.full_propagate(&mut ct_tmp);
//
// /// DECRYPT //
// let mut result = 0_u64;
// let mut shift = 1_u64;
// let modulus = ct_one_rad.blocks[0].message_modulus.0 as u64;
//
// for c_i in ct_tmp.blocks.iter() {
// // decrypt the component i of the integer and multiply it by the radix product
// let block_value = cks.key.decrypt_message_and_carry(c_i).wrapping_mul(shift);
//
// // update the result
// result = result.wrapping_add(block_value);
//
// // update the shift for the next iteration
// shift = shift.wrapping_mul(modulus);
// }
//
// let dec_res = result % msg_space;
// println!("FULL PROP");
// assert_eq!(dec_res, (clear_0 + clear_0) % msg_space);
//
//
// ////END TEST CARRY /////////
//
// let ct_res = sks.unchecked_mul(&ct_zero_rad, &ct_one_rad);
//
// /// DECRYPT //
// let mut result = 0_u64;
// let mut shift = 1_u64;
// let modulus = ct_one_rad.blocks[0].message_modulus.0 as u64;
//
// for c_i in ct_res.blocks.iter() {
// // decrypt the component i of the integer and multiply it by the radix product
// let block_value = cks.key.decrypt_message_and_carry(c_i).wrapping_mul(shift);
//
// // update the result
// result = result.wrapping_add(block_value);
//
// // update the shift for the next iteration
// shift = shift.wrapping_mul(modulus);
// }
//
// let dec_res = result % msg_space;
// println!("MUL");
// assert_eq!(dec_res, (clear_0 * clear_0) % msg_space);
}
}
}
*/

View File

@@ -12,12 +12,12 @@ use crate::integer::{ClientKey, CrtCiphertext, IntegerCiphertext, RadixCiphertex
use crate::shortint::ciphertext::Degree;
use rayon::prelude::*;
use crate::shortint::Parameters;
use crate::shortint::{Parameters};
use serde::{Deserialize, Serialize};
#[derive(Clone, Serialize, Deserialize)]
pub struct WopbsKey {
wopbs_key: crate::shortint::wopbs::WopbsKey,
pub wopbs_key: crate::shortint::wopbs::WopbsKey,
}
/// ```rust

View File

@@ -8,6 +8,7 @@ use crate::shortint::engine::EngineResult;
use crate::shortint::server_key::{Accumulator, MaxDegree};
use crate::shortint::{Ciphertext, ClientKey, CompressedServerKey, ServerKey};
use std::cmp::min;
use crate::shortint::prelude::{CarryModulus, MessageModulus};
mod add;
mod bitwise_op;
@@ -495,6 +496,41 @@ impl ShortintEngine {
})
}
// Impossible to call the assign function in this case
pub(crate) fn create_trivial_with_message_modulus(
&mut self,
server_key: &ServerKey,
value: u64,
message_modulus: MessageModulus
) -> EngineResult<Ciphertext> {
let lwe_size = server_key
.bootstrapping_key
.output_lwe_dimension()
.to_lwe_size();
let modular_value = value as usize % message_modulus.0;
let delta =
(1_u64 << 63) / (server_key.message_modulus.0 * server_key.carry_modulus.0) as u64;
let shifted_value = (modular_value as u64) * delta;
let encoded = Plaintext(shifted_value);
let ct = allocate_and_trivially_encrypt_new_lwe_ciphertext(lwe_size, encoded);
let degree = Degree(modular_value);
Ok(Ciphertext {
ct,
degree,
message_modulus: message_modulus,
carry_modulus: CarryModulus(((server_key.message_modulus.0 * server_key.carry_modulus
.0) as u64 /(
message_modulus.0 as u64)) as usize) ,
})
}
pub(crate) fn create_trivial_assign(
&mut self,
server_key: &ServerKey,

View File

@@ -1,3 +1,4 @@
use crate::shortint::ciphertext::Degree;
use crate::shortint::engine::{EngineResult, ShortintEngine};
use crate::shortint::{Ciphertext, ServerKey};
@@ -69,7 +70,7 @@ impl ShortintEngine {
self.unchecked_add_assign(ct_left, ct_right)?;
// Modulus of the msg in the msg bits
let res_modulus = server_key.message_modulus.0 as u64;
let res_modulus = ct_left.message_modulus.0 as u64;
// Generate the accumulator for the multiplication
let acc = self.generate_accumulator(server_key, |x| {
@@ -86,7 +87,7 @@ impl ShortintEngine {
&mut self,
server_key: &ServerKey,
ct1: &mut Ciphertext,
ct2: &mut Ciphertext,
ct2: &Ciphertext,
) -> EngineResult<Ciphertext> {
//ct1 + ct2
let mut ct_tmp_left = self.unchecked_add(ct1, ct2)?;
@@ -108,16 +109,75 @@ impl ShortintEngine {
self.unchecked_sub(server_key, &ct_tmp_left, &ct_tmp_right)
}
pub(crate) fn unchecked_mul_msb_small_carry_modulus(
&mut self,
server_key: &ServerKey,
ct1: &mut Ciphertext,
ct2: &Ciphertext,
) -> EngineResult<Ciphertext> {
let modulus = ct1.message_modulus.0 as u64;
let deg = (ct1.degree.0 * ct2.degree.0) / ct2.message_modulus.0;
//ct1 + ct2
let mut ct_tmp_left = self.unchecked_add(ct1, ct2)?;
//ct1-ct2
let (mut ct_tmp_right, z) = self.unchecked_sub_with_z(server_key, ct1, ct2)?;
// let acc_add = self.generate_accumulator(server_key, |x| ((((x * x) / 4) / modulus) as
// f64).ceil() as u64 )?;
// let acc_sub =
// self.generate_accumulator(server_key, |x| (((((x - z) * (x - z)) / 4) / modulus) as
// f64).ceil() as u64
// )?;
let acc_add = self.generate_accumulator(server_key, |x| (((x * x) / modulus) / 4) %
ct_tmp_left
.message_modulus.0 as u64)?;
let acc_sub = self.generate_accumulator(server_key, |x| (((x - z) * (x - z) / modulus) /
4) %
ct_tmp_left.message_modulus.0 as u64)?;
self.keyswitch_programmable_bootstrap_assign(server_key, &mut ct_tmp_left, &acc_add)?;
self.keyswitch_programmable_bootstrap_assign(server_key, &mut ct_tmp_right, &acc_sub)?;
//Last subtraction might fill one bit of carry
let mut ct_sub = self.unchecked_sub(server_key, &ct_tmp_left, &ct_tmp_right)?;
let acc_corrective_term = self.generate_accumulator(server_key, |x| (x - x % modulus)/
modulus)?;
self.keyswitch_programmable_bootstrap(server_key, &mut ct_sub, &acc_corrective_term)
}
pub(crate) fn unchecked_mul_lsb_small_carry_modulus_assign(
&mut self,
server_key: &ServerKey,
ct1: &mut Ciphertext,
ct2: &mut Ciphertext,
ct2: &Ciphertext,
) -> EngineResult<()> {
*ct1 = self.unchecked_mul_lsb_small_carry_modulus(server_key, ct1, ct2)?;
Ok(())
}
pub(crate) fn unchecked_mul_msb_small_carry_modulus_assign(
&mut self,
server_key: &ServerKey,
ct1: &mut Ciphertext,
ct2: &Ciphertext,
) -> EngineResult<()> {
*ct1 = self.unchecked_mul_msb_small_carry_modulus(server_key, ct1, ct2)?;
Ok(())
}
pub(crate) fn smart_mul_lsb_assign(
&mut self,
server_key: &ServerKey,

View File

@@ -504,6 +504,40 @@ impl ShortintEngine {
Ok(ciphertext)
}
//
// pub(crate) fn programmable_bootstrapping_native_crt_bivariate(
// &mut self,
// wopbs_key: &WopbsKey,
// ct_1: &mut Ciphertext,
// ct_2: &mut Ciphertext,
// lut: &[u64],
// ) -> EngineResult<Ciphertext> {
// let nb_bit_to_extract =
// f64::log2((ct_in.message_modulus.0 * ct_in.carry_modulus.0) as f64).ceil() as usize;
// let delta_log = DeltaLog(64 - nb_bit_to_extract);
//
// // trick ( ct - delta/2 + delta/2^4 )
// let lwe_size = ct_in.ct.lwe_size().0;
// let mut cont = vec![0u64; lwe_size];
// cont[lwe_size - 1] =
// (1 << (64 - nb_bit_to_extract - 1)) - (1 << (64 - nb_bit_to_extract - 5));
// let tmp = LweCiphertextOwned::from_container(cont);
//
// lwe_ciphertext_sub_assign(&mut ct_in.ct, &tmp);
//
// let ciphertext = self.extract_bits_circuit_bootstrapping(
// wopbs_key,
// ct_in,
// lut,
// delta_log,
// ExtractedBitsCount(nb_bit_to_extract),
// )?;
//
// Ok(ciphertext)
// Ok(ciphertext)
// }
/// Temporary wrapper.
///
/// # Warning Experimental

View File

@@ -3,6 +3,8 @@ use crate::shortint::parameters::parameters_wopbs_message_carry::*;
use crate::shortint::parameters::parameters_wopbs_prime_moduli::*;
use crate::shortint::parameters::*;
use crate::shortint::wopbs::WopbsKey;
use crate::integer::parameters::parameters_benches_joc::*;
use crate::shortint::{ClientKey, ServerKey};
use lazy_static::*;
use serde::{Deserialize, Serialize};
@@ -369,6 +371,21 @@ impl NamedParam for Parameters {
WOPBS_PRIME_PARAM_MESSAGE_8_NORM2_6,
WOPBS_PRIME_PARAM_MESSAGE_8_NORM2_7,
PARAM_4_BITS_5_BLOCKS,
ID_1_RADIX_16_BITS_16_BLOCKS,
ID_2_RADIX_16_BITS_8_BLOCKS,
ID_3_CRT_16_BITS_5_BLOCKS,
ID_4_RADIX_32_BITS_32_BLOCKS,
ID_5_RADIX_32_BITS_16_BLOCKS,
ID_6_RADIX_32_BITS_8_BLOCKS,
ID_6_CRT_32_BITS_6_BLOCKS,
ID_7_RADIX_16_BITS_16_BLOCKS_WOPBS,
ID_8_RADIX_16_BITS_8_BLOCKS_WOPBS,
ID_9_CRT_16_BITS_5_BLOCKS_WOPBS,
ID_10_NATIF_CRT_16_BITS_5_BLOCKS_WOPBS,
ID_11_NATIF_CRT_32_BITS_6_BLOCKS_WOPBS,
ID_12_HYBRID_CRT_32_bits,
ID_11_BIS_NATIF_CRT_32_BITS_8_BLOCKS_WOPBS,
TEST_WOPBS,
)
);
}

View File

@@ -291,7 +291,7 @@ impl ServerKey {
/// let modulus = cks.parameters.message_modulus.0 as u64;
///
/// // Generate the accumulator for the function f: x -> x^3 mod 2^2
/// let acc = sks.generate_accumulator(|x| x * x * x % modulus);
/// let acc = sks.generate_accumulator(|x| x * x * x % modulus,);
/// let ct_res = sks.keyswitch_programmable_bootstrap(&ct, &acc);
///
/// let dec = cks.decrypt(&ct_res);
@@ -538,6 +538,13 @@ impl ServerKey {
ShortintEngine::with_thread_local_mut(|engine| engine.create_trivial(self, value).unwrap())
}
pub fn create_trivial_with_message_modulus(&self, value: u64, message_modulus: MessageModulus)
-> Ciphertext {
ShortintEngine::with_thread_local_mut(|engine| engine.create_trivial_with_message_modulus
(self, value, message_modulus).unwrap())
}
pub fn create_trivial_assign(&self, ct: &mut Ciphertext, value: u64) {
ShortintEngine::with_thread_local_mut(|engine| {
engine.create_trivial_assign(self, ct, value).unwrap()

View File

@@ -354,7 +354,7 @@ impl ServerKey {
pub fn unchecked_mul_lsb_small_carry_assign(
&self,
ct_left: &mut Ciphertext,
ct_right: &mut Ciphertext,
ct_right: &Ciphertext,
) {
ShortintEngine::with_thread_local_mut(|engine| {
engine
@@ -363,6 +363,31 @@ impl ServerKey {
})
}
pub fn unchecked_mul_msb_small_carry(
&self,
ct_left: &mut Ciphertext,
ct_right: &mut Ciphertext,
) -> Ciphertext {
ShortintEngine::with_thread_local_mut(|engine| {
engine
.unchecked_mul_msb_small_carry_modulus(self, ct_left, ct_right)
.unwrap()
})
}
pub fn unchecked_mul_msb_small_carry_assign(
&self,
ct_left: &mut Ciphertext,
ct_right: &Ciphertext,
) {
ShortintEngine::with_thread_local_mut(|engine| {
engine
.unchecked_mul_msb_small_carry_modulus_assign(self, ct_left, ct_right)
.unwrap()
})
}
/// Verify if two ciphertexts can be multiplied together in the case where the carry
/// modulus is smaller than the message modulus.
///

View File

@@ -286,7 +286,7 @@ fn shortint_keyswitch_programmable_bootstrap(param: Parameters) {
let ctxt_0 = cks.encrypt(clear_0);
//define the accumulator as identity
let acc = sks.generate_accumulator(|n| n % modulus);
let acc = sks.generate_accumulator(|n| n % modulus, );
// add the two ciphertexts
let ct_res = sks.keyswitch_programmable_bootstrap(&ctxt_0, &acc);
@@ -392,7 +392,7 @@ fn shortint_generate_accumulator(param: Parameters) {
let keys = KEY_CACHE.get_from_param(param);
let (cks, sks) = (keys.client_key(), keys.server_key());
let double = |x| 2 * x;
let acc = sks.generate_accumulator(double);
let acc = sks.generate_accumulator(double, );
//RNG
let mut rng = rand::thread_rng();
@@ -1796,8 +1796,11 @@ fn shortint_unchecked_sub(param: Parameters) {
let modulus = cks.parameters.message_modulus.0 as u64;
for _ in 0..NB_TEST {
// Define the cleartexts
let clear1 = rng.gen::<u64>() % modulus;
let clear2 = rng.gen::<u64>() % modulus;
// let clear1 = rng.gen::<u64>() % modulus;
// let clear2 = rng.gen::<u64>() % modulus;
let clear1 = 3;
let clear2 = 3;
// Encrypt the integers
let ctxt_1 = cks.encrypt(clear1);
@@ -1877,6 +1880,168 @@ fn shortint_mul_small_carry(param: Parameters) {
}
}
/// test multiplication
#[test]
fn shortint_mul_msb_small_carry() {
let keys = KEY_CACHE.get_from_param(PARAM_MESSAGE_2_CARRY_2);
let (cks, sks) = (keys.client_key(), keys.server_key());
//RNG
let mut rng = rand::thread_rng();
let modulus = cks.parameters.message_modulus.0 as u64;
///
///STUPID TEST
///
println!(" £££££ STUPID TEST ££££££ ");
let stupid_modulus = 3;
for i in 0..stupid_modulus {
for j in 0..stupid_modulus {
let ct_stupid_1 = cks.encrypt_with_message_modulus(i, MessageModulus(stupid_modulus
as usize));
let ct_stupid_2 = cks.encrypt_with_message_modulus(j, MessageModulus(stupid_modulus
as usize));
let ct_res_stupid = sks.unchecked_mul_msb(&ct_stupid_1, &ct_stupid_2);
let dec_res_stupid = cks.decrypt(&ct_res_stupid);
println!("{i} * {j} = {dec_res_stupid}");
assert_eq!(dec_res_stupid, (i*j)/stupid_modulus);
}
}
println!(" £££££ END OF STUPID TEST ££££££ ");
///
/// END OF STUPID TEST
///
for _ in 0..50 {
let clear_0 = rng.gen::<u64>() % modulus;
let clear_1 = rng.gen::<u64>() % modulus;
println!("%%%%%%%%%%%%%%%%%");
println!("clear_0 = {clear_0}, clear_1 = {clear_1}");
// encryption of an integer
let mut ctxt_zero = cks.encrypt_with_message_modulus(clear_0, MessageModulus(modulus as
usize));
// encryption of an integer
let mut ctxt_one = cks.encrypt_with_message_modulus(clear_1, MessageModulus(modulus as
usize));
// multiply together the two ciphertexts
let ct_res = sks.unchecked_mul_msb_small_carry(&mut ctxt_zero, &mut ctxt_one);
// ////
// /// DEBUG OPERATION MODE
// ///
// let modulus = ctxt_one.message_modulus.0 as u64;
// //let deg = (ct1.degree.0 * ct2.degree.0) / ct2.message_modulus.0;
//
// //ct1 + ct2
// let mut ct_tmp_left = sks.unchecked_add(&ctxt_zero, &ctxt_one);
//
// //ct1-ct2
// let (mut ct_tmp_right, z) = sks.unchecked_sub_with_correcting_term(&ctxt_zero, &ctxt_one);
//
//
// println!("ct1 + ct2 = {}, clear0 + clear1 = {}", cks.decrypt_message_and_carry(&ct_tmp_left), clear_0 +
// clear_1);
//
// println!("ct1 - ct2 = {} correcting_term = {z}, clear0 - clear1 = {}", cks
// .decrypt_message_and_carry
// (&ct_tmp_right), (
// (modulus + clear_0 -
// clear_1) % modulus ));
// // let acc_add = sks.generate_accumulator(|x| ((((x * x) / 4) / modulus) as
// // f64).ceil() as u64 );
// // let acc_sub =
// // sks.generate_accumulator(|x| (((((x - z) * (x - z)) / 4) / modulus) as
// // f64).ceil() as u64
// // );
//
// let acc_add = sks.generate_accumulator(|x| ( ( (x * x) /4*(modulus as f64).log2().ceil()
// as u64)
// %
// ct_tmp_left
// .message_modulus.0 as u64));
// let acc_sub =
// sks.generate_accumulator(|x| (((x - z) * (x - z) / 4 * (modulus as f64).log2().ceil()
// as u64) % ct_tmp_left
// .message_modulus.0 as u64));
//
// sks.keyswitch_programmable_bootstrap_assign(&mut ct_tmp_left, &acc_add);
// println!("ct1+ square ct2 = {}", cks.decrypt_message_and_carry
// (&ct_tmp_left), );
//
// println!("deg(ct) = {}", ct_tmp_left.degree.0);
//
//
// sks.keyswitch_programmable_bootstrap_assign(&mut ct_tmp_right, &acc_sub);
// println!("ct1- square ct2 = {}", cks.decrypt_message_and_carry(&ct_tmp_right));
//
// println!("deg(ct) = {}", ct_tmp_right.degree.0);
//
//
//
// //Last subtraction might fill one bit of carry
// let ct_res = sks.unchecked_sub(&mut ct_tmp_left, &mut ct_tmp_right);
//
// ////
// /// END OF DEBUG OPERATION MODE
// ///
// decryption of ct_res
let dec_res = cks.decrypt(&ct_res);
// // assert
// assert_eq!(((clear_0 * clear_1)/ct_res.message_modulus.0 as u64) % ct_res.message_modulus.0
// as u64, dec_res % ct_res.message_modulus.0 as u64);
//Comparison with the usual msb multiplication
let ct_res_classical = sks.unchecked_mul_msb(&mut ctxt_zero, &mut ctxt_one);
let dec_res_classical = cks.decrypt(&ct_res_classical);
assert_eq!(dec_res, dec_res_classical);
}
}
/// test multiplication
#[test]
fn shortint_mul_msb_not_power_of_two() {
let keys = KEY_CACHE.get_from_param(PARAM_MESSAGE_3_CARRY_3);
let (cks, sks) = (keys.client_key(), keys.server_key());
//RNG
let mut rng = rand::thread_rng();
let modulus = cks.parameters.message_modulus.0 as u64;
///
///STUPID TEST
///
println!(" £££££ STUPID TEST ££££££ ");
let stupid_modulus = 3;
for i in 0..stupid_modulus {
for j in 0..stupid_modulus {
let ct_stupid_1 = cks.encrypt_with_message_modulus(i, MessageModulus(stupid_modulus
as usize));
let ct_stupid_2 = cks.encrypt_with_message_modulus(j, MessageModulus(stupid_modulus
as usize));
let ct_res_stupid = sks.unchecked_mul_msb(&ct_stupid_1, &ct_stupid_2);
let dec_res_stupid = cks.decrypt(&ct_res_stupid);
println!("{i} * {j} = {dec_res_stupid}");
assert_eq!(dec_res_stupid, (i*j)/stupid_modulus);
}
}
println!(" £££££ END OF STUPID TEST ££££££ ");
}
/// test encryption and decryption with the LWE client key
fn shortint_encrypt_with_message_modulus_smart_add_and_mul(param: Parameters) {
let keys = KEY_CACHE.get_from_param(param);
@@ -1936,3 +2101,4 @@ fn shortint_mux(param: Parameters) {
println!("(msg_true - msg_false) * control_bit + msg_false = {clear_mux}, res = {dec_res}");
assert_eq!(clear_mux, dec_res);
}

View File

@@ -190,6 +190,50 @@ impl WopbsKey {
vec_lut
}
//TODO
pub fn generate_lut_bivariate_native_crt<F>(&self, ct_1: &Ciphertext, f: F) -> Vec<u64>
where
F: Fn(u64, u64) -> u64,
{
let mut bit = vec![];
let mut total_bit = 0;
let mut modulus = 1;
let basis = ct_1.message_modulus.0 as u64;
modulus = basis;
let b = f64::log2(basis as f64).ceil() as u64;
total_bit += b;
bit.push(b);
let mut lut_size = 1 << (2 * total_bit);
if 1 << (2 * total_bit) < self.param.polynomial_size.0 as u64 {
lut_size = self.param.polynomial_size.0;
}
let mut vec_lut = vec![0; lut_size];
for value in 0..1 << (2 * total_bit) {
let value_1 = (value % (1 << total_bit)) as u64;
let value_2 = (value >> total_bit) as u64;
let mut index_lut_1 = 0;
let mut index_lut_2 = 0;
let mut tmp = 1;
for bit in bit.iter() {
index_lut_1 += (((value_1 % basis) << bit) / basis) * tmp;
index_lut_2 += (((value_2 % basis) << bit) / basis) * tmp;
tmp <<= bit;
}
let index = (index_lut_2 << total_bit) + (index_lut_1);
vec_lut[index as usize] =
(((f(value_1, value_2) % b) as u128 * (1 << 64)) / basis as u128) as u64
}
vec_lut
}
/// Apply the Look-Up Table homomorphically using the WoPBS approach.
///
/// #Warning: this assumes one bit of padding.
@@ -321,6 +365,39 @@ impl WopbsKey {
})
}
// /// Apply the Look-Up Table homomorphically using the WoPBS approach.
// ///
// /// # Example
// ///
// /// ```rust
// /// use tfhe::shortint::ciphertext::Ciphertext;
// /// use tfhe::shortint::gen_keys;
// /// use tfhe::shortint::parameters::parameters_wopbs::WOPBS_PARAM_MESSAGE_3_NORM2_2;
// /// use tfhe::shortint::wopbs::*;
// ///
// /// let (cks, sks) = gen_keys(WOPBS_PARAM_MESSAGE_3_NORM2_2);
// /// let wopbs_key = WopbsKey::new_wopbs_key_only_for_wopbs(&cks, &sks);
// /// let msg = 2;
// /// let modulus = 5;
// /// let mut ct = cks.encrypt_native_crt(msg, modulus);
// /// let lut = wopbs_key.generate_lut_native_crt(&ct, |x| x);
// /// let ct_res = wopbs_key.programmable_bootstrapping_native_crt(&mut ct, &lut);
// /// let res = cks.decrypt_message_native_crt(&ct_res, modulus);
// /// assert_eq!(res, msg);
// /// ```
// pub fn programmable_bootstrapping_native_crt_bivariate(
// &self,
// ct_1: &mut Ciphertext,
// ct_2: &mut Ciphertext,
// lut: &[u64],
// ) -> Ciphertext {
// ShortintEngine::with_thread_local_mut(|engine| {
// engine
// .programmable_bootstrapping_native_crt_bivariate(self, ct_1, ct_2, lut)
// .unwrap()
// })
// }
/// Extract the given number of bits from a ciphertext.
///
/// # Warning Experimental