Compare commits

...

2 Commits

Author SHA1 Message Date
J-B Orfila
2b7ed85546 examples 2023-04-11 14:59:15 +02:00
J-B Orfila
8bc7c1d217 fix(integer): fix unchecked_add in unchecked_mul 2023-04-07 12:05:51 +02:00
4 changed files with 159 additions and 27 deletions

View File

@@ -1,28 +1,81 @@
use tfhe::boolean::client_key::ClientKey;
use tfhe::boolean::parameters::TFHE_LIB_PARAMETERS;
use tfhe::boolean::prelude::BinaryBooleanGates;
use tfhe::boolean::server_key::ServerKey;
use tfhe::integer::{gen_keys_crt, gen_keys_radix};
use tfhe::prelude::*;
use tfhe::shortint::prelude::*;
use tfhe::{generate_keys, set_server_key, ConfigBuilder, FheUint16};
fn main() {
// let (cks, sks) = gen_keys();
let cks = ClientKey::new(&TFHE_LIB_PARAMETERS);
let sks = ServerKey::new(&cks);
let left = false;
let right = true;
let ct_left = cks.encrypt(left);
let ct_right = cks.encrypt(right);
let start = std::time::Instant::now();
let num_loops: usize = 10000;
for _ in 0..num_loops {
let _ = sks.and(&ct_left, &ct_right);
}
let elapsed = start.elapsed().as_millis() as f64;
let mean: f64 = elapsed / num_loops as f64;
println!("{elapsed:?} ms, mean {mean:?} ms");
// crt_mul();
// min_blog_post_example();
hl_api_example();
}
fn hl_api_example() {
// Client-side
let config = ConfigBuilder::all_disabled()
.enable_default_uint16()
.build();
let (client_key, server_key) = generate_keys(config);
let clear_a = 12345u16;
let clear_b = 6789u16;
let clear_c = 1011u16;
let a = FheUint16::encrypt(clear_a, &client_key);
let b = FheUint16::encrypt(clear_b, &client_key);
let c = FheUint16::encrypt(clear_c, &client_key);
// Server-side
set_server_key(server_key);
let result = ((a << 2u16) * b) + c;
// Client-side
let decrypted_result: u16 = result.decrypt(&client_key);
let clear_result = ((clear_a << 2) * clear_b) + clear_c;
assert_eq!(decrypted_result, clear_result);
}
fn crt_mul() {
//CRT-based integer modulus 3*4*5*7 = 420
//To work with homomorphic unsigned integers > 8 bits
let basis = vec![3, 4, 5, 7];
let modulus = 420;
let param = PARAM_MESSAGE_3_CARRY_3;
let (cks, sks) = gen_keys_crt(&param, basis.clone());
let clear_0 = 234;
let clear_1 = 123;
// encryption of an integer
let mut ct_zero = cks.encrypt(clear_0);
let mut ct_one = cks.encrypt(clear_1);
// mul the two ciphertexts
let ct_res = sks.smart_crt_mul_parallelized(&mut ct_zero, &mut ct_one);
// decryption of ct_res
let dec_res = cks.decrypt(&ct_res);
assert_eq!((clear_0 * clear_1) % modulus, dec_res % modulus);
}
fn min_blog_post_example() {
let param = PARAM_MESSAGE_2_CARRY_2;
//Radix-based integers over 8 bits
let num_block = 4;
let (cks, sks) = gen_keys_radix(&param, num_block);
let clear_0 = 157;
let clear_1 = 127;
let mut ct_0 = cks.encrypt(clear_0);
let mut ct_1 = cks.encrypt(clear_1);
let ct_res = sks.smart_min_parallelized(&mut ct_0, &mut ct_1);
let dec_res = cks.decrypt(&ct_res);
assert_eq!(u64::min(clear_0, clear_1), dec_res);
}

View File

@@ -1,3 +1,4 @@
use crate::integer::gen_keys_crt;
use crate::integer::keycache::KEY_CACHE;
use crate::shortint::parameters::*;
use crate::shortint::Parameters;
@@ -271,3 +272,29 @@ fn integer_smart_crt_sub(param: Parameters) {
assert_eq!(clear_0, dec_res);
}
}
#[test]
fn mul_blog_post() {
//CRT-based integer modulus 3*4*5*7 = 420
//To work with homomorphic unsigned integers > 8 bits
let basis = vec![3, 4, 5, 7];
let param = PARAM_MESSAGE_3_CARRY_3;
let (cks, sks) = gen_keys_crt(&param, basis.clone());
let clear_0 = 234;
let clear_1 = 123;
// encryption of an integer
let mut ct_zero = cks.encrypt(clear_0);
let mut ct_one = cks.encrypt(clear_1);
// mul the two ciphertexts
let ct_res = sks.smart_crt_mul(&mut ct_zero, &mut ct_one);
// decryption of ct_res
let dec_res = cks.decrypt(&ct_res);
let modulus = 420;
assert_eq!((clear_0 * clear_1) % modulus, dec_res % modulus);
}

View File

@@ -219,9 +219,9 @@ impl ServerKey {
let mut result = self.create_trivial_zero_radix(ct1.blocks.len());
for (i, ct2_i) in ct2.blocks.iter().enumerate() {
let tmp = self.unchecked_block_mul(ct1, ct2_i, i);
let mut tmp = self.unchecked_block_mul(ct1, ct2_i, i);
self.unchecked_add_assign(&mut result, &tmp);
self.smart_add_assign(&mut result, &mut tmp);
}
result

View File

@@ -1,3 +1,4 @@
use crate::integer::gen_keys_radix;
use crate::integer::keycache::KEY_CACHE;
use crate::shortint::parameters::*;
use crate::shortint::Parameters;
@@ -39,6 +40,7 @@ create_parametrized_test!(integer_smart_sub);
create_parametrized_test!(integer_unchecked_block_mul);
create_parametrized_test!(integer_smart_block_mul);
create_parametrized_test!(integer_smart_mul);
create_parametrized_test!(integer_unchecked_mul);
create_parametrized_test!(integer_smart_scalar_sub);
create_parametrized_test!(integer_smart_scalar_add);
@@ -957,6 +959,35 @@ fn integer_smart_block_mul(param: Parameters) {
}
}
fn integer_unchecked_mul(param: Parameters) {
let (cks, sks) = KEY_CACHE.get_from_params(param);
//RNG
let mut rng = rand::thread_rng();
// message_modulus^vec_length
let modulus = param.message_modulus.0.pow(NB_CTXT as u32) as u64;
for _ in 0..1 {
// Define the cleartexts
let clear1 = rng.gen::<u64>() % modulus;
let clear2 = rng.gen::<u64>() % modulus;
// Encrypt the integers;;
let ctxt_1 = cks.encrypt_radix(clear1, NB_CTXT);
let ctxt_2 = cks.encrypt_radix(clear2, NB_CTXT);
let res = sks.unchecked_mul(&ctxt_1, &ctxt_2);
let dec: u64 = cks.decrypt_radix(&res);
let expected = (clear1 * clear2) % modulus;
// Check the correctness
assert_eq!(expected, dec);
}
}
fn integer_smart_mul(param: Parameters) {
let (cks, sks) = KEY_CACHE.get_from_params(param);
@@ -1193,3 +1224,24 @@ fn integer_smart_scalar_mul_decomposition_overflow() {
assert_eq!((clear_0 * scalar as u128), dec_res);
}
#[test]
fn min_blog_post_example() {
let param = PARAM_MESSAGE_2_CARRY_2;
//Radix-based integers over 8 bits
let num_block = 4;
let (cks, sks) = gen_keys_radix(&param, num_block);
let clear_0 = 157;
let clear_1 = 127;
let mut ct_0 = cks.encrypt(clear_0);
let mut ct_1 = cks.encrypt(clear_1);
let ct_res = sks.smart_min_parallelized(&mut ct_0, &mut ct_1);
let dec_res = cks.decrypt(&ct_res);
assert_eq!(u64::min(clear_0, clear_1), dec_res);
}